Skip to content

Commit

Permalink
close weaviate connection correctly (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbarda authored Sep 9, 2024
1 parent da7f267 commit cca4729
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 82 deletions.
51 changes: 28 additions & 23 deletions backend/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import os
from collections import defaultdict
from typing import Annotated, Literal, Optional, Sequence, TypedDict
from typing import Annotated, Iterator, Literal, Optional, Sequence, TypedDict

import weaviate
from langchain_anthropic import ChatAnthropic
Expand Down Expand Up @@ -203,23 +204,24 @@ class AgentState(TypedDict):
)


def get_retriever(k: Optional[int] = None) -> BaseRetriever:
weaviate_client = weaviate.connect_to_wcs(
@contextlib.contextmanager
def get_retriever(k: Optional[int] = None) -> Iterator[BaseRetriever]:
with weaviate.connect_to_weaviate_cloud(
cluster_url=os.environ["WEAVIATE_URL"],
auth_credentials=weaviate.classes.init.Auth.api_key(
os.environ.get("WEAVIATE_API_KEY", "not_provided")
),
skip_init_checks=True,
)
weaviate_client = WeaviateVectorStore(
client=weaviate_client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=get_embeddings_model(),
attributes=["source", "title"],
)
k = k or 6
return weaviate_client.as_retriever(search_kwargs=dict(k=k))
) as weaviate_client:
store = WeaviateVectorStore(
client=weaviate_client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=get_embeddings_model(),
attributes=["source", "title"],
)
k = k or 6
yield store.as_retriever(search_kwargs=dict(k=k))


def format_docs(docs: Sequence[Document]) -> str:
Expand All @@ -234,15 +236,17 @@ def retrieve_documents(
state: AgentState, *, config: Optional[RunnableConfig] = None
) -> AgentState:
config = ensure_config(config)
retriever = get_retriever(k=config["configurable"].get("k"))
messages = convert_to_messages(state["messages"])
query = messages[-1].content
relevant_documents = retriever.invoke(query)
with get_retriever(k=config["configurable"].get("k")) as retriever:
relevant_documents = retriever.invoke(query)
return {"query": query, "documents": relevant_documents}


def retrieve_documents_with_chat_history(state: AgentState) -> AgentState:
retriever = get_retriever()
def retrieve_documents_with_chat_history(
state: AgentState, *, config: Optional[RunnableConfig] = None
) -> AgentState:
config = ensure_config(config)
model = llm.with_config(tags=["nostream"])

CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE)
Expand All @@ -254,12 +258,13 @@ def retrieve_documents_with_chat_history(state: AgentState) -> AgentState:

messages = convert_to_messages(state["messages"])
query = messages[-1].content
retriever_with_condensed_question = condense_question_chain | retriever
# NOTE: we're ignoring the last message here, as it's going to contain the most recent
# query and we don't want that to be included in the chat history
relevant_documents = retriever_with_condensed_question.invoke(
{"question": query, "chat_history": get_chat_history(messages[:-1])}
)
with get_retriever(k=config["configurable"].get("k")) as retriever:
retriever_with_condensed_question = condense_question_chain | retriever
# NOTE: we're ignoring the last message here, as it's going to contain the most recent
# query and we don't want that to be included in the chat history
relevant_documents = retriever_with_condensed_question.invoke(
{"question": query, "chat_history": get_chat_history(messages[:-1])}
)
return {"query": query, "documents": relevant_documents}


Expand Down
120 changes: 61 additions & 59 deletions backend/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,68 +131,70 @@ def ingest_docs():
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)
embedding = get_embeddings_model()

client = weaviate.connect_to_wcs(
with weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY),
skip_init_checks=True,
)
vectorstore = WeaviateVectorStore(
client=client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=embedding,
attributes=["source", "title"],
)

record_manager = SQLRecordManager(
f"weaviate/{WEAVIATE_DOCS_INDEX_NAME}", db_url=RECORD_MANAGER_DB_URL
)
record_manager.create_schema()

docs_from_documentation = load_langchain_docs()
logger.info(f"Loaded {len(docs_from_documentation)} docs from documentation")
docs_from_api = load_api_docs()
logger.info(f"Loaded {len(docs_from_api)} docs from API")
docs_from_langsmith = load_langsmith_docs()
logger.info(f"Loaded {len(docs_from_langsmith)} docs from LangSmith")
docs_from_langgraph = load_langgraph_docs()
logger.info(f"Loaded {len(docs_from_langgraph)} docs from LangGraph")

docs_transformed = text_splitter.split_documents(
docs_from_documentation
+ docs_from_api
+ docs_from_langsmith
+ docs_from_langgraph
)
docs_transformed = [doc for doc in docs_transformed if len(doc.page_content) > 10]

# We try to return 'source' and 'title' metadata when querying vector store and
# Weaviate will error at query time if one of the attributes is missing from a
# retrieved document.
for doc in docs_transformed:
if "source" not in doc.metadata:
doc.metadata["source"] = ""
if "title" not in doc.metadata:
doc.metadata["title"] = ""

indexing_stats = index(
docs_transformed,
record_manager,
vectorstore,
cleanup="full",
source_id_key="source",
force_update=(os.environ.get("FORCE_UPDATE") or "false").lower() == "true",
)

logger.info(f"Indexing stats: {indexing_stats}")
num_vecs = (
client.collections.get(WEAVIATE_DOCS_INDEX_NAME)
.aggregate.over_all()
.total_count
)
logger.info(
f"LangChain now has this many vectors: {num_vecs}",
)
) as weaviate_client:
vectorstore = WeaviateVectorStore(
client=weaviate_client,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=embedding,
attributes=["source", "title"],
)

record_manager = SQLRecordManager(
f"weaviate/{WEAVIATE_DOCS_INDEX_NAME}", db_url=RECORD_MANAGER_DB_URL
)
record_manager.create_schema()

docs_from_documentation = load_langchain_docs()
logger.info(f"Loaded {len(docs_from_documentation)} docs from documentation")
docs_from_api = load_api_docs()
logger.info(f"Loaded {len(docs_from_api)} docs from API")
docs_from_langsmith = load_langsmith_docs()
logger.info(f"Loaded {len(docs_from_langsmith)} docs from LangSmith")
docs_from_langgraph = load_langgraph_docs()
logger.info(f"Loaded {len(docs_from_langgraph)} docs from LangGraph")

docs_transformed = text_splitter.split_documents(
docs_from_documentation
+ docs_from_api
+ docs_from_langsmith
+ docs_from_langgraph
)
docs_transformed = [
doc for doc in docs_transformed if len(doc.page_content) > 10
]

# We try to return 'source' and 'title' metadata when querying vector store and
# Weaviate will error at query time if one of the attributes is missing from a
# retrieved document.
for doc in docs_transformed:
if "source" not in doc.metadata:
doc.metadata["source"] = ""
if "title" not in doc.metadata:
doc.metadata["title"] = ""

indexing_stats = index(
docs_transformed,
record_manager,
vectorstore,
cleanup="full",
source_id_key="source",
force_update=(os.environ.get("FORCE_UPDATE") or "false").lower() == "true",
)

logger.info(f"Indexing stats: {indexing_stats}")
num_vecs = (
weaviate_client.collections.get(WEAVIATE_DOCS_INDEX_NAME)
.aggregate.over_all()
.total_count
)
logger.info(
f"LangChain now has this many vectors: {num_vecs}",
)


if __name__ == "__main__":
Expand Down

0 comments on commit cca4729

Please sign in to comment.