Skip to content

Commit

Permalink
0.4.5 - add where and where_document clause handling to postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Aug 11, 2023
1 parent b32d41d commit 5dd364c
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 55 deletions.
4 changes: 2 additions & 2 deletions agentmemory/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CLIENT_TYPE = os.environ.get("CLIENT_TYPE", DEFAULT_CLIENT_TYPE)
STORAGE_PATH = os.environ.get("STORAGE_PATH", "./memory")
POSTGRES_CONNECTION_STRING = os.environ.get("POSTGRES_CONNECTION_STRING")

POSTGRES_MODEL_NAME = os.environ.get("POSTGRES_MODEL_NAME", "all-MiniLM-L6-v2")
client = None


Expand All @@ -28,7 +28,7 @@ def get_client(client_type=None, *args, **kwargs):
raise EnvironmentError(
"Postgres connection string not set in environment variables!"
)
client = PostgresClient(POSTGRES_CONNECTION_STRING)
client = PostgresClient(POSTGRES_CONNECTION_STRING, model_name=POSTGRES_MODEL_NAME)
else:
client = chromadb.PersistentClient(path=STORAGE_PATH, *args, **kwargs)

Expand Down
1 change: 0 additions & 1 deletion agentmemory/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def cluster(epsilon, min_samples, category, filter_metadata=None, novel=False):
cluster_id = 0
for memory in memories:
memory_id = memory["id"]
print("Memory ID: ", memory_id)
if visited[memory_id]:
continue
visited[memory_id] = True
Expand Down
6 changes: 3 additions & 3 deletions agentmemory/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import datetime
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from agentmemory.helpers import (
chroma_collection_to_list,
Expand All @@ -10,7 +13,6 @@

from agentmemory.client import get_client


def create_memory(category, text, metadata={}, embedding=None, id=None):
"""
Create a new memory in a collection.
Expand Down Expand Up @@ -344,8 +346,6 @@ def update_memory(category, id, text=None, metadata=None, embedding=None):
documents = [text] if text is not None else None
metadatas = [metadata] if metadata is not None else None
embeddings = [embedding] if embedding is not None else None
print('********************** UPDATE')
print(id, documents, metadatas, embeddings)
# Update the memory with the new text and/or metadata
memories.update(
ids=[str(id)], documents=documents, metadatas=metadatas, embeddings=embeddings
Expand Down
191 changes: 144 additions & 47 deletions agentmemory/postgres.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,44 @@
import json
from pathlib import Path
import psycopg2

from agentmemory.check_model import check_model, infer_embeddings


def handle_and_condition(and_conditions):
conditions = []
params = []
for condition in and_conditions:
for key, value in condition.items():
for operator, operand in value.items():
sql_operator = get_sql_operator(operator)
conditions.append(f"{key} {sql_operator} %s")
params.append(operand)
return conditions, params


def handle_or_condition(or_conditions):
or_groups = []
params = []
for condition in or_conditions:
conditions, new_params = handle_and_condition([condition])
or_groups.append(" AND ".join(conditions))
params.extend(new_params)
return f"({') OR ('.join(or_groups)})", params


def get_sql_operator(operator):
if operator == "$eq":
return "="
elif operator == "$ne":
return "!="
elif operator == "$gt":
return ">"
elif operator == "$lt":
return "<"
else:
raise ValueError(f"Operator {operator} not supported")


class PostgresCollection:
def __init__(self, category, client):
self.category = category
Expand Down Expand Up @@ -40,37 +74,61 @@ def get(
):
category = self.category
table_name = self.client._table_name(category)
conditions = []
params = []

if not ids:
if limit is None:
limit = 100 # or another default value
if offset is None:
offset = 0
if where_document is not None:
if where_document.get("$contains", None) is not None:
where_document = where_document["$contains"]
conditions.append("document LIKE %s")
params.append(f"%{where_document}%")

query = f"SELECT * FROM {table_name} LIMIT %s OFFSET %s"
params = (limit, offset)
if where:
for key, value in where.items():
if key == "$and":
new_conditions, new_params = handle_and_condition(value)
conditions.extend(new_conditions)
params.extend(new_params)
elif key == "$or":
or_condition, new_params = handle_or_condition(value)
conditions.append(or_condition)
params.extend(new_params)
elif key == "$contains":
conditions.append(f"document LIKE %s")
params.append(f"%{value}%")
else:
conditions.append(f"{key}=%s")
params.append(str(value))

else:
if ids:
if not all(isinstance(i, str) or isinstance(i, int) for i in ids):
raise Exception(
"ids must be a list of integers or strings representing integers"
)
ids = [int(i) for i in ids]
conditions.append("id=ANY(%s)")
params.append(ids)

if limit is None:
limit = len(ids)
if offset is None:
offset = 0
if limit is None:
limit = 100 # or another default value
if offset is None:
offset = 0

ids = [int(i) for i in ids]
query = f"SELECT * FROM {table_name} WHERE id=ANY(%s) LIMIT %s OFFSET %s"
params = (ids, limit, offset)
query = f"SELECT * FROM {table_name}"
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " LIMIT %s OFFSET %s"
params.extend([limit, offset])

self.client.cur.execute(query, tuple(params))

self.client.cur.execute(query, params)
rows = self.client.cur.fetchall()

# Convert rows to list of dictionaries
columns = [desc[0] for desc in self.client.cur.description]
metadata_columns = [col for col in columns if col not in ["id", "document", "embedding"]]
metadata_columns = [
col for col in columns if col not in ["id", "document", "embedding"]
]

result = []
for row in rows:
Expand All @@ -85,7 +143,6 @@ def get(
"metadatas": [row["metadata"] for row in result],
}


def peek(self, limit=10):
return self.get(limit=limit)

Expand All @@ -98,7 +155,9 @@ def query(
where_document=None,
include=["metadatas", "documents", "distances"],
):
return self.client.query(self.category, query_texts, n_results)
return self.client.query(
self.category, query_texts, n_results, where, where_document
)

def update(self, ids, documents=None, metadatas=None, embeddings=None):
self.client.ensure_table_exists(self.category)
Expand All @@ -107,8 +166,6 @@ def update(self, ids, documents=None, metadatas=None, embeddings=None):
if documents is None:
documents = [None] * len(ids)
for id_, document, metadata in zip(ids, documents, metadatas):
print("updating")
print(id_, document, metadata)
self.client.update(self.category, id_, document, metadata)
else:
for id_, document, metadata, emb in zip(
Expand All @@ -121,37 +178,43 @@ def upsert(self, ids, documents=None, metadatas=None, embeddings=None):

def delete(self, ids=None, where=None, where_document=None):
table_name = self.client._table_name(self.category)
# check if table exists
self.client.ensure_table_exists(self.category)

# Base of the query
query = f"DELETE FROM {table_name}"
conditions = []
params = []

conditions = []
if where_document is not None:
if where_document.get("$contains", None) is not None:
where_document = where_document["$contains"]
conditions.append("document LIKE %s")
params.append(f"%{where_document}%")

if ids is not None:
if not all(isinstance(i, (int, str)) and str(i).isdigit() for i in ids):
if ids:
if not all(isinstance(i, str) or isinstance(i, int) for i in ids):
raise Exception(
"ids must be a list of integers or strings representing integers"
)
ids = [int(i) for i in ids]
conditions.append("id=ANY(%s::int[])")
conditions.append("id=ANY(%s::int[])") # Added explicit type casting
params.append(ids)

if where_document is not None:
if "$contains" in where_document:
conditions.append("document LIKE %s")
params.append(f"%{where_document['$contains']}%")
# You can add more operators for 'where_document' here if needed

if where is not None:
if where:
for key, value in where.items():
conditions.append(f"{key}=%s")
params.append(value)
if key == "$and":
new_conditions, new_params = handle_and_condition(value)
conditions.extend(new_conditions)
params.extend(new_params)
elif key == "$or":
or_condition, new_params = handle_or_condition(value)
conditions.append(or_condition)
params.extend(new_params)
elif key == "$contains":
conditions.append(f"document LIKE %s")
params.append(f"%{value}%")
else:
conditions.append(f"{key}=%s")
params.append(str(value))

if conditions:
query += " WHERE " + " AND ".join(conditions)
query = f"DELETE FROM {table_name} WHERE " + " AND ".join(conditions)
else:
raise Exception("No valid conditions provided for deletion.")

Expand Down Expand Up @@ -286,9 +349,40 @@ def add(self, category, documents, metadatas, ids):
cur.execute(query, tuple(values))
self.connection.commit()

def query(self, category, query_texts, n_results=5):
def query(
self, category, query_texts, n_results=5, where=None, where_document=None
):
self.ensure_table_exists(category)
table_name = self._table_name(category)
conditions = []
params = []

# Check if where_document is given
if where_document:
if where_document.get("$contains", None) is not None:
where_document = where_document["$contains"]
conditions.append("document LIKE %s")
params.append(f"%{where_document}%")

if where:
for key, value in where.items():
if key == "$and":
new_conditions, new_params = handle_and_condition(value)
conditions.extend(new_conditions)
params.extend(new_params)
elif key == "$or":
or_condition, new_params = handle_or_condition(value)
conditions.append(or_condition)
params.extend(new_params)
elif key == "$contains":
conditions.append(f"document LIKE %s")
params.append(f"%{value}%")
else:
conditions.append(f"{key}=%s")
params.append(str(value))

where_clause = " WHERE " + " AND ".join(conditions) if conditions else ""

results = {
"ids": [],
"documents": [],
Expand All @@ -299,14 +393,17 @@ def query(self, category, query_texts, n_results=5):
with self.connection.cursor() as cur:
for emb in query_texts:
query_emb = self.create_embedding(emb)
cur.execute(
f"""
params_with_emb = [query_emb] + params + [query_emb, n_results]
string = f"""
SELECT id, document, embedding, embedding <-> %s AS distance, *
FROM {table_name}
{where_clause}
ORDER BY embedding <-> %s
LIMIT %s
""",
(query_emb, query_emb, n_results),
"""
cur.execute(
string,
tuple(params_with_emb),
)
rows = cur.fetchall()
columns = [desc[0] for desc in cur.description]
Expand Down Expand Up @@ -363,4 +460,4 @@ def update(self, category, id_, document=None, metadata=None, embedding=None):

def close(self):
self.cur.close()
self.connection.close()
self.connection.close()
1 change: 0 additions & 1 deletion agentmemory/tests/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def test_create_event():
event = get_events()[0]
assert event["document"] == "test event"
assert event["metadata"]["test"] == "test"
print(event["metadata"])
assert int(event["metadata"]["epoch"]) == 1
wipe_category("events")
wipe_category("epoch")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name='agentmemory',
version='0.4.4',
version='0.4.5',
description='Easy-to-use memory for agents, document search, knowledge graphing and more.',
long_description=long_description, # added this line
long_description_content_type="text/markdown", # and this line
Expand Down

0 comments on commit 5dd364c

Please sign in to comment.