diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fd86de..40b3465 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # CHANGELOG +### **0.8.0** (May 14 2024) + +> Support for MultiDiGraphs. + +#### Features + +- Support for MultiDiGraphs (#42, thanks @jackboyla!) + ### **0.7.0** (May 4 2024) > Support for `ORDER BY` and `DISTINCT` diff --git a/grandcypher/__init__.py b/grandcypher/__init__.py index d127d3a..3818684 100644 --- a/grandcypher/__init__.py +++ b/grandcypher/__init__.py @@ -154,7 +154,7 @@ start="start", ) -__version__ = "0.7.0" +__version__ = "0.8.0" _ALPHABET = string.ascii_lowercase + string.digits @@ -200,12 +200,12 @@ def _is_edge_attr_match( motif_edge_id: Tuple[str, str, Union[int, str]], host_edge_id: Tuple[str, str, Union[int, str]], motif: Union[nx.Graph, nx.MultiDiGraph], - host: Union[nx.Graph, nx.MultiDiGraph] + host: Union[nx.Graph, nx.MultiDiGraph], ) -> bool: """ - Check if an edge in the host graph matches the attributes in the motif, - including the special '__labels__' set attribute. - This function formats edges into + Check if an edge in the host graph matches the attributes in the motif, + including the special '__labels__' set attribute. + This function formats edges into nx.MultiDiGraph format i.e {0: first_relation, 1: ...}. Arguments: @@ -235,7 +235,7 @@ def _is_edge_attr_match( continue if host_edges.get(attr) != val: return False - + return True @@ -247,6 +247,7 @@ def _get_edge_attributes(graph: Union[nx.Graph, nx.MultiDiGraph], u, v) -> Dict: return graph[u][v] return {0: graph[u][v]} # Mock single edge for DiGraph + def _aggregate_edge_labels(edges: Dict) -> Dict: """ Aggregate '__labels__' attributes from edges into a single set. @@ -259,7 +260,10 @@ def _aggregate_edge_labels(edges: Dict) -> Dict: aggregated[edge_id] = attrs return aggregated -def _get_entity_from_host(host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None): + +def _get_entity_from_host( + host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None +): if entity_name in host.nodes(): # We are looking for a node mapping in the target graph: if entity_attribute: @@ -310,7 +314,9 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool: def cond_(should_be, entity_id, operator, value) -> CONDITION: - def inner(match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list) -> bool: + def inner( + match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list + ) -> bool: host_entity_id = entity_id.split(".") if host_entity_id[0] in match: host_entity_id[0] = match[host_entity_id[0]] @@ -380,7 +386,11 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: for data_path in data_paths: entity_name, _ = _data_path_to_entity_name_attribute(data_path) - if entity_name not in motif_nodes and entity_name not in self._return_edges and entity_name not in self._paths: + if ( + entity_name not in motif_nodes + and entity_name not in self._return_edges + and entity_name not in self._paths + ): raise NotImplementedError(f"Unknown entity name: {data_path}") result = {} @@ -412,7 +422,9 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: for x, node in enumerate(nodes): # Edge if x > 0: - path.append(self._target_graph.get_edge_data(nodes[x - 1], node)) + path.append( + self._target_graph.get_edge_data(nodes[x - 1], node) + ) # Node path.append(node) @@ -420,7 +432,7 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: ret.append(path) else: - mapping_u, mapping_v = self._return_edges[data_path.split('.')[0]] + mapping_u, mapping_v = self._return_edges[data_path.split(".")[0]] # We are looking for an edge mapping in the target graph: is_hop = self._motif.edges[(mapping_u, mapping_v, 0)]["__is_hop__"] ret = ( @@ -435,18 +447,28 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: # Get all edge labels from the motif -- this is used to filter the relations for multigraphs motif_edge_labels = set() for edge in self._motif.get_edge_data(mapping_u, mapping_v).values(): - if edge.get('__labels__', None): - motif_edge_labels.update(edge['__labels__']) + if edge.get("__labels__", None): + motif_edge_labels.update(edge["__labels__"]) if entity_attribute: # Get the correct entity from the target host graph, # and then return the attribute: - if isinstance(self._motif, nx.MultiDiGraph) and len(motif_edge_labels) > 0: + if ( + isinstance(self._motif, nx.MultiDiGraph) + and len(motif_edge_labels) > 0 + ): # filter the retrieved edge(s) based on the motif edge labels filtered_ret = [] for r in ret: - if any([i.get('__labels__', None).issubset(motif_edge_labels) for i in r.values()]): + if any( + [ + i.get("__labels__", None).issubset( + motif_edge_labels + ) + for i in r.values() + ] + ): filtered_ret.append(r) ret = filtered_ret @@ -463,10 +485,9 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]: result[data_path] = list(ret)[offset_limit] - return result - - def return_clause(self, clause): + + def return_clause(self, clause): # collect all entity identifiers to be returned for item in clause: if item: @@ -474,18 +495,17 @@ def return_clause(self, clause): item = str(item.value) self._return_requests.append(item) - def order_clause(self, order_clause): self._order_by = [] for item in order_clause[0].children: field = str(item.children[0]) # assuming the field name is the first child # Default to 'ASC' if not specified - if len(item.children) > 1 and str(item.children[1].data).lower() != 'desc': - direction = 'ASC' + if len(item.children) > 1 and str(item.children[1].data).lower() != "desc": + direction = "ASC" else: - direction = 'DESC' - - self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...] + direction = "DESC" + + self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...] self._order_by_attributes.add(field) def distinct_return(self, distinct): @@ -502,8 +522,8 @@ def skip_clause(self, skip): def returns(self, ignore_limit=False): results = self._lookup( - self._return_requests + list(self._order_by_attributes), - offset_limit=slice(0, None) + self._return_requests + list(self._order_by_attributes), + offset_limit=slice(0, None), ) if self._order_by: results = self._apply_order_by(results) @@ -511,53 +531,74 @@ def returns(self, ignore_limit=False): results = self._apply_distinct(results) results = self._apply_pagination(results, ignore_limit) - # Exclude order-by-only attributes from the final results results = { - key: values for key, values in results.items() if key in self._return_requests + key: values + for key, values in results.items() + if key in self._return_requests } return results - + def _apply_order_by(self, results): if self._order_by: - sort_lists = [(results[field], direction) for field, direction in self._order_by if field in results] + sort_lists = [ + (results[field], direction) + for field, direction in self._order_by + if field in results + ] if sort_lists: # Generate a list of indices sorted by the specified fields - indices = range(len(next(iter(results.values())))) # Safe because all lists are assumed to be of the same length - for sort_list, direction in reversed(sort_lists): # reverse to ensure the first sort key is primary - indices = sorted(indices, key=lambda i: sort_list[i], reverse=(direction == 'DESC')) + indices = range( + len(next(iter(results.values()))) + ) # Safe because all lists are assumed to be of the same length + for sort_list, direction in reversed( + sort_lists + ): # reverse to ensure the first sort key is primary + indices = sorted( + indices, + key=lambda i: sort_list[i], + reverse=(direction == "DESC"), + ) # Reorder all lists in results using sorted indices for key in results: results[key] = [results[key][i] for i in indices] - + return results - + def _apply_distinct(self, results): if self._order_by: - assert self._order_by_attributes.issubset(self._return_requests), "In a WITH/RETURN with DISTINCT or an aggregation, it is not possible to access variables declared before the WITH/RETURN" + assert self._order_by_attributes.issubset( + self._return_requests + ), "In a WITH/RETURN with DISTINCT or an aggregation, it is not possible to access variables declared before the WITH/RETURN" # ordered dict to maintain the first occurrence of each unique tuple based on return requests unique_rows = OrderedDict() - + # Iterate over each 'row' by index - for i in range(len(next(iter(results.values())))): # assume all columns are of the same length + for i in range( + len(next(iter(results.values()))) + ): # assume all columns are of the same length # create a tuple key of all the values from return requests for this row - row_key = tuple(results[key][i] for key in self._return_requests if key in results) - + row_key = tuple( + results[key][i] for key in self._return_requests if key in results + ) + if row_key not in unique_rows: - unique_rows[row_key] = i # store the index of the first occurrence of this unique row - + unique_rows[row_key] = ( + i # store the index of the first occurrence of this unique row + ) + # construct the results based on unique indices collected distinct_results = {key: [] for key in self._return_requests} for row_key, index in unique_rows.items(): for _, key in enumerate(self._return_requests): distinct_results[key].append(results[key][index]) - + return distinct_results - + def _apply_pagination(self, results, ignore_limit): # apply LIMIT and SKIP (if set) after ordering if self._limit is not None and not ignore_limit: @@ -570,7 +611,7 @@ def _apply_pagination(self, results, ignore_limit): for key in results.keys(): start_index = self._skip results[key] = results[key][start_index:] - + return results def _get_true_matches(self): @@ -685,7 +726,7 @@ def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict] if motif.out_degree(n) == 0 and motif.in_degree(n) == 0: new_motif.add_node(n, **motif.nodes[n]) motifs: List[Tuple[nx.DiGraph, dict]] = [(new_motif, {})] - for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)]) + for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)]) new_motifs = [] min_hop = motif.edges[u, v, k]["__min_hop__"] max_hop = motif.edges[u, v, k]["__max_hop__"] diff --git a/setup.py b/setup.py index 84d153b..857140b 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="grand-cypher", - version="0.7.0", + version="0.8.0", author="Jordan Matelsky", author_email="opensource@matelsky.com", description="Query Grand graphs using Cypher",