diff --git a/graphrag/index/__main__.py b/graphrag/index/__main__.py index 9a652ae07a..a501f829df 100644 --- a/graphrag/index/__main__.py +++ b/graphrag/index/__main__.py @@ -53,13 +53,13 @@ type=str, ) parser.add_argument( - "--context-id", + "--context_id", required=False, help="Context id to activate or deactivate.", type=str ) parser.add_argument( - "--context-operation", + "--context_operation", help="Context operation activate or deactivate.", required=False, # Only required if contextId is provided @@ -77,7 +77,7 @@ action="store_true", ) parser.add_argument( - "--overlay-defaults", + "--overlay_defaults", help="Overlay default configuration values on a provided configuration file (--config).", action="store_true", ) @@ -87,6 +87,11 @@ type=int, default=2, ) + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are loaded into Kusto during activation", + action="store_true", + ) args = parser.parse_args() @@ -108,5 +113,6 @@ cli=True, context_id=args.context_id, context_operation=args.context_operation, - community_level=args.community_level + community_level=args.community_level, + use_kusto_community_reports=args.use_kusto_community_reports ) diff --git a/graphrag/index/cli.py b/graphrag/index/cli.py index 1e33e61e4c..8fd3a80dd6 100644 --- a/graphrag/index/cli.py +++ b/graphrag/index/cli.py @@ -86,6 +86,7 @@ def index_cli( dryrun: bool, overlay_defaults: bool, cli: bool = False, + use_kusto_community_reports: bool = False, ): """Run the pipeline with the given config.""" run_id = resume or time.strftime("%Y%m%d-%H%M%S") @@ -107,8 +108,15 @@ def index_cli( ValueError("ContextId is invalid: It should be a valid Guid") if (context_operation != ContextSwitchType.Activate and context_operation != ContextSwitchType.Deactivate): ValueError("ContextOperation is invalid: It should be Active or DeActive") - #graphrag_config = _read_config_parameters(root, config, progress_reporter) - _switch_context(config,root,context_operation,context_id,progress_reporter,community_level) + _switch_context( + config, + root, + context_operation, + context_id, + progress_reporter, + community_level, + use_kusto_community_reports, + ) sys.exit(0) cache = NoopPipelineCache() if nocache else None pipeline_emit = emit.split(",") if emit else None @@ -185,11 +193,12 @@ async def execute(): sys.exit(1 if encountered_errors else 0) def _switch_context(config: GraphRagConfig | str, root: str , context_operation: str | None, - context_id: str, reporter: ProgressReporter,community_level: int) -> None: + context_id: str, reporter: ProgressReporter,community_level: int, + use_kusto_community_reports: bool) -> None: """Switch the context to the given context.""" reporter.info(f"Switching context to {context_id} using operation {context_operation}") from graphrag.index.context_switch.contextSwitcher import ContextSwitcher - context_switcher = ContextSwitcher(root,config,reporter,context_id,community_level) + context_switcher = ContextSwitcher(root,config, reporter,context_id,community_level,use_kusto_community_reports) if context_operation == ContextSwitchType.Activate: context_switcher.activate() elif context_operation == ContextSwitchType.Deactivate: diff --git a/graphrag/index/context_switch/contextSwitcher.py b/graphrag/index/context_switch/contextSwitcher.py index ccf1c0f90b..4b902d42d2 100644 --- a/graphrag/index/context_switch/contextSwitcher.py +++ b/graphrag/index/context_switch/contextSwitcher.py @@ -1,31 +1,29 @@ -from graphrag.common.progress import ProgressReporter -from graphrag.config import GraphRagConfig -from graphrag.config.enums import StorageType -from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage -from graphrag.common.utils.context_utils import get_files_by_contextid -import pandas as pd -from typing import cast -from azure.core.exceptions import ResourceNotFoundError import asyncio +import os from io import BytesIO from pathlib import Path +from typing import cast + +import pandas as pd + +from common.graph_db_client import GraphDBClient +from graphrag.common.progress import ProgressReporter +from graphrag.common.storage import ( + BlobPipelineStorage, + FilePipelineStorage, + PipelineStorage, +) +from graphrag.common.utils.context_utils import get_files_by_contextid from graphrag.config import ( - create_graphrag_config, GraphRagConfig, + create_graphrag_config, ) -from common.graph_db_client import GraphDBClient -import os -from graphrag.vector_stores import VectorStoreFactory, VectorStoreType -from graphrag.vector_stores.base import BaseVectorStore -from graphrag.vector_stores.lancedb import LanceDBVectorStore -from graphrag.vector_stores.kusto import KustoVectorStore +from graphrag.config.enums import StorageType +from graphrag.model.community_report import CommunityReport +from graphrag.model.entity import Entity from graphrag.query.indexer_adapters import ( - read_indexer_covariates, read_indexer_entities, - read_indexer_relationships, read_indexer_reports, - kt_read_indexer_reports, - read_indexer_text_units, ) from graphrag.model.entity import Entity from azure.cosmos import CosmosClient, PartitionKey @@ -36,7 +34,8 @@ class ContextSwitcher: def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, context_id:str, community_level:int , data_dir: str = None, - optimized_search: bool= False): + optimized_search: bool= False, + use_kusto_community_reports: bool = False,): self.root_dir=root_dir self.config_dir=config_dir @@ -45,11 +44,13 @@ def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter, self.context_id=context_id self.optimized_search=optimized_search self.community_level = community_level + self.use_kusto_community_reports = use_kusto_community_reports def set_ctx_activation( self, activate: int, entities: list[Entity]=[], + reports: list[CommunityReport]=[], config_args: dict | None = None, ): if not config_args: @@ -65,6 +66,7 @@ def set_ctx_activation( "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{self.context_id}"}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=VectorStoreType.Kusto, kwargs=config_args @@ -73,8 +75,11 @@ def set_ctx_activation( if activate: description_embedding_store.load_entities(entities) + if self.use_kusto_community_reports: + description_embedding_store.load_reports(reports) else: description_embedding_store.unload_entities() + # I don't think it is necessary to unload anything as the retention policy will take care of it. return 0 @@ -246,10 +251,13 @@ def _read_config_parameters(root: str, config: str | None): ValueError("Context switching is only supporeted for vectore_store.type=kusto ") entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. + reports = read_indexer_reports(final_community_reports, final_nodes, community_level) self.set_ctx_activation( entities=entities, - activate=1, config_args=vector_store_args, + reports=reports, + activate=1, + config_args=vector_store_args, ) def deactivate(self): diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index 1d1cd25b4f..e0a5e9f583 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -16,8 +16,6 @@ class SearchType(Enum): LOCAL = "local" GLOBAL = "global" - KUSTO_LOCAL = "kusto_local" - KUSTO_GLOBAL = "kusto_global" def __str__(self): """Return the string representation of the enum value.""" @@ -85,6 +83,12 @@ def __str__(self): default=False, ) + parser.add_argument( + "--use_kusto_community_reports", + help="If enabled community reports are attempted to be used in Kusto during query", + action="store_true", + ) + parser.add_argument( "query", nargs=1, @@ -92,8 +96,6 @@ def __str__(self): type=str, ) - - args = parser.parse_args() match args.method: @@ -106,7 +108,8 @@ def __str__(self): args.response_type, args.context_id, args.query[0], - optimized_search=args.optimized_search + optimized_search=args.optimized_search, + use_kusto_community_reports=args.use_kusto_community_reports, ) case SearchType.GLOBAL: run_global_search( diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 44d6d1bb55..8fab982b55 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -45,9 +45,10 @@ reporter = PrintProgressReporter("") def __get_embedding_description_store( - entities: list[Entity], + entities: list[Entity] = [], vector_store_type: str = VectorStoreType.LanceDB, config_args: dict | None = None, + context_id: str = "", ): """Get the embedding description store.""" if not config_args: @@ -56,11 +57,12 @@ def __get_embedding_description_store( collection_name = config_args.get( "query_collection_name", "entity_description_embeddings" ) - config_args.update({"collection_name": collection_name}) + config_args.update({"collection_name": f"{collection_name}_{context_id}" if context_id else collection_name}) vector_name = config_args.get( "vector_search_column", "description_embedding" ) config_args.update({"vector_name": vector_name}) + config_args.update({"reports_name": f"reports_{context_id}" if context_id else "reports"}) description_embedding_store = VectorStoreFactory.get_vector_store( vector_store_type=vector_store_type, kwargs=config_args @@ -155,6 +157,7 @@ def run_local_search( context_id: str, query: str, optimized_search: bool = False, + use_kusto_community_reports: bool = False, ): """Run a local search with the given query.""" data_dir, root_dir, config = _configure_paths_and_settings( @@ -213,12 +216,14 @@ def run_local_search( vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it. - - + reports=read_indexer_reports( + final_community_reports, final_nodes, community_level + ) description_embedding_store = __get_embedding_description_store( entities=entities, vector_store_type=vector_store_type, config_args=vector_store_args, + context_id=context_id, ) covariates = ( @@ -229,12 +234,13 @@ def run_local_search( if(isinstance(description_embedding_store, KustoVectorStore)): entities = [] + description_embedding_store.load_reports(reports) + if use_kusto_community_reports: + reports = [] search_engine = get_local_search_engine( config, - reports=read_indexer_reports( - final_community_reports, final_nodes, community_level - ), + reports=reports, text_units=read_indexer_text_units(final_text_units), entities=entities, relationships=read_indexer_relationships(final_relationships), @@ -242,6 +248,7 @@ def run_local_search( description_embedding_store=description_embedding_store, response_type=response_type, is_optimized_search=optimized_search, + use_kusto_community_reports=use_kusto_community_reports, ) if optimized_search: diff --git a/graphrag/query/context_builder/entity_extraction.py b/graphrag/query/context_builder/entity_extraction.py index fe8253b612..037da80932 100644 --- a/graphrag/query/context_builder/entity_extraction.py +++ b/graphrag/query/context_builder/entity_extraction.py @@ -46,6 +46,9 @@ def map_query_to_entities_in_place( text_embedder=lambda t: text_embedder.embed(t), k=k * oversample_scaler, ) + import ast + for result in search_results: + result.community_ids = ast.literal_eval(result.community_ids) return search_results def map_query_to_entities( diff --git a/graphrag/query/factories.py b/graphrag/query/factories.py index 3dfe104230..e9775f2bf2 100644 --- a/graphrag/query/factories.py +++ b/graphrag/query/factories.py @@ -108,7 +108,8 @@ def get_local_search_engine( covariates: dict[str, list[Covariate]], response_type: str, description_embedding_store: BaseVectorStore, - is_optimized_search: bool = False + is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, ) -> LocalSearch: """Create a local search engine based on data + configuration.""" llm = get_llm(config) @@ -130,6 +131,7 @@ def get_local_search_engine( text_embedder=text_embedder, token_encoder=token_encoder, is_optimized_search= is_optimized_search, + use_kusto_community_reports=use_kusto_community_reports, ), token_encoder=token_encoder, llm_params={ diff --git a/graphrag/query/indexer_adapters.py b/graphrag/query/indexer_adapters.py index c3b5f6e1ab..101fc16f9c 100644 --- a/graphrag/query/indexer_adapters.py +++ b/graphrag/query/indexer_adapters.py @@ -94,6 +94,7 @@ def read_indexer_reports( report_df = _filter_under_community_level(report_df, community_level) report_df = report_df.merge(filtered_community_df, on="community", how="inner") + report_df = report_df.drop_duplicates(subset=["community"]) return read_community_reports( df=report_df, diff --git a/graphrag/query/structured_search/local_search/mixed_context.py b/graphrag/query/structured_search/local_search/mixed_context.py index b26d0581ee..988d290848 100644 --- a/graphrag/query/structured_search/local_search/mixed_context.py +++ b/graphrag/query/structured_search/local_search/mixed_context.py @@ -43,6 +43,7 @@ from graphrag.query.llm.text_utils import num_tokens from graphrag.query.structured_search.base import LocalContextBuilder from graphrag.vector_stores import BaseVectorStore +from graphrag.vector_stores.kusto import KustoVectorStore log = logging.getLogger(__name__) @@ -62,6 +63,7 @@ def __init__( token_encoder: tiktoken.Encoding | None = None, embedding_vectorstore_key: str = EntityVectorStoreKey.ID, is_optimized_search: bool = False, + use_kusto_community_reports: bool = False, ): if community_reports is None: community_reports = [] @@ -85,6 +87,7 @@ def __init__( self.token_encoder = token_encoder self.embedding_vectorstore_key = embedding_vectorstore_key self.is_optimized_search = is_optimized_search + self.use_kusto_community_reports = use_kusto_community_reports def filter_by_entity_keys(self, entity_keys: list[int] | list[str]): """Filter entity text embeddings by entity keys.""" @@ -238,7 +241,7 @@ def _build_community_context( is_optimized_search: bool = False, ) -> tuple[str, dict[str, pd.DataFrame]]: """Add community data to the context window until it hits the max_tokens limit.""" - if len(selected_entities) == 0 or len(self.community_reports) == 0: + if len(selected_entities) == 0 or (len(self.community_reports) == 0 and not self.use_kusto_community_reports): return ("", {context_name.lower(): pd.DataFrame()}) community_matches = {} @@ -250,12 +253,19 @@ def _build_community_context( community_matches.get(community_id, 0) + 1 ) + selected_communities = [] + if self.use_kusto_community_reports: + selected_communities = self.entity_text_embeddings.get_extracted_reports( + community_ids=list(community_matches.keys()) + ) + else: + selected_communities = [ + self.community_reports[community_id] + for community_id in community_matches + if community_id in self.community_reports + ] + # sort communities by number of matched entities and rank - selected_communities = [ - self.community_reports[community_id] - for community_id in community_matches - if community_id in self.community_reports - ] for community in selected_communities: if community.attributes is None: community.attributes = {} @@ -450,7 +460,7 @@ def _build_local_context( relationship_context, self.token_encoder ) - + # build covariate context for covariate in self.covariates: covariate_context, covariate_context_data = build_covariates_context( diff --git a/graphrag/vector_stores/azure_ai_search.py b/graphrag/vector_stores/azure_ai_search.py index efab79d70f..afb0c18393 100644 --- a/graphrag/vector_stores/azure_ai_search.py +++ b/graphrag/vector_stores/azure_ai_search.py @@ -24,6 +24,7 @@ ) from azure.search.documents.models import VectorizedQuery +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -194,9 +195,15 @@ def similarity_search_by_text( ) return [] + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for Azure AI Search") + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: raise NotImplementedError("Extracting entities is not supported for Azure AI Search") - def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: - raise NotImplementedError("Loading entities is not supported for Azure AI Search") \ No newline at end of file + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for Azure AI Search") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting reports is not supported for Azure AI Search") \ No newline at end of file diff --git a/graphrag/vector_stores/base.py b/graphrag/vector_stores/base.py index 104c9790c3..67121a04b9 100644 --- a/graphrag/vector_stores/base.py +++ b/graphrag/vector_stores/base.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import Any +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -45,6 +46,7 @@ def __init__( self, collection_name: str, vector_name: str, + reports_name: str, db_connection: Any | None = None, document_collection: Any | None = None, query_filter: Any | None = None, @@ -52,6 +54,7 @@ def __init__( ): self.collection_name = collection_name self.vector_name = vector_name + self.reports_name = reports_name self.db_connection = db_connection self.document_collection = document_collection self.query_filter = query_filter @@ -91,4 +94,14 @@ def get_extracted_entities( @abstractmethod def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: - """Load entities into the vector-store.""" \ No newline at end of file + """Load entities into the vector-store.""" + + @abstractmethod + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + """Load reports into the vector-store.""" + + @abstractmethod + def get_extracted_reports( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + """Get reports for a given list of community ids.""" \ No newline at end of file diff --git a/graphrag/vector_stores/kusto.py b/graphrag/vector_stores/kusto.py index e12c872239..ec969e1f5f 100644 --- a/graphrag/vector_stores/kusto.py +++ b/graphrag/vector_stores/kusto.py @@ -6,6 +6,7 @@ import typing from azure.kusto.data import KustoClient, KustoConnectionStringBuilder from azure.kusto.data.helpers import dataframe_from_result_table +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -208,7 +209,7 @@ def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int attributes=row["attributes"], ) for _, row in df.iterrows() ] - + def unload_entities(self) -> None: self.client.execute(self.database,f".drop table {self.collection_name} ifexists") @@ -230,3 +231,49 @@ def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: # Ingest data ingestion_command = f".ingest inline into table {self.collection_name} <| {df.to_csv(index=False, header=False)}" self.client.execute(self.database, ingestion_command) + + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + # Convert data to DataFrame + df = pd.DataFrame(reports) + + # Create or replace table + if overwrite: + command = f".drop table {self.reports_name} ifexists" + self.client.execute(self.database, command) + command = f".create table {self.reports_name} (id: string, short_id: string, title: string, community_id: string, summary: string, full_content: string, rank: real, summary_embedding: dynamic, full_content_embedding: dynamic, attributes: dynamic)" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.summary_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + command = f".alter column {self.reports_name}.full_content_embedding policy encoding type = 'Vector16'" + self.client.execute(self.database, command) + + # Ingest data + ingestion_command = f".ingest inline into table {self.reports_name} <| {df.to_csv(index=False, header=False)}" + self.client.execute(self.database, ingestion_command) + + + def get_extracted_reports( + self, community_ids: list[int], **kwargs: Any + ) -> list[CommunityReport]: + community_ids = ", ".join([str(id) for id in community_ids]) + query = f""" + reports + | where community_id in ({community_ids}) + """ + response = self.client.execute(self.database, query) + df = dataframe_from_result_table(response.primary_results[0]) + + return [ + CommunityReport( + id=row["id"], + short_id=row["short_id"], + title=row["title"], + community_id=row["community_id"], + summary=row["summary"], + full_content=row["full_content"], + rank=row["rank"], + summary_embedding=row["summary_embedding"], + full_content_embedding=row["full_content_embedding"], + attributes=row["attributes"], + ) for _, row in df.iterrows() + ] diff --git a/graphrag/vector_stores/lancedb.py b/graphrag/vector_stores/lancedb.py index 095b6bae42..a32b8abb51 100644 --- a/graphrag/vector_stores/lancedb.py +++ b/graphrag/vector_stores/lancedb.py @@ -4,6 +4,7 @@ """The LanceDB vector storage implementation package.""" import lancedb as lancedb # noqa: I001 (Ruff was breaking on this file imports, even tho they were sorted and passed local tests) +from graphrag.model.community_report import CommunityReport from graphrag.model.entity import Entity from graphrag.model.types import TextEmbedder @@ -121,9 +122,15 @@ def similarity_search_by_text( return self.similarity_search_by_vector(query_embedding, k) return [] + def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: + raise NotImplementedError("Loading entities is not supported for LanceDB") + def get_extracted_entities(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any ) -> list[Entity]: raise NotImplementedError("Extracting entities is not supported for LanceDB") - def load_entities(self, entities: list[Entity], overwrite: bool = True) -> None: - raise NotImplementedError("Loading entities is not supported for LanceDB") + def load_reports(self, reports: list[CommunityReport], overwrite: bool = True) -> None: + raise NotImplementedError("Loading reports is not supported for LanceDB") + + def get_extracted_reports(self, community_ids: list[int], **kwargs: Any) -> list[CommunityReport]: + raise NotImplementedError("Extracting community reports is not supported for LanceDB")