Skip to content

Commit

Permalink
Merge pull request microsoft#26 from prateejain-linked/COMMUNITY_REPORTS
Browse files Browse the repository at this point in the history
Adding community reports to Kusto
  • Loading branch information
sirus-ms authored Sep 3, 2024
2 parents 0e5a4b9 + aa6095b commit 98ff552
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 55 deletions.
14 changes: 10 additions & 4 deletions graphrag/index/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@
type=str,
)
parser.add_argument(
"--context-id",
"--context_id",
required=False,
help="Context id to activate or deactivate.",
type=str
)
parser.add_argument(
"--context-operation",
"--context_operation",
help="Context operation activate or deactivate.",
required=False,
# Only required if contextId is provided
Expand All @@ -77,7 +77,7 @@
action="store_true",
)
parser.add_argument(
"--overlay-defaults",
"--overlay_defaults",
help="Overlay default configuration values on a provided configuration file (--config).",
action="store_true",
)
Expand All @@ -87,6 +87,11 @@
type=int,
default=2,
)
parser.add_argument(
"--use_kusto_community_reports",
help="If enabled community reports are loaded into Kusto during activation",
action="store_true",
)

args = parser.parse_args()

Expand All @@ -108,5 +113,6 @@
cli=True,
context_id=args.context_id,
context_operation=args.context_operation,
community_level=args.community_level
community_level=args.community_level,
use_kusto_community_reports=args.use_kusto_community_reports
)
17 changes: 13 additions & 4 deletions graphrag/index/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def index_cli(
dryrun: bool,
overlay_defaults: bool,
cli: bool = False,
use_kusto_community_reports: bool = False,
):
"""Run the pipeline with the given config."""
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
Expand All @@ -107,8 +108,15 @@ def index_cli(
ValueError("ContextId is invalid: It should be a valid Guid")
if (context_operation != ContextSwitchType.Activate and context_operation != ContextSwitchType.Deactivate):
ValueError("ContextOperation is invalid: It should be Active or DeActive")
#graphrag_config = _read_config_parameters(root, config, progress_reporter)
_switch_context(config,root,context_operation,context_id,progress_reporter,community_level)
_switch_context(
config,
root,
context_operation,
context_id,
progress_reporter,
community_level,
use_kusto_community_reports,
)
sys.exit(0)
cache = NoopPipelineCache() if nocache else None
pipeline_emit = emit.split(",") if emit else None
Expand Down Expand Up @@ -185,11 +193,12 @@ async def execute():
sys.exit(1 if encountered_errors else 0)

def _switch_context(config: GraphRagConfig | str, root: str , context_operation: str | None,
context_id: str, reporter: ProgressReporter,community_level: int) -> None:
context_id: str, reporter: ProgressReporter,community_level: int,
use_kusto_community_reports: bool) -> None:
"""Switch the context to the given context."""
reporter.info(f"Switching context to {context_id} using operation {context_operation}")
from graphrag.index.context_switch.contextSwitcher import ContextSwitcher
context_switcher = ContextSwitcher(root,config,reporter,context_id,community_level)
context_switcher = ContextSwitcher(root,config, reporter,context_id,community_level,use_kusto_community_reports)
if context_operation == ContextSwitchType.Activate:
context_switcher.activate()
elif context_operation == ContextSwitchType.Deactivate:
Expand Down
50 changes: 29 additions & 21 deletions graphrag/index/context_switch/contextSwitcher.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
from graphrag.common.progress import ProgressReporter
from graphrag.config import GraphRagConfig
from graphrag.config.enums import StorageType
from graphrag.common.storage import PipelineStorage, BlobPipelineStorage, FilePipelineStorage
from graphrag.common.utils.context_utils import get_files_by_contextid
import pandas as pd
from typing import cast
from azure.core.exceptions import ResourceNotFoundError
import asyncio
import os
from io import BytesIO
from pathlib import Path
from typing import cast

