diff --git a/.semversioner/next-release/patch-20241210232215730615.json b/.semversioner/next-release/patch-20241210232215730615.json new file mode 100644 index 0000000000..81dbe42390 --- /dev/null +++ b/.semversioner/next-release/patch-20241210232215730615.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Create separate community workflow, collapse subflows." +} diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index fd10f685db..2c8fe4fc05 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -52,7 +52,7 @@ PipelineWorkflowReference, ) from graphrag.index.workflows.default_workflows import ( - create_base_entity_graph, + compute_communities, create_base_text_units, create_final_communities, create_final_community_reports, @@ -62,6 +62,7 @@ create_final_nodes, create_final_relationships, create_final_text_units, + extract_graph, generate_text_embeddings, ) @@ -216,7 +217,7 @@ def _get_embedding_settings( def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference]: return [ PipelineWorkflowReference( - name=create_base_entity_graph, + name=extract_graph, config={ "snapshot_graphml": settings.snapshots.graphml, "snapshot_transient": settings.snapshots.transient, @@ -235,9 +236,15 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference settings.root_dir, ), }, + }, + ), + PipelineWorkflowReference( + name=compute_communities, + config={ "cluster_graph": { "strategy": settings.cluster_graph.resolved_strategy() }, + "snapshot_transient": settings.snapshots.transient, }, ), PipelineWorkflowReference( diff --git a/graphrag/index/flows/compute_communities.py b/graphrag/index/flows/compute_communities.py new file mode 100644 index 0000000000..09ec084ac6 --- /dev/null +++ b/graphrag/index/flows/compute_communities.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""All the steps to create the base entity graph.""" + +from typing import Any + +import pandas as pd + +from graphrag.index.operations.cluster_graph import cluster_graph +from graphrag.index.operations.create_graph import create_graph +from graphrag.index.operations.snapshot import snapshot +from graphrag.storage.pipeline_storage import PipelineStorage + + +async def compute_communities( + base_relationship_edges: pd.DataFrame, + storage: PipelineStorage, + clustering_strategy: dict[str, Any], + snapshot_transient_enabled: bool = False, +) -> pd.DataFrame: + """All the steps to create the base entity graph.""" + graph = create_graph(base_relationship_edges) + + communities = cluster_graph( + graph, + strategy=clustering_strategy, + ) + + base_communities = pd.DataFrame( + communities, columns=pd.Index(["level", "community", "parent", "title"]) + ).explode("title") + base_communities["community"] = base_communities["community"].astype(int) + + if snapshot_transient_enabled: + await snapshot( + base_communities, + name="base_communities", + storage=storage, + formats=["parquet"], + ) + + return base_communities diff --git a/graphrag/index/flows/create_base_entity_graph.py b/graphrag/index/flows/extract_graph.py similarity index 82% rename from graphrag/index/flows/create_base_entity_graph.py rename to graphrag/index/flows/extract_graph.py index a1abf070e8..f274d55f64 100644 --- a/graphrag/index/flows/create_base_entity_graph.py +++ b/graphrag/index/flows/extract_graph.py @@ -14,7 +14,6 @@ ) from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.operations.cluster_graph import cluster_graph from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.extract_entities import extract_entities from graphrag.index.operations.snapshot import snapshot @@ -25,13 +24,11 @@ from graphrag.storage.pipeline_storage import PipelineStorage -async def create_base_entity_graph( +async def extract_graph( text_units: pd.DataFrame, callbacks: VerbCallbacks, cache: PipelineCache, storage: PipelineStorage, - runtime_storage: PipelineStorage, - clustering_strategy: dict[str, Any], extraction_strategy: dict[str, Any] | None = None, extraction_num_threads: int = 4, extraction_async_mode: AsyncType = AsyncType.AsyncIO, @@ -40,7 +37,7 @@ async def create_base_entity_graph( summarization_num_threads: int = 4, snapshot_graphml_enabled: bool = False, snapshot_transient_enabled: bool = False, -) -> None: +) -> tuple[pd.DataFrame, pd.DataFrame]: """All the steps to create the base entity graph.""" # this returns a graph for each text unit, to be merged later entity_dfs, relationship_dfs = await extract_entities( @@ -73,17 +70,6 @@ async def create_base_entity_graph( base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph) - communities = cluster_graph( - graph, - strategy=clustering_strategy, - ) - - base_communities = _prep_communities(communities) - - await runtime_storage.set("base_entity_nodes", base_entity_nodes) - await runtime_storage.set("base_relationship_edges", base_relationship_edges) - await runtime_storage.set("base_communities", base_communities) - if snapshot_graphml_enabled: # todo: extract graphs at each level, and add in meta like descriptions await snapshot_graphml( @@ -105,12 +91,8 @@ async def create_base_entity_graph( storage=storage, formats=["parquet"], ) - await snapshot( - base_communities, - name="base_communities", - storage=storage, - formats=["parquet"], - ) + + return (base_entity_nodes, base_relationship_edges) def _merge_entities(entity_dfs) -> pd.DataFrame: @@ -158,13 +140,6 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame: return edges -def _prep_communities(communities) -> pd.DataFrame: - # Convert the input into a DataFrame and explode the title column - return pd.DataFrame( - communities, columns=pd.Index(["level", "community", "parent", "title"]) - ).explode("title") - - def _compute_degree(graph: nx.Graph) -> pd.DataFrame: return pd.DataFrame([ {"name": node, "degree": int(degree)} diff --git a/graphrag/index/update/entities.py b/graphrag/index/update/entities.py index 538e39e9c6..12792076f3 100644 --- a/graphrag/index/update/entities.py +++ b/graphrag/index/update/entities.py @@ -112,7 +112,7 @@ async def _run_entity_summarization( The updated entities dataframe with summarized descriptions. """ summarize_config = _find_workflow_config( - config, "create_base_entity_graph", "summarize_descriptions" + config, "extract_graph", "summarize_descriptions" ) strategy = summarize_config.get("strategy", {}) diff --git a/graphrag/index/workflows/default_workflows.py b/graphrag/index/workflows/default_workflows.py index 536423c4e3..009f9fa8ce 100644 --- a/graphrag/index/workflows/default_workflows.py +++ b/graphrag/index/workflows/default_workflows.py @@ -3,15 +3,12 @@ """A package containing default workflows definitions.""" -# load and register all subflows -from graphrag.index.workflows.v1.subflows import * # noqa - from graphrag.index.workflows.typing import WorkflowDefinitions -from graphrag.index.workflows.v1.create_base_entity_graph import ( - build_steps as build_create_base_entity_graph_steps, +from graphrag.index.workflows.v1.compute_communities import ( + build_steps as build_compute_communities_steps, ) -from graphrag.index.workflows.v1.create_base_entity_graph import ( - workflow_name as create_base_entity_graph, +from graphrag.index.workflows.v1.compute_communities import ( + workflow_name as compute_communities, ) from graphrag.index.workflows.v1.create_base_text_units import ( build_steps as build_create_base_text_units_steps, @@ -67,16 +64,22 @@ from graphrag.index.workflows.v1.create_final_text_units import ( workflow_name as create_final_text_units, ) +from graphrag.index.workflows.v1.extract_graph import ( + build_steps as build_extract_graph_steps, +) +from graphrag.index.workflows.v1.extract_graph import ( + workflow_name as extract_graph, +) from graphrag.index.workflows.v1.generate_text_embeddings import ( build_steps as build_generate_text_embeddings_steps, ) - from graphrag.index.workflows.v1.generate_text_embeddings import ( workflow_name as generate_text_embeddings, ) default_workflows: WorkflowDefinitions = { - create_base_entity_graph: build_create_base_entity_graph_steps, + extract_graph: build_extract_graph_steps, + compute_communities: build_compute_communities_steps, create_base_text_units: build_create_base_text_units_steps, create_final_text_units: build_create_final_text_units, create_final_community_reports: build_create_final_community_reports_steps, diff --git a/graphrag/index/workflows/v1/compute_communities.py b/graphrag/index/workflows/v1/compute_communities.py new file mode 100644 index 0000000000..46a739d392 --- /dev/null +++ b/graphrag/index/workflows/v1/compute_communities.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + Table, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.compute_communities import compute_communities +from graphrag.storage.pipeline_storage import PipelineStorage + +workflow_name = "compute_communities" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base communities from the graph edges. + + ## Dependencies + * `workflow:extract_graph` + """ + clustering_config = config.get( + "cluster_graph", + {"strategy": {"type": "leiden"}}, + ) + clustering_strategy = clustering_config.get("strategy") + + snapshot_transient = config.get("snapshot_transient", False) or False + + return [ + { + "verb": workflow_name, + "args": { + "clustering_strategy": clustering_strategy, + "snapshot_transient_enabled": snapshot_transient, + }, + "input": ({"source": "workflow:extract_graph"}), + }, + ] + + +@verb( + name=workflow_name, + treats_input_tables_as_immutable=True, +) +async def workflow( + storage: PipelineStorage, + runtime_storage: PipelineStorage, + clustering_strategy: dict[str, Any], + snapshot_transient_enabled: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to create the base entity graph.""" + base_relationship_edges = await runtime_storage.get("base_relationship_edges") + + base_communities = await compute_communities( + base_relationship_edges, + storage, + clustering_strategy=clustering_strategy, + snapshot_transient_enabled=snapshot_transient_enabled, + ) + + await runtime_storage.set("base_communities", base_communities) + + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/create_base_entity_graph.py b/graphrag/index/workflows/v1/create_base_entity_graph.py deleted file mode 100644 index 833f8733c5..0000000000 --- a/graphrag/index/workflows/v1/create_base_entity_graph.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""A module containing build_steps method definition.""" - -from datashaper import ( - AsyncType, -) - -from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep - -workflow_name = "create_base_entity_graph" - - -def build_steps( - config: PipelineWorkflowConfig, -) -> list[PipelineWorkflowStep]: - """ - Create the base table for the entity graph. - - ## Dependencies - * `workflow:create_base_text_units` - """ - entity_extraction_config = config.get("entity_extract", {}) - async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) - extraction_strategy = entity_extraction_config.get("strategy") - extraction_num_threads = entity_extraction_config.get("num_threads", 4) - entity_types = entity_extraction_config.get("entity_types") - - summarize_descriptions_config = config.get("summarize_descriptions", {}) - summarization_strategy = summarize_descriptions_config.get("strategy") - summarization_num_threads = summarize_descriptions_config.get("num_threads", 4) - - clustering_config = config.get( - "cluster_graph", - {"strategy": {"type": "leiden"}}, - ) - clustering_strategy = clustering_config.get("strategy") - - snapshot_graphml = config.get("snapshot_graphml", False) or False - snapshot_transient = config.get("snapshot_transient", False) or False - - return [ - { - "verb": "create_base_entity_graph", - "args": { - "extraction_strategy": extraction_strategy, - "extraction_num_threads": extraction_num_threads, - "extraction_async_mode": async_mode, - "entity_types": entity_types, - "summarization_strategy": summarization_strategy, - "summarization_num_threads": summarization_num_threads, - "clustering_strategy": clustering_strategy, - "snapshot_graphml_enabled": snapshot_graphml, - "snapshot_transient_enabled": snapshot_transient, - }, - "input": ({"source": "workflow:create_base_text_units"}), - }, - ] diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py index 40250b62d2..37e849b8e2 100644 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ b/graphrag/index/workflows/v1/create_base_text_units.py @@ -3,9 +3,23 @@ """A module containing build_steps method definition.""" -from datashaper import DEFAULT_INPUT_NAME +from typing import Any, cast + +import pandas as pd +from datashaper import ( + DEFAULT_INPUT_NAME, + Table, + VerbCallbacks, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_base_text_units import ( + create_base_text_units, +) +from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_base_text_units" @@ -26,7 +40,7 @@ def build_steps( snapshot_transient = config.get("snapshot_transient", False) or False return [ { - "verb": "create_base_text_units", + "verb": workflow_name, "args": { "chunk_by_columns": chunk_by_columns, "chunk_strategy": chunk_strategy, @@ -35,3 +49,36 @@ def build_steps( "input": {"source": DEFAULT_INPUT_NAME}, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + input: VerbInput, + callbacks: VerbCallbacks, + storage: PipelineStorage, + runtime_storage: PipelineStorage, + chunk_by_columns: list[str], + chunk_strategy: dict[str, Any] | None = None, + snapshot_transient_enabled: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform base text_units.""" + source = cast("pd.DataFrame", input.get_input()) + + output = await create_base_text_units( + source, + callbacks, + storage, + chunk_by_columns, + chunk_strategy=chunk_strategy, + snapshot_transient_enabled=snapshot_transient_enabled, + ) + + await runtime_storage.set("base_text_units", output) + + return create_verb_result( + cast( + "Table", + pd.DataFrame(), + ) + ) diff --git a/graphrag/index/workflows/v1/create_final_communities.py b/graphrag/index/workflows/v1/create_final_communities.py index b5296b4bfc..c9683991c5 100644 --- a/graphrag/index/workflows/v1/create_final_communities.py +++ b/graphrag/index/workflows/v1/create_final_communities.py @@ -3,7 +3,19 @@ """A module containing build_steps method definition.""" +from typing import cast + +from datashaper import ( + Table, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_communities import ( + create_final_communities, +) +from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_final_communities" @@ -15,11 +27,34 @@ def build_steps( Create the final communities table. ## Dependencies - * `workflow:create_base_entity_graph` + * `workflow:extract_graph` """ return [ { - "verb": "create_final_communities", - "input": {"source": "workflow:create_base_entity_graph"}, + "verb": workflow_name, + "input": {"source": "workflow:extract_graph"}, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + runtime_storage: PipelineStorage, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final communities.""" + base_entity_nodes = await runtime_storage.get("base_entity_nodes") + base_relationship_edges = await runtime_storage.get("base_relationship_edges") + base_communities = await runtime_storage.get("base_communities") + output = create_final_communities( + base_entity_nodes, + base_relationship_edges, + base_communities, + ) + + return create_verb_result( + cast( + "Table", + output, + ) + ) diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py index 6b8d110fe1..401a4bffab 100644 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/create_final_community_reports.py @@ -3,7 +3,26 @@ """A module containing build_steps method definition.""" +from typing import TYPE_CHECKING, cast + +from datashaper import ( + AsyncType, + Table, + VerbCallbacks, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_community_reports import ( + create_final_community_reports, +) +from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table + +if TYPE_CHECKING: + import pandas as pd workflow_name = "create_final_community_reports" @@ -15,7 +34,7 @@ def build_steps( Create the final community reports table. ## Dependencies - * `workflow:create_base_entity_graph` + * `workflow:extract_graph` """ covariates_enabled = config.get("covariates_enabled", False) create_community_reports_config = config.get("create_community_reports", {}) @@ -34,7 +53,7 @@ def build_steps( return [ { - "verb": "create_final_community_reports", + "verb": workflow_name, "args": { "summarization_strategy": summarization_strategy, "async_mode": async_mode, @@ -43,3 +62,46 @@ def build_steps( "input": input, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + summarization_strategy: dict, + async_mode: AsyncType = AsyncType.AsyncIO, + num_threads: int = 4, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform community reports.""" + nodes = cast("pd.DataFrame", input.get_input()) + edges = cast("pd.DataFrame", get_required_input_table(input, "relationships").table) + entities = cast("pd.DataFrame", get_required_input_table(input, "entities").table) + communities = cast( + "pd.DataFrame", get_required_input_table(input, "communities").table + ) + + claims = get_named_input_table(input, "covariates") + if claims: + claims = cast("pd.DataFrame", claims.table) + + output = await create_final_community_reports( + nodes, + edges, + entities, + communities, + claims, + callbacks, + cache, + summarization_strategy, + async_mode=async_mode, + num_threads=num_threads, + ) + + return create_verb_result( + cast( + "Table", + output, + ) + ) diff --git a/graphrag/index/workflows/v1/create_final_covariates.py b/graphrag/index/workflows/v1/create_final_covariates.py index b730a1737d..2804e389f3 100644 --- a/graphrag/index/workflows/v1/create_final_covariates.py +++ b/graphrag/index/workflows/v1/create_final_covariates.py @@ -3,11 +3,22 @@ """A module containing build_steps method definition.""" +from typing import Any, cast + from datashaper import ( AsyncType, + Table, + VerbCallbacks, + verb, ) +from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_covariates import ( + create_final_covariates, +) +from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_final_covariates" @@ -28,7 +39,7 @@ def build_steps( return [ { - "verb": "create_final_covariates", + "verb": workflow_name, "args": { "covariate_type": "claim", "extraction_strategy": extraction_strategy, @@ -38,3 +49,32 @@ def build_steps( "input": {"source": "workflow:create_base_text_units"}, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + callbacks: VerbCallbacks, + cache: PipelineCache, + runtime_storage: PipelineStorage, + covariate_type: str, + extraction_strategy: dict[str, Any] | None, + async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + num_threads: int = 4, + **_kwargs: dict, +) -> VerbResult: + """All the steps to extract and format covariates.""" + text_units = await runtime_storage.get("base_text_units") + + output = await create_final_covariates( + text_units, + callbacks, + cache, + covariate_type, + extraction_strategy, + async_mode=async_mode, + entity_types=entity_types, + num_threads=num_threads, + ) + + return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_documents.py b/graphrag/index/workflows/v1/create_final_documents.py index ad0a1f036e..a9b5af67fd 100644 --- a/graphrag/index/workflows/v1/create_final_documents.py +++ b/graphrag/index/workflows/v1/create_final_documents.py @@ -3,9 +3,25 @@ """A module containing build_steps method definition.""" -from datashaper import DEFAULT_INPUT_NAME +from typing import TYPE_CHECKING, cast + +from datashaper import ( + DEFAULT_INPUT_NAME, + Table, + VerbInput, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_documents import ( + create_final_documents, +) +from graphrag.storage.pipeline_storage import PipelineStorage + +if TYPE_CHECKING: + import pandas as pd + workflow_name = "create_final_documents" @@ -22,7 +38,7 @@ def build_steps( document_attribute_columns = config.get("document_attribute_columns", None) return [ { - "verb": "create_final_documents", + "verb": workflow_name, "args": {"document_attribute_columns": document_attribute_columns}, "input": { "source": DEFAULT_INPUT_NAME, @@ -30,3 +46,22 @@ def build_steps( }, }, ] + + +@verb( + name=workflow_name, + treats_input_tables_as_immutable=True, +) +async def workflow( + input: VerbInput, + runtime_storage: PipelineStorage, + document_attribute_columns: list[str] | None = None, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final documents.""" + source = cast("pd.DataFrame", input.get_input()) + text_units = await runtime_storage.get("base_text_units") + + output = create_final_documents(source, text_units, document_attribute_columns) + + return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_entities.py b/graphrag/index/workflows/v1/create_final_entities.py index d36d5bb331..35a86bbdff 100644 --- a/graphrag/index/workflows/v1/create_final_entities.py +++ b/graphrag/index/workflows/v1/create_final_entities.py @@ -4,8 +4,19 @@ """A module containing build_steps method definition.""" import logging +from typing import cast + +from datashaper import ( + Table, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_entities import ( + create_final_entities, +) +from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_final_entities" log = logging.getLogger(__name__) @@ -18,12 +29,28 @@ def build_steps( Create the final entities table. ## Dependencies - * `workflow:create_base_entity_graph` + * `workflow:extract_graph` """ return [ { - "verb": "create_final_entities", + "verb": workflow_name, "args": {}, - "input": {"source": "workflow:create_base_entity_graph"}, + "input": {"source": "workflow:extract_graph"}, }, ] + + +@verb( + name=workflow_name, + treats_input_tables_as_immutable=True, +) +async def workflow( + runtime_storage: PipelineStorage, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final entities.""" + base_entity_nodes = await runtime_storage.get("base_entity_nodes") + + output = create_final_entities(base_entity_nodes) + + return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_nodes.py b/graphrag/index/workflows/v1/create_final_nodes.py index 4385853f77..60b4cfe17e 100644 --- a/graphrag/index/workflows/v1/create_final_nodes.py +++ b/graphrag/index/workflows/v1/create_final_nodes.py @@ -3,7 +3,20 @@ """A module containing build_steps method definition.""" +from typing import Any, cast + +from datashaper import ( + Table, + VerbCallbacks, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_nodes import ( + create_final_nodes, +) +from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_final_nodes" @@ -15,7 +28,7 @@ def build_steps( Create the base table for the document graph. ## Dependencies - * `workflow:create_base_entity_graph` + * `workflow:extract_graph` """ layout_graph_enabled = config.get("layout_graph_enabled", True) layout_graph_config = config.get( @@ -46,13 +59,46 @@ def build_steps( return [ { - "verb": "create_final_nodes", + "verb": workflow_name, "args": { "layout_strategy": layout_strategy, "embedding_strategy": embedding_strategy if embed_graph_enabled else None, }, - "input": {"source": "workflow:create_base_entity_graph"}, + "input": { + "source": "workflow:extract_graph", + "communities": "workflow:compute_communities", + }, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + callbacks: VerbCallbacks, + runtime_storage: PipelineStorage, + layout_strategy: dict[str, Any], + embedding_strategy: dict[str, Any] | None = None, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final nodes.""" + base_entity_nodes = await runtime_storage.get("base_entity_nodes") + base_relationship_edges = await runtime_storage.get("base_relationship_edges") + base_communities = await runtime_storage.get("base_communities") + + output = create_final_nodes( + base_entity_nodes, + base_relationship_edges, + base_communities, + callbacks, + layout_strategy, + embedding_strategy=embedding_strategy, + ) + + return create_verb_result( + cast( + "Table", + output, + ) + ) diff --git a/graphrag/index/workflows/v1/create_final_relationships.py b/graphrag/index/workflows/v1/create_final_relationships.py index d951d19ab7..a278607e1e 100644 --- a/graphrag/index/workflows/v1/create_final_relationships.py +++ b/graphrag/index/workflows/v1/create_final_relationships.py @@ -4,8 +4,19 @@ """A module containing build_steps method definition.""" import logging +from typing import cast + +from datashaper import ( + Table, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_relationships import ( + create_final_relationships, +) +from graphrag.storage.pipeline_storage import PipelineStorage workflow_name = "create_final_relationships" @@ -19,14 +30,31 @@ def build_steps( Create the final relationships table. ## Dependencies - * `workflow:create_base_entity_graph` + * `workflow:extract_graph` """ return [ { - "verb": "create_final_relationships", + "verb": workflow_name, "args": {}, "input": { - "source": "workflow:create_base_entity_graph", + "source": "workflow:extract_graph", }, }, ] + + +@verb( + name=workflow_name, + treats_input_tables_as_immutable=True, +) +async def workflow( + runtime_storage: PipelineStorage, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform final relationships.""" + base_relationship_edges = await runtime_storage.get("base_relationship_edges") + base_entity_nodes = await runtime_storage.get("base_entity_nodes") + + output = create_final_relationships(base_relationship_edges, base_entity_nodes) + + return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/create_final_text_units.py b/graphrag/index/workflows/v1/create_final_text_units.py index a39e22d2e2..887477c593 100644 --- a/graphrag/index/workflows/v1/create_final_text_units.py +++ b/graphrag/index/workflows/v1/create_final_text_units.py @@ -3,7 +3,25 @@ """A module containing build_steps method definition.""" +from typing import TYPE_CHECKING, cast + +from datashaper import ( + Table, + VerbInput, + VerbResult, + create_verb_result, + verb, +) + from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.create_final_text_units import ( + create_final_text_units, +) +from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table +from graphrag.storage.pipeline_storage import PipelineStorage + +if TYPE_CHECKING: + import pandas as pd workflow_name = "create_final_text_units" @@ -32,8 +50,37 @@ def build_steps( return [ { - "verb": "create_final_text_units", + "verb": workflow_name, "args": {}, "input": input, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + input: VerbInput, + runtime_storage: PipelineStorage, + **_kwargs: dict, +) -> VerbResult: + """All the steps to transform the text units.""" + text_units = await runtime_storage.get("base_text_units") + final_entities = cast( + "pd.DataFrame", get_required_input_table(input, "entities").table + ) + final_relationships = cast( + "pd.DataFrame", get_required_input_table(input, "relationships").table + ) + final_covariates = get_named_input_table(input, "covariates") + + if final_covariates: + final_covariates = cast("pd.DataFrame", final_covariates.table) + + output = create_final_text_units( + text_units, + final_entities, + final_relationships, + final_covariates, + ) + + return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/extract_graph.py b/graphrag/index/workflows/v1/extract_graph.py new file mode 100644 index 0000000000..86af232cfe --- /dev/null +++ b/graphrag/index/workflows/v1/extract_graph.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing build_steps method definition.""" + +from typing import Any, cast + +import pandas as pd +from datashaper import ( + AsyncType, + Table, + VerbCallbacks, + verb, +) +from datashaper.table_store.types import VerbResult, create_verb_result + +from graphrag.cache.pipeline_cache import PipelineCache +from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.extract_graph import ( + extract_graph, +) +from graphrag.storage.pipeline_storage import PipelineStorage + +workflow_name = "extract_graph" + + +def build_steps( + config: PipelineWorkflowConfig, +) -> list[PipelineWorkflowStep]: + """ + Create the base table for the entity graph. + + ## Dependencies + * `workflow:create_base_text_units` + """ + entity_extraction_config = config.get("entity_extract", {}) + async_mode = entity_extraction_config.get("async_mode", AsyncType.AsyncIO) + extraction_strategy = entity_extraction_config.get("strategy") + extraction_num_threads = entity_extraction_config.get("num_threads", 4) + entity_types = entity_extraction_config.get("entity_types") + + summarize_descriptions_config = config.get("summarize_descriptions", {}) + summarization_strategy = summarize_descriptions_config.get("strategy") + summarization_num_threads = summarize_descriptions_config.get("num_threads", 4) + + snapshot_graphml = config.get("snapshot_graphml", False) or False + snapshot_transient = config.get("snapshot_transient", False) or False + + return [ + { + "verb": workflow_name, + "args": { + "extraction_strategy": extraction_strategy, + "extraction_num_threads": extraction_num_threads, + "extraction_async_mode": async_mode, + "entity_types": entity_types, + "summarization_strategy": summarization_strategy, + "summarization_num_threads": summarization_num_threads, + "snapshot_graphml_enabled": snapshot_graphml, + "snapshot_transient_enabled": snapshot_transient, + }, + "input": ({"source": "workflow:create_base_text_units"}), + }, + ] + + +@verb( + name=workflow_name, + treats_input_tables_as_immutable=True, +) +async def workflow( + callbacks: VerbCallbacks, + cache: PipelineCache, + storage: PipelineStorage, + runtime_storage: PipelineStorage, + extraction_strategy: dict[str, Any] | None, + extraction_num_threads: int = 4, + extraction_async_mode: AsyncType = AsyncType.AsyncIO, + entity_types: list[str] | None = None, + summarization_strategy: dict[str, Any] | None = None, + summarization_num_threads: int = 4, + snapshot_graphml_enabled: bool = False, + snapshot_transient_enabled: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to create the base entity graph.""" + text_units = await runtime_storage.get("base_text_units") + + base_entity_nodes, base_relationship_edges = await extract_graph( + text_units, + callbacks, + cache, + storage, + extraction_strategy=extraction_strategy, + extraction_num_threads=extraction_num_threads, + extraction_async_mode=extraction_async_mode, + entity_types=entity_types, + summarization_strategy=summarization_strategy, + summarization_num_threads=summarization_num_threads, + snapshot_graphml_enabled=snapshot_graphml_enabled, + snapshot_transient_enabled=snapshot_transient_enabled, + ) + + await runtime_storage.set("base_entity_nodes", base_entity_nodes) + await runtime_storage.set("base_relationship_edges", base_relationship_edges) + + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/generate_text_embeddings.py b/graphrag/index/workflows/v1/generate_text_embeddings.py index 58464b33a8..5af6f354ea 100644 --- a/graphrag/index/workflows/v1/generate_text_embeddings.py +++ b/graphrag/index/workflows/v1/generate_text_embeddings.py @@ -4,8 +4,25 @@ """A module containing build_steps method definition.""" import logging +from typing import cast +import pandas as pd +from datashaper import ( + Table, + VerbCallbacks, + VerbInput, + VerbResult, + create_verb_result, + verb, +) + +from graphrag.cache.pipeline_cache import PipelineCache from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep +from graphrag.index.flows.generate_text_embeddings import ( + generate_text_embeddings, +) +from graphrag.index.utils.ds_util import get_required_input_table +from graphrag.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) @@ -38,7 +55,7 @@ def build_steps( snapshot_embeddings = config.get("snapshot_embeddings", False) return [ { - "verb": "generate_text_embeddings", + "verb": workflow_name, "args": { "text_embed": text_embed, "embedded_fields": embedded_fields, @@ -47,3 +64,47 @@ def build_steps( "input": input, }, ] + + +@verb(name=workflow_name, treats_input_tables_as_immutable=True) +async def workflow( + input: VerbInput, + callbacks: VerbCallbacks, + cache: PipelineCache, + storage: PipelineStorage, + text_embed: dict, + embedded_fields: set[str], + snapshot_embeddings_enabled: bool = False, + **_kwargs: dict, +) -> VerbResult: + """All the steps to generate embeddings.""" + source = cast("pd.DataFrame", input.get_input()) + final_relationships = cast( + "pd.DataFrame", get_required_input_table(input, "relationships").table + ) + final_text_units = cast( + "pd.DataFrame", get_required_input_table(input, "text_units").table + ) + final_entities = cast( + "pd.DataFrame", get_required_input_table(input, "entities").table + ) + + final_community_reports = cast( + "pd.DataFrame", get_required_input_table(input, "community_reports").table + ) + + await generate_text_embeddings( + final_documents=source, + final_relationships=final_relationships, + final_text_units=final_text_units, + final_entities=final_entities, + final_community_reports=final_community_reports, + callbacks=callbacks, + cache=cache, + storage=storage, + text_embed_config=text_embed, + embedded_fields=embedded_fields, + snapshot_embeddings_enabled=snapshot_embeddings_enabled, + ) + + return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/subflows/__init__.py b/graphrag/index/workflows/v1/subflows/__init__.py deleted file mode 100644 index 1002a31af9..0000000000 --- a/graphrag/index/workflows/v1/subflows/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""The Indexing Engine workflows -> subflows package root.""" - -from graphrag.index.workflows.v1.subflows.create_base_entity_graph import ( - create_base_entity_graph, -) -from graphrag.index.workflows.v1.subflows.create_base_text_units import ( - create_base_text_units, -) -from graphrag.index.workflows.v1.subflows.create_final_communities import ( - create_final_communities, -) -from graphrag.index.workflows.v1.subflows.create_final_community_reports import ( - create_final_community_reports, -) -from graphrag.index.workflows.v1.subflows.create_final_covariates import ( - create_final_covariates, -) -from graphrag.index.workflows.v1.subflows.create_final_documents import ( - create_final_documents, -) -from graphrag.index.workflows.v1.subflows.create_final_entities import ( - create_final_entities, -) -from graphrag.index.workflows.v1.subflows.create_final_nodes import create_final_nodes -from graphrag.index.workflows.v1.subflows.create_final_relationships import ( - create_final_relationships, -) -from graphrag.index.workflows.v1.subflows.create_final_text_units import ( - create_final_text_units, -) -from graphrag.index.workflows.v1.subflows.generate_text_embeddings import ( - generate_text_embeddings, -) - -__all__ = [ - "create_base_entity_graph", - "create_base_text_units", - "create_final_communities", - "create_final_community_reports", - "create_final_covariates", - "create_final_documents", - "create_final_entities", - "create_final_nodes", - "create_final_relationships", - "create_final_text_units", - "generate_text_embeddings", -] diff --git a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py deleted file mode 100644 index 194e56b0d8..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to create the base entity graph.""" - -from typing import Any, cast - -import pandas as pd -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.flows.create_base_entity_graph import ( - create_base_entity_graph as create_base_entity_graph_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb( - name="create_base_entity_graph", - treats_input_tables_as_immutable=True, -) -async def create_base_entity_graph( - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - runtime_storage: PipelineStorage, - clustering_strategy: dict[str, Any], - extraction_strategy: dict[str, Any] | None, - extraction_num_threads: int = 4, - extraction_async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - summarization_strategy: dict[str, Any] | None = None, - summarization_num_threads: int = 4, - snapshot_graphml_enabled: bool = False, - snapshot_transient_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to create the base entity graph.""" - text_units = await runtime_storage.get("base_text_units") - - await create_base_entity_graph_flow( - text_units, - callbacks, - cache, - storage, - runtime_storage, - clustering_strategy=clustering_strategy, - extraction_strategy=extraction_strategy, - extraction_num_threads=extraction_num_threads, - extraction_async_mode=extraction_async_mode, - entity_types=entity_types, - summarization_strategy=summarization_strategy, - summarization_num_threads=summarization_num_threads, - snapshot_graphml_enabled=snapshot_graphml_enabled, - snapshot_transient_enabled=snapshot_transient_enabled, - ) - - return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/graphrag/index/workflows/v1/subflows/create_base_text_units.py b/graphrag/index/workflows/v1/subflows/create_base_text_units.py deleted file mode 100644 index 36a33d9044..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_base_text_units.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform base text_units.""" - -from typing import Any, cast - -import pandas as pd -from datashaper import ( - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_base_text_units import ( - create_base_text_units as create_base_text_units_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb(name="create_base_text_units", treats_input_tables_as_immutable=True) -async def create_base_text_units( - input: VerbInput, - callbacks: VerbCallbacks, - storage: PipelineStorage, - runtime_storage: PipelineStorage, - chunk_by_columns: list[str], - chunk_strategy: dict[str, Any] | None = None, - snapshot_transient_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform base text_units.""" - source = cast("pd.DataFrame", input.get_input()) - - output = await create_base_text_units_flow( - source, - callbacks, - storage, - chunk_by_columns, - chunk_strategy=chunk_strategy, - snapshot_transient_enabled=snapshot_transient_enabled, - ) - - await runtime_storage.set("base_text_units", output) - - return create_verb_result( - cast( - "Table", - pd.DataFrame(), - ) - ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_communities.py b/graphrag/index/workflows/v1/subflows/create_final_communities.py deleted file mode 100644 index de200a1393..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_communities.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform final communities.""" - -from typing import cast - -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_final_communities import ( - create_final_communities as create_final_communities_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb(name="create_final_communities", treats_input_tables_as_immutable=True) -async def create_final_communities( - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final communities.""" - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_communities = await runtime_storage.get("base_communities") - output = create_final_communities_flow( - base_entity_nodes, - base_relationship_edges, - base_communities, - ) - - return create_verb_result( - cast( - "Table", - output, - ) - ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py b/graphrag/index/workflows/v1/subflows/create_final_community_reports.py deleted file mode 100644 index 88a3e6dd8a..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_community_reports.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform community reports.""" - -from typing import TYPE_CHECKING, cast - -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.flows.create_final_community_reports import ( - create_final_community_reports as create_final_community_reports_flow, -) -from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table - -if TYPE_CHECKING: - import pandas as pd - - -@verb(name="create_final_community_reports", treats_input_tables_as_immutable=True) -async def create_final_community_reports( - input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, - summarization_strategy: dict, - async_mode: AsyncType = AsyncType.AsyncIO, - num_threads: int = 4, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform community reports.""" - nodes = cast("pd.DataFrame", input.get_input()) - edges = cast("pd.DataFrame", get_required_input_table(input, "relationships").table) - entities = cast("pd.DataFrame", get_required_input_table(input, "entities").table) - communities = cast( - "pd.DataFrame", get_required_input_table(input, "communities").table - ) - - claims = get_named_input_table(input, "covariates") - if claims: - claims = cast("pd.DataFrame", claims.table) - - output = await create_final_community_reports_flow( - nodes, - edges, - entities, - communities, - claims, - callbacks, - cache, - summarization_strategy, - async_mode=async_mode, - num_threads=num_threads, - ) - - return create_verb_result( - cast( - "Table", - output, - ) - ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_covariates.py b/graphrag/index/workflows/v1/subflows/create_final_covariates.py deleted file mode 100644 index b8ff11d65f..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_covariates.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to extract and format covariates.""" - -from typing import Any, cast - -from datashaper import ( - AsyncType, - Table, - VerbCallbacks, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.flows.create_final_covariates import ( - create_final_covariates as create_final_covariates_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb(name="create_final_covariates", treats_input_tables_as_immutable=True) -async def create_final_covariates( - callbacks: VerbCallbacks, - cache: PipelineCache, - runtime_storage: PipelineStorage, - covariate_type: str, - extraction_strategy: dict[str, Any] | None, - async_mode: AsyncType = AsyncType.AsyncIO, - entity_types: list[str] | None = None, - num_threads: int = 4, - **_kwargs: dict, -) -> VerbResult: - """All the steps to extract and format covariates.""" - text_units = await runtime_storage.get("base_text_units") - - output = await create_final_covariates_flow( - text_units, - callbacks, - cache, - covariate_type, - extraction_strategy, - async_mode=async_mode, - entity_types=entity_types, - num_threads=num_threads, - ) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_documents.py b/graphrag/index/workflows/v1/subflows/create_final_documents.py deleted file mode 100644 index 94d6d692b3..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_documents.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform final documents.""" - -from typing import TYPE_CHECKING, cast - -from datashaper import ( - Table, - VerbInput, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_final_documents import ( - create_final_documents as create_final_documents_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - -if TYPE_CHECKING: - import pandas as pd - - -@verb( - name="create_final_documents", - treats_input_tables_as_immutable=True, -) -async def create_final_documents( - input: VerbInput, - runtime_storage: PipelineStorage, - document_attribute_columns: list[str] | None = None, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final documents.""" - source = cast("pd.DataFrame", input.get_input()) - text_units = await runtime_storage.get("base_text_units") - - output = create_final_documents_flow(source, text_units, document_attribute_columns) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_entities.py b/graphrag/index/workflows/v1/subflows/create_final_entities.py deleted file mode 100644 index 1a8ff57410..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_entities.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform final entities.""" - -from typing import cast - -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_final_entities import ( - create_final_entities as create_final_entities_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb( - name="create_final_entities", - treats_input_tables_as_immutable=True, -) -async def create_final_entities( - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final entities.""" - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - - output = create_final_entities_flow(base_entity_nodes) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_nodes.py b/graphrag/index/workflows/v1/subflows/create_final_nodes.py deleted file mode 100644 index 92bd6e34fc..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_nodes.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform final nodes.""" - -from typing import Any, cast - -from datashaper import ( - Table, - VerbCallbacks, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_final_nodes import ( - create_final_nodes as create_final_nodes_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb(name="create_final_nodes", treats_input_tables_as_immutable=True) -async def create_final_nodes( - callbacks: VerbCallbacks, - runtime_storage: PipelineStorage, - layout_strategy: dict[str, Any], - embedding_strategy: dict[str, Any] | None = None, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final nodes.""" - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_communities = await runtime_storage.get("base_communities") - - output = create_final_nodes_flow( - base_entity_nodes, - base_relationship_edges, - base_communities, - callbacks, - layout_strategy, - embedding_strategy=embedding_strategy, - ) - - return create_verb_result( - cast( - "Table", - output, - ) - ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_relationships.py b/graphrag/index/workflows/v1/subflows/create_final_relationships.py deleted file mode 100644 index 919954dc7c..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_relationships.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform final relationships.""" - -from typing import cast - -from datashaper import ( - Table, - verb, -) -from datashaper.table_store.types import VerbResult, create_verb_result - -from graphrag.index.flows.create_final_relationships import ( - create_final_relationships as create_final_relationships_flow, -) -from graphrag.storage.pipeline_storage import PipelineStorage - - -@verb( - name="create_final_relationships", - treats_input_tables_as_immutable=True, -) -async def create_final_relationships( - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform final relationships.""" - base_relationship_edges = await runtime_storage.get("base_relationship_edges") - base_entity_nodes = await runtime_storage.get("base_entity_nodes") - - output = create_final_relationships_flow(base_relationship_edges, base_entity_nodes) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units.py b/graphrag/index/workflows/v1/subflows/create_final_text_units.py deleted file mode 100644 index 2db0e79c9b..0000000000 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform the text units.""" - -from typing import TYPE_CHECKING, cast - -from datashaper import ( - Table, - VerbInput, - VerbResult, - create_verb_result, - verb, -) - -from graphrag.index.flows.create_final_text_units import ( - create_final_text_units as create_final_text_units_flow, -) -from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table -from graphrag.storage.pipeline_storage import PipelineStorage - -if TYPE_CHECKING: - import pandas as pd - - -@verb(name="create_final_text_units", treats_input_tables_as_immutable=True) -async def create_final_text_units( - input: VerbInput, - runtime_storage: PipelineStorage, - **_kwargs: dict, -) -> VerbResult: - """All the steps to transform the text units.""" - text_units = await runtime_storage.get("base_text_units") - final_entities = cast( - "pd.DataFrame", get_required_input_table(input, "entities").table - ) - final_relationships = cast( - "pd.DataFrame", get_required_input_table(input, "relationships").table - ) - final_covariates = get_named_input_table(input, "covariates") - - if final_covariates: - final_covariates = cast("pd.DataFrame", final_covariates.table) - - output = create_final_text_units_flow( - text_units, - final_entities, - final_relationships, - final_covariates, - ) - - return create_verb_result(cast("Table", output)) diff --git a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py b/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py deleted file mode 100644 index 5e50bd1978..0000000000 --- a/graphrag/index/workflows/v1/subflows/generate_text_embeddings.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""All the steps to transform the text units.""" - -import logging -from typing import cast - -import pandas as pd -from datashaper import ( - Table, - VerbCallbacks, - VerbInput, - VerbResult, - create_verb_result, - verb, -) - -from graphrag.cache.pipeline_cache import PipelineCache -from graphrag.index.flows.generate_text_embeddings import ( - generate_text_embeddings as generate_text_embeddings_flow, -) -from graphrag.index.utils.ds_util import get_required_input_table -from graphrag.storage.pipeline_storage import PipelineStorage - -log = logging.getLogger(__name__) - - -@verb(name="generate_text_embeddings", treats_input_tables_as_immutable=True) -async def generate_text_embeddings( - input: VerbInput, - callbacks: VerbCallbacks, - cache: PipelineCache, - storage: PipelineStorage, - text_embed: dict, - embedded_fields: set[str], - snapshot_embeddings_enabled: bool = False, - **_kwargs: dict, -) -> VerbResult: - """All the steps to generate embeddings.""" - source = cast("pd.DataFrame", input.get_input()) - final_relationships = cast( - "pd.DataFrame", get_required_input_table(input, "relationships").table - ) - final_text_units = cast( - "pd.DataFrame", get_required_input_table(input, "text_units").table - ) - final_entities = cast( - "pd.DataFrame", get_required_input_table(input, "entities").table - ) - - final_community_reports = cast( - "pd.DataFrame", get_required_input_table(input, "community_reports").table - ) - - await generate_text_embeddings_flow( - final_documents=source, - final_relationships=final_relationships, - final_text_units=final_text_units, - final_entities=final_entities, - final_community_reports=final_community_reports, - callbacks=callbacks, - cache=cache, - storage=storage, - text_embed_config=text_embed, - embedded_fields=embedded_fields, - snapshot_embeddings_enabled=snapshot_embeddings_enabled, - ) - - return create_verb_result(cast("Table", pd.DataFrame())) diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index 11b2dceb51..3fb865dce3 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -2,6 +2,15 @@ "input_path": "./tests/fixtures/min-csv", "input_file_type": "text", "workflow_config": { + "compute_communities": { + "row_range": [ + 1, + 2500 + ], + "subworkflows": 1, + "max_runtime": 150, + "expected_artifacts": 0 + }, "create_base_text_units": { "row_range": [ 1, @@ -11,7 +20,7 @@ "max_runtime": 150, "expected_artifacts": 0 }, - "create_base_entity_graph": { + "extract_graph": { "row_range": [ 1, 2500 diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 0b5ca26581..812603dd83 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -2,6 +2,15 @@ "input_path": "./tests/fixtures/text", "input_file_type": "text", "workflow_config": { + "compute_communities": { + "row_range": [ + 1, + 2500 + ], + "subworkflows": 1, + "max_runtime": 150, + "expected_artifacts": 0 + }, "create_base_text_units": { "row_range": [ 1, @@ -11,7 +20,7 @@ "max_runtime": 150, "expected_artifacts": 0 }, - "create_base_entity_graph": { + "extract_graph": { "row_range": [ 1, 2500 diff --git a/tests/integration/_pipeline/megapipeline.yml b/tests/integration/_pipeline/megapipeline.yml index 13ce43683e..363f8b37e4 100644 --- a/tests/integration/_pipeline/megapipeline.yml +++ b/tests/integration/_pipeline/megapipeline.yml @@ -19,7 +19,7 @@ workflows: # Just lump everything together chunk_by: [] - - name: create_base_entity_graph + - name: extract_graph config: snapshot_graphml_enabled: True entity_extract: @@ -44,6 +44,9 @@ workflows: type: static_response responses: - This is a MOCK response for the LLM. It is summarized! + + - name: compute_communities + config: cluster_graph: strategy: type: leiden diff --git a/tests/verbs/test_compute_communities.py b/tests/verbs/test_compute_communities.py new file mode 100644 index 0000000000..07db7d42c3 --- /dev/null +++ b/tests/verbs/test_compute_communities.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +from graphrag.index.flows.compute_communities import ( + compute_communities, +) +from graphrag.index.run.utils import create_run_context +from graphrag.index.workflows.v1.compute_communities import ( + workflow_name, +) + +from .util import ( + compare_outputs, + get_config_for_workflow, + load_test_table, +) + + +async def test_compute_communities(): + edges = load_test_table("base_relationship_edges") + expected = load_test_table("base_communities") + + context = create_run_context(None, None, None) + config = get_config_for_workflow(workflow_name) + clustering_strategy = config["cluster_graph"]["strategy"] + + actual = await compute_communities( + edges, storage=context.storage, clustering_strategy=clustering_strategy + ) + + columns = list(expected.columns.values) + compare_outputs(actual, expected, columns) + assert len(actual.columns) == len(expected.columns) + + +async def test_compute_communities_with_snapshots(): + edges = load_test_table("base_relationship_edges") + + context = create_run_context(None, None, None) + config = get_config_for_workflow(workflow_name) + clustering_strategy = config["cluster_graph"]["strategy"] + + await compute_communities( + edges, + storage=context.storage, + clustering_strategy=clustering_strategy, + snapshot_transient_enabled=True, + ) + + assert context.storage.keys() == [ + "base_communities.parquet", + ], "Community snapshot keys differ" diff --git a/tests/verbs/test_create_base_entity_graph.py b/tests/verbs/test_extract_graph.py similarity index 90% rename from tests/verbs/test_create_base_entity_graph.py rename to tests/verbs/test_extract_graph.py index 7fc4ec356d..9e4d0a4280 100644 --- a/tests/verbs/test_create_base_entity_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -5,7 +5,7 @@ from graphrag.config.enums import LLMType from graphrag.index.run.utils import create_run_context -from graphrag.index.workflows.v1.create_base_entity_graph import ( +from graphrag.index.workflows.v1.extract_graph import ( build_steps, workflow_name, ) @@ -48,14 +48,13 @@ } -async def test_create_base_entity_graph(): +async def test_extract_graph(): input_tables = load_input_tables([ "workflow:create_base_text_units", ]) nodes_expected = load_test_table("base_entity_nodes") edges_expected = load_test_table("base_relationship_edges") - communities_expected = load_test_table("base_communities") context = create_run_context(None, None, None) await context.runtime_storage.set( @@ -79,7 +78,6 @@ async def test_create_base_entity_graph(): # graph construction creates transient tables for nodes, edges, and communities nodes_actual = await context.runtime_storage.get("base_entity_nodes") edges_actual = await context.runtime_storage.get("base_relationship_edges") - communities_actual = await context.runtime_storage.get("base_communities") assert len(nodes_actual.columns) == len(nodes_expected.columns), ( "Nodes dataframe columns differ" @@ -89,10 +87,6 @@ async def test_create_base_entity_graph(): "Edges dataframe columns differ" ) - assert len(communities_actual.columns) == len(communities_expected.columns), ( - "Edges dataframe columns differ" - ) - # TODO: with the combined verb we can't force summarization # this is because the mock responses always result in a single description, which is returned verbatim rather than summarized # we need to update the mocking to provide somewhat unique graphs so a true merge happens @@ -103,7 +97,7 @@ async def test_create_base_entity_graph(): assert len(context.storage.keys()) == 0, "Storage should be empty" -async def test_create_base_entity_graph_with_snapshots(): +async def test_extract_graph_with_snapshots(): input_tables = load_input_tables([ "workflow:create_base_text_units", ]) @@ -135,11 +129,10 @@ async def test_create_base_entity_graph_with_snapshots(): "graph.graphml", "base_entity_nodes.parquet", "base_relationship_edges.parquet", - "base_communities.parquet", ], "Graph snapshot keys differ" -async def test_create_base_entity_graph_missing_llm_throws(): +async def test_extract_graph_missing_llm_throws(): input_tables = load_input_tables([ "workflow:create_base_text_units", ])