diff --git a/backend/retrieval_graph/graph.py b/backend/retrieval_graph/graph.py index 3a861810..ffa787fb 100644 --- a/backend/retrieval_graph/graph.py +++ b/backend/retrieval_graph/graph.py @@ -11,7 +11,6 @@ from langchain_core.messages import BaseMessage from langchain_core.runnables import RunnableConfig from langgraph.graph import END, START, StateGraph -from pydantic import BaseModel, Field from backend.retrieval_graph.configuration import AgentConfiguration from backend.retrieval_graph.researcher_graph.graph import graph as researcher_graph @@ -150,7 +149,6 @@ class Plan(TypedDict): "steps": response["steps"], "documents": "delete", "query": state.messages[-1].content, - "num_response_attempts": 0, } @@ -209,57 +207,18 @@ async def respond( """ configuration = AgentConfiguration.from_runnable_config(config) model = load_chat_model(configuration.response_model) - num_response_attempts = state.num_response_attempts # TODO: add a re-ranker here top_k = 20 context = format_docs(state.documents[:top_k]) prompt = configuration.response_system_prompt.format(context=context) messages = [{"role": "system", "content": prompt}] + state.messages response = await model.ainvoke(messages) - return { - "messages": [response], - "answer": response.content, - "num_response_attempts": num_response_attempts + 1, - } - - -def check_hallucination(state: AgentState) -> Literal["respond", "end"]: - """Check if the answer is hallucinated.""" - model = load_chat_model("openai/gpt-4o-mini") - top_k = 20 - answer = state.answer - num_response_attempts = state.num_response_attempts - context = format_docs(state.documents[:top_k]) + return {"messages": [response], "answer": response.content} - class GradeHallucinations(BaseModel): - """Binary score for hallucination present in generation answer.""" - binary_score: str = Field( - description="Answer is grounded in the facts, 'yes' or 'no'" - ) - - grade_hallucinations_llm = model.with_structured_output(GradeHallucinations) - grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n - Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts.""" - grade_hallucinations_prompt = ( - "Set of facts: \n\n {context} \n\n LLM generation: {answer}" - ) - grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format( - context=context, answer=answer - ) - result = grade_hallucinations_llm.invoke( - [ - {"role": "system", "content": grade_hallucinations_system_prompt}, - {"role": "human", "content": grade_hallucinations_prompt_formatted}, - ] - ) - if result.binary_score == "yes" or num_response_attempts >= 2: - return "end" - else: - return "respond" +# Define the graph -# Define the graph builder = StateGraph(AgentState, input=InputState, config_schema=AgentConfiguration) builder.add_node(create_research_plan) builder.add_node(conduct_research) @@ -268,9 +227,8 @@ class GradeHallucinations(BaseModel): builder.add_edge(START, "create_research_plan") builder.add_edge("create_research_plan", "conduct_research") builder.add_conditional_edges("conduct_research", check_finished) -builder.add_conditional_edges( - "respond", check_hallucination, {"end": END, "respond": "respond"} -) +builder.add_edge("respond", END) + # Compile into a graph object that you can invoke and deploy. graph = builder.compile() graph.name = "RetrievalGraph" diff --git a/backend/retrieval_graph/state.py b/backend/retrieval_graph/state.py index fcd154ac..72ddbf39 100644 --- a/backend/retrieval_graph/state.py +++ b/backend/retrieval_graph/state.py @@ -80,8 +80,5 @@ class AgentState(InputState): documents: Annotated[list[Document], reduce_docs] = field(default_factory=list) """Populated by the retriever. This is a list of documents that the agent can reference.""" answer: str = field(default="") - """Final answer. Useful for evaluations.""" + """Final answer. Useful for evaluations""" query: str = field(default="") - """The user's query.""" - num_response_attempts: int = field(default=0) - """The number of times the agent has tried to respond."""