diff --git a/agents-api/agents_api/activities/embed_docs.py b/agents-api/agents_api/activities/embed_docs.py index 924424881..86a119309 100644 --- a/agents-api/agents_api/activities/embed_docs.py +++ b/agents-api/agents_api/activities/embed_docs.py @@ -1,3 +1,6 @@ +import asyncio +from itertools import batched + from beartype import beartype from temporalio import activity @@ -8,30 +11,47 @@ @beartype -async def embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: +async def embed_docs( + payload: EmbedDocsPayload, cozo_client=None, max_batch_size: int = 100 +) -> None: indices, snippets = list(zip(*enumerate(payload.content))) - embed_instruction: str = payload.embed_instruction or "" - title: str = payload.title or "" - - embeddings = await litellm.aembedding( - inputs=[ - ( - embed_instruction + (title + "\n\n" + snippet) if title else snippet - ).strip() - for snippet in snippets - ] + batched_indices, batched_snippets = ( + batched(indices, max_batch_size), + batched(snippets, max_batch_size), ) - embed_snippets_query( - developer_id=payload.developer_id, - doc_id=payload.doc_id, - snippet_indices=indices, - embeddings=embeddings, - client=cozo_client or cozo.get_cozo_client(), + async def embed_batch(indices, snippets): + embed_instruction: str = payload.embed_instruction or "" + title: str = payload.title or "" + + embeddings = await litellm.aembedding( + inputs=[ + ( + embed_instruction + (title + "\n\n" + snippet) if title else snippet + ).strip() + for snippet in snippets + ] + ) + + embed_snippets_query( + developer_id=payload.developer_id, + doc_id=payload.doc_id, + snippet_indices=indices, + embeddings=embeddings, + client=cozo_client or cozo.get_cozo_client(), + ) + + await asyncio.wait( + [ + embed_batch(indices, snippets) + for indices, snippets in zip(batched_indices, batched_snippets) + ] ) -async def mock_embed_docs(payload: EmbedDocsPayload, cozo_client=None) -> None: +async def mock_embed_docs( + payload: EmbedDocsPayload, cozo_client=None, max_batch_size=100 +) -> None: # Does nothing return None