Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor matching_engine (migration from lc monorepo) #17

Merged
merged 18 commits into from
Feb 26, 2024
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Union

from google.cloud import storage

if TYPE_CHECKING:
from google.cloud import datastore


class DocumentStorage(ABC):
"""Abstract interface of a key, text storage for retrieving documents."""

@abstractmethod
def get_by_id(self, document_id: str) -> Union[str, None]:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""
raise NotImplementedError()

@abstractmethod
def store_by_id(self, document_id: str, text: str):
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
raise NotImplementedError()

def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
"""Stores a list of ids and documents in batch.
The default implementation only loops to the individual `store_by_id`.
Subclasses that have faster ways to store data via batch uploading should
implement the proper way.
Args:
ids: List of ids for the text.
texts: List of texts.
"""
for id_, text in zip(ids, texts):
self.store_by_id(id_, text)

def batch_get_by_id(self, ids: List[str]) -> List[Union[str, None]]:
"""Gets a batch of documents by id.
The default implementation only loops `get_by_id`.
Subclasses that have faster ways to retrieve data by batch should implement
this method.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
"""
return [self.get_by_id(id_) for id_ in ids]


class GCSDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud Storage.
For each pair id, document_text the name of the blob will be {prefix}/{id} stored
in plain text format.
"""

def __init__(
self, bucket: "storage.Bucket", prefix: Optional[str] = "documents"
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
super().__init__()
self._bucket = bucket
self._prefix = prefix

def get_by_id(self, document_id: str) -> Union[str, None]:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""

blob_name = self._get_blob_name(document_id)
existing_blob = self._bucket.get_blob(blob_name)

if existing_blob is None:
return None

return existing_blob.download_as_text()

def store_by_id(self, document_id: str, text: str) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
blob_name = self._get_blob_name(document_id)
new_blow = self._bucket.blob(blob_name)
new_blow.upload_from_string(text)

def _get_blob_name(self, document_id: str) -> str:
"""Builds a blob name using the prefix and the document_id.
Args:
document_id: Id of the document.
Returns:
Name of the blob that the document will be/is stored in
"""
return f"{self._prefix}/{document_id}"


class DataStoreDocumentStorage(DocumentStorage):
"""Stores documents in Google Cloud DataStore."""

def __init__(
self,
datastore_client: "datastore.Client",
kind: str = "document_id",
text_property_name: str = "text",
) -> None:
"""Constructor.
Args:
bucket: Bucket where the documents will be stored.
prefix: Prefix that is prepended to all document names.
"""
super().__init__()
self._client = datastore_client
self._text_property_name = text_property_name
self._kind = kind

def get_by_id(self, document_id: str) -> Union[str, None]:
"""Gets the text of a document by its id. If not found, returns None.
Args:
document_id: Id of the document to get from the storage.
Returns:
Text of the document if found, otherwise None.
"""
key = self._client.key(self._kind, document_id)
entity = self._client.get(key)
return entity[self._text_property_name]

def store_by_id(self, document_id: str, text: str) -> None:
"""Stores a document text associated to a document_id.
Args:
document_id: Id of the document to be stored.
text: Text of the document to be stored.
"""
with self._client.transaction():
key = self._client.key(self._kind, document_id)
entity = self._client.entity(key=key)
entity[self._text_property_name] = text
self._client.put(entity)

def batch_get_by_id(self, ids: List[str]) -> List[Union[str, None]]:
"""Gets a batch of documents by id.
Args:
ids: List of ids for the text.
Returns:
List of texts. If the key id is not found for any id record returns a None
instead.
"""
keys = [self._client.key(self._kind, id_) for id_ in ids]

# TODO: Handle when a key is not present
entities = self._client.get_multi(keys)

return [entity[self._text_property_name] for entity in entities]

def batch_store_by_id(self, ids: List[str], texts: List[str]) -> None:
"""Stores a list of ids and documents in batch.
Args:
ids: List of ids for the text.
texts: List of texts.
"""

with self._client.transaction():
keys = [self._client.key(self._kind, id_) for id_ in ids]

entities = []
for key, text in zip(keys, texts):
entity = self._client.entity(key=key)
entity[self._text_property_name] = text
entities.append(entity)

self._client.put_multi(entities)
116 changes: 116 additions & 0 deletions libs/vertexai/langchain_google_vertexai/vectorstores/_sdk_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import TYPE_CHECKING, Any, Union

from google.cloud import aiplatform, storage
from google.cloud.aiplatform.matching_engine import (
MatchingEngineIndex,
MatchingEngineIndexEndpoint,
)
from google.oauth2.service_account import Credentials

if TYPE_CHECKING:
from google.cloud import datastore


class VectorSearchSDKManager:
"""Class in charge of building all Google Cloud SDK Objects needed to build
VectorStores from project_id, credentials or other specifications. Abstracts
away the authentication layer.
"""

def __init__(
self,
*,
project_id: str,
region: str,
credentials: Union[Credentials, None] = None,
credentials_path: Union[str, None] = None,
) -> None:
"""Constructor.
If `credentials` is provided, those credentials are used. If not provided
`credentials_path` is used to retrieve credentials from a file. If also not
provided, falls back to default credentials.
Args:
project_id: Id of the project.
region: Region of the project. E.j. 'us-central1'
credentials: Google cloud Credentials object.
credentials_path: Google Cloud Credentials json file path.
"""
self._project_id = project_id
self._region = region

if credentials is not None:
self._credentials = credentials
elif credentials_path is not None:
self._credentials = Credentials.from_service_account_file(credentials_path)
else:
self._credentials = None

self.initialize_aiplatform()

def initialize_aiplatform(self) -> None:
"""Initializes aiplatform."""
aiplatform.init(
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_gcs_client(self) -> storage.Client:
"""Retrieves a Google Cloud Storage client.
Returns:
Google Cloud Storage Agent.
"""
return storage.Client(project=self._project_id, credentials=self._credentials)

def get_gcs_bucket(self, bucket_name: str) -> storage.Bucket:
"""Retrieves a Google Cloud Bucket by bucket name.
Args:
bucket_name: Name of the bucket to be retrieved.
Returns:
Google Cloud Bucket.
"""
client = self.get_gcs_client()
return client.get_bucket(bucket_name)

def get_index(self, index_id: str) -> MatchingEngineIndex:
"""Retrieves a MatchingEngineIndex (VectorSearchIndex) by id.
Args:
index_id: Id of the index to be retrieved.
Returns:
MatchingEngineIndex instance.
"""
return MatchingEngineIndex(
index_name=index_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_endpoint(self, endpoint_id: str) -> MatchingEngineIndexEndpoint:
"""Retrieves a MatchingEngineIndexEndpoint (VectorSearchIndexEndpoint) by id.
Args:
endpoint_id: Id of the endpoint to be retrieved.
Returns:
MatchingEngineIndexEndpoint instance.
"""
return MatchingEngineIndexEndpoint(
index_endpoint_name=endpoint_id,
project=self._project_id,
location=self._region,
credentials=self._credentials,
)

def get_datastore_client(self, **kwargs: Any) -> "datastore.Client":
"""Gets a datastore Client.
Args:
**kwargs: Keyword arguments to pass to datatastore.Client constructor.
Returns:
datastore Client.
"""
from google.cloud import datastore

ds_client = datastore.Client(
project=self._project_id, credentials=self._credentials, **kwargs
)

return ds_client
Loading
Loading