From cca47296a9f25776151f6e7b9d680006a67146fc Mon Sep 17 00:00:00 2001 From: Vadym Barda Date: Mon, 9 Sep 2024 13:54:12 -0400 Subject: [PATCH] close weaviate connection correctly (#371) --- backend/graph.py | 51 +++++++++++--------- backend/ingest.py | 120 +++++++++++++++++++++++----------------------- 2 files changed, 89 insertions(+), 82 deletions(-) diff --git a/backend/graph.py b/backend/graph.py index 7ef3a6c7d..5f98ebfa0 100644 --- a/backend/graph.py +++ b/backend/graph.py @@ -1,6 +1,7 @@ +import contextlib import os from collections import defaultdict -from typing import Annotated, Literal, Optional, Sequence, TypedDict +from typing import Annotated, Iterator, Literal, Optional, Sequence, TypedDict import weaviate from langchain_anthropic import ChatAnthropic @@ -203,23 +204,24 @@ class AgentState(TypedDict): ) -def get_retriever(k: Optional[int] = None) -> BaseRetriever: - weaviate_client = weaviate.connect_to_wcs( +@contextlib.contextmanager +def get_retriever(k: Optional[int] = None) -> Iterator[BaseRetriever]: + with weaviate.connect_to_weaviate_cloud( cluster_url=os.environ["WEAVIATE_URL"], auth_credentials=weaviate.classes.init.Auth.api_key( os.environ.get("WEAVIATE_API_KEY", "not_provided") ), skip_init_checks=True, - ) - weaviate_client = WeaviateVectorStore( - client=weaviate_client, - index_name=WEAVIATE_DOCS_INDEX_NAME, - text_key="text", - embedding=get_embeddings_model(), - attributes=["source", "title"], - ) - k = k or 6 - return weaviate_client.as_retriever(search_kwargs=dict(k=k)) + ) as weaviate_client: + store = WeaviateVectorStore( + client=weaviate_client, + index_name=WEAVIATE_DOCS_INDEX_NAME, + text_key="text", + embedding=get_embeddings_model(), + attributes=["source", "title"], + ) + k = k or 6 + yield store.as_retriever(search_kwargs=dict(k=k)) def format_docs(docs: Sequence[Document]) -> str: @@ -234,15 +236,17 @@ def retrieve_documents( state: AgentState, *, config: Optional[RunnableConfig] = None ) -> AgentState: config = ensure_config(config) - retriever = get_retriever(k=config["configurable"].get("k")) messages = convert_to_messages(state["messages"]) query = messages[-1].content - relevant_documents = retriever.invoke(query) + with get_retriever(k=config["configurable"].get("k")) as retriever: + relevant_documents = retriever.invoke(query) return {"query": query, "documents": relevant_documents} -def retrieve_documents_with_chat_history(state: AgentState) -> AgentState: - retriever = get_retriever() +def retrieve_documents_with_chat_history( + state: AgentState, *, config: Optional[RunnableConfig] = None +) -> AgentState: + config = ensure_config(config) model = llm.with_config(tags=["nostream"]) CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(REPHRASE_TEMPLATE) @@ -254,12 +258,13 @@ def retrieve_documents_with_chat_history(state: AgentState) -> AgentState: messages = convert_to_messages(state["messages"]) query = messages[-1].content - retriever_with_condensed_question = condense_question_chain | retriever - # NOTE: we're ignoring the last message here, as it's going to contain the most recent - # query and we don't want that to be included in the chat history - relevant_documents = retriever_with_condensed_question.invoke( - {"question": query, "chat_history": get_chat_history(messages[:-1])} - ) + with get_retriever(k=config["configurable"].get("k")) as retriever: + retriever_with_condensed_question = condense_question_chain | retriever + # NOTE: we're ignoring the last message here, as it's going to contain the most recent + # query and we don't want that to be included in the chat history + relevant_documents = retriever_with_condensed_question.invoke( + {"question": query, "chat_history": get_chat_history(messages[:-1])} + ) return {"query": query, "documents": relevant_documents} diff --git a/backend/ingest.py b/backend/ingest.py index b79a34021..5f2f61008 100644 --- a/backend/ingest.py +++ b/backend/ingest.py @@ -131,68 +131,70 @@ def ingest_docs(): text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200) embedding = get_embeddings_model() - client = weaviate.connect_to_wcs( + with weaviate.connect_to_weaviate_cloud( cluster_url=WEAVIATE_URL, auth_credentials=weaviate.classes.init.Auth.api_key(WEAVIATE_API_KEY), skip_init_checks=True, - ) - vectorstore = WeaviateVectorStore( - client=client, - index_name=WEAVIATE_DOCS_INDEX_NAME, - text_key="text", - embedding=embedding, - attributes=["source", "title"], - ) - - record_manager = SQLRecordManager( - f"weaviate/{WEAVIATE_DOCS_INDEX_NAME}", db_url=RECORD_MANAGER_DB_URL - ) - record_manager.create_schema() - - docs_from_documentation = load_langchain_docs() - logger.info(f"Loaded {len(docs_from_documentation)} docs from documentation") - docs_from_api = load_api_docs() - logger.info(f"Loaded {len(docs_from_api)} docs from API") - docs_from_langsmith = load_langsmith_docs() - logger.info(f"Loaded {len(docs_from_langsmith)} docs from LangSmith") - docs_from_langgraph = load_langgraph_docs() - logger.info(f"Loaded {len(docs_from_langgraph)} docs from LangGraph") - - docs_transformed = text_splitter.split_documents( - docs_from_documentation - + docs_from_api - + docs_from_langsmith - + docs_from_langgraph - ) - docs_transformed = [doc for doc in docs_transformed if len(doc.page_content) > 10] - - # We try to return 'source' and 'title' metadata when querying vector store and - # Weaviate will error at query time if one of the attributes is missing from a - # retrieved document. - for doc in docs_transformed: - if "source" not in doc.metadata: - doc.metadata["source"] = "" - if "title" not in doc.metadata: - doc.metadata["title"] = "" - - indexing_stats = index( - docs_transformed, - record_manager, - vectorstore, - cleanup="full", - source_id_key="source", - force_update=(os.environ.get("FORCE_UPDATE") or "false").lower() == "true", - ) - - logger.info(f"Indexing stats: {indexing_stats}") - num_vecs = ( - client.collections.get(WEAVIATE_DOCS_INDEX_NAME) - .aggregate.over_all() - .total_count - ) - logger.info( - f"LangChain now has this many vectors: {num_vecs}", - ) + ) as weaviate_client: + vectorstore = WeaviateVectorStore( + client=weaviate_client, + index_name=WEAVIATE_DOCS_INDEX_NAME, + text_key="text", + embedding=embedding, + attributes=["source", "title"], + ) + + record_manager = SQLRecordManager( + f"weaviate/{WEAVIATE_DOCS_INDEX_NAME}", db_url=RECORD_MANAGER_DB_URL + ) + record_manager.create_schema() + + docs_from_documentation = load_langchain_docs() + logger.info(f"Loaded {len(docs_from_documentation)} docs from documentation") + docs_from_api = load_api_docs() + logger.info(f"Loaded {len(docs_from_api)} docs from API") + docs_from_langsmith = load_langsmith_docs() + logger.info(f"Loaded {len(docs_from_langsmith)} docs from LangSmith") + docs_from_langgraph = load_langgraph_docs() + logger.info(f"Loaded {len(docs_from_langgraph)} docs from LangGraph") + + docs_transformed = text_splitter.split_documents( + docs_from_documentation + + docs_from_api + + docs_from_langsmith + + docs_from_langgraph + ) + docs_transformed = [ + doc for doc in docs_transformed if len(doc.page_content) > 10 + ] + + # We try to return 'source' and 'title' metadata when querying vector store and + # Weaviate will error at query time if one of the attributes is missing from a + # retrieved document. + for doc in docs_transformed: + if "source" not in doc.metadata: + doc.metadata["source"] = "" + if "title" not in doc.metadata: + doc.metadata["title"] = "" + + indexing_stats = index( + docs_transformed, + record_manager, + vectorstore, + cleanup="full", + source_id_key="source", + force_update=(os.environ.get("FORCE_UPDATE") or "false").lower() == "true", + ) + + logger.info(f"Indexing stats: {indexing_stats}") + num_vecs = ( + weaviate_client.collections.get(WEAVIATE_DOCS_INDEX_NAME) + .aggregate.over_all() + .total_count + ) + logger.info( + f"LangChain now has this many vectors: {num_vecs}", + ) if __name__ == "__main__":