diff --git a/.semversioner/next-release/patch-20241224192900934104.json b/.semversioner/next-release/patch-20241224192900934104.json new file mode 100644 index 0000000000..0c60a93626 --- /dev/null +++ b/.semversioner/next-release/patch-20241224192900934104.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Simplify and streamline internal config." +} diff --git a/graphrag/config/create_graphrag_config.py b/graphrag/config/create_graphrag_config.py index 1155cee46f..6358fcc788 100644 --- a/graphrag/config/create_graphrag_config.py +++ b/graphrag/config/create_graphrag_config.py @@ -31,7 +31,7 @@ from graphrag.config.input_models.graphrag_config_input import GraphRagConfigInput from graphrag.config.input_models.llm_config_input import LLMConfigInput from graphrag.config.models.cache_config import CacheConfig -from graphrag.config.models.chunking_config import ChunkingConfig +from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType from graphrag.config.models.claim_extraction_config import ClaimExtractionConfig from graphrag.config.models.cluster_graph_config import ClusterGraphConfig from graphrag.config.models.community_reports_config import CommunityReportsConfig @@ -318,13 +318,16 @@ def hydrate_parallelization_params( reader.envvar_prefix(Section.node2vec), reader.use(values.get("embed_graph")), ): + use_lcc = reader.bool("use_lcc") embed_graph_model = EmbedGraphConfig( enabled=reader.bool(Fragment.enabled) or defs.NODE2VEC_ENABLED, + dimensions=reader.int("dimensions") or defs.NODE2VEC_DIMENSIONS, num_walks=reader.int("num_walks") or defs.NODE2VEC_NUM_WALKS, walk_length=reader.int("walk_length") or defs.NODE2VEC_WALK_LENGTH, window_size=reader.int("window_size") or defs.NODE2VEC_WINDOW_SIZE, iterations=reader.int("iterations") or defs.NODE2VEC_ITERATIONS, random_seed=reader.int("random_seed") or defs.NODE2VEC_RANDOM_SEED, + use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC, ) with reader.envvar_prefix(Section.input), reader.use(values.get("input")): input_type = reader.str("type") @@ -412,12 +415,15 @@ def hydrate_parallelization_params( encoding_model = ( reader.str(Fragment.encoding_model) or global_encoding_model ) - + strategy = reader.str("strategy") chunks_model = ChunkingConfig( size=reader.int("size") or defs.CHUNK_SIZE, overlap=reader.int("overlap") or defs.CHUNK_OVERLAP, group_by_columns=group_by_columns, encoding_model=encoding_model, + strategy=ChunkStrategyType(strategy) + if strategy + else ChunkStrategyType.tokens, ) with ( reader.envvar_prefix(Section.snapshot), @@ -522,8 +528,13 @@ def hydrate_parallelization_params( ) with reader.use(values.get("cluster_graph")): + use_lcc = reader.bool("use_lcc") + seed = reader.int("seed") cluster_graph_model = ClusterGraphConfig( - max_cluster_size=reader.int("max_cluster_size") or defs.MAX_CLUSTER_SIZE + max_cluster_size=reader.int("max_cluster_size") + or defs.MAX_CLUSTER_SIZE, + use_lcc=use_lcc if use_lcc is not None else defs.USE_LCC, + seed=seed if seed is not None else defs.CLUSTER_GRAPH_SEED, ) with ( diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 48a6f5f3e6..9da336cca9 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -60,6 +60,8 @@ CLAIM_MAX_GLEANINGS = 1 CLAIM_EXTRACTION_ENABLED = False MAX_CLUSTER_SIZE = 10 +USE_LCC = True +CLUSTER_GRAPH_SEED = 0xDEADBEEF COMMUNITY_REPORT_MAX_LENGTH = 2000 COMMUNITY_REPORT_MAX_INPUT_LENGTH = 8000 ENTITY_EXTRACTION_ENTITY_TYPES = ["organization", "person", "geo", "event"] @@ -74,6 +76,7 @@ PARALLELIZATION_STAGGER = 0.3 PARALLELIZATION_NUM_THREADS = 50 NODE2VEC_ENABLED = False +NODE2VEC_DIMENSIONS = 1536 NODE2VEC_NUM_WALKS = 10 NODE2VEC_WALK_LENGTH = 40 NODE2VEC_WINDOW_SIZE = 2 diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index 8e3000ab3b..7506eb3a7b 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -12,7 +12,7 @@ ### LLM settings ### ## There are a number of settings to tune the threading and token limits for LLM calls - check the docs. -encoding_model: cl100k_base # this needs to be matched to your model! +encoding_model: {defs.ENCODING_MODEL} # this needs to be matched to your model! llm: api_key: ${{GRAPHRAG_API_KEY}} # set this in the generated .env file @@ -111,7 +111,7 @@ enabled: false # if true, will generate node2vec embeddings for nodes umap: - enabled: false # if true, will generate UMAP embeddings for nodes + enabled: false # if true, will generate UMAP embeddings for nodes (embed_graph must also be enabled) snapshots: graphml: false diff --git a/graphrag/config/models/chunking_config.py b/graphrag/config/models/chunking_config.py index f2a39bf1bb..84d69d36bf 100644 --- a/graphrag/config/models/chunking_config.py +++ b/graphrag/config/models/chunking_config.py @@ -3,11 +3,24 @@ """Parameterization settings for the default configuration.""" +from enum import Enum + from pydantic import BaseModel, Field import graphrag.config.defaults as defs +class ChunkStrategyType(str, Enum): + """ChunkStrategy class definition.""" + + tokens = "tokens" + sentence = "sentence" + + def __repr__(self): + """Get a string representation.""" + return f'"{self.value}"' + + class ChunkingConfig(BaseModel): """Configuration section for chunking.""" @@ -19,22 +32,9 @@ class ChunkingConfig(BaseModel): description="The chunk by columns to use.", default=defs.CHUNK_GROUP_BY_COLUMNS, ) - strategy: dict | None = Field( - description="The chunk strategy to use, overriding the default tokenization strategy", - default=None, + strategy: ChunkStrategyType = Field( + description="The chunking strategy to use.", default=ChunkStrategyType.tokens ) - encoding_model: str | None = Field( - default=None, description="The encoding model to use." + encoding_model: str = Field( + description="The encoding model to use.", default=defs.ENCODING_MODEL ) - - def resolved_strategy(self, encoding_model: str | None) -> dict: - """Get the resolved chunking strategy.""" - from graphrag.index.operations.chunk_text import ChunkStrategyType - - return self.strategy or { - "type": ChunkStrategyType.tokens, - "chunk_size": self.size, - "chunk_overlap": self.overlap, - "group_by_columns": self.group_by_columns, - "encoding_name": encoding_model or self.encoding_model, - } diff --git a/graphrag/config/models/cluster_graph_config.py b/graphrag/config/models/cluster_graph_config.py index 805e5a184b..7e91c7f0df 100644 --- a/graphrag/config/models/cluster_graph_config.py +++ b/graphrag/config/models/cluster_graph_config.py @@ -14,15 +14,11 @@ class ClusterGraphConfig(BaseModel): max_cluster_size: int = Field( description="The maximum cluster size to use.", default=defs.MAX_CLUSTER_SIZE ) - strategy: dict | None = Field( - description="The cluster strategy to use.", default=None + use_lcc: bool = Field( + description="Whether to use the largest connected component.", + default=defs.USE_LCC, + ) + seed: int | None = Field( + description="The seed to use for the clustering.", + default=defs.CLUSTER_GRAPH_SEED, ) - - def resolved_strategy(self) -> dict: - """Get the resolved cluster strategy.""" - from graphrag.index.operations.cluster_graph import GraphCommunityStrategyType - - return self.strategy or { - "type": GraphCommunityStrategyType.leiden, - "max_cluster_size": self.max_cluster_size, - } diff --git a/graphrag/config/models/embed_graph_config.py b/graphrag/config/models/embed_graph_config.py index c4597e03ca..9a28b80090 100644 --- a/graphrag/config/models/embed_graph_config.py +++ b/graphrag/config/models/embed_graph_config.py @@ -15,6 +15,9 @@ class EmbedGraphConfig(BaseModel): description="A flag indicating whether to enable node2vec.", default=defs.NODE2VEC_ENABLED, ) + dimensions: int = Field( + description="The node2vec vector dimensions.", default=defs.NODE2VEC_DIMENSIONS + ) num_walks: int = Field( description="The node2vec number of walks.", default=defs.NODE2VEC_NUM_WALKS ) @@ -30,21 +33,7 @@ class EmbedGraphConfig(BaseModel): random_seed: int = Field( description="The node2vec random seed.", default=defs.NODE2VEC_RANDOM_SEED ) - strategy: dict | None = Field( - description="The graph embedding strategy override.", default=None + use_lcc: bool = Field( + description="Whether to use the largest connected component.", + default=defs.USE_LCC, ) - - def resolved_strategy(self) -> dict: - """Get the resolved node2vec strategy.""" - from graphrag.index.operations.embed_graph.typing import ( - EmbedGraphStrategyType, - ) - - return self.strategy or { - "type": EmbedGraphStrategyType.node2vec, - "num_walks": self.num_walks, - "walk_length": self.walk_length, - "window_size": self.window_size, - "iterations": self.iterations, - "random_seed": self.iterations, - } diff --git a/graphrag/config/models/entity_extraction_config.py b/graphrag/config/models/entity_extraction_config.py index 9a7b078295..5cb7c6e50c 100644 --- a/graphrag/config/models/entity_extraction_config.py +++ b/graphrag/config/models/entity_extraction_config.py @@ -48,7 +48,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict: if self.prompt else None, "max_gleanings": self.max_gleanings, - # It's prechunked in create_base_text_units "encoding_name": encoding_model or self.encoding_model, - "prechunked": True, } diff --git a/graphrag/index/create_pipeline_config.py b/graphrag/index/create_pipeline_config.py index 4d88c5ad57..4ec4342222 100644 --- a/graphrag/index/create_pipeline_config.py +++ b/graphrag/index/create_pipeline_config.py @@ -176,13 +176,8 @@ def _text_unit_workflows( PipelineWorkflowReference( name=create_base_text_units, config={ + "chunks": settings.chunks, "snapshot_transient": settings.snapshots.transient, - "chunk_by": settings.chunks.group_by_columns, - "text_chunk": { - "strategy": settings.chunks.resolved_strategy( - settings.encoding_model - ) - }, }, ), PipelineWorkflowReference( @@ -243,9 +238,7 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference PipelineWorkflowReference( name=compute_communities, config={ - "cluster_graph": { - "strategy": settings.cluster_graph.resolved_strategy() - }, + "cluster_graph": settings.cluster_graph, "snapshot_transient": settings.snapshots.transient, }, ), @@ -260,9 +253,8 @@ def _graph_workflows(settings: GraphRagConfig) -> list[PipelineWorkflowReference PipelineWorkflowReference( name=create_final_nodes, config={ - "layout_graph_enabled": settings.umap.enabled, - "embed_graph_enabled": settings.embed_graph.enabled, - "embed_graph": {"strategy": settings.embed_graph.resolved_strategy()}, + "layout_enabled": settings.umap.enabled, + "embed_graph": settings.embed_graph, }, ), ] diff --git a/graphrag/index/flows/compute_communities.py b/graphrag/index/flows/compute_communities.py index 6ca74ded4a..0b7ab5a5fb 100644 --- a/graphrag/index/flows/compute_communities.py +++ b/graphrag/index/flows/compute_communities.py @@ -3,8 +3,6 @@ """All the steps to create the base entity graph.""" -from typing import Any - import pandas as pd from graphrag.index.operations.cluster_graph import cluster_graph @@ -13,14 +11,18 @@ def compute_communities( base_relationship_edges: pd.DataFrame, - clustering_strategy: dict[str, Any], + max_cluster_size: int, + use_lcc: bool, + seed: int | None = None, ) -> pd.DataFrame: """All the steps to create the base entity graph.""" graph = create_graph(base_relationship_edges) communities = cluster_graph( graph, - strategy=clustering_strategy, + max_cluster_size, + use_lcc, + seed=seed, ) base_communities = pd.DataFrame( diff --git a/graphrag/index/flows/create_base_text_units.py b/graphrag/index/flows/create_base_text_units.py index 3204425d11..63f8f62b6e 100644 --- a/graphrag/index/flows/create_base_text_units.py +++ b/graphrag/index/flows/create_base_text_units.py @@ -14,15 +14,19 @@ aggregate_operation_mapping, ) -from graphrag.index.operations.chunk_text import chunk_text +from graphrag.config.models.chunking_config import ChunkStrategyType +from graphrag.index.operations.chunk_text.chunk_text import chunk_text from graphrag.index.utils.hashing import gen_sha512_hash def create_base_text_units( documents: pd.DataFrame, callbacks: VerbCallbacks, - chunk_by_columns: list[str], - chunk_strategy: dict[str, Any] | None = None, + group_by_columns: list[str], + size: int, + overlap: int, + encoding_model: str, + strategy: ChunkStrategyType, ) -> pd.DataFrame: """All the steps to transform base text_units.""" sort = documents.sort_values(by=["id"], ascending=[True]) @@ -35,7 +39,7 @@ def create_base_text_units( aggregated = _aggregate_df( sort, - groupby=[*chunk_by_columns] if len(chunk_by_columns) > 0 else None, + groupby=[*group_by_columns] if len(group_by_columns) > 0 else None, aggregations=[ { "column": "text_with_ids", @@ -47,30 +51,36 @@ def create_base_text_units( callbacks.progress(Progress(percent=1)) - chunked = chunk_text( + aggregated["chunks"] = chunk_text( aggregated, column="texts", - to="chunks", + size=size, + overlap=overlap, + encoding_model=encoding_model, + strategy=strategy, callbacks=callbacks, - strategy=chunk_strategy, ) - chunked = cast("pd.DataFrame", chunked[[*chunk_by_columns, "chunks"]]) - chunked = chunked.explode("chunks") - chunked.rename( + aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]]) + aggregated = aggregated.explode("chunks") + aggregated.rename( columns={ "chunks": "chunk", }, inplace=True, ) - chunked["id"] = chunked.apply(lambda row: gen_sha512_hash(row, ["chunk"]), axis=1) - chunked[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame( - chunked["chunk"].tolist(), index=chunked.index + aggregated["id"] = aggregated.apply( + lambda row: gen_sha512_hash(row, ["chunk"]), axis=1 + ) + aggregated[["document_ids", "chunk", "n_tokens"]] = pd.DataFrame( + aggregated["chunk"].tolist(), index=aggregated.index ) # rename for downstream consumption - chunked.rename(columns={"chunk": "text"}, inplace=True) + aggregated.rename(columns={"chunk": "text"}, inplace=True) - return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)) + return cast( + "pd.DataFrame", aggregated[aggregated["text"].notna()].reset_index(drop=True) + ) # TODO: would be nice to inline this completely in the main method with pandas diff --git a/graphrag/index/flows/create_final_nodes.py b/graphrag/index/flows/create_final_nodes.py index 0b6932e405..511ff429e7 100644 --- a/graphrag/index/flows/create_final_nodes.py +++ b/graphrag/index/flows/create_final_nodes.py @@ -3,13 +3,12 @@ """All the steps to transform final nodes.""" -from typing import Any - import pandas as pd from datashaper import ( VerbCallbacks, ) +from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.index.operations.compute_degree import compute_degree from graphrag.index.operations.create_graph import create_graph from graphrag.index.operations.embed_graph.embed_graph import embed_graph @@ -21,21 +20,21 @@ def create_final_nodes( base_relationship_edges: pd.DataFrame, base_communities: pd.DataFrame, callbacks: VerbCallbacks, - layout_strategy: dict[str, Any], - embedding_strategy: dict[str, Any] | None = None, + embed_config: EmbedGraphConfig, + layout_enabled: bool, ) -> pd.DataFrame: """All the steps to transform final nodes.""" graph = create_graph(base_relationship_edges) graph_embeddings = None - if embedding_strategy: + if embed_config.enabled: graph_embeddings = embed_graph( graph, - embedding_strategy, + embed_config, ) layout = layout_graph( graph, callbacks, - layout_strategy, + layout_enabled, embeddings=graph_embeddings, ) diff --git a/graphrag/index/operations/chunk_text/__init__.py b/graphrag/index/operations/chunk_text/__init__.py index d84b4c0c38..1e000e6aa7 100644 --- a/graphrag/index/operations/chunk_text/__init__.py +++ b/graphrag/index/operations/chunk_text/__init__.py @@ -2,11 +2,3 @@ # Licensed under the MIT License """The Indexing Engine text chunk package root.""" - -from graphrag.index.operations.chunk_text.chunk_text import ( - ChunkStrategy, - ChunkStrategyType, - chunk_text, -) - -__all__ = ["ChunkStrategy", "ChunkStrategyType", "chunk_text"] diff --git a/graphrag/index/operations/chunk_text/chunk_text.py b/graphrag/index/operations/chunk_text/chunk_text.py index cea796e9b0..554cfbda35 100644 --- a/graphrag/index/operations/chunk_text/chunk_text.py +++ b/graphrag/index/operations/chunk_text/chunk_text.py @@ -12,20 +12,22 @@ progress_ticker, ) +from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType from graphrag.index.operations.chunk_text.typing import ( ChunkInput, ChunkStrategy, - ChunkStrategyType, ) def chunk_text( input: pd.DataFrame, column: str, - to: str, + size: int, + overlap: int, + encoding_model: str, + strategy: ChunkStrategyType, callbacks: VerbCallbacks, - strategy: dict[str, Any] | None = None, -) -> pd.DataFrame: +) -> pd.Series: """ Chunk a piece of text into smaller pieces. @@ -33,7 +35,6 @@ def chunk_text( ```yaml args: column: # The name of the column containing the text to chunk, this can either be a column with text, or a column with a list[tuple[doc_id, str]] - to: # The name of the column to output the chunks to strategy: # The strategy to use to chunk the text, see below for more details ``` @@ -43,52 +44,46 @@ def chunk_text( ### tokens This strategy uses the [tokens] library to chunk a piece of text. The strategy config is as follows: - > Note: In the future, this will likely be renamed to something more generic, like "openai_tokens". - ```yaml - strategy: - type: tokens - chunk_size: 1200 # Optional, The chunk size to use, default: 1200 - chunk_overlap: 100 # Optional, The chunk overlap to use, default: 100 + strategy: tokens + size: 1200 # Optional, The chunk size to use, default: 1200 + overlap: 100 # Optional, The chunk overlap to use, default: 100 ``` ### sentence This strategy uses the nltk library to chunk a piece of text into sentences. The strategy config is as follows: ```yaml - strategy: - type: sentence + strategy: sentence ``` """ - output = input - if strategy is None: - strategy = {} - strategy_name = strategy.get("type", ChunkStrategyType.tokens) - strategy_config = {**strategy} - strategy_exec = load_strategy(strategy_name) - - num_total = _get_num_total(output, column) - tick = progress_ticker(callbacks.progress, num_total) + strategy_exec = load_strategy(strategy) - output[to] = output.apply( - cast( - "Any", - lambda x: run_strategy(strategy_exec, x[column], strategy_config, tick), + num_total = _get_num_total(input, column) + tick = progress_ticker(callbacks.progress, num_total) + # collapse the config back to a single object to support "polymorphic" function call + config = ChunkingConfig(size=size, overlap=overlap, encoding_model=encoding_model) + return cast( + "pd.Series", + input.apply( + cast( + "Any", + lambda x: run_strategy(strategy_exec, x[column], config, tick), + ), + axis=1, ), - axis=1, ) - return output def run_strategy( - strategy: ChunkStrategy, + strategy_exec: ChunkStrategy, input: ChunkInput, - strategy_args: dict[str, Any], + config: ChunkingConfig, tick: ProgressTicker, ) -> list[str | tuple[list[str] | None, str, int]]: """Run strategy method definition.""" if isinstance(input, str): - return [item.text_chunk for item in strategy([input], {**strategy_args}, tick)] + return [item.text_chunk for item in strategy_exec([input], config, tick)] # We can work with both just a list of text content # or a list of tuples of (document_id, text content) @@ -100,7 +95,7 @@ def run_strategy( else: texts.append(item[1]) - strategy_results = strategy(texts, {**strategy_args}, tick) + strategy_results = strategy_exec(texts, config, tick) results = [] for strategy_result in strategy_results: diff --git a/graphrag/index/operations/chunk_text/strategies.py b/graphrag/index/operations/chunk_text/strategies.py index 35c32585c0..1468028537 100644 --- a/graphrag/index/operations/chunk_text/strategies.py +++ b/graphrag/index/operations/chunk_text/strategies.py @@ -4,24 +4,23 @@ """A module containing chunk strategies.""" from collections.abc import Iterable -from typing import Any import nltk import tiktoken from datashaper import ProgressTicker -import graphrag.config.defaults as defs +from graphrag.config.models.chunking_config import ChunkingConfig from graphrag.index.operations.chunk_text.typing import TextChunk from graphrag.index.text_splitting.text_splitting import Tokenizer def run_tokens( - input: list[str], args: dict[str, Any], tick: ProgressTicker + input: list[str], config: ChunkingConfig, tick: ProgressTicker ) -> Iterable[TextChunk]: """Chunks text into chunks based on encoding tokens.""" - tokens_per_chunk = args.get("chunk_size", defs.CHUNK_SIZE) - chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) - encoding_name = args.get("encoding_name", defs.ENCODING_MODEL) + tokens_per_chunk = config.size + chunk_overlap = config.overlap + encoding_name = config.encoding_model enc = tiktoken.get_encoding(encoding_name) def encode(text: str) -> list[int]: @@ -83,7 +82,7 @@ def _split_text_on_tokens( def run_sentences( - input: list[str], _args: dict[str, Any], tick: ProgressTicker + input: list[str], _config: ChunkingConfig, tick: ProgressTicker ) -> Iterable[TextChunk]: """Chunks text into multiple parts by sentence.""" for doc_idx, text in enumerate(input): diff --git a/graphrag/index/operations/chunk_text/typing.py b/graphrag/index/operations/chunk_text/typing.py index ebfa4db963..5f0994ec05 100644 --- a/graphrag/index/operations/chunk_text/typing.py +++ b/graphrag/index/operations/chunk_text/typing.py @@ -5,11 +5,11 @@ from collections.abc import Callable, Iterable from dataclasses import dataclass -from enum import Enum -from typing import Any from datashaper import ProgressTicker +from graphrag.config.models.chunking_config import ChunkingConfig + @dataclass class TextChunk: @@ -24,16 +24,5 @@ class TextChunk: """Input to a chunking strategy. Can be a string, a list of strings, or a list of tuples of (id, text).""" ChunkStrategy = Callable[ - [list[str], dict[str, Any], ProgressTicker], Iterable[TextChunk] + [list[str], ChunkingConfig, ProgressTicker], Iterable[TextChunk] ] - - -class ChunkStrategyType(str, Enum): - """ChunkStrategy class definition.""" - - tokens = "tokens" - sentence = "sentence" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' diff --git a/graphrag/index/operations/cluster_graph.py b/graphrag/index/operations/cluster_graph.py index 07f5998f9f..28ee3507dc 100644 --- a/graphrag/index/operations/cluster_graph.py +++ b/graphrag/index/operations/cluster_graph.py @@ -4,8 +4,6 @@ """A module containing cluster_graph, apply_clustering and run_layout methods definition.""" import logging -from enum import Enum -from typing import Any import networkx as nx @@ -14,82 +12,44 @@ Communities = list[tuple[int, int, int, list[str]]] -class GraphCommunityStrategyType(str, Enum): - """GraphCommunityStrategyType class definition.""" - - leiden = "leiden" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - log = logging.getLogger(__name__) def cluster_graph( - input: nx.Graph, - strategy: dict[str, Any], + graph: nx.Graph, + max_cluster_size: int, + use_lcc: bool, + seed: int | None = None, ) -> Communities: """Apply a hierarchical clustering algorithm to a graph.""" - return run_layout(strategy, input) - - -def run_layout(strategy: dict[str, Any], graph: nx.Graph) -> Communities: - """Run layout method definition.""" if len(graph.nodes) == 0: log.warning("Graph has no nodes") return [] - clusters: dict[int, dict[int, list[str]]] = {} - strategy_type = strategy.get("type", GraphCommunityStrategyType.leiden) - match strategy_type: - case GraphCommunityStrategyType.leiden: - clusters, parent_mapping = run_leiden(graph, strategy) - case _: - msg = f"Unknown clustering strategy {strategy_type}" - raise ValueError(msg) - - results: Communities = [] - for level in clusters: - for cluster_id, nodes in clusters[level].items(): - results.append((level, cluster_id, parent_mapping[cluster_id], nodes)) - return results - - -def run_leiden( - graph: nx.Graph, args: dict[str, Any] -) -> tuple[dict[int, dict[int, list[str]]], dict[int, int]]: - """Run method definition.""" - max_cluster_size = args.get("max_cluster_size", 10) - use_lcc = args.get("use_lcc", True) - if args.get("verbose", False): - log.info( - "Running leiden with max_cluster_size=%s, lcc=%s", max_cluster_size, use_lcc - ) - - node_id_to_community_map, community_hierarchy_map = _compute_leiden_communities( + node_id_to_community_map, parent_mapping = _compute_leiden_communities( graph=graph, max_cluster_size=max_cluster_size, use_lcc=use_lcc, - seed=args.get("seed", 0xDEADBEEF), + seed=seed, ) - levels = args.get("levels") - # If they don't pass in levels, use them all - if levels is None: - levels = sorted(node_id_to_community_map.keys()) + levels = sorted(node_id_to_community_map.keys()) - results_by_level: dict[int, dict[int, list[str]]] = {} + clusters: dict[int, dict[int, list[str]]] = {} for level in levels: result = {} - results_by_level[level] = result + clusters[level] = result for node_id, raw_community_id in node_id_to_community_map[level].items(): community_id = raw_community_id if community_id not in result: result[community_id] = [] result[community_id].append(node_id) - return results_by_level, community_hierarchy_map + + results: Communities = [] + for level in clusters: + for cluster_id, nodes in clusters[level].items(): + results.append((level, cluster_id, parent_mapping[cluster_id], nodes)) + return results # Taken from graph_intelligence & adapted @@ -97,7 +57,7 @@ def _compute_leiden_communities( graph: nx.Graph | nx.DiGraph, max_cluster_size: int, use_lcc: bool, - seed=0xDEADBEEF, + seed: int | None = None, ) -> tuple[dict[int, dict[str, int]], dict[int, int]]: """Return Leiden root communities and their hierarchy mapping.""" # NOTE: This import is done here to reduce the initial import time of the graphrag package diff --git a/graphrag/index/operations/embed_graph/embed_graph.py b/graphrag/index/operations/embed_graph/embed_graph.py index 6e161db7dc..0328402db2 100644 --- a/graphrag/index/operations/embed_graph/embed_graph.py +++ b/graphrag/index/operations/embed_graph/embed_graph.py @@ -3,84 +3,45 @@ """A module containing embed_graph and run_embeddings methods definition.""" -from typing import Any - import networkx as nx +from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.index.operations.embed_graph.embed_node2vec import embed_node2vec from graphrag.index.operations.embed_graph.typing import ( - EmbedGraphStrategyType, NodeEmbeddings, ) -from graphrag.index.utils.load_graph import load_graph from graphrag.index.utils.stable_lcc import stable_largest_connected_component def embed_graph( graph: nx.Graph, - strategy: dict[str, Any], + config: EmbedGraphConfig, ) -> NodeEmbeddings: """ - Embed a graph into a vector space. The graph is expected to be in nx.Graph format. The operation outputs a mapping between node name and vector. + Embed a graph into a vector space using node2vec. The graph is expected to be in nx.Graph format. The operation outputs a mapping between node name and vector. ## Usage ```yaml - args: - strategy: # See strategies section below - ``` - - ## Strategies - The embed_graph operation uses a strategy to embed the graph. The strategy is an object which defines the strategy to use. The following strategies are available: - - ### node2vec - This strategy uses the node2vec algorithm to embed a graph. The strategy config is as follows: - - ```yaml - strategy: - type: node2vec - dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536 - num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10 - walk_length: 40 # Optional, The walk length to use for the embedding, default: 40 - window_size: 2 # Optional, The window size to use for the embedding, default: 2 - iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3 - random_seed: 86 # Optional, The random seed to use for the embedding, default: 86 + dimensions: 1536 # Optional, The number of dimensions to use for the embedding, default: 1536 + num_walks: 10 # Optional, The number of walks to use for the embedding, default: 10 + walk_length: 40 # Optional, The walk length to use for the embedding, default: 40 + window_size: 2 # Optional, The window size to use for the embedding, default: 2 + iterations: 3 # Optional, The number of iterations to use for the embedding, default: 3 + random_seed: 86 # Optional, The random seed to use for the embedding, default: 86 ``` """ - strategy_type = strategy.get("type", EmbedGraphStrategyType.node2vec) - strategy_args = {**strategy} - - return run_embeddings(strategy_type, graph, strategy_args) - - -def run_embeddings( - strategy: EmbedGraphStrategyType, - graphml_or_graph: str | nx.Graph, - args: dict[str, Any], -) -> NodeEmbeddings: - """Run embeddings method definition.""" - graph = load_graph(graphml_or_graph) - match strategy: - case EmbedGraphStrategyType.node2vec: - return run_node_2_vec(graph, args) - case _: - msg = f"Unknown strategy {strategy}" - raise ValueError(msg) - - -def run_node_2_vec(graph: nx.Graph, args: dict[str, Any]) -> NodeEmbeddings: - """Run method definition.""" - if args.get("use_lcc", True): + if config.use_lcc: graph = stable_largest_connected_component(graph) # create graph embedding using node2vec embeddings = embed_node2vec( graph=graph, - dimensions=args.get("dimensions", 1536), - num_walks=args.get("num_walks", 10), - walk_length=args.get("walk_length", 40), - window_size=args.get("window_size", 2), - iterations=args.get("iterations", 3), - random_seed=args.get("random_seed", 86), + dimensions=config.dimensions, + num_walks=config.num_walks, + walk_length=config.walk_length, + window_size=config.window_size, + iterations=config.iterations, + random_seed=config.random_seed, ) pairs = zip(embeddings.nodes, embeddings.embeddings.tolist(), strict=True) diff --git a/graphrag/index/operations/embed_graph/typing.py b/graphrag/index/operations/embed_graph/typing.py index 618806eaed..fea792c9b1 100644 --- a/graphrag/index/operations/embed_graph/typing.py +++ b/graphrag/index/operations/embed_graph/typing.py @@ -4,20 +4,8 @@ """A module containing different lists and dictionaries.""" # Use this for now instead of a wrapper -from enum import Enum from typing import Any - -class EmbedGraphStrategyType(str, Enum): - """EmbedGraphStrategyType class definition.""" - - node2vec = "node2vec" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - NodeList = list[str] EmbeddingList = list[Any] NodeEmbeddings = dict[str, list[float]] diff --git a/graphrag/index/operations/extract_covariates/claim_extractor.py b/graphrag/index/operations/extract_covariates/claim_extractor.py index 66162f8f12..e5fb6c3b40 100644 --- a/graphrag/index/operations/extract_covariates/claim_extractor.py +++ b/graphrag/index/operations/extract_covariates/claim_extractor.py @@ -87,7 +87,7 @@ def __init__( self._on_error = on_error or (lambda _e, _s, _d: None) # Construct the looping arguments - encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + encoding = tiktoken.get_encoding(encoding_model or defs.ENCODING_MODEL) yes = f"{encoding.encode('YES')[0]}" no = f"{encoding.encode('NO')[0]}" self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} diff --git a/graphrag/index/operations/extract_entities/extract_entities.py b/graphrag/index/operations/extract_entities/extract_entities.py index 3245c2481c..e3b7410d06 100644 --- a/graphrag/index/operations/extract_entities/extract_entities.py +++ b/graphrag/index/operations/extract_entities/extract_entities.py @@ -72,10 +72,7 @@ async def extract_entities( tuple_delimiter: "<|>" # Optional, the delimiter to use for the LLM to mark a tuple record_delimiter: "##" # Optional, the delimiter to use for the LLM to mark a record - prechunked: true | false # Optional, If the document is already chunked beforehand, otherwise this will chunk the document into smaller bits. default: false - encoding_name: cl100k_base # Optional, The encoding to use for the LLM, if not already prechunked, default: cl100k_base - chunk_size: 1000 # Optional ,The chunk size to use for the LLM, if not already prechunked, default: 1200 - chunk_overlap: 100 # Optional, The chunk overlap to use for the LLM, if not already prechunked, default: 100 + encoding_name: cl100k_base # Optional, The encoding to use for the LLM with gleanings llm: # The configuration for the LLM type: openai # the type of llm to use, available options are: openai, azure, openai_chat, azure_openai_chat. The last two being chat based LLMs. diff --git a/graphrag/index/operations/extract_entities/graph_extractor.py b/graphrag/index/operations/extract_entities/graph_extractor.py index 26df56a945..890a06d083 100644 --- a/graphrag/index/operations/extract_entities/graph_extractor.py +++ b/graphrag/index/operations/extract_entities/graph_extractor.py @@ -91,7 +91,7 @@ def __init__( self._on_error = on_error or (lambda _e, _s, _d: None) # Construct the looping arguments - encoding = tiktoken.get_encoding(encoding_model or "cl100k_base") + encoding = tiktoken.get_encoding(encoding_model or defs.ENCODING_MODEL) yes = f"{encoding.encode('YES')[0]}" no = f"{encoding.encode('NO')[0]}" self._loop_args = {"logit_bias": {yes: 100, no: 100}, "max_tokens": 1} diff --git a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py index a91e0748d5..9084321621 100644 --- a/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py +++ b/graphrag/index/operations/extract_entities/graph_intelligence_strategy.py @@ -17,11 +17,6 @@ EntityTypes, StrategyConfig, ) -from graphrag.index.text_splitting.text_splitting import ( - NoopTextSplitter, - TextSplitter, - TokenTextSplitter, -) async def run_graph_intelligence( @@ -45,14 +40,6 @@ async def run_extract_entities( args: StrategyConfig, ) -> EntityExtractionResult: """Run the entity extraction chain.""" - encoding_name = args.get("encoding_name", "cl100k_base") - - # Chunking Arguments - prechunked = args.get("prechunked", False) - chunk_size = args.get("chunk_size", defs.CHUNK_SIZE) - chunk_overlap = args.get("chunk_overlap", defs.CHUNK_OVERLAP) - - # Extraction Arguments tuple_delimiter = args.get("tuple_delimiter", None) record_delimiter = args.get("record_delimiter", None) completion_delimiter = args.get("completion_delimiter", None) @@ -60,12 +47,6 @@ async def run_extract_entities( encoding_model = args.get("encoding_name", None) max_gleanings = args.get("max_gleanings", defs.ENTITY_EXTRACTION_MAX_GLEANINGS) - # note: We're not using UnipartiteGraphChain.from_params - # because we want to pass "timeout" to the llm_kwargs - text_splitter = _create_text_splitter( - prechunked, chunk_size, chunk_overlap, encoding_name - ) - extractor = GraphExtractor( llm_invoker=llm, prompt=extraction_prompt, @@ -77,10 +58,6 @@ async def run_extract_entities( ) text_list = [doc.text.strip() for doc in docs] - # If it's not pre-chunked, then re-chunk the input - if not prechunked: - text_list = text_splitter.split_text("\n".join(text_list)) - results = await extractor( list(text_list), { @@ -114,26 +91,3 @@ async def run_extract_entities( relationships = nx.to_pandas_edgelist(graph) return EntityExtractionResult(entities, relationships, graph) - - -def _create_text_splitter( - prechunked: bool, chunk_size: int, chunk_overlap: int, encoding_name: str -) -> TextSplitter: - """Create a text splitter for the extraction chain. - - Args: - - prechunked - Whether the text is already chunked - - chunk_size - The size of each chunk - - chunk_overlap - The overlap between chunks - - encoding_name - The name of the encoding to use - Returns: - - output - A text splitter - """ - if prechunked: - return NoopTextSplitter() - - return TokenTextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - encoding_name=encoding_name, - ) diff --git a/graphrag/index/operations/layout_graph/layout_graph.py b/graphrag/index/operations/layout_graph/layout_graph.py index 756fb4ff24..a4c7471292 100644 --- a/graphrag/index/operations/layout_graph/layout_graph.py +++ b/graphrag/index/operations/layout_graph/layout_graph.py @@ -3,9 +3,6 @@ """A module containing layout_graph, _run_layout and _apply_layout_to_graph methods definition.""" -from enum import Enum -from typing import Any - import networkx as nx import pandas as pd from datashaper import VerbCallbacks @@ -14,21 +11,10 @@ from graphrag.index.operations.layout_graph.typing import GraphLayout -class LayoutGraphStrategyType(str, Enum): - """LayoutGraphStrategyType class definition.""" - - umap = "umap" - zero = "zero" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - def layout_graph( graph: nx.Graph, callbacks: VerbCallbacks, - strategy: dict[str, Any], + enabled: bool, embeddings: NodeEmbeddings | None, ): """ @@ -54,14 +40,10 @@ def layout_graph( min_dist: 0.75 # Optional, The min distance to use for the umap algorithm, default: 0.75 ``` """ - strategy_type = strategy.get("type", LayoutGraphStrategyType.umap) - strategy_args = {**strategy} - layout = _run_layout( - strategy_type, graph, + enabled, embeddings if embeddings is not None else {}, - strategy_args, callbacks, ) @@ -73,34 +55,26 @@ def layout_graph( def _run_layout( - strategy: LayoutGraphStrategyType, graph: nx.Graph, + enabled: bool, embeddings: NodeEmbeddings, - args: dict[str, Any], callbacks: VerbCallbacks, ) -> GraphLayout: - match strategy: - case LayoutGraphStrategyType.umap: - from graphrag.index.operations.layout_graph.umap import ( - run as run_umap, - ) - - return run_umap( - graph, - embeddings, - args, - lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d), - ) - case LayoutGraphStrategyType.zero: - from graphrag.index.operations.layout_graph.zero import ( - run as run_zero, - ) + if enabled: + from graphrag.index.operations.layout_graph.umap import ( + run as run_umap, + ) + + return run_umap( + graph, + embeddings, + lambda e, stack, d: callbacks.error("Error in Umap", e, stack, d), + ) + from graphrag.index.operations.layout_graph.zero import ( + run as run_zero, + ) - return run_zero( - graph, - args, - lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d), - ) - case _: - msg = f"Unknown strategy {strategy}" - raise ValueError(msg) + return run_zero( + graph, + lambda e, stack, d: callbacks.error("Error in Zero", e, stack, d), + ) diff --git a/graphrag/index/operations/layout_graph/umap.py b/graphrag/index/operations/layout_graph/umap.py index e5ab1668ca..ffe25c6b55 100644 --- a/graphrag/index/operations/layout_graph/umap.py +++ b/graphrag/index/operations/layout_graph/umap.py @@ -5,7 +5,6 @@ import logging import traceback -from typing import Any import networkx as nx import numpy as np @@ -27,7 +26,6 @@ def run( graph: nx.Graph, embeddings: NodeEmbeddings, - args: dict[str, Any], on_error: ErrorHandlerFn, ) -> GraphLayout: """Run method definition.""" @@ -56,8 +54,6 @@ def run( embedding_vectors=np.array(embedding_vectors), node_labels=nodes, **additional_args, - min_dist=args.get("min_dist", 0.75), - n_neighbors=args.get("n_neighbors", 5), ) except Exception as e: log.exception("Error running UMAP") @@ -87,7 +83,7 @@ def compute_umap_positions( node_categories: list[int] | None = None, node_sizes: list[int] | None = None, min_dist: float = 0.75, - n_neighbors: int = 25, + n_neighbors: int = 5, spread: int = 1, metric: str = "euclidean", n_components: int = 2, diff --git a/graphrag/index/operations/layout_graph/zero.py b/graphrag/index/operations/layout_graph/zero.py index 4bb7d39b00..519fa59ee3 100644 --- a/graphrag/index/operations/layout_graph/zero.py +++ b/graphrag/index/operations/layout_graph/zero.py @@ -5,7 +5,6 @@ import logging import traceback -from typing import Any import networkx as nx @@ -24,7 +23,6 @@ def run( graph: nx.Graph, - _args: dict[str, Any], on_error: ErrorHandlerFn, ) -> GraphLayout: """Run method definition.""" diff --git a/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py b/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py index 7fa0b684fd..ae550b520c 100644 --- a/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py +++ b/graphrag/index/operations/summarize_communities/community_reports_extractor/community_reports_extractor.py @@ -35,10 +35,6 @@ class CommunityReportResponse(BaseModel): rating: float = Field(description="The rating of the report.") rating_explanation: str = Field(description="An explanation of the rating.") - extra_attributes: dict[str, Any] = Field( - default_factory=dict, description="Extra attributes." - ) - @dataclass class CommunityReportsResult: diff --git a/graphrag/index/operations/summarize_communities/strategies.py b/graphrag/index/operations/summarize_communities/strategies.py index 43adb07148..9003e777bf 100644 --- a/graphrag/index/operations/summarize_communities/strategies.py +++ b/graphrag/index/operations/summarize_communities/strategies.py @@ -21,8 +21,6 @@ ) from graphrag.index.utils.rate_limiter import RateLimiter -DEFAULT_CHUNK_SIZE = 3000 - log = logging.getLogger(__name__) diff --git a/graphrag/index/text_splitting/text_splitting.py b/graphrag/index/text_splitting/text_splitting.py index 32b430bb69..2f6201cab7 100644 --- a/graphrag/index/text_splitting/text_splitting.py +++ b/graphrag/index/text_splitting/text_splitting.py @@ -14,6 +14,7 @@ import pandas as pd import tiktoken +import graphrag.config.defaults as defs from graphrag.index.utils.tokens import num_tokens_from_string EncodedText = list[int] @@ -88,7 +89,7 @@ class TokenTextSplitter(TextSplitter): def __init__( self, - encoding_name: str = "cl100k_base", + encoding_name: str = defs.ENCODING_MODEL, model_name: str | None = None, allowed_special: Literal["all"] | set[str] | None = None, disallowed_special: Literal["all"] | Collection[str] = "all", diff --git a/graphrag/index/utils/tokens.py b/graphrag/index/utils/tokens.py index 4a189b9b22..fcd996840f 100644 --- a/graphrag/index/utils/tokens.py +++ b/graphrag/index/utils/tokens.py @@ -7,7 +7,10 @@ import tiktoken -DEFAULT_ENCODING_NAME = "cl100k_base" +import graphrag.config.defaults as defs + +DEFAULT_ENCODING_NAME = defs.ENCODING_MODEL + log = logging.getLogger(__name__) diff --git a/graphrag/index/workflows/v1/compute_communities.py b/graphrag/index/workflows/v1/compute_communities.py index 543a6ece6a..3e70725c32 100644 --- a/graphrag/index/workflows/v1/compute_communities.py +++ b/graphrag/index/workflows/v1/compute_communities.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from typing import Any, cast +from typing import TYPE_CHECKING, cast import pandas as pd from datashaper import ( @@ -17,6 +17,9 @@ from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage +if TYPE_CHECKING: + from graphrag.config.models.cluster_graph_config import ClusterGraphConfig + workflow_name = "compute_communities" @@ -29,11 +32,10 @@ def build_steps( ## Dependencies * `workflow:extract_graph` """ - clustering_config = config.get( - "cluster_graph", - {"strategy": {"type": "leiden"}}, - ) - clustering_strategy = clustering_config.get("strategy") + clustering_config = cast("ClusterGraphConfig", config.get("cluster_graph")) + max_cluster_size = clustering_config.max_cluster_size + use_lcc = clustering_config.use_lcc + seed = clustering_config.seed snapshot_transient = config.get("snapshot_transient", False) or False @@ -41,7 +43,9 @@ def build_steps( { "verb": workflow_name, "args": { - "clustering_strategy": clustering_strategy, + "max_cluster_size": max_cluster_size, + "use_lcc": use_lcc, + "seed": seed, "snapshot_transient_enabled": snapshot_transient, }, "input": ({"source": "workflow:extract_graph"}), @@ -56,7 +60,9 @@ def build_steps( async def workflow( storage: PipelineStorage, runtime_storage: PipelineStorage, - clustering_strategy: dict[str, Any], + max_cluster_size: int, + use_lcc: bool, + seed: int | None, snapshot_transient_enabled: bool = False, **_kwargs: dict, ) -> VerbResult: @@ -65,7 +71,9 @@ async def workflow( base_communities = compute_communities( base_relationship_edges, - clustering_strategy=clustering_strategy, + max_cluster_size=max_cluster_size, + use_lcc=use_lcc, + seed=seed, ) await runtime_storage.set("base_communities", base_communities) diff --git a/graphrag/index/workflows/v1/create_base_text_units.py b/graphrag/index/workflows/v1/create_base_text_units.py index d045f0d6e4..84f5366df8 100644 --- a/graphrag/index/workflows/v1/create_base_text_units.py +++ b/graphrag/index/workflows/v1/create_base_text_units.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from typing import Any, cast +from typing import TYPE_CHECKING, cast import pandas as pd from datashaper import ( @@ -15,6 +15,7 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.config.models.chunking_config import ChunkStrategyType from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep from graphrag.index.flows.create_base_text_units import ( create_base_text_units, @@ -22,6 +23,9 @@ from graphrag.index.operations.snapshot import snapshot from graphrag.storage.pipeline_storage import PipelineStorage +if TYPE_CHECKING: + from graphrag.config.models.chunking_config import ChunkingConfig + workflow_name = "create_base_text_units" @@ -34,17 +38,23 @@ def build_steps( ## Dependencies (input dataframe) """ - chunk_by_columns = config.get("chunk_by", []) or [] - text_chunk_config = config.get("text_chunk", {}) - chunk_strategy = text_chunk_config.get("strategy") + chunks = cast("ChunkingConfig", config.get("chunks")) + group_by_columns = chunks.group_by_columns + size = chunks.size + overlap = chunks.overlap + encoding_model = chunks.encoding_model + strategy = chunks.strategy snapshot_transient = config.get("snapshot_transient", False) or False return [ { "verb": workflow_name, "args": { - "chunk_by_columns": chunk_by_columns, - "chunk_strategy": chunk_strategy, + "group_by_columns": group_by_columns, + "size": size, + "overlap": overlap, + "encoding_model": encoding_model, + "strategy": strategy, "snapshot_transient_enabled": snapshot_transient, }, "input": {"source": DEFAULT_INPUT_NAME}, @@ -58,8 +68,11 @@ async def workflow( callbacks: VerbCallbacks, storage: PipelineStorage, runtime_storage: PipelineStorage, - chunk_by_columns: list[str], - chunk_strategy: dict[str, Any] | None = None, + group_by_columns: list[str], + size: int, + overlap: int, + encoding_model: str, + strategy: ChunkStrategyType, snapshot_transient_enabled: bool = False, **_kwargs: dict, ) -> VerbResult: @@ -69,8 +82,11 @@ async def workflow( output = create_base_text_units( source, callbacks, - chunk_by_columns, - chunk_strategy=chunk_strategy, + group_by_columns, + size, + overlap, + encoding_model, + strategy=strategy, ) await runtime_storage.set("base_text_units", output) diff --git a/graphrag/index/workflows/v1/create_final_nodes.py b/graphrag/index/workflows/v1/create_final_nodes.py index 60b4cfe17e..bdbfab084e 100644 --- a/graphrag/index/workflows/v1/create_final_nodes.py +++ b/graphrag/index/workflows/v1/create_final_nodes.py @@ -3,7 +3,7 @@ """A module containing build_steps method definition.""" -from typing import Any, cast +from typing import cast from datashaper import ( Table, @@ -12,6 +12,7 @@ ) from datashaper.table_store.types import VerbResult, create_verb_result +from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.index.config.workflow import PipelineWorkflowConfig, PipelineWorkflowStep from graphrag.index.flows.create_final_nodes import ( create_final_nodes, @@ -30,42 +31,13 @@ def build_steps( ## Dependencies * `workflow:extract_graph` """ - layout_graph_enabled = config.get("layout_graph_enabled", True) - layout_graph_config = config.get( - "layout_graph", - { - "strategy": { - "type": "umap" if layout_graph_enabled else "zero", - }, - }, - ) - layout_strategy = layout_graph_config.get("strategy") - - embed_graph_config = config.get( - "embed_graph", - { - "strategy": { - "type": "node2vec", - "num_walks": config.get("embed_num_walks", 10), - "walk_length": config.get("embed_walk_length", 40), - "window_size": config.get("embed_window_size", 2), - "iterations": config.get("embed_iterations", 3), - "random_seed": config.get("embed_random_seed", 86), - } - }, - ) - embedding_strategy = embed_graph_config.get("strategy") - embed_graph_enabled = config.get("embed_graph_enabled", False) or False + layout_enabled = config["layout_enabled"] + embed_config = cast("EmbedGraphConfig", config["embed_graph"]) return [ { "verb": workflow_name, - "args": { - "layout_strategy": layout_strategy, - "embedding_strategy": embedding_strategy - if embed_graph_enabled - else None, - }, + "args": {"layout_enabled": layout_enabled, "embed_config": embed_config}, "input": { "source": "workflow:extract_graph", "communities": "workflow:compute_communities", @@ -78,8 +50,8 @@ def build_steps( async def workflow( callbacks: VerbCallbacks, runtime_storage: PipelineStorage, - layout_strategy: dict[str, Any], - embedding_strategy: dict[str, Any] | None = None, + embed_config: EmbedGraphConfig, + layout_enabled: bool, **_kwargs: dict, ) -> VerbResult: """All the steps to transform final nodes.""" @@ -92,8 +64,8 @@ async def workflow( base_relationship_edges, base_communities, callbacks, - layout_strategy, - embedding_strategy=embedding_strategy, + embed_config=embed_config, + layout_enabled=layout_enabled, ) return create_verb_result( diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index 0a1bdbe7ae..e2b6c49b56 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -14,7 +14,7 @@ from graphrag.config.models.llm_parameters import LLMParameters from graphrag.index.input.factory import create_input from graphrag.index.llm.load_llm import load_llm_embeddings -from graphrag.index.operations.chunk_text import chunk_text +from graphrag.index.operations.chunk_text.chunk_text import chunk_text from graphrag.logger.base import ProgressLogger from graphrag.prompt_tune.defaults import ( MIN_CHUNK_OVERLAP, @@ -67,22 +67,21 @@ async def load_docs_in_chunks( dataset = await create_input(config.input, logger, root) # covert to text units - chunk_strategy = config.chunks.resolved_strategy(defs.ENCODING_MODEL) + chunk_config = config.chunks # Use smaller chunks, to avoid huge prompts - chunk_strategy["chunk_size"] = chunk_size - chunk_strategy["chunk_overlap"] = MIN_CHUNK_OVERLAP - - dataset_chunks = chunk_text( + dataset["chunks"] = chunk_text( dataset, column="text", - to="chunks", + size=chunk_size, + overlap=MIN_CHUNK_OVERLAP, + encoding_model=defs.ENCODING_MODEL, + strategy=chunk_config.strategy, callbacks=NoopVerbCallbacks(), - strategy=chunk_strategy, ) # Select chunks into a new df and explode it - chunks_df = pd.DataFrame(dataset_chunks["chunks"].explode()) # type: ignore + chunks_df = pd.DataFrame(dataset["chunks"].explode()) # type: ignore # Depending on the select method, build the dataset if limit <= 0 or limit > len(chunks_df): diff --git a/graphrag/query/llm/oai/embedding.py b/graphrag/query/llm/oai/embedding.py index b8a4dfcd38..b54c97b3a4 100644 --- a/graphrag/query/llm/oai/embedding.py +++ b/graphrag/query/llm/oai/embedding.py @@ -18,6 +18,7 @@ wait_exponential_jitter, ) +import graphrag.config.defaults as defs from graphrag.logger.base import StatusLogger from graphrag.query.llm.base import BaseTextEmbedding from graphrag.query.llm.oai.base import OpenAILLMImpl @@ -41,7 +42,7 @@ def __init__( api_version: str | None = None, api_type: OpenaiApiType = OpenaiApiType.OpenAI, organization: str | None = None, - encoding_name: str = "cl100k_base", + encoding_name: str = defs.ENCODING_MODEL, max_tokens: int = 8191, max_retries: int = 10, request_timeout: float = 180.0, diff --git a/graphrag/query/llm/text_utils.py b/graphrag/query/llm/text_utils.py index 041c9e9572..71740fafdb 100644 --- a/graphrag/query/llm/text_utils.py +++ b/graphrag/query/llm/text_utils.py @@ -12,13 +12,15 @@ import tiktoken from json_repair import repair_json +import graphrag.config.defaults as defs + log = logging.getLogger(__name__) def num_tokens(text: str, token_encoder: tiktoken.Encoding | None = None) -> int: """Return the number of tokens in the given text.""" if token_encoder is None: - token_encoder = tiktoken.get_encoding("cl100k_base") + token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL) return len(token_encoder.encode(text)) # type: ignore @@ -42,7 +44,7 @@ def chunk_text( ): """Chunk text by token length.""" if token_encoder is None: - token_encoder = tiktoken.get_encoding("cl100k_base") + token_encoder = tiktoken.get_encoding(defs.ENCODING_MODEL) tokens = token_encoder.encode(text) # type: ignore chunk_iterator = batched(iter(tokens), max_tokens) yield from (token_encoder.decode(list(chunk)) for chunk in chunk_iterator) diff --git a/tests/integration/_pipeline/__init__.py b/tests/integration/_pipeline/__init__.py deleted file mode 100644 index 0a3e38adfb..0000000000 --- a/tests/integration/_pipeline/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License diff --git a/tests/integration/_pipeline/megapipeline.yml b/tests/integration/_pipeline/megapipeline.yml deleted file mode 100644 index 363f8b37e4..0000000000 --- a/tests/integration/_pipeline/megapipeline.yml +++ /dev/null @@ -1,81 +0,0 @@ -input: - file_type: text - base_dir: ../../fixtures/min-csv - file_pattern: .*\.txt$ - -storage: - type: memory - -cache: - type: memory - -workflows: - - name: create_base_text_units - config: - text_chunk: - strategy: - type: sentence - - # Just lump everything together - chunk_by: [] - - - name: extract_graph - config: - snapshot_graphml_enabled: True - entity_extract: - strategy: - type: graph_intelligence - llm: - type: static_response - responses: - - '("entity"<|>COMPANY_A<|>COMPANY<|>Company_A is a test company) - ## - ("entity"<|>COMPANY_B<|>COMPANY<|>Company_B owns Company_A and also shares an address with Company_A) - ## - ("entity"<|>PERSON_C<|>PERSON<|>Person_C is director of Company_A) - ## - ("relationship"<|>COMPANY_A<|>COMPANY_B<|>Company_A and Company_B are related because Company_A is 100% owned by Company_B and the two companies also share the same address)<|>2) - ## - ("relationship"<|>COMPANY_A<|>PERSON_C<|>Company_A and Person_C are related because Person_C is director of Company_A<|>1))' - summarize_descriptions: - strategy: - type: graph_intelligence - llm: - type: static_response - responses: - - This is a MOCK response for the LLM. It is summarized! - - - name: compute_communities - config: - cluster_graph: - strategy: - type: leiden - verbose: True - - - name: create_final_nodes - config: - embed_graph_enabled: True - - - name: create_final_communities - - name: create_final_text_units - config: - text_embed: - strategy: - type: mock - - - name: create_final_entities - config: - text_embed: - strategy: - type: mock - - - name: create_final_documents - config: - text_embed: - strategy: - type: mock - - name: create_final_relationships - config: - text_embed: - strategy: - type: mock diff --git a/tests/integration/_pipeline/test_run.py b/tests/integration/_pipeline/test_run.py deleted file mode 100644 index dad314daeb..0000000000 --- a/tests/integration/_pipeline/test_run.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License -import logging -import os -import unittest - -from graphrag.index.run import run_pipeline_with_config -from graphrag.index.typing import PipelineRunResult - -log = logging.getLogger(__name__) - - -class TestRun(unittest.IsolatedAsyncioTestCase): - async def test_megapipeline(self): - pipeline_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "./megapipeline.yml", - ) - pipeline_result = [gen async for gen in run_pipeline_with_config(pipeline_path)] - - errors = [] - for result in pipeline_result: - if result.errors is not None and len(result.errors) > 0: - errors.extend(result.errors) - - if len(errors) > 0: - print("Errors: ", errors) - assert len(errors) == 0, "received errors\n!" + "\n".join(errors) - - self._assert_text_units_and_entities_reference_each_other(pipeline_result) - - def _assert_text_units_and_entities_reference_each_other( - self, pipeline_result: list[PipelineRunResult] - ): - text_unit_df = next( - filter(lambda x: x.workflow == "create_final_text_units", pipeline_result) - ).result - entity_df = next( - filter(lambda x: x.workflow == "create_final_entities", pipeline_result) - ).result - - assert text_unit_df is not None, "Text unit dataframe should not be None" - assert entity_df is not None, "Entity dataframe should not be None" - - # Get around typing issues - if text_unit_df is None or entity_df is None: - return - - assert len(text_unit_df) > 0, "Text unit dataframe should not be empty" - assert len(entity_df) > 0, "Entity dataframe should not be empty" - - text_unit_entity_map = {} - log.info("text_unit_df %s", text_unit_df.columns) - - for _, row in text_unit_df.iterrows(): - values = row.get("entity_ids", []) - text_unit_entity_map[row["id"]] = set([] if values is None else values) - - entity_text_unit_map = {} - for _, row in entity_df.iterrows(): - # ALL entities should have text units - values = row.get("text_unit_ids", []) - entity_text_unit_map[row["id"]] = set([] if values is None else values) - - text_unit_ids = set(text_unit_entity_map.keys()) - entity_ids = set(entity_text_unit_map.keys()) - - for text_unit_id, text_unit_entities in text_unit_entity_map.items(): - assert text_unit_entities.issubset(entity_ids), ( - f"Text unit {text_unit_id} has entities {text_unit_entities} that are not in the entity set" - ) - for entity_id, entity_text_units in entity_text_unit_map.items(): - assert entity_text_units.issubset(text_unit_ids), ( - f"Entity {entity_id} has text units {entity_text_units} that are not in the text unit set" - ) diff --git a/tests/unit/indexing/config/test_load.py b/tests/unit/indexing/config/test_load.py index 636525b320..c458081ced 100644 --- a/tests/unit/indexing/config/test_load.py +++ b/tests/unit/indexing/config/test_load.py @@ -35,7 +35,9 @@ def test_loading_default_config_with_input_overridden(self): # Check that the config is merged # but skip checking the input - self.assert_is_default_config(config, check_input=False) + self.assert_is_default_config( + config, check_input=False, ignore_workflows=["create_base_text_units"] + ) if config.input is None: msg = "Input should not be none" @@ -72,7 +74,10 @@ def assert_is_default_config( check_reporting=True, check_cache=True, check_workflows=True, + ignore_workflows=None, ): + if ignore_workflows is None: + ignore_workflows = [] assert config is not None assert isinstance(config, PipelineConfig) @@ -111,7 +116,14 @@ def assert_is_default_config( checked_config.pop(prop, None) actual_default_config.pop(prop, None) - assert actual_default_config == actual_default_config | checked_config + for prop in actual_default_config: + if prop == "workflows": + assert len(checked_config[prop]) == len(actual_default_config[prop]) + for i, workflow in enumerate(actual_default_config[prop]): + if workflow["name"] not in ignore_workflows: + assert workflow == actual_default_config[prop][i] + else: + assert checked_config[prop] == actual_default_config[prop] def setUp(self) -> None: os.environ["GRAPHRAG_OPENAI_API_KEY"] = "test" diff --git a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py index 0403bee4d1..366fcc15de 100644 --- a/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py +++ b/tests/unit/indexing/verbs/entities/extraction/strategies/graph_intelligence/test_gi_entity_extraction.py @@ -18,7 +18,6 @@ async def test_run_extract_entities_single_document_correct_entities_returned(se entity_types=["person"], callbacks=None, args={ - "prechunked": True, "max_gleanings": 0, "summarize_descriptions": False, }, @@ -53,7 +52,6 @@ async def test_run_extract_entities_multiple_documents_correct_entities_returned entity_types=["person"], callbacks=None, args={ - "prechunked": True, "max_gleanings": 0, "summarize_descriptions": False, }, @@ -90,7 +88,6 @@ async def test_run_extract_entities_multiple_documents_correct_edges_returned(se entity_types=["person"], callbacks=None, args={ - "prechunked": True, "max_gleanings": 0, "summarize_descriptions": False, }, @@ -135,7 +132,6 @@ async def test_run_extract_entities_multiple_documents_correct_entity_source_ids entity_types=["person"], callbacks=None, args={ - "prechunked": True, "max_gleanings": 0, "summarize_descriptions": False, }, @@ -185,7 +181,6 @@ async def test_run_extract_entities_multiple_documents_correct_edge_source_ids_m entity_types=["person"], callbacks=None, args={ - "prechunked": True, "max_gleanings": 0, "summarize_descriptions": False, }, diff --git a/tests/verbs/test_compute_communities.py b/tests/verbs/test_compute_communities.py index 5c91fc46b7..1b23ef97b9 100644 --- a/tests/verbs/test_compute_communities.py +++ b/tests/verbs/test_compute_communities.py @@ -20,9 +20,14 @@ def test_compute_communities(): expected = load_test_table("base_communities") config = get_config_for_workflow(workflow_name) - clustering_strategy = config["cluster_graph"]["strategy"] + cluster_config = config["cluster_graph"] - actual = compute_communities(edges, clustering_strategy=clustering_strategy) + actual = compute_communities( + edges, + cluster_config.max_cluster_size, + cluster_config.use_lcc, + cluster_config.seed, + ) columns = list(expected.columns.values) compare_outputs(actual, expected, columns) diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 1485bedc83..cf1d267aa3 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -24,7 +24,7 @@ async def test_create_base_text_units(): config = get_config_for_workflow(workflow_name) # test data was created with 4o, so we need to match the encoding for chunks to be identical - config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base" + config["chunks"].encoding_model = "o200k_base" steps = build_steps(config) @@ -47,7 +47,7 @@ async def test_create_base_text_units_with_snapshot(): config = get_config_for_workflow(workflow_name) # test data was created with 4o, so we need to match the encoding for chunks to be identical - config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base" + config["chunks"].encoding_model = "o200k_base" config["snapshot_transient"] = True steps = build_steps(config) diff --git a/tests/verbs/test_create_final_nodes.py b/tests/verbs/test_create_final_nodes.py index ff9b1b8c45..db3b6ec57f 100644 --- a/tests/verbs/test_create_final_nodes.py +++ b/tests/verbs/test_create_final_nodes.py @@ -3,6 +3,7 @@ from datashaper import NoopVerbCallbacks +from graphrag.config.models.embed_graph_config import EmbedGraphConfig from graphrag.index.flows.create_final_nodes import ( create_final_nodes, ) @@ -23,13 +24,14 @@ def test_create_final_nodes(): expected = load_test_table(workflow_name) + embed_config = EmbedGraphConfig(enabled=False) actual = create_final_nodes( base_entity_nodes=base_entity_nodes, base_relationship_edges=base_relationship_edges, base_communities=base_communities, callbacks=NoopVerbCallbacks(), - layout_strategy={"type": "zero"}, - embedding_strategy=None, + embed_config=embed_config, + layout_enabled=False, ) assert "id" in expected.columns