Skip to content

Commit

Permalink
core[patch]: add standard tracing params for retrievers (#25240)
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme authored Aug 12, 2024
1 parent 9927a48 commit e77eeee
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 3 deletions.
10 changes: 10 additions & 0 deletions libs/community/tests/unit_tests/retrievers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ async def test_fake_retriever_v1_upgrade_async(
assert callbacks.retriever_errors == 0


def test_fake_retriever_v1_standard_params(fake_retriever_v1: BaseRetriever) -> None:
ls_params = fake_retriever_v1._get_ls_params()
assert ls_params == {"ls_retriever_name": "fakeretrieverv1"}


@pytest.fixture
def fake_retriever_v1_with_kwargs() -> BaseRetriever:
# Test for things like the Weaviate V1 Retriever.
Expand Down Expand Up @@ -213,3 +218,8 @@ async def test_fake_retriever_v2_async(
await fake_erroring_retriever_v2.ainvoke(
"Foo", config={"callbacks": [callbacks]}
)


def test_fake_retriever_v2_standard_params(fake_retriever_v2: BaseRetriever) -> None:
ls_params = fake_retriever_v2._get_ls_params()
assert ls_params == {"ls_retriever_name": "fakeretrieverv2"}
5 changes: 5 additions & 0 deletions libs/community/tests/unit_tests/retrievers/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def test_create_client(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
amazon_retriever.create_client({})


def test_standard_params(amazon_retriever: AmazonKnowledgeBasesRetriever) -> None:
ls_params = amazon_retriever._get_ls_params()
assert ls_params == {"ls_retriever_name": "amazonknowledgebases"}


def test_get_relevant_documents(
amazon_retriever: AmazonKnowledgeBasesRetriever, mock_client: MagicMock
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,28 @@ def test_similarity_score_threshold(index_details: dict, threshold: float) -> No
assert len(search_result) == 0


@pytest.mark.requires("databricks", "databricks.vector_search")
def test_standard_params() -> None:
index = mock_index(DIRECT_ACCESS_INDEX)
vectorstore = default_databricks_vector_search(index)
retriever = vectorstore.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "DatabricksVectorSearch",
"ls_embedding_provider": "FakeEmbeddingsWithDimension",
}

index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS)
vectorstore = default_databricks_vector_search(index)
retriever = vectorstore.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "DatabricksVectorSearch",
}


@pytest.mark.requires("databricks", "databricks.vector_search")
@pytest.mark.parametrize(
"index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX]
Expand Down
9 changes: 9 additions & 0 deletions libs/community/tests/unit_tests/vectorstores/test_faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ def test_faiss() -> None:
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo")]

# Retriever standard params
retriever = docsearch.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "FAISS",
"ls_embedding_provider": "FakeEmbeddings",
}


@pytest.mark.requires("faiss")
async def test_faiss_afrom_texts() -> None:
Expand Down
40 changes: 38 additions & 2 deletions libs/core/langchain_core/retrievers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from typing_extensions import TypedDict

from langchain_core._api import deprecated
from langchain_core.documents import Document
from langchain_core.load.dump import dumpd
Expand All @@ -50,6 +52,19 @@
RetrieverOutputLike = Runnable[Any, RetrieverOutput]


class LangSmithRetrieverParams(TypedDict, total=False):
"""LangSmith parameters for tracing."""

ls_retriever_name: str
"""Retriever name."""
ls_vector_store_provider: Optional[str]
"""Vector store provider."""
ls_embedding_provider: Optional[str]
"""Embedding provider."""
ls_embedding_model: Optional[str]
"""Embedding model."""


class BaseRetriever(RunnableSerializable[RetrieverInput, RetrieverOutput], ABC):
"""Abstract base class for a Document retrieval system.
Expand Down Expand Up @@ -167,6 +182,19 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)

def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""

default_retriever_name = self.get_name()
if default_retriever_name.startswith("Retriever"):
default_retriever_name = default_retriever_name[9:]
elif default_retriever_name.endswith("Retriever"):
default_retriever_name = default_retriever_name[:-9]
default_retriever_name = default_retriever_name.lower()

ls_params = LangSmithRetrieverParams(ls_retriever_name=default_retriever_name)
return ls_params

def invoke(
self, input: str, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Document]:
Expand All @@ -191,13 +219,17 @@ def invoke(
from langchain_core.callbacks.manager import CallbackManager

config = ensure_config(config)
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(**kwargs),
}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags"),
local_tags=self.tags,
inheritable_metadata=config.get("metadata"),
inheritable_metadata=inheritable_metadata,
local_metadata=self.metadata,
)
run_manager = callback_manager.on_retriever_start(
Expand Down Expand Up @@ -250,13 +282,17 @@ async def ainvoke(
from langchain_core.callbacks.manager import AsyncCallbackManager

config = ensure_config(config)
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(**kwargs),
}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
None,
verbose=kwargs.get("verbose", False),
inheritable_tags=config.get("tags"),
local_tags=self.tags,
inheritable_metadata=config.get("metadata"),
inheritable_metadata=inheritable_metadata,
local_metadata=self.metadata,
)
run_manager = await callback_manager.on_retriever_start(
Expand Down
21 changes: 20 additions & 1 deletion libs/core/langchain_core/vectorstores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.retrievers import BaseRetriever
from langchain_core.retrievers import BaseRetriever, LangSmithRetrieverParams
from langchain_core.runnables.config import run_in_executor

if TYPE_CHECKING:
Expand Down Expand Up @@ -1014,6 +1014,25 @@ def validate_search_type(cls, values: Dict) -> Dict:
)
return values

def _get_ls_params(self, **kwargs: Any) -> LangSmithRetrieverParams:
"""Get standard params for tracing."""

ls_params = super()._get_ls_params(**kwargs)
ls_params["ls_vector_store_provider"] = self.vectorstore.__class__.__name__

if self.vectorstore.embeddings:
ls_params["ls_embedding_provider"] = (
self.vectorstore.embeddings.__class__.__name__
)
elif hasattr(self.vectorstore, "embedding") and isinstance(
self.vectorstore.embedding, Embeddings
):
ls_params["ls_embedding_provider"] = (
self.vectorstore.embedding.__class__.__name__
)

return ls_params

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
Expand Down

0 comments on commit e77eeee

Please sign in to comment.