Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "conditional edge that checks for hallucinations (#401)" #402

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 4 additions & 46 deletions backend/retrieval_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from langchain_core.messages import BaseMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from pydantic import BaseModel, Field

from backend.retrieval_graph.configuration import AgentConfiguration
from backend.retrieval_graph.researcher_graph.graph import graph as researcher_graph
Expand Down Expand Up @@ -150,7 +149,6 @@ class Plan(TypedDict):
"steps": response["steps"],
"documents": "delete",
"query": state.messages[-1].content,
"num_response_attempts": 0,
}


Expand Down Expand Up @@ -209,57 +207,18 @@ async def respond(
"""
configuration = AgentConfiguration.from_runnable_config(config)
model = load_chat_model(configuration.response_model)
num_response_attempts = state.num_response_attempts
# TODO: add a re-ranker here
top_k = 20
context = format_docs(state.documents[:top_k])
prompt = configuration.response_system_prompt.format(context=context)
messages = [{"role": "system", "content": prompt}] + state.messages
response = await model.ainvoke(messages)
return {
"messages": [response],
"answer": response.content,
"num_response_attempts": num_response_attempts + 1,
}


def check_hallucination(state: AgentState) -> Literal["respond", "end"]:
"""Check if the answer is hallucinated."""
model = load_chat_model("openai/gpt-4o-mini")
top_k = 20
answer = state.answer
num_response_attempts = state.num_response_attempts
context = format_docs(state.documents[:top_k])
return {"messages": [response], "answer": response.content}

class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""

binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)

grade_hallucinations_llm = model.with_structured_output(GradeHallucinations)
grade_hallucinations_system_prompt = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
grade_hallucinations_prompt = (
"Set of facts: \n\n {context} \n\n LLM generation: {answer}"
)
grade_hallucinations_prompt_formatted = grade_hallucinations_prompt.format(
context=context, answer=answer
)
result = grade_hallucinations_llm.invoke(
[
{"role": "system", "content": grade_hallucinations_system_prompt},
{"role": "human", "content": grade_hallucinations_prompt_formatted},
]
)
if result.binary_score == "yes" or num_response_attempts >= 2:
return "end"
else:
return "respond"
# Define the graph


# Define the graph
builder = StateGraph(AgentState, input=InputState, config_schema=AgentConfiguration)
builder.add_node(create_research_plan)
builder.add_node(conduct_research)
Expand All @@ -268,9 +227,8 @@ class GradeHallucinations(BaseModel):
builder.add_edge(START, "create_research_plan")
builder.add_edge("create_research_plan", "conduct_research")
builder.add_conditional_edges("conduct_research", check_finished)
builder.add_conditional_edges(
"respond", check_hallucination, {"end": END, "respond": "respond"}
)
builder.add_edge("respond", END)

# Compile into a graph object that you can invoke and deploy.
graph = builder.compile()
graph.name = "RetrievalGraph"
5 changes: 1 addition & 4 deletions backend/retrieval_graph/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,5 @@ class AgentState(InputState):
documents: Annotated[list[Document], reduce_docs] = field(default_factory=list)
"""Populated by the retriever. This is a list of documents that the agent can reference."""
answer: str = field(default="")
"""Final answer. Useful for evaluations."""
"""Final answer. Useful for evaluations"""
query: str = field(default="")
"""The user's query."""
num_response_attempts: int = field(default=0)
"""The number of times the agent has tried to respond."""
Loading