Skip to content

Commit

Permalink
Prepare for v0.8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
j6k4m8 committed May 14, 2024
1 parent 78173a3 commit 4bedf30
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 47 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# CHANGELOG

### **0.8.0** (May 14 2024)

> Support for MultiDiGraphs.
#### Features

- Support for MultiDiGraphs (#42, thanks @jackboyla!)

### **0.7.0** (May 4 2024)

> Support for `ORDER BY` and `DISTINCT`
Expand Down
133 changes: 87 additions & 46 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
start="start",
)

__version__ = "0.7.0"
__version__ = "0.8.0"


_ALPHABET = string.ascii_lowercase + string.digits
Expand Down Expand Up @@ -200,12 +200,12 @@ def _is_edge_attr_match(
motif_edge_id: Tuple[str, str, Union[int, str]],
host_edge_id: Tuple[str, str, Union[int, str]],
motif: Union[nx.Graph, nx.MultiDiGraph],
host: Union[nx.Graph, nx.MultiDiGraph]
host: Union[nx.Graph, nx.MultiDiGraph],
) -> bool:
"""
Check if an edge in the host graph matches the attributes in the motif,
including the special '__labels__' set attribute.
This function formats edges into
Check if an edge in the host graph matches the attributes in the motif,
including the special '__labels__' set attribute.
This function formats edges into
nx.MultiDiGraph format i.e {0: first_relation, 1: ...}.
Arguments:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _is_edge_attr_match(
continue
if host_edges.get(attr) != val:
return False

return True


Expand All @@ -247,6 +247,7 @@ def _get_edge_attributes(graph: Union[nx.Graph, nx.MultiDiGraph], u, v) -> Dict:
return graph[u][v]
return {0: graph[u][v]} # Mock single edge for DiGraph


def _aggregate_edge_labels(edges: Dict) -> Dict:
"""
Aggregate '__labels__' attributes from edges into a single set.
Expand All @@ -259,7 +260,10 @@ def _aggregate_edge_labels(edges: Dict) -> Dict:
aggregated[edge_id] = attrs
return aggregated

def _get_entity_from_host(host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None):

def _get_entity_from_host(
host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None
):
if entity_name in host.nodes():
# We are looking for a node mapping in the target graph:
if entity_attribute:
Expand Down Expand Up @@ -310,7 +314,9 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:


def cond_(should_be, entity_id, operator, value) -> CONDITION:
def inner(match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list) -> bool:
def inner(
match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list
) -> bool:
host_entity_id = entity_id.split(".")
if host_entity_id[0] in match:
host_entity_id[0] = match[host_entity_id[0]]
Expand Down Expand Up @@ -380,7 +386,11 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:

for data_path in data_paths:
entity_name, _ = _data_path_to_entity_name_attribute(data_path)
if entity_name not in motif_nodes and entity_name not in self._return_edges and entity_name not in self._paths:
if (
entity_name not in motif_nodes
and entity_name not in self._return_edges
and entity_name not in self._paths
):
raise NotImplementedError(f"Unknown entity name: {data_path}")

result = {}
Expand Down Expand Up @@ -412,15 +422,17 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
for x, node in enumerate(nodes):
# Edge
if x > 0:
path.append(self._target_graph.get_edge_data(nodes[x - 1], node))
path.append(
self._target_graph.get_edge_data(nodes[x - 1], node)
)

# Node
path.append(node)

ret.append(path)

else:
mapping_u, mapping_v = self._return_edges[data_path.split('.')[0]]
mapping_u, mapping_v = self._return_edges[data_path.split(".")[0]]
# We are looking for an edge mapping in the target graph:
is_hop = self._motif.edges[(mapping_u, mapping_v, 0)]["__is_hop__"]
ret = (
Expand All @@ -435,18 +447,28 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
# Get all edge labels from the motif -- this is used to filter the relations for multigraphs
motif_edge_labels = set()
for edge in self._motif.get_edge_data(mapping_u, mapping_v).values():
if edge.get('__labels__', None):
motif_edge_labels.update(edge['__labels__'])
if edge.get("__labels__", None):
motif_edge_labels.update(edge["__labels__"])

if entity_attribute:
# Get the correct entity from the target host graph,
# and then return the attribute:
if isinstance(self._motif, nx.MultiDiGraph) and len(motif_edge_labels) > 0:
if (
isinstance(self._motif, nx.MultiDiGraph)
and len(motif_edge_labels) > 0
):
# filter the retrieved edge(s) based on the motif edge labels
filtered_ret = []
for r in ret:

if any([i.get('__labels__', None).issubset(motif_edge_labels) for i in r.values()]):
if any(
[
i.get("__labels__", None).issubset(
motif_edge_labels
)
for i in r.values()
]
):
filtered_ret.append(r)

ret = filtered_ret
Expand All @@ -463,29 +485,27 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:

result[data_path] = list(ret)[offset_limit]


return result
def return_clause(self, clause):

def return_clause(self, clause):
# collect all entity identifiers to be returned
for item in clause:
if item:
if not isinstance(item, str):
item = str(item.value)
self._return_requests.append(item)


def order_clause(self, order_clause):
self._order_by = []
for item in order_clause[0].children:
field = str(item.children[0]) # assuming the field name is the first child
# Default to 'ASC' if not specified
if len(item.children) > 1 and str(item.children[1].data).lower() != 'desc':
direction = 'ASC'
if len(item.children) > 1 and str(item.children[1].data).lower() != "desc":
direction = "ASC"
else:
direction = 'DESC'
self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...]
direction = "DESC"

