Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
davidx33 committed Dec 3, 2024
1 parent 1f4b9d7 commit d9f67d5
Showing 1 changed file with 28 additions and 14 deletions.
42 changes: 28 additions & 14 deletions backend/retrieval_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
"""

from typing import Any, Literal, TypedDict, cast

from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field

from backend.retrieval_graph.configuration import AgentConfiguration
from backend.retrieval_graph.researcher_graph.graph import graph as researcher_graph
from backend.retrieval_graph.state import AgentState, InputState, Router
from backend.utils import format_docs, load_chat_model
from pydantic import BaseModel, Field


async def analyze_and_route_query(
state: AgentState, *, config: RunnableConfig
Expand Down Expand Up @@ -214,7 +216,12 @@ async def respond(
prompt = configuration.response_system_prompt.format(context=context)
messages = [{"role": "system", "content": prompt}] + state.messages
response = await model.ainvoke(messages)
return {"messages": [response], "answer": response.content, "num_response_attempts": num_response_attempts + 1}
return {
"messages": [response],
"answer": response.content,
"num_response_attempts": num_response_attempts + 1,
}


def check_hallucination(state: AgentState) -> Literal["respond", "end"]:
"""Check if the answer is hallucinated."""
Expand All @@ -223,27 +230,35 @@ def check_hallucination(state: AgentState) -> Literal["respond", "end"]:
answer = state.answer
num_response_attempts = state.num_response_attempts
context = format_docs(state.documents[:top_k])

class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""

binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
description="Answer is grounded in the facts, 'yes' or 'no'"
)

grade_hallucinations_llm = model.with_structured_output(GradeHallucinations)
grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
grade_hallucinations_prompt = "Set of facts: \n\n {context} \n\n LLM generation: {answer}"
grade_hallucinations_prompt = (
"Set of facts: \n\n {context} \n\n LLM generation: {answer}"
)
grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
context=context,
answer=answer
context=context, answer=answer
)
result = grade_hallucinations_llm.invoke(
[
{"role": "system", "content": grade_hallucinations_system_prompt},
{"role": "human", "content": grade_hallucinations_prompt_formatted},
]
)
result = grade_hallucinations_llm.invoke([{"role": "system", "content": grade_hallucinations_system_prompt}, {"role": "human", "content": grade_hallucinations_prompt_formatted}])
if result.binary_score == "yes" or num_response_attempts >= 2:
return "end"
else:
return "respond"


# Define the graph
builder = StateGraph(AgentState, input=InputState, config_schema=AgentConfiguration)
builder.add_node(create_research_plan)
Expand All @@ -253,10 +268,9 @@ class GradeHallucinations(BaseModel):
builder.add_edge(START, "create_research_plan")
builder.add_edge("create_research_plan", "conduct_research")
builder.add_conditional_edges("conduct_research", check_finished)
builder.add_conditional_edges("respond", check_hallucination, {
"end": END,
"respond": "respond"
})
builder.add_conditional_edges(
"respond", check_hallucination, {"end": END, "respond": "respond"}
)
# Compile into a graph object that you can invoke and deploy.
graph = builder.compile()
graph.name = "RetrievalGraph"
graph.name = "RetrievalGraph"

0 comments on commit d9f67d5

Please sign in to comment.