From 30b26be14d8fb4cd503459450aca4a8b6a910955 Mon Sep 17 00:00:00 2001 From: Dmitry Paramonov Date: Tue, 15 Oct 2024 11:45:57 +0300 Subject: [PATCH] fix: Split payload content by smaller batches for embedding --- .../agents_api/activities/embed_docs.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 924424881..bbdc66043 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,3 +1,8 @@ +import asyncio +import operator +from functools import reduce +from itertools import batched + from beartype import beartype from temporalio import activity @@ -8,18 +13,27 @@ @beartype -async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: +async def embed_docs( + payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 +) -> None: indices, snippets = list(zip(*enumerate(payload.content))) + batched_snippets = batched(snippets, max_batch_size) embed_instruction: str = payload.embed_instruction or "" title: str = payload.title or "" - embeddings = await litellm.aembedding( - inputs=[ - ( - embed_instruction + (title + "\n\n" + snippet) if title else snippet - ).strip() - for snippet in snippets - ] + async def embed_batch(snippets): + return await litellm.aembedding( + inputs=[ + ( + embed_instruction + (title + "\n\n" + snippet) if title else snippet + ).strip() + for snippet in snippets + ] + ) + + embeddings = reduce( + operator.add, + await asyncio.gather(*[embed_batch(snippets) for snippets in batched_snippets]), ) embed_snippets_query( @@ -31,7 +45,9 @@ async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: ) -async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: +async def mock_embed_docs( + payload: EmbedDocsPayload, cozo_client=None, max_batch_size=100 +) -> None: # Does nothing return None