import gradio as gr
import logging
import json
from langchain_ollama.chat_models import ChatOllama
from langgraph.graph import MessagesState, StateGraph, START
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.tools import Tool
# --------------------------------------
# 1. Logger
# --------------------------------------
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(message)s')
log = logging.getLogger(__name__)
# --------------------------------------
# 2. Prompt système enrichi
# --------------------------------------
PROMPT_SYSTEM = """
Tu es un assistant expert en recertification d’applications.
Tu disposes des outils suivants :
tool_names = ["smart_get_info", "smart_audit_access", "smart_recertification_status"]
tools:
- smart_get_info : Donne l'équipe responsable et la classification d'une application.
- smart_audit_access : Liste les derniers accès à une application.
- smart_recertification_status : Informe si une application doit être recertifiée.
Règles :
- Si la question est générale (ex. « c’est quoi une recertification ? »), réponds directement sans appeler d’outil.
- Si une information précise est demandée sur une application, appelle un outil si besoin.
- Ne répète jamais un outil si tu as déjà reçu l’Observation.
Réponds toujours de façon claire, concise et utile.
Exemple :
User: Qui gère l'application XYZ ?
Assistant:
Thought: J'ai besoin d'information.
Action: smart_get_info
Action Input: XYZ
[Observation] L'application XYZ est gérée par l'équipe Sécurité.
Assistant:
Final Answer: L'application XYZ est gérée par l'équipe Sécurité.
"""
# --------------------------------------
# 3. Outils simulés
# --------------------------------------
def smart_get_info(app: str) -> str:
log.info(f"[TOOL] smart_get_info: {app}")
return f"L'application {app} est gérée par l'équipe Développement."
def smart_audit_access(app: str) -> str:
log.info(f"[TOOL] smart_audit_access: {app}")
return f"Derniers accès à {app} : UserX, UserY."
def smart_recertification_status(app: str) -> str:
log.info(f"[TOOL] smart_recertification_status: {app}")
return f"{app} doit être recertifiée avant la fin du mois."
tools = [
Tool(
name="smart_get_info",
func=smart_get_info,
description="Retourne l'équipe en charge et la classification d'une application."
),
Tool(
name="smart_audit_access",
func=smart_audit_access,
description="Retourne les derniers accès à une application."
),
Tool(
name="smart_recertification_status",
func=smart_recertification_status,
description="Indique si une application nécessite une recertification prochaine."
),
]
# --------------------------------------
# 4. ToolExecutor local
# --------------------------------------
class ToolExecutor:
def __init__(self, tools):
self.tool_map = {tool.name: tool for tool in tools}
def invoke(self, tool_call):
tool_name = tool_call.get("name")
arguments = tool_call.get("arguments", {})
if tool_name not in self.tool_map:
raise ValueError(f"Outil non trouvé : {tool_name}")
tool = self.tool_map[tool_name]
if isinstance(arguments, str):
try:
arguments = json.loads(arguments)
except Exception as e:
raise ValueError(f"Arguments JSON invalides : {e}")
return tool.run(arguments)
tool_executor = ToolExecutor(tools)
# --------------------------------------
# 5. LLM Ollama avec tools
# --------------------------------------
llm = ChatOllama(model="mistral", temperature=0)
llm_with_tools = llm.bind_tools(tools)
# --------------------------------------
# 6. Assistant node
# --------------------------------------
def assistant(state: MessagesState):
log.info("[GRAPH] Assistant node triggered.")
response = llm_with_tools.invoke(state["messages"])
log.info("[GRAPH] Assistant responded.")
return {"messages": [response]}
# --------------------------------------
# 7. Tools node (exécution manuelle avec observation)
# --------------------------------------
def tools_node(state: MessagesState):
last_message = state["messages"][-1]
tool_calls = last_message.additional_kwargs.get("tool_calls", [])
new_messages = []
for call in tool_calls:
tool_name = call["function"]["name"]
tool_args = call["function"].get("arguments", "{}")
log.info(f"[TOOL EXECUTOR] Calling {tool_name} with args {tool_args}")
try:
output = tool_executor.invoke({
"name": tool_name,
"arguments": tool_args
})
log.info(f"[TOOL EXECUTOR] Output: {output}")
new_messages.append(AIMessage(
content=f"[Observation] {output}",
additional_kwargs={"tool_call_id": call.get("id")}
))
except Exception as e:
log.error(f"[TOOL ERROR] {tool_name} failed: {e}")
new_messages.append(AIMessage(
content=f"[Observation] Erreur lors de l’appel à {tool_name}.",
additional_kwargs={"tool_call_id": call.get("id")}
))
return {"messages": new_messages}
# --------------------------------------
# 8. Stop condition
# --------------------------------------
def stop_condition(state: MessagesState) -> str:
last = state["messages"][-1]
if isinstance(last, AIMessage):
has_tool_call = bool(last.additional_kwargs.get("tool_calls"))
log.info(f"[STOP_CONDITION] tool_call detected: {has_tool_call}")
return "continue" if has_tool_call else "end"
return "end"
# --------------------------------------
# 9. Graph LangGraph
# --------------------------------------
builder = StateGraph(MessagesState)
builder.add_node("assistant", assistant)
builder.add_node("tools", tools_node)
builder.add_node("end", lambda state: state)
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", lambda state: "tools", path_map={"tools": "tools"})
builder.add_conditional_edges("tools", stop_condition, path_map={
"continue": "assistant",
"end": "end"
})
builder.set_finish_point("end")
graph = builder.compile()
# --------------------------------------
# 10. Historique de conversation
# --------------------------------------
chat_history = []
# --------------------------------------
# 11. Fonction Gradio principale
# --------------------------------------
def ask_agent(message: str) -> str:
global chat_history
log.info(f"[USER INPUT] {message}")
# Ajout du prompt système au tout début
if not chat_history:
system_message = SystemMessage(
content=PROMPT_SYSTEM,
additional_kwargs={
"tool_names": [tool.name for tool in tools],
"tools": [tool.model_dump() for tool in tools]
}
)
chat_history.insert(0, system_message)
chat_history.append(HumanMessage(content=message))
log.info("[MESSAGES LOG BEFORE INVOKE]")
for i, msg in enumerate(chat_history):
msg_type = type(msg).__name__
log.info(f" [{i}] {msg_type}: {msg.content}")
if hasattr(msg, "additional_kwargs") and msg.additional_kwargs:
log.info(f" ↳ kwargs: {msg.additional_kwargs}")
result = graph.invoke({"messages": chat_history})
ai_msg = result["messages"][-1]
chat_history.append(ai_msg)
if isinstance(ai_msg, AIMessage):
log.info("[AI RESPONSE]")
log.info(f" Content: {ai_msg.content}")
if ai_msg.additional_kwargs:
log.info(f" ↳ kwargs: {ai_msg.additional_kwargs}")
content = ai_msg.content
if "Final Answer:" in content:
return content.split("Final Answer:")[-1].strip()
return content
return "[Erreur] Aucune réponse générée."
# --------------------------------------
# 12. Interface Gradio
# --------------------------------------
demo = gr.Interface(
fn=ask_agent,
inputs=gr.Textbox(label="Posez votre question"),
outputs=gr.Textbox(label="Réponse de l'agent"),
title="Assistant Recertification - Mistral (Ollama)",
theme="default"
)
demo.launch(server_name="127.0.0.1", share=False)