diff --git a/userInput/processUserInput.py b/userInput/processUserInput.py index 66f3540..4f420c5 100644 --- a/userInput/processUserInput.py +++ b/userInput/processUserInput.py @@ -34,7 +34,7 @@ def embedUserInput(): embedding = response.json()['data'][0]['embedding'] return embedding - def performKNNSearch(embedding, k=5): + def performKNNSearch(embedding, k=5, merge_threshold=0.5): # Initialize the Qdrant client client = QdrantClient(host='localhost', port=6333) @@ -61,10 +61,13 @@ def performKNNSearch(embedding, k=5): payload = result.payload score = payload['score'] if 'score' in payload else 0 total_score += score - average_score = total_score / k + + # Determine if the function should be merged based on the threshold + should_merge = "🎉 Merge the function 🎉" if average_score >= merge_threshold else "🙅 Do not merge the function 🙅" - return average_score + return should_merge, average_score embed = embedUserInput() - print('Average Score:', performKNNSearch(embed)) \ No newline at end of file + print('Merge or no Merge?:', performKNNSearch(embed)[0]) + print('Average Score:', performKNNSearch(embed)[1]) \ No newline at end of file