From 4efc25dd2e4ab83eb132af6aab785041c8b0da78 Mon Sep 17 00:00:00 2001 From: jai Date: Thu, 19 Oct 2023 11:35:39 -0400 Subject: [PATCH] Fix fallback for keyword-based search --- docker-compose.yaml | 8 +++++--- rest/clients/openai.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 27d11b9..e6ff471 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -17,13 +17,13 @@ services: - "5173:5173" environment: REST_API_URL: http://localhost:8000 - volumes: - - ./frontend:/app + # volumes: + # - ./frontend:/app rest: image: ghcr.io/catalystneuro/dandiset-search-service:latest container_name: dandi-search-rest - command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"] + command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4", "--reload"] ports: - "8000:8000" environment: @@ -34,3 +34,5 @@ services: QDRANT_API_KEY: ${QDRANT_API_KEY} OPENAI_API_KEY: ${OPENAI_API_KEY} DANDI_API_KEY: ${DANDI_API_KEY} + volumes: + - ./rest:/app diff --git a/rest/clients/openai.py b/rest/clients/openai.py index 49ca274..a92c293 100644 --- a/rest/clients/openai.py +++ b/rest/clients/openai.py @@ -92,15 +92,24 @@ def keywords_extraction(self, user_input: str, model: str = "gpt-3.5-turbo"): ) chain = create_extraction_chain(schema, llm) keywords_extracted = list(chain.run(user_input)) + + # Temporary fallback to extracting nouns (if no schema-related keywords found) if any(isinstance(item, str) for item in keywords_extracted): - keywords_extracted = [{key: ""} for key in schema["properties"]] + # import nltk + # nltk.download(['punkt', 'averaged_perceptron_tagger']) + # words = nltk.word_tokenize(user_input) + # pos_tags = nltk.pos_tag(words) + # keywords_extracted = [noun for noun, tag in pos_tags if tag.startswith("N")] + # return keywords_extracted + return list(set(user_input.split())) + return self.prepare_keywords_for_semantic_search(keywords_extracted) def prepare_keywords_for_semantic_search(self, keywords_list: list) -> list: keywords_set = set() for obj in keywords_list: - for k, v in obj.items(): + for _, v in obj.items(): if v: keywords_set.add(v.lower()) return list(keywords_set)