From fd6afd81ec27f552f102875b48e923384c2a4c7e Mon Sep 17 00:00:00 2001 From: moon Date: Fri, 11 Aug 2023 02:29:39 -0700 Subject: [PATCH] Add validation to postgres --- agentmemory/postgres.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/agentmemory/postgres.py b/agentmemory/postgres.py index f302f68..2f45d28 100644 --- a/agentmemory/postgres.py +++ b/agentmemory/postgres.py @@ -4,6 +4,20 @@ from agentmemory.check_model import check_model, infer_embeddings +def parse_metadata(where): + metadata = {} + for key, value in where.items(): + if key[0] != "$": + metadata[key] = value + if isinstance(value, dict): + metadata.update(parse_metadata(value)) + if isinstance(value, list): + for item in value: + if isinstance(item, dict): + metadata.update(parse_metadata(item)) + return metadata + + def handle_and_condition(and_conditions): conditions = [] params = [] @@ -76,7 +90,6 @@ def get( table_name = self.client._table_name(category) conditions = [] params = [] - if where_document is not None: if where_document.get("$contains", None) is not None: where_document = where_document["$contains"] @@ -100,6 +113,8 @@ def get( conditions.append(f"{key}=%s") params.append(str(value)) + self.client._ensure_metadata_columns_exist(category, parse_metadata(where)) + if ids: if not all(isinstance(i, str) or isinstance(i, int) for i in ids): raise Exception( @@ -304,7 +319,7 @@ def get_or_create_collection(self, category): def insert_memory(self, category, document, metadata={}, embedding=None, id=None): self.ensure_table_exists(category) - self._ensure_metadata_columns_exist(category, metadata) + self._ensure_metadata_columns_exist(category, parse_metadata(metadata)) table_name = self._table_name(category) if embedding is None: @@ -336,7 +351,7 @@ def add(self, category, documents, metadatas, ids): table_name = self._table_name(category) with self.connection.cursor() as cur: for document, metadata, id_ in zip(documents, metadatas, ids): - self._ensure_metadata_columns_exist(category, metadata) + self._ensure_metadata_columns_exist(category, parse_metadata(metadata)) columns = ["id", "document", "embedding"] + list(metadata.keys()) placeholders = ["%s"] * len(columns) @@ -431,7 +446,9 @@ def update(self, category, id_, document=None, metadata=None, embedding=None): if embedding is None: embedding = self.create_embedding(document) if metadata: - self._ensure_metadata_columns_exist(category, metadata) + self._ensure_metadata_columns_exist( + category, parse_metadata(metadata) + ) columns = ["document=%s", "embedding=%s"] + [ f"{key}=%s" for key in metadata.keys() ] @@ -447,7 +464,7 @@ def update(self, category, id_, document=None, metadata=None, embedding=None): """ cur.execute(query, tuple(values) + (id_,)) elif metadata: - self._ensure_metadata_columns_exist(category, metadata) + self._ensure_metadata_columns_exist(category, parse_metadata(metadata)) columns = [f"{key}=%s" for key in metadata.keys()] values = list(metadata.values()) query = f"""