Skip to content

Commit

Permalink
Community workflow (#1495)
Browse files Browse the repository at this point in the history
* Create separate communities workflow

* Add test for new workflow

* Rename workflows

* Collapse subflows into parents

* Rename flows, reuse variables

* Semver

* Fix integration test

* Fix smoke tests

* Fix megapipeline format

* Rename missed files

---------

Co-authored-by: Alonso Guevara <[email protected]>
  • Loading branch information
natoverse and AlonsoGuevara authored Dec 11, 2024
1 parent de12521 commit 1d68af3
Show file tree
Hide file tree
Showing 36 changed files with 783 additions and 735 deletions.
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241210232215730615.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Create separate community workflow, collapse subflows."
}
11 changes: 9 additions & 2 deletions graphrag/index/create_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
PipelineWorkflowReference,
)
from graphrag.index.workflows.default_workflows import (
create_base_entity_graph,
compute_communities,
create_base_text_units,
create_final_communities,
create_final_community_reports,
Expand All @@ -62,6 +62,7 @@
create_final_nodes,
create_final_relationships,
create_final_text_units,
extract_graph,
generate_text_embeddings,
)

Expand Down Expand Up @@ -216,7 +217,7 @@ def _get_embedding_settings(
def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]:
return [
PipelineWorkflowReference(
name=create_base_entity_graph,
name=extract_graph,
config={
"snapshot_graphml": settings.snapshots.graphml,
"snapshot_transient": settings.snapshots.transient,
Expand All @@ -235,9 +236,15 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference
settings.root_dir,
),
},
},
),
PipelineWorkflowReference(
name=compute_communities,
config={
"cluster_graph": {
"strategy": settings.cluster_graph.resolved_strategy()
},
"snapshot_transient": settings.snapshots.transient,
},
),
PipelineWorkflowReference(
Expand Down
43 changes: 43 additions & 0 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""All the steps to create the base entity graph."""

from typing import Any

import pandas as pd

from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage


async def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)

communities = cluster_graph(
graph,
strategy=clustering_strategy,
)

base_communities = pd.DataFrame(
communities, columns=pd.Index(["level", "community", "parent", "title"])
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)

if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return base_communities
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
Expand All @@ -25,13 +24,11 @@
from graphrag.storage.pipeline_storage import PipelineStorage


async def create_base_entity_graph(
async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
runtime_storage: PipelineStorage,
clustering_strategy: dict[str, Any],
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
Expand All @@ -40,7 +37,7 @@ async def create_base_entity_graph(
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> None:
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
Expand Down Expand Up @@ -73,17 +70,6 @@ async def create_base_entity_graph(

base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)

communities = cluster_graph(
graph,
strategy=clustering_strategy,
)

base_communities = _prep_communities(communities)

await runtime_storage.set("base_entity_nodes", base_entity_nodes)
await runtime_storage.set("base_relationship_edges", base_relationship_edges)
await runtime_storage.set("base_communities", base_communities)

if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
Expand All @@ -105,12 +91,8 @@ async def create_base_entity_graph(
storage=storage,
formats=["parquet"],
)
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return (base_entity_nodes, base_relationship_edges)


def _merge_entities(entity_dfs) -> pd.DataFrame:
Expand Down Expand Up @@ -158,13 +140,6 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
return edges


def _prep_communities(communities) -> pd.DataFrame:
# Convert the input into a DataFrame and explode the title column
return pd.DataFrame(
communities, columns=pd.Index(["level", "community", "parent", "title"])
).explode("title")


def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/update/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def _run_entity_summarization(
The updated entities dataframe with summarized descriptions.
"""
summarize_config = _find_workflow_config(
config, "create_base_entity_graph", "summarize_descriptions"
config, "extract_graph", "summarize_descriptions"
)
strategy = summarize_config.get("strategy", {})

Expand Down
21 changes: 12 additions & 9 deletions graphrag/index/workflows/default_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

"""A package containing default workflows definitions."""

# load and register all subflows
from graphrag.index.workflows.v1.subflows import * # noqa

from graphrag.index.workflows.typing import WorkflowDefinitions
from graphrag.index.workflows.v1.create_base_entity_graph import (
build_steps as build_create_base_entity_graph_steps,
from graphrag.index.workflows.v1.compute_communities import (
build_steps as build_compute_communities_steps,
)
from graphrag.index.workflows.v1.create_base_entity_graph import (
workflow_name as create_base_entity_graph,
from graphrag.index.workflows.v1.compute_communities import (
workflow_name as compute_communities,
)
from graphrag.index.workflows.v1.create_base_text_units import (
build_steps as build_create_base_text_units_steps,
Expand Down Expand Up @@ -67,16 +64,22 @@
from graphrag.index.workflows.v1.create_final_text_units import (
workflow_name as create_final_text_units,
)
from graphrag.index.workflows.v1.extract_graph import (
build_steps as build_extract_graph_steps,
)
from graphrag.index.workflows.v1.extract_graph import (
workflow_name as extract_graph,
)
from graphrag.index.workflows.v1.generate_text_embeddings import (
build_steps as build_generate_text_embeddings_steps,
)

from graphrag.index.workflows.v1.generate_text_embeddings import (
workflow_name as generate_text_embeddings,
)

default_workflows: WorkflowDefinitions = {
create_base_entity_graph: build_create_base_entity_graph_steps,
extract_graph: build_extract_graph_steps,
compute_communities: build_compute_communities_steps,
create_base_text_units: build_create_base_text_units_steps,
create_final_text_units: build_create_final_text_units,
create_final_community_reports: build_create_final_community_reports_steps,
Expand Down
74 changes: 74 additions & 0 deletions graphrag/index/workflows/v1/compute_communities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License

"""A module containing build_steps method definition."""

from typing import Any, cast

import pandas as pd
from datashaper import (
Table,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result

from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep
from graphrag.index.flows.compute_communities import compute_communities
from graphrag.storage.pipeline_storage import PipelineStorage

workflow_name = "compute_communities"


def build_steps(
config: PipelineWorkflowConfig,
) -> list[PipelineWorkflowStep]:
"""
Create the base communities from the graph edges.
## Dependencies
* `workflow:extract_graph`
"""
clustering_config = config.get(
"cluster_graph",
{"strategy": {"type": "leiden"}},
)
clustering_strategy = clustering_config.get("strategy")

snapshot_transient = config.get("snapshot_transient", False) or False

return [
{
"verb": workflow_name,
"args": {
"clustering_strategy": clustering_strategy,
"snapshot_transient_enabled": snapshot_transient,
},
"input": ({"source": "workflow:extract_graph"}),
},
]


@verb(
name=workflow_name,
treats_input_tables_as_immutable=True,
)
async def workflow(
storage: PipelineStorage,
runtime_storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to create the base entity graph."""
base_relationship_edges = await runtime_storage.get("base_relationship_edges")

base_communities = await compute_communities(
base_relationship_edges,
storage,
clustering_strategy=clustering_strategy,
snapshot_transient_enabled=snapshot_transient_enabled,
)

await runtime_storage.set("base_communities", base_communities)

return create_verb_result(cast("Table", pd.DataFrame()))
59 changes: 0 additions & 59 deletions graphrag/index/workflows/v1/create_base_entity_graph.py

This file was deleted.

Loading

0 comments on commit 1d68af3

Please sign in to comment.