import pandas as pd

from common.graph_db_client import GraphDBClient
from graphrag.common.progress import ProgressReporter
from graphrag.common.storage import (
BlobPipelineStorage,
FilePipelineStorage,
PipelineStorage,
)
from graphrag.common.utils.context_utils import get_files_by_contextid
from graphrag.config import (
create_graphrag_config,
GraphRagConfig,
create_graphrag_config,
)
from common.graph_db_client import GraphDBClient
import os
from graphrag.vector_stores import VectorStoreFactory, VectorStoreType
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.kusto import KustoVectorStore
from graphrag.config.enums import StorageType
from graphrag.model.community_report import CommunityReport
from graphrag.model.entity import Entity
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
kt_read_indexer_reports,
read_indexer_text_units,
)
from graphrag.model.entity import Entity
from azure.cosmos import CosmosClient, PartitionKey
Expand All @@ -36,7 +34,8 @@ class ContextSwitcher:
def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter,
context_id:str, community_level:int ,
data_dir: str = None,
optimized_search: bool= False):
optimized_search: bool= False,
use_kusto_community_reports: bool = False,):

self.root_dir=root_dir
self.config_dir=config_dir
Expand All @@ -45,11 +44,13 @@ def __init__(self, root_dir:str , config_dir:str,reporter: ProgressReporter,
self.context_id=context_id
self.optimized_search=optimized_search
self.community_level = community_level
self.use_kusto_community_reports = use_kusto_community_reports

def set_ctx_activation(
self,
activate: int,
entities: list[Entity]=[],
reports: list[CommunityReport]=[],
config_args: dict | None = None,
):
if not config_args:
Expand All @@ -65,6 +66,7 @@ def set_ctx_activation(
"vector_search_column", "description_embedding"
)
config_args.update({"vector_name": vector_name})
config_args.update({"reports_name": f"reports_{self.context_id}"})

description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=VectorStoreType.Kusto, kwargs=config_args
Expand All @@ -73,8 +75,11 @@ def set_ctx_activation(

if activate:
description_embedding_store.load_entities(entities)
if self.use_kusto_community_reports:
description_embedding_store.load_reports(reports)
else:
description_embedding_store.unload_entities()
# I don't think it is necessary to unload anything as the retention policy will take care of it.

return 0

Expand Down Expand Up @@ -246,10 +251,13 @@ def _read_config_parameters(root: str, config: str | None):
ValueError("Context switching is only supporeted for vectore_store.type=kusto ")

entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it.
reports = read_indexer_reports(final_community_reports, final_nodes, community_level)

self.set_ctx_activation(
entities=entities,
activate=1, config_args=vector_store_args,
reports=reports,
activate=1,
config_args=vector_store_args,
)

def deactivate(self):
Expand Down
13 changes: 8 additions & 5 deletions graphrag/query/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ class SearchType(Enum):

LOCAL = "local"
GLOBAL = "global"
KUSTO_LOCAL = "kusto_local"
KUSTO_GLOBAL = "kusto_global"

def __str__(self):
"""Return the string representation of the enum value."""
Expand Down Expand Up @@ -85,15 +83,19 @@ def __str__(self):
default=False,
)

parser.add_argument(
"--use_kusto_community_reports",
help="If enabled community reports are attempted to be used in Kusto during query",
action="store_true",
)

parser.add_argument(
"query",
nargs=1,
help="The query to run",
type=str,
)



args = parser.parse_args()

match args.method:
Expand All @@ -106,7 +108,8 @@ def __str__(self):
args.response_type,
args.context_id,
args.query[0],
optimized_search=args.optimized_search
optimized_search=args.optimized_search,
use_kusto_community_reports=args.use_kusto_community_reports,
)
case SearchType.GLOBAL:
run_global_search(
Expand Down
21 changes: 14 additions & 7 deletions graphrag/query/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
reporter = PrintProgressReporter("")

def __get_embedding_description_store(
entities: list[Entity],
entities: list[Entity] = [],
vector_store_type: str = VectorStoreType.LanceDB,
config_args: dict | None = None,
context_id: str = "",
):
"""Get the embedding description store."""
if not config_args:
Expand All @@ -56,11 +57,12 @@ def __get_embedding_description_store(
collection_name = config_args.get(
"query_collection_name", "entity_description_embeddings"
)
config_args.update({"collection_name": collection_name})
config_args.update({"collection_name": f"{collection_name}_{context_id}" if context_id else collection_name})
vector_name = config_args.get(
"vector_search_column", "description_embedding"
)
config_args.update({"vector_name": vector_name})
config_args.update({"reports_name": f"reports_{context_id}" if context_id else "reports"})

description_embedding_store = VectorStoreFactory.get_vector_store(
vector_store_type=vector_store_type, kwargs=config_args
Expand Down Expand Up @@ -155,6 +157,7 @@ def run_local_search(
context_id: str,
query: str,
optimized_search: bool = False,
use_kusto_community_reports: bool = False,
):
"""Run a local search with the given query."""
data_dir, root_dir, config = _configure_paths_and_settings(
Expand Down Expand Up @@ -213,12 +216,14 @@ def run_local_search(
vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB)

entities = read_indexer_entities(final_nodes, final_entities, community_level) # KustoDB: read Final nodes data and entities data and merge it.


reports=read_indexer_reports(
final_community_reports, final_nodes, community_level
)
description_embedding_store = __get_embedding_description_store(
entities=entities,
vector_store_type=vector_store_type,
config_args=vector_store_args,
context_id=context_id,
)

covariates = (
Expand All @@ -229,19 +234,21 @@ def run_local_search(

if(isinstance(description_embedding_store, KustoVectorStore)):
entities = []
description_embedding_store.load_reports(reports)
if use_kusto_community_reports:
reports = []

search_engine = get_local_search_engine(
config,
reports=read_indexer_reports(
final_community_reports, final_nodes, community_level
),
reports=reports,
text_units=read_indexer_text_units(final_text_units),
entities=entities,
relationships=read_indexer_relationships(final_relationships),
covariates={"claims": covariates},
description_embedding_store=description_embedding_store,
response_type=response_type,
is_optimized_search=optimized_search,
use_kusto_community_reports=use_kusto_community_reports,
)

if optimized_search:
Expand Down
3 changes: 3 additions & 0 deletions graphrag/query/context_builder/entity_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def map_query_to_entities_in_place(
text_embedder=lambda t: text_embedder.embed(t),
k=k * oversample_scaler,
)
import ast
for result in search_results:
result.community_ids = ast.literal_eval(result.community_ids)
return search_results

def map_query_to_entities(
Expand Down
4 changes: 3 additions & 1 deletion graphrag/query/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def get_local_search_engine(
covariates: dict[str, list[Covariate]],
response_type: str,
description_embedding_store: BaseVectorStore,
is_optimized_search: bool = False
is_optimized_search: bool = False,
use_kusto_community_reports: bool = False,
) -> LocalSearch:
"""Create a local search engine based on data + configuration."""
llm = get_llm(config)
Expand All @@ -130,6 +131,7 @@ def get_local_search_engine(
text_embedder=text_embedder,
token_encoder=token_encoder,
is_optimized_search= is_optimized_search,
use_kusto_community_reports=use_kusto_community_reports,
),
token_encoder=token_encoder,
llm_params={
Expand Down
1 change: 1 addition & 0 deletions graphrag/query/indexer_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def read_indexer_reports(

report_df = _filter_under_community_level(report_df, community_level)
report_df = report_df.merge(filtered_community_df, on="community", how="inner")
report_df = report_df.drop_duplicates(subset=["community"])

return read_community_reports(
df=report_df,
Expand Down
Loading

0 comments on commit 98ff552

Please sign in to comment.