From 1e21a3f7ed79c8adcaaeae6ee7da8cc35eee93f7 Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Wed, 13 Dec 2023 17:05:31 -0800 Subject: [PATCH] [Partner] Gemini Embeddings (#14690) Add support for Gemini embeddings in the langchain-google-genai package --- .../text_embedding/google_generative_ai.ipynb | 220 ++++++++++++++++++ libs/partners/google-genai/README.md | 13 ++ .../langchain_google_genai/__init__.py | 45 +++- .../langchain_google_genai/_common.py | 4 + .../langchain_google_genai/chat_models.py | 47 ++-- .../langchain_google_genai/embeddings.py | 99 ++++++++ .../langchain_google_genai/py.typed | 0 libs/partners/google-genai/poetry.lock | 58 ++++- libs/partners/google-genai/pyproject.toml | 23 +- .../integration_tests/test_embeddings.py | 98 ++++++++ .../tests/unit_tests/test_chat_models.py | 16 +- .../tests/unit_tests/test_embeddings.py | 37 +++ .../tests/unit_tests/test_imports.py | 1 + 13 files changed, 606 insertions(+), 55 deletions(-) create mode 100644 docs/docs/integrations/text_embedding/google_generative_ai.ipynb create mode 100644 libs/partners/google-genai/langchain_google_genai/_common.py create mode 100644 libs/partners/google-genai/langchain_google_genai/embeddings.py create mode 100644 libs/partners/google-genai/langchain_google_genai/py.typed create mode 100644 libs/partners/google-genai/tests/integration_tests/test_embeddings.py create mode 100644 libs/partners/google-genai/tests/unit_tests/test_embeddings.py diff --git a/docs/docs/integrations/text_embedding/google_generative_ai.ipynb b/docs/docs/integrations/text_embedding/google_generative_ai.ipynb new file mode 100644 index 0000000000000..4105b0955d434 --- /dev/null +++ b/docs/docs/integrations/text_embedding/google_generative_ai.ipynb @@ -0,0 +1,220 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "afab8b36-10bb-4795-bc98-75ab2d2081bb", + "metadata": {}, + "source": [ + "# Google Generative AI Embeddings\n", + "\n", + "Connect to Google's generative AI embeddings service using the `GoogleGenerativeAIEmbeddings` class, found in the [langchain-google-genai](https://pypi.org/project/langchain-google-genai/) package." + ] + }, + { + "cell_type": "markdown", + "id": "63545b38-9d56-4312-8f61-8d4f1e7a3b1b", + "metadata": {}, + "source": [ + "## Installation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2f6a3cd-379f-4dff-a449-d3a9f3196f2a", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -U langchain-google-genai" + ] + }, + { + "cell_type": "markdown", + "id": "25f3f88e-164e-400d-b371-9fa488baba19", + "metadata": {}, + "source": [ + "## Credentials" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec89153f-8999-4aab-a21b-0bfba1cc3893", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if \"GOOGLE_API_KEY\" not in os.environ:\n", + " os.environ[\"GOOGLE_API_KEY\"] = getpass(\"Provide your Google API key here\")" + ] + }, + { + "cell_type": "markdown", + "id": "f2437b22-e364-418a-8c13-490a026cb7b5", + "metadata": {}, + "source": [ + "## Usage" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "eedc551e-a1f3-4fd8-8d65-4e0784c4441b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[0.05636945, 0.0048285457, -0.0762591, -0.023642512, 0.05329321]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_google_genai import GoogleGenerativeAIEmbeddings\n", + "\n", + "embeddings = GoogleGenerativeAIEmbeddings(model=\"models/embedding-001\")\n", + "vector = embeddings.embed_query(\"hello, world!\")\n", + "vector[:5]" + ] + }, + { + "cell_type": "markdown", + "id": "2b2bed60-e7bd-4e48-83d6-1c87001f98bd", + "metadata": {}, + "source": [ + "## Batch\n", + "\n", + "You can also embed multiple strings at once for a processing speedup:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6ec53aba-404f-4778-acd9-5d6664e79ed2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3, 768)" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectors = embeddings.embed_documents(\n", + " [\n", + " \"Today is Monday\",\n", + " \"Today is Tuesday\",\n", + " \"Today is April Fools day\",\n", + " ]\n", + ")\n", + "len(vectors), len(vectors[0])" + ] + }, + { + "cell_type": "markdown", + "id": "1482486f-5617-498a-8a44-1974d3212dda", + "metadata": {}, + "source": [ + "## Task type\n", + "`GoogleGenerativeAIEmbeddings` optionally support a `task_type`, which currently must be one of:\n", + "\n", + "- task_type_unspecified\n", + "- retrieval_query\n", + "- retrieval_document\n", + "- semantic_similarity\n", + "- classification\n", + "- clustering\n", + "\n", + "By default, we use `retrieval_document` in the `embed_documents` method and `retrieval_query` in the `embed_query` method. If you provide a task type, we will use that for all methods." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a223bb25-2b1b-418e-a570-2f543083132e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --quiet matplotlib scikit-learn" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "f1f077db-8eb4-49f7-8866-471a8528dcdb", + "metadata": {}, + "outputs": [], + "source": [ + "query_embeddings = GoogleGenerativeAIEmbeddings(\n", + " model=\"models/embedding-001\", task_type=\"retrieval_query\"\n", + ")\n", + "doc_embeddings = GoogleGenerativeAIEmbeddings(\n", + " model=\"models/embedding-001\", task_type=\"retrieval_document\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "79bd4a5e-75ba-413c-befa-86167c938caf", + "metadata": {}, + "source": [ + "All of these will be embedded with the 'retrieval_query' task set\n", + "```python\n", + "query_vecs = [query_embeddings.embed_query(q) for q in [query, query_2, answer_1]]\n", + "```\n", + "All of these will be embedded with the 'retrieval_document' task set\n", + "```python\n", + "doc_vecs = [doc_embeddings.embed_query(q) for q in [query, query_2, answer_1]]\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "9e1fae5e-0f84-4812-89f5-7d4d71affbc1", + "metadata": {}, + "source": [ + "In retrieval, relative distance matters. In the image above, you can see the difference in similarity scores between the \"relevant doc\" and \"simil stronger delta between the similar query and relevant doc on the latter case." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/partners/google-genai/README.md b/libs/partners/google-genai/README.md index 578645a9d7cea..32d9393cc0b9b 100644 --- a/libs/partners/google-genai/README.md +++ b/libs/partners/google-genai/README.md @@ -56,3 +56,16 @@ The value of `image_url` can be any of the following: - A local file path - A base64 encoded image (e.g., ``) - A PIL image + + + +## Embeddings + +This package also adds support for google's embeddings models. + +``` +from langchain_google_genai import GoogleGenerativeAIEmbeddings + +embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") +embeddings.embed_query("hello, world!") +``` \ No newline at end of file diff --git a/libs/partners/google-genai/langchain_google_genai/__init__.py b/libs/partners/google-genai/langchain_google_genai/__init__.py index 5f3d136929571..fce9280d68f6f 100644 --- a/libs/partners/google-genai/langchain_google_genai/__init__.py +++ b/libs/partners/google-genai/langchain_google_genai/__init__.py @@ -1,3 +1,46 @@ +"""**LangChain Google Generative AI Integration** + +This module integrates Google's Generative AI models, specifically the Gemini series, with the LangChain framework. It provides classes for interacting with chat models and generating embeddings, leveraging Google's advanced AI capabilities. + +**Chat Models** + +The `ChatGoogleGenerativeAI` class is the primary interface for interacting with Google's Gemini chat models. It allows users to send and receive messages using a specified Gemini model, suitable for various conversational AI applications. + +**Embeddings** + +The `GoogleGenerativeAIEmbeddings` class provides functionalities to generate embeddings using Google's models. +These embeddings can be used for a range of NLP tasks, including semantic analysis, similarity comparisons, and more. + +**Installation** + +To install the package, use pip: + +```python +pip install -U langchain-google-genai +``` +## Using Chat Models + +After setting up your environment with the required API key, you can interact with the Google Gemini models. + +```python +from langchain_google_genai import ChatGoogleGenerativeAI + +llm = ChatGoogleGenerativeAI(model="gemini-pro") +llm.invoke("Sing a ballad of LangChain.") +``` + +## Embedding Generation + +The package also supports creating embeddings with Google's models, useful for textual similarity and other NLP applications. + +```python +from langchain_google_genai import GoogleGenerativeAIEmbeddings + +embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") +embeddings.embed_query("hello, world!") +``` +""" # noqa: E501 from langchain_google_genai.chat_models import ChatGoogleGenerativeAI +from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings -__all__ = ["ChatGoogleGenerativeAI"] +__all__ = ["ChatGoogleGenerativeAI", "GoogleGenerativeAIEmbeddings"] diff --git a/libs/partners/google-genai/langchain_google_genai/_common.py b/libs/partners/google-genai/langchain_google_genai/_common.py new file mode 100644 index 0000000000000..d7bae390b23a4 --- /dev/null +++ b/libs/partners/google-genai/langchain_google_genai/_common.py @@ -0,0 +1,4 @@ +class GoogleGenerativeAIError(Exception): + """ + Custom exception class for errors associated with the `Google GenAI` API. + """ diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index 9d22411a6e27a..d4724f9d892a4 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -5,7 +5,6 @@ import os from io import BytesIO from typing import ( - TYPE_CHECKING, Any, AsyncIterator, Callable, @@ -22,6 +21,8 @@ ) from urllib.parse import urlparse +# TODO: remove ignore once the google package is published with types +import google.generativeai as genai # type: ignore[import] import requests from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -38,7 +39,7 @@ HumanMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator from langchain_core.utils import get_from_dict_or_env from tenacity import ( before_sleep_log, @@ -48,11 +49,8 @@ wait_exponential, ) -logger = logging.getLogger(__name__) +from langchain_google_genai._common import GoogleGenerativeAIError -if TYPE_CHECKING: - # TODO: remove ignore once the google package is published with types - import google.generativeai as genai # type: ignore[import] IMAGE_TYPES: Tuple = () try: import PIL @@ -63,8 +61,10 @@ PIL = None # type: ignore Image = None # type: ignore +logger = logging.getLogger(__name__) + -class ChatGoogleGenerativeAIError(Exception): +class ChatGoogleGenerativeAIError(GoogleGenerativeAIError): """ Custom exception class for errors associated with the `Google GenAI` API. @@ -106,7 +106,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]: ) -def chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: +def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: """ Executes a chat generation method with retry logic using tenacity. @@ -139,7 +139,7 @@ def _chat_with_retry(**kwargs: Any) -> Any: return _chat_with_retry(**kwargs) -async def achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: +async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: """ Executes a chat generation method with retry logic using tenacity. @@ -269,8 +269,6 @@ def _convert_to_parts( content: Sequence[Union[str, dict]], ) -> List[genai.types.PartType]: """Converts a list of LangChain messages into a google parts.""" - import google.generativeai as genai - parts = [] for part in content: if isinstance(part, str): @@ -410,8 +408,7 @@ def _response_to_result( class ChatGoogleGenerativeAI(BaseChatModel): """`Google Generative AI` Chat models API. - To use you must have the google.generativeai Python package installed and - either: + To use, you must have either: 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or 2. Pass your API key using the google_api_key kwarg to the ChatGoogle @@ -435,7 +432,7 @@ class ChatGoogleGenerativeAI(BaseChatModel): max_output_tokens: int = Field(default=None, description="Max output tokens") client: Any #: :meta private: - google_api_key: Optional[str] = None + google_api_key: Optional[SecretStr] = None temperature: Optional[float] = None """Run inference with this temperature. Must by in the closed interval [0.0, 1.0].""" @@ -487,17 +484,9 @@ def validate_environment(cls, values: Dict) -> Dict: google_api_key = get_from_dict_or_env( values, "google_api_key", "GOOGLE_API_KEY" ) - try: - import google.generativeai as genai - - genai.configure(api_key=google_api_key) - except ImportError: - raise ChatGoogleGenerativeAIError( - "Could not import google.generativeai python package. " - "Please install it with `pip install google-generativeai`" - ) - - values["client"] = genai + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() + genai.configure(api_key=google_api_key) if ( values.get("temperature") is not None and not 0 <= values["temperature"] <= 1 @@ -560,7 +549,7 @@ def _generate( **kwargs: Any, ) -> ChatResult: params = self._prepare_params(messages, stop, **kwargs) - response: genai.types.GenerateContentResponse = chat_with_retry( + response: genai.types.GenerateContentResponse = _chat_with_retry( **params, generation_method=self._generation_method, ) @@ -574,7 +563,7 @@ async def _agenerate( **kwargs: Any, ) -> ChatResult: params = self._prepare_params(messages, stop, **kwargs) - response: genai.types.GenerateContentResponse = await achat_with_retry( + response: genai.types.GenerateContentResponse = await _achat_with_retry( **params, generation_method=self._async_generation_method, ) @@ -588,7 +577,7 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: params = self._prepare_params(messages, stop, **kwargs) - response: genai.types.GenerateContentResponse = chat_with_retry( + response: genai.types.GenerateContentResponse = _chat_with_retry( **params, generation_method=self._generation_method, stream=True, @@ -614,7 +603,7 @@ async def _astream( **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: params = self._prepare_params(messages, stop, **kwargs) - async for chunk in await achat_with_retry( + async for chunk in await _achat_with_retry( **params, generation_method=self._async_generation_method, stream=True, diff --git a/libs/partners/google-genai/langchain_google_genai/embeddings.py b/libs/partners/google-genai/langchain_google_genai/embeddings.py new file mode 100644 index 0000000000000..0b581265fef07 --- /dev/null +++ b/libs/partners/google-genai/langchain_google_genai/embeddings.py @@ -0,0 +1,99 @@ +from typing import Dict, List, Optional + +# TODO: remove ignore once the google package is published with types +import google.generativeai as genai # type: ignore[import] +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator +from langchain_core.utils import get_from_dict_or_env + +from langchain_google_genai._common import GoogleGenerativeAIError + + +class GoogleGenerativeAIEmbeddings(BaseModel, Embeddings): + """`Google Generative AI Embeddings`. + + To use, you must have either: + + 1. The ``GOOGLE_API_KEY``` environment variable set with your API key, or + 2. Pass your API key using the google_api_key kwarg to the ChatGoogle + constructor. + + Example: + .. code-block:: python + + from langchain_google_genai import GoogleGenerativeAIEmbeddings + + embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") + embeddings.embed_query("What's our Q1 revenue?") + """ + + model: str = Field( + ..., + description="The name of the embedding model to use. " + "Example: models/embedding-001", + ) + task_type: Optional[str] = Field( + None, + description="The task type. Valid options include: " + "task_type_unspecified, retrieval_query, retrieval_document, " + "semantic_similarity, classification, and clustering", + ) + google_api_key: Optional[SecretStr] = Field( + None, + description="The Google API key to use. If not provided, " + "the GOOGLE_API_KEY environment variable will be used.", + ) + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validates that the python package exists in environment.""" + google_api_key = get_from_dict_or_env( + values, "google_api_key", "GOOGLE_API_KEY" + ) + if isinstance(google_api_key, SecretStr): + google_api_key = google_api_key.get_secret_value() + genai.configure(api_key=google_api_key) + return values + + def _embed( + self, texts: List[str], task_type: str, title: Optional[str] = None + ) -> List[List[float]]: + task_type = self.task_type or "retrieval_document" + try: + result = genai.embed_content( + model=self.model, + content=texts, + task_type=task_type, + title=title, + ) + except Exception as e: + raise GoogleGenerativeAIError(f"Error embedding content: {e}") from e + return result["embedding"] + + def embed_documents( + self, texts: List[str], batch_size: int = 5 + ) -> List[List[float]]: + """Embed a list of strings. Vertex AI currently + sets a max batch size of 5 strings. + + Args: + texts: List[str] The list of strings to embed. + batch_size: [int] The batch size of embeddings to send to the model + + Returns: + List of embeddings, one for each text. + """ + task_type = self.task_type or "retrieval_document" + return self._embed(texts, task_type=task_type) + + def embed_query(self, text: str) -> List[float]: + """Embed a text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + task_type = self.task_type or "retrieval_query" + return self._embed([text], task_type=task_type)[0] diff --git a/libs/partners/google-genai/langchain_google_genai/py.typed b/libs/partners/google-genai/langchain_google_genai/py.typed new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/libs/partners/google-genai/poetry.lock b/libs/partners/google-genai/poetry.lock index f642dd02e1dec..e017eb6980878 100644 --- a/libs/partners/google-genai/poetry.lock +++ b/libs/partners/google-genai/poetry.lock @@ -441,7 +441,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"}, - {file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"}, ] [[package]] @@ -546,6 +545,51 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "numpy" +version = "1.26.2" +description = "Fundamental package for array computing in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "numpy-1.26.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3703fc9258a4a122d17043e57b35e5ef1c5a5837c3db8be396c82e04c1cf9b0f"}, + {file = "numpy-1.26.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cc392fdcbd21d4be6ae1bb4475a03ce3b025cd49a9be5345d76d7585aea69440"}, + {file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36340109af8da8805d8851ef1d74761b3b88e81a9bd80b290bbfed61bd2b4f75"}, + {file = "numpy-1.26.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcc008217145b3d77abd3e4d5ef586e3bdfba8fe17940769f8aa09b99e856c00"}, + {file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:3ced40d4e9e18242f70dd02d739e44698df3dcb010d31f495ff00a31ef6014fe"}, + {file = "numpy-1.26.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b272d4cecc32c9e19911891446b72e986157e6a1809b7b56518b4f3755267523"}, + {file = "numpy-1.26.2-cp310-cp310-win32.whl", hash = "sha256:22f8fc02fdbc829e7a8c578dd8d2e15a9074b630d4da29cda483337e300e3ee9"}, + {file = "numpy-1.26.2-cp310-cp310-win_amd64.whl", hash = "sha256:26c9d33f8e8b846d5a65dd068c14e04018d05533b348d9eaeef6c1bd787f9919"}, + {file = "numpy-1.26.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b96e7b9c624ef3ae2ae0e04fa9b460f6b9f17ad8b4bec6d7756510f1f6c0c841"}, + {file = "numpy-1.26.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aa18428111fb9a591d7a9cc1b48150097ba6a7e8299fb56bdf574df650e7d1f1"}, + {file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06fa1ed84aa60ea6ef9f91ba57b5ed963c3729534e6e54055fc151fad0423f0a"}, + {file = "numpy-1.26.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96ca5482c3dbdd051bcd1fce8034603d6ebfc125a7bd59f55b40d8f5d246832b"}, + {file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:854ab91a2906ef29dc3925a064fcd365c7b4da743f84b123002f6139bcb3f8a7"}, + {file = "numpy-1.26.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f43740ab089277d403aa07567be138fc2a89d4d9892d113b76153e0e412409f8"}, + {file = "numpy-1.26.2-cp311-cp311-win32.whl", hash = "sha256:a2bbc29fcb1771cd7b7425f98b05307776a6baf43035d3b80c4b0f29e9545186"}, + {file = "numpy-1.26.2-cp311-cp311-win_amd64.whl", hash = "sha256:2b3fca8a5b00184828d12b073af4d0fc5fdd94b1632c2477526f6bd7842d700d"}, + {file = "numpy-1.26.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a4cd6ed4a339c21f1d1b0fdf13426cb3b284555c27ac2f156dfdaaa7e16bfab0"}, + {file = "numpy-1.26.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5d5244aabd6ed7f312268b9247be47343a654ebea52a60f002dc70c769048e75"}, + {file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a3cdb4d9c70e6b8c0814239ead47da00934666f668426fc6e94cce869e13fd7"}, + {file = "numpy-1.26.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa317b2325f7aa0a9471663e6093c210cb2ae9c0ad824732b307d2c51983d5b6"}, + {file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:174a8880739c16c925799c018f3f55b8130c1f7c8e75ab0a6fa9d41cab092fd6"}, + {file = "numpy-1.26.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f79b231bf5c16b1f39c7f4875e1ded36abee1591e98742b05d8a0fb55d8a3eec"}, + {file = "numpy-1.26.2-cp312-cp312-win32.whl", hash = "sha256:4a06263321dfd3598cacb252f51e521a8cb4b6df471bb12a7ee5cbab20ea9167"}, + {file = "numpy-1.26.2-cp312-cp312-win_amd64.whl", hash = "sha256:b04f5dc6b3efdaab541f7857351aac359e6ae3c126e2edb376929bd3b7f92d7e"}, + {file = "numpy-1.26.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4eb8df4bf8d3d90d091e0146f6c28492b0be84da3e409ebef54349f71ed271ef"}, + {file = "numpy-1.26.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1a13860fdcd95de7cf58bd6f8bc5a5ef81c0b0625eb2c9a783948847abbef2c2"}, + {file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:64308ebc366a8ed63fd0bf426b6a9468060962f1a4339ab1074c228fa6ade8e3"}, + {file = "numpy-1.26.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baf8aab04a2c0e859da118f0b38617e5ee65d75b83795055fb66c0d5e9e9b818"}, + {file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d73a3abcac238250091b11caef9ad12413dab01669511779bc9b29261dd50210"}, + {file = "numpy-1.26.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b361d369fc7e5e1714cf827b731ca32bff8d411212fccd29ad98ad622449cc36"}, + {file = "numpy-1.26.2-cp39-cp39-win32.whl", hash = "sha256:bd3f0091e845164a20bd5a326860c840fe2af79fa12e0469a12768a3ec578d80"}, + {file = "numpy-1.26.2-cp39-cp39-win_amd64.whl", hash = "sha256:2beef57fb031dcc0dc8fa4fe297a742027b954949cabb52a2a376c144e5e6060"}, + {file = "numpy-1.26.2-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:1cc3d5029a30fb5f06704ad6b23b35e11309491c999838c31f124fee32107c79"}, + {file = "numpy-1.26.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94cc3c222bb9fb5a12e334d0479b97bb2df446fbe622b470928f5284ffca3f8d"}, + {file = "numpy-1.26.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fe6b44fb8fcdf7eda4ef4461b97b3f63c466b27ab151bec2366db8b197387841"}, + {file = "numpy-1.26.2.tar.gz", hash = "sha256:f65738447676ab5777f11e6bbbdb8ce11b785e105f690bc45966574816b6d3ea"}, +] + [[package]] name = "packaging" version = "23.2" @@ -935,7 +979,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -943,15 +986,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -968,7 +1004,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -976,7 +1011,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1229,4 +1263,4 @@ watchmedo = ["PyYAML (>=3.10)"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "7753b9e2cb62c5b4dac124f0ff43027232c45138dbf07fdacc3c320b82367dad" +content-hash = "ec0b5e3da951c44178eac11414611121ed2783d04b8957de8f6a189b5a6bcc2b" diff --git a/libs/partners/google-genai/pyproject.toml b/libs/partners/google-genai/pyproject.toml index fd620d81f15d1..bfd42291753b4 100644 --- a/libs/partners/google-genai/pyproject.toml +++ b/libs/partners/google-genai/pyproject.toml @@ -1,9 +1,10 @@ [tool.poetry] name = "langchain-google-genai" -version = "0.0.2" +version = "0.0.3" description = "An integration package connecting Google's genai package and LangChain" authors = [] readme = "README.md" +repository = "https://github.com/langchain-ai/langchain/blob/master/libs/partners/google-genai" [tool.poetry.dependencies] python = ">=3.9,<4.0" @@ -16,11 +17,12 @@ optional = true [tool.poetry.group.test.dependencies] pytest = "^7.3.0" freezegun = "^1.2.2" -pytest-mock = "^3.10.0" +pytest-mock = "^3.10.0" syrupy = "^4.0.2" pytest-watcher = "^0.3.4" pytest-asyncio = "^0.21.1" -langchain-core = {path = "../../core", develop = true} +langchain-core = { path = "../../core", develop = true } +numpy = "^1.26.2" [tool.poetry.group.codespell] optional = true @@ -41,7 +43,7 @@ ruff = "^0.1.5" [tool.poetry.group.typing.dependencies] mypy = "^0.991" -langchain-core = {path = "../../core", develop = true} +langchain-core = { path = "../../core", develop = true } types-requests = "^2.28.11.5" types-google-cloud-ndb = "^2.2.0.1" types-pillow = "^10.1.0.2" @@ -50,7 +52,7 @@ types-pillow = "^10.1.0.2" optional = true [tool.poetry.group.dev.dependencies] -langchain-core = {path = "../../core", develop = true} +langchain-core = { path = "../../core", develop = true } pillow = "^10.1.0" types-requests = "^2.31.0.10" types-pillow = "^10.1.0.2" @@ -58,19 +60,16 @@ types-google-cloud-ndb = "^2.2.0.1" [tool.ruff] select = [ - "E", # pycodestyle - "F", # pyflakes - "I", # isort + "E", # pycodestyle + "F", # pyflakes + "I", # isort ] [tool.mypy] disallow_untyped_defs = "True" -exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"] [tool.coverage.run] -omit = [ - "tests/*", -] +omit = ["tests/*"] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/libs/partners/google-genai/tests/integration_tests/test_embeddings.py b/libs/partners/google-genai/tests/integration_tests/test_embeddings.py new file mode 100644 index 0000000000000..13fbe9a632c85 --- /dev/null +++ b/libs/partners/google-genai/tests/integration_tests/test_embeddings.py @@ -0,0 +1,98 @@ +import numpy as np +import pytest + +from langchain_google_genai._common import GoogleGenerativeAIError +from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings + +_MODEL = "models/embedding-001" + + +@pytest.mark.parametrize( + "query", + [ + "Hi", + "This is a longer query string to test the embedding functionality of the" + " model against the pickle rick?", + ], +) +def test_embed_query_different_lengths(query: str) -> None: + """Test embedding queries of different lengths.""" + model = GoogleGenerativeAIEmbeddings(model=_MODEL) + result = model.embed_query(query) + assert len(result) == 768 + + +@pytest.mark.parametrize( + "query", + [ + "Hi", + "This is a longer query string to test the embedding functionality of the" + " model against the pickle rick?", + ], +) +async def test_aembed_query_different_lengths(query: str) -> None: + """Test embedding queries of different lengths.""" + model = GoogleGenerativeAIEmbeddings(model=_MODEL) + result = await model.aembed_query(query) + assert len(result) == 768 + + +def test_embed_documents() -> None: + """Test embedding a query.""" + model = GoogleGenerativeAIEmbeddings( + model=_MODEL, + ) + result = model.embed_documents(["Hello world", "Good day, world"]) + assert len(result) == 2 + assert len(result[0]) == 768 + assert len(result[1]) == 768 + + +async def test_aembed_documents() -> None: + """Test embedding a query.""" + model = GoogleGenerativeAIEmbeddings( + model=_MODEL, + ) + result = await model.aembed_documents(["Hello world", "Good day, world"]) + assert len(result) == 2 + assert len(result[0]) == 768 + assert len(result[1]) == 768 + + +def test_invalid_model_error_handling() -> None: + """Test error handling with an invalid model name.""" + with pytest.raises(GoogleGenerativeAIError): + GoogleGenerativeAIEmbeddings(model="invalid_model").embed_query("Hello world") + + +def test_invalid_api_key_error_handling() -> None: + """Test error handling with an invalid API key.""" + with pytest.raises(GoogleGenerativeAIError): + GoogleGenerativeAIEmbeddings( + model=_MODEL, google_api_key="invalid_key" + ).embed_query("Hello world") + + +def test_embed_documents_consistency() -> None: + """Test embedding consistency for the same document.""" + model = GoogleGenerativeAIEmbeddings(model=_MODEL) + doc = "Consistent document for testing" + result1 = model.embed_documents([doc]) + result2 = model.embed_documents([doc]) + assert result1 == result2 + + +def test_embed_documents_quality() -> None: + """Smoke test embedding quality by comparing similar and dissimilar documents.""" + model = GoogleGenerativeAIEmbeddings(model=_MODEL) + similar_docs = ["Document A", "Similar Document A"] + dissimilar_docs = ["Document A", "Completely Different Zebra"] + similar_embeddings = model.embed_documents(similar_docs) + dissimilar_embeddings = model.embed_documents(dissimilar_docs) + similar_distance = np.linalg.norm( + np.array(similar_embeddings[0]) - np.array(similar_embeddings[1]) + ) + dissimilar_distance = np.linalg.norm( + np.array(dissimilar_embeddings[0]) - np.array(dissimilar_embeddings[1]) + ) + assert similar_distance < dissimilar_distance diff --git a/libs/partners/google-genai/tests/unit_tests/test_chat_models.py b/libs/partners/google-genai/tests/unit_tests/test_chat_models.py index afa26de4d1941..651d64508f963 100644 --- a/libs/partners/google-genai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-genai/tests/unit_tests/test_chat_models.py @@ -1,5 +1,6 @@ """Test chat model integration.""" - +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture from langchain_google_genai.chat_models import ChatGoogleGenerativeAI @@ -22,3 +23,16 @@ def test_integration_initialization() -> None: temperature=0.7, candidate_count=2, ) + + +def test_api_key_is_string() -> None: + chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key") + assert isinstance(chat.google_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: + chat = ChatGoogleGenerativeAI(model="gemini-nano", google_api_key="secret-api-key") + print(chat.google_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" diff --git a/libs/partners/google-genai/tests/unit_tests/test_embeddings.py b/libs/partners/google-genai/tests/unit_tests/test_embeddings.py new file mode 100644 index 0000000000000..45acffb33b42e --- /dev/null +++ b/libs/partners/google-genai/tests/unit_tests/test_embeddings.py @@ -0,0 +1,37 @@ +"""Test embeddings model integration.""" +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture + +from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings + + +def test_integration_initialization() -> None: + """Test chat model initialization.""" + GoogleGenerativeAIEmbeddings( + model="models/embedding-001", + google_api_key="...", + ) + GoogleGenerativeAIEmbeddings( + model="models/embedding-001", + google_api_key="...", + task_type="retrieval_document", + ) + + +def test_api_key_is_string() -> None: + embeddings = GoogleGenerativeAIEmbeddings( + model="models/embedding-001", + google_api_key="secret-api-key", + ) + assert isinstance(embeddings.google_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: + embeddings = GoogleGenerativeAIEmbeddings( + model="models/embedding-001", + google_api_key="secret-api-key", + ) + print(embeddings.google_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" diff --git a/libs/partners/google-genai/tests/unit_tests/test_imports.py b/libs/partners/google-genai/tests/unit_tests/test_imports.py index 9cbd73df816bd..8a2bc789f567b 100644 --- a/libs/partners/google-genai/tests/unit_tests/test_imports.py +++ b/libs/partners/google-genai/tests/unit_tests/test_imports.py @@ -2,6 +2,7 @@ EXPECTED_ALL = [ "ChatGoogleGenerativeAI", + "GoogleGenerativeAIEmbeddings", ]