self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...]
self._order_by_attributes.add(field)

def distinct_return(self, distinct):
Expand All @@ -502,62 +522,83 @@ def skip_clause(self, skip):
def returns(self, ignore_limit=False):

results = self._lookup(
self._return_requests + list(self._order_by_attributes),
offset_limit=slice(0, None)
self._return_requests + list(self._order_by_attributes),
offset_limit=slice(0, None),
)
if self._order_by:
results = self._apply_order_by(results)
if self._distinct:
results = self._apply_distinct(results)
results = self._apply_pagination(results, ignore_limit)


# Exclude order-by-only attributes from the final results
results = {
key: values for key, values in results.items() if key in self._return_requests
key: values
for key, values in results.items()
if key in self._return_requests
}

return results

def _apply_order_by(self, results):
if self._order_by:
sort_lists = [(results[field], direction) for field, direction in self._order_by if field in results]
sort_lists = [
(results[field], direction)
for field, direction in self._order_by
if field in results
]

if sort_lists:
# Generate a list of indices sorted by the specified fields
indices = range(len(next(iter(results.values())))) # Safe because all lists are assumed to be of the same length
for sort_list, direction in reversed(sort_lists): # reverse to ensure the first sort key is primary
indices = sorted(indices, key=lambda i: sort_list[i], reverse=(direction == 'DESC'))
indices = range(
len(next(iter(results.values())))
) # Safe because all lists are assumed to be of the same length
for sort_list, direction in reversed(
sort_lists
): # reverse to ensure the first sort key is primary
indices = sorted(
indices,
key=lambda i: sort_list[i],
reverse=(direction == "DESC"),
)

# Reorder all lists in results using sorted indices
for key in results:
results[key] = [results[key][i] for i in indices]

return results

def _apply_distinct(self, results):
if self._order_by:
assert self._order_by_attributes.issubset(self._return_requests), "In a WITH/RETURN with DISTINCT or an aggregation, it is not possible to access variables declared before the WITH/RETURN"
assert self._order_by_attributes.issubset(
self._return_requests
), "In a WITH/RETURN with DISTINCT or an aggregation, it is not possible to access variables declared before the WITH/RETURN"

# ordered dict to maintain the first occurrence of each unique tuple based on return requests
unique_rows = OrderedDict()

# Iterate over each 'row' by index
for i in range(len(next(iter(results.values())))): # assume all columns are of the same length
for i in range(
len(next(iter(results.values())))
): # assume all columns are of the same length
# create a tuple key of all the values from return requests for this row
row_key = tuple(results[key][i] for key in self._return_requests if key in results)

row_key = tuple(
results[key][i] for key in self._return_requests if key in results
)

if row_key not in unique_rows:
unique_rows[row_key] = i # store the index of the first occurrence of this unique row

unique_rows[row_key] = (
i # store the index of the first occurrence of this unique row
)

# construct the results based on unique indices collected
distinct_results = {key: [] for key in self._return_requests}
for row_key, index in unique_rows.items():
for _, key in enumerate(self._return_requests):
distinct_results[key].append(results[key][index])

return distinct_results

def _apply_pagination(self, results, ignore_limit):
# apply LIMIT and SKIP (if set) after ordering
if self._limit is not None and not ignore_limit:
Expand All @@ -570,7 +611,7 @@ def _apply_pagination(self, results, ignore_limit):
for key in results.keys():
start_index = self._skip
results[key] = results[key][start_index:]

return results

def _get_true_matches(self):
Expand Down Expand Up @@ -685,7 +726,7 @@ def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]
if motif.out_degree(n) == 0 and motif.in_degree(n) == 0:
new_motif.add_node(n, **motif.nodes[n])
motifs: List[Tuple[nx.DiGraph, dict]] = [(new_motif, {})]
for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)])
for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)])
new_motifs = []
min_hop = motif.edges[u, v, k]["__min_hop__"]
max_hop = motif.edges[u, v, k]["__max_hop__"]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="grand-cypher",
version="0.7.0",
version="0.8.0",
author="Jordan Matelsky",
author_email="[email protected]",
description="Query Grand graphs using Cypher",
Expand Down

0 comments on commit 4bedf30

Please sign in to comment.