Skip to content

Commit

Permalink
update lib versions and do some renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
gustavz committed Feb 9, 2024
1 parent 17ef0b5 commit 5b6f9e9
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 51 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
stores
data
models
__pycache__
Expand Down
55 changes: 32 additions & 23 deletions datachad/backend/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langchain.schema import BaseChatMessageHistory, BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.vectorstore import VectorStore
from pydantic import Extra
from datachad.backend.constants import VERBOSE

from datachad.backend.deeplake import get_or_create_deeplake_vector_store_display_name
from datachad.backend.logging import logger
Expand Down Expand Up @@ -39,13 +39,6 @@ class MultiRetrieverFAQChain(Chain):
smart_faq_chain: BaseCombineDocumentsChain
smart_faq_retriever: BaseRetriever | None

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True
allow_population_by_field_name = True

@property
def input_keys(self) -> list[str]:
"""Will be whatever keys the prompt expects."""
Expand Down Expand Up @@ -215,24 +208,33 @@ def from_llm(
)


def get_knowledge_base_search_kwargs(options: dict) -> dict:
def get_knowledge_base_search_kwargs(options: dict) -> tuple[dict, str]:
k = int(options["max_tokens"] // options["chunk_size"])
fetch_k = k * options["k_fetch_k_ratio"]
search_kwargs = {
"maximal_marginal_relevance": options["maximal_marginal_relevance"],
"distance_metric": options["distance_metric"],
"fetch_k": fetch_k,
"k": k,
}
return search_kwargs
if options["maximal_marginal_relevance"]:
search_kwargs = {
"distance_metric": options["distance_metric"],
"fetch_k": fetch_k,
"k": k,
}
search_type = "mmr"
else:
search_kwargs = {
"k": k,
"distance_metric": options["distance_metric"],
}
search_type = "similarity"

return search_kwargs, search_type


def get_smart_faq_search_kwargs(options: dict) -> dict:
def get_smart_faq_search_kwargs(options: dict) -> tuple[dict, str]:
search_kwargs = {
"k": 20,
"distance_metric": options["distance_metric"],
}
return search_kwargs
search_type = "similarity"
return search_kwargs, search_type


def get_multi_chain(
Expand All @@ -243,10 +245,17 @@ def get_multi_chain(
options: dict,
credentials: dict,
) -> MultiRetrieverFAQChain:
kb_search_kwargs = get_knowledge_base_search_kwargs(options)
kb_retrievers = [kb.as_retriever(search_kwargs=kb_search_kwargs) for kb in knowledge_bases]
faq_search_kwargs = get_smart_faq_search_kwargs(options)
faq_retriever = smart_faq.as_retriever(search_kwargs=faq_search_kwargs) if smart_faq else None
kb_search_kwargs, search_type = get_knowledge_base_search_kwargs(options)
kb_retrievers = [
kb.as_retriever(search_type=search_type, search_kwargs=kb_search_kwargs)
for kb in knowledge_bases
]
faq_search_kwargs, search_type = get_smart_faq_search_kwargs(options)
faq_retriever = (
smart_faq.as_retriever(search_type=search_type, search_kwargs=faq_search_kwargs)
if smart_faq
else None
)
model = get_model(options, credentials)
memory = ConversationBufferMemory(
memory_key="chat_history", chat_memory=chat_history, return_messages=True
Expand All @@ -262,7 +271,7 @@ def get_multi_chain(
max_tokens_limit=options["max_tokens"],
use_vanilla_llm=use_vanilla_llm,
memory=memory,
verbose=True,
verbose=VERBOSE,
)
logger.info(f"Multi chain with settings {options} build!")
return chain
3 changes: 2 additions & 1 deletion datachad/backend/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

MODEL_PATH = Path("models")
DATA_PATH = Path("data")
VECTOR_STORE_PATH = Path("stores")

DEFAULT_USER = "admin"
DEFAULT_SMART_FAQ = None
Expand All @@ -21,4 +22,4 @@
LOCAL_DEEPLAKE = False
LOCAL_EMBEDDINGS = False

VERBOSE = True
VERBOSE = False
10 changes: 5 additions & 5 deletions datachad/backend/deeplake.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from deeplake.client.client import DeepLakeBackendClient
from deeplake.util.bugout_reporter import deeplake_reporter
from langchain.schema import Document
from langchain.vectorstores import DeepLake, VectorStore
from langchain.vectorstores import VectorStore
from langchain_community.vectorstores.deeplake import DeepLake

from datachad.backend.constants import (
DATA_PATH,
DEFAULT_USER,
LOCAL_DEEPLAKE,
STORE_DOCS_EXTRA,
VECTOR_STORE_PATH,
VERBOSE,
)
from datachad.backend.io import clean_string_for_storing
Expand Down Expand Up @@ -65,7 +66,7 @@ def get_datasets(self, workspace: str):

def get_deeplake_dataset_path(dataset_name: str, credentials: dict) -> str:
if LOCAL_DEEPLAKE:
dataset_path = str(DATA_PATH / dataset_name)
dataset_path = str(VECTOR_STORE_PATH / dataset_name)
else:
dataset_path = f"hub://{credentials['activeloop_id']}/{dataset_name}"
return dataset_path
Expand All @@ -81,7 +82,7 @@ def delete_all_deeplake_datasets(credentials: dict) -> None:

def get_existing_deeplake_vector_store_paths(credentials: dict) -> list[str]:
if LOCAL_DEEPLAKE:
return glob(str(DATA_PATH / "*"), recursive=False)
return glob(str(VECTOR_STORE_PATH / "*"), recursive=False)
else:
dataset_names = list_deeplake_datasets(
credentials["activeloop_id"], credentials["activeloop_token"]
Expand Down Expand Up @@ -224,6 +225,5 @@ def get_or_create_deeplake_vector_store(
token=credentials["activeloop_token"],
verbose=VERBOSE,
)

logger.info(f"Vector Store {vector_store_path} loaded in {round(time.time() - t_start)}s!")
return vector_store
8 changes: 4 additions & 4 deletions datachad/backend/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import shutil
from pathlib import Path

from langchain.document_loaders import (
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import (
CSVLoader,
EverNoteLoader,
GitLoader,
Expand All @@ -21,9 +24,6 @@
UnstructuredWordDocumentLoader,
WebBaseLoader,
)
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from tqdm import tqdm

from datachad.backend.constants import DATA_PATH
Expand Down
2 changes: 1 addition & 1 deletion datachad/backend/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def create_logger(level: str = "DEBUG"):
# if no streamhandler present, add one
if not any(isinstance(handler, logging.StreamHandler) for handler in logger.handlers):
stream_handler = logging.StreamHandler(stream=sys.stdout)
formatter = logging.Formatter("%(name)s :: %(levelname)s :: %(message)s")
formatter = logging.Formatter("%(asctime)s :: %(name)s :: %(levelname)s :: %(message)s")
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
Expand Down
13 changes: 9 additions & 4 deletions datachad/backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import streamlit as st
import tiktoken
from langchain.base_language import BaseLanguageModel
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.openai import Embeddings, OpenAIEmbeddings
from langchain_community.chat_models import ChatOpenAI
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings.openai import Embeddings, OpenAIEmbeddings
from transformers import AutoTokenizer

from datachad.backend.constants import LOCAL_EMBEDDINGS, MODEL_PATH
Expand Down Expand Up @@ -36,7 +36,7 @@ class STORES(Enum):

class EMBEDDINGS(Enum):
# Add more embeddings as needed
OPENAI = "text-embedding-ada-002"
OPENAI = "text-embedding-3-small"
HUGGINGFACE = "sentence-transformers/all-MiniLM-L6-v2"


Expand All @@ -57,6 +57,11 @@ class MODELS(Enum):
embedding=EMBEDDINGS.OPENAI,
context=8192,
)
GPT4TURBO = Model(
name="gpt-4-turbo-preview",
embedding=EMBEDDINGS.OPENAI,
context=128000,
)


def get_model(options: dict, credentials: dict) -> BaseLanguageModel:
Expand Down
20 changes: 19 additions & 1 deletion datachad/streamlit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,23 @@
They need to be in the format of numbers with periods followed by arbirtary text.
The next FAQ is identified by two new lines `\\n\\n` followed by the next number.
You can check if your documents are correctly formatted by using the following regex pattern:\n
`r"(?=\\n\\n\d+\.)"`
`r"(?=\\n\\n\d+\.)"`. Here is an example of a correctly formatted FAQ:\n
1. First item
Some description here.
1. some numbered list
2. beloing to the first item
2. Second item
Another description.
a) another list
b) but with characters
3. Third item
And another one.
- a list with dashes
- more items
"""
10 changes: 5 additions & 5 deletions datachad/streamlit/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import streamlit as st
from dotenv import load_dotenv
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.openai_info import get_openai_token_cost_for_model
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain_community.callbacks.openai_info import get_openai_token_cost_for_model
from langchain_community.chat_message_histories import StreamlitChatMessageHistory

from datachad.backend.constants import (
CHUNK_OVERLAP_PCT,
Expand Down Expand Up @@ -132,9 +132,9 @@ def get_options() -> dict:
}


def update_vector_store() -> None:
def upload_data() -> None:
try:
with st.session_state["info_container"], st.spinner("Updating Vector Stores..."):
with st.session_state["info_container"], st.spinner("Uploading Data..."):
options = get_options()
create_vector_store(
data_source=st.session_state["data_source"],
Expand All @@ -161,7 +161,7 @@ def update_vector_store() -> None:

def update_chain() -> None:
try:
with st.session_state["info_container"], st.spinner("Updating Knowledge Base..."):
with st.session_state["info_container"], st.spinner("Applying data selection..."):
st.session_state["chat_history"].clear()
options = get_options()
st.session_state["chain"] = create_chain(
Expand Down
4 changes: 2 additions & 2 deletions datachad/streamlit/widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
get_existing_knowledge_bases,
get_existing_smart_faqs_and_default_index,
update_chain,
update_vector_store,
upload_data,
)


Expand Down Expand Up @@ -173,7 +173,7 @@ def data_upload_widget() -> None:
if (
st.session_state["uploaded_files"] or st.session_state["data_source"]
) and st.session_state["data_name"]:
update_vector_store()
upload_data()
else:
st.session_state["info_container"].error(
"Missing required files and name!", icon=PAGE_ICON
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
streamlit==1.28.1
deeplake==3.8.0
openai==1.3.5
langchain==0.0.340
tiktoken==0.4.0
streamlit==1.31.0
deeplake==3.8.19
openai==1.12.0
langchain==0.1.6
tiktoken==0.6.0
unstructured==0.6.5
pdf2image==1.16.3
pytesseract==0.3.10
Expand Down

0 comments on commit 5b6f9e9

Please sign in to comment.