Skip to content

Commit

Permalink
Add validation to postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Aug 11, 2023
1 parent 5dd364c commit fd6afd8
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions agentmemory/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from agentmemory.check_model import check_model, infer_embeddings


def parse_metadata(where):
metadata = {}
for key, value in where.items():
if key[0] != "$":
metadata[key] = value
if isinstance(value, dict):
metadata.update(parse_metadata(value))
if isinstance(value, list):
for item in value:
if isinstance(item, dict):
metadata.update(parse_metadata(item))
return metadata


def handle_and_condition(and_conditions):
conditions = []
params = []
Expand Down Expand Up @@ -76,7 +90,6 @@ def get(
table_name = self.client._table_name(category)
conditions = []
params = []

if where_document is not None:
if where_document.get("$contains", None) is not None:
where_document = where_document["$contains"]
Expand All @@ -100,6 +113,8 @@ def get(
conditions.append(f"{key}=%s")
params.append(str(value))

self.client._ensure_metadata_columns_exist(category, parse_metadata(where))

if ids:
if not all(isinstance(i, str) or isinstance(i, int) for i in ids):
raise Exception(
Expand Down Expand Up @@ -304,7 +319,7 @@ def get_or_create_collection(self, category):

def insert_memory(self, category, document, metadata={}, embedding=None, id=None):
self.ensure_table_exists(category)
self._ensure_metadata_columns_exist(category, metadata)
self._ensure_metadata_columns_exist(category, parse_metadata(metadata))
table_name = self._table_name(category)

if embedding is None:
Expand Down Expand Up @@ -336,7 +351,7 @@ def add(self, category, documents, metadatas, ids):
table_name = self._table_name(category)
with self.connection.cursor() as cur:
for document, metadata, id_ in zip(documents, metadatas, ids):
self._ensure_metadata_columns_exist(category, metadata)
self._ensure_metadata_columns_exist(category, parse_metadata(metadata))

columns = ["id", "document", "embedding"] + list(metadata.keys())
placeholders = ["%s"] * len(columns)
Expand Down Expand Up @@ -431,7 +446,9 @@ def update(self, category, id_, document=None, metadata=None, embedding=None):
if embedding is None:
embedding = self.create_embedding(document)
if metadata:
self._ensure_metadata_columns_exist(category, metadata)
self._ensure_metadata_columns_exist(
category, parse_metadata(metadata)
)
columns = ["document=%s", "embedding=%s"] + [
f"{key}=%s" for key in metadata.keys()
]
Expand All @@ -447,7 +464,7 @@ def update(self, category, id_, document=None, metadata=None, embedding=None):
"""
cur.execute(query, tuple(values) + (id_,))
elif metadata:
self._ensure_metadata_columns_exist(category, metadata)
self._ensure_metadata_columns_exist(category, parse_metadata(metadata))
columns = [f"{key}=%s" for key in metadata.keys()]
values = list(metadata.values())
query = f"""
Expand Down

0 comments on commit fd6afd8

Please sign in to comment.