diff --git a/.semversioner/0.3.0.json b/.semversioner/0.3.0.json new file mode 100644 index 0000000000..711aeb80f8 --- /dev/null +++ b/.semversioner/0.3.0.json @@ -0,0 +1,30 @@ +{ + "changes": [ + { + "description": "Implement auto templating API.", + "type": "minor" + }, + { + "description": "Implement query engine API.", + "type": "minor" + }, + { + "description": "Fix file dumps using json for non ASCII chars", + "type": "patch" + }, + { + "description": "Stabilize smoke tests for query context building", + "type": "patch" + }, + { + "description": "fix query embedding", + "type": "patch" + }, + { + "description": "fix sort_context & max_tokens params in verb", + "type": "patch" + } + ], + "created_at": "2024-08-12T23:51:49+00:00", + "version": "0.3.0" +} \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ef6e2f345..c4d68bd464 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,21 @@ # Changelog - Note: version releases in the 0.x.y range may introduce breaking changes. +## 0.3.0 + +- minor: Implement auto templating API. +- minor: Implement query engine API. +- patch: Fix file dumps using json for non ASCII chars +- patch: Stabilize smoke tests for query context building +- patch: fix query embedding +- patch: fix sort_context & max_tokens params in verb + ## 0.2.2 - patch: Add a check if there is no community record added in local search context - patch: Add sepparate workflow for Python Tests - patch: Docs updates +- patch: Run smoke tests on 4o ## 0.2.1 diff --git a/CODEOWNERS b/CODEOWNERS index 47b118f4d8..ebfb11b8d4 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -2,5 +2,4 @@ # the repo. Unless a later match takes precedence, # @global-owner1 and @global-owner2 will be requested for # review when someone opens a pull request. -* @microsoft/societal-resilience -* @microsoft/graphrag-core-team +* @microsoft/societal-resilience @microsoft/graphrag-core-team diff --git a/graphrag/index/graph/extractors/community_reports/sort_context.py b/graphrag/index/graph/extractors/community_reports/sort_context.py index 811cb7e95c..c62710e1c8 100644 --- a/graphrag/index/graph/extractors/community_reports/sort_context.py +++ b/graphrag/index/graph/extractors/community_reports/sort_context.py @@ -144,7 +144,7 @@ def _get_context_string( new_context_string = _get_context_string( sorted_nodes, sorted_edges, sorted_claims, sub_community_reports ) - if num_tokens(context_string) > max_tokens: + if num_tokens(new_context_string) > max_tokens: break context_string = new_context_string diff --git a/graphrag/index/graph/extractors/summarize/prompts.py b/graphrag/index/graph/extractors/summarize/prompts.py index 90e4434ee8..8e544999ad 100644 --- a/graphrag/index/graph/extractors/summarize/prompts.py +++ b/graphrag/index/graph/extractors/summarize/prompts.py @@ -8,7 +8,7 @@ Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. -Make sure it is written in third person, and include the entity names so we the have full context. +Make sure it is written in third person, and include the entity names so we have the full context. ####### -Data- diff --git a/graphrag/index/workflows/v1/create_final_community_reports.py b/graphrag/index/workflows/v1/create_final_community_reports.py index 164c70e0dd..b653595880 100644 --- a/graphrag/index/workflows/v1/create_final_community_reports.py +++ b/graphrag/index/workflows/v1/create_final_community_reports.py @@ -19,6 +19,10 @@ def build_steps( """ covariates_enabled = config.get("covariates_enabled", False) create_community_reports_config = config.get("create_community_reports", {}) + community_report_strategy = create_community_reports_config.get("strategy", {}) + community_report_max_input_length = community_report_strategy.get( + "max_input_length", 16_000 + ) base_text_embed = config.get("text_embed", {}) community_report_full_content_embed_config = config.get( "community_report_full_content_embed", base_text_embed @@ -77,6 +81,7 @@ def build_steps( { "id": "local_contexts", "verb": "prepare_community_reports", + "args": {"max_tokens": community_report_max_input_length}, "input": { "source": "nodes", "nodes": "nodes", diff --git a/graphrag/prompt_tune/__main__.py b/graphrag/prompt_tune/__main__.py index e752b05a8f..cbf8dd66c4 100644 --- a/graphrag/prompt_tune/__main__.py +++ b/graphrag/prompt_tune/__main__.py @@ -1,37 +1,32 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -"""The Prompt auto templating package root.""" +"""The auto templating package root.""" import argparse import asyncio -from enum import Enum - -from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT -from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE +from .api import DocSelectionType from .cli import prompt_tune - - -class DocSelectionType(Enum): - """The type of document selection to use.""" - - ALL = "all" - RANDOM = "random" - TOP = "top" - AUTO = "auto" - - def __str__(self): - """Return the string representation of the enum value.""" - return self.value - +from .generator import MAX_TOKEN_COUNT +from .loader import MIN_CHUNK_SIZE if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + prog="python -m graphrag.prompt_tune", + description="The graphrag auto templating module.", + ) + + parser.add_argument( + "--config", + help="Configuration yaml file to use when generating prompts", + required=True, + type=str, + ) parser.add_argument( "--root", - help="The data project root. Including the config yml, json or .env", + help="Data project root. Default: current directory", required=False, type=str, default=".", @@ -39,15 +34,15 @@ def __str__(self): parser.add_argument( "--domain", - help="The domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If left empty, the domain will be inferred from the input data.", + help="Domain your input data is related to. For example 'space science', 'microbiology', 'environmental news'. If not defined, the domain will be inferred from the input data.", required=False, default="", type=str, ) parser.add_argument( - "--method", - help="The method to select documents, one of: all, random, top or auto", + "--selection-method", + help=f"Chunk selection method. Default: {DocSelectionType.RANDOM}", required=False, type=DocSelectionType, choices=list(DocSelectionType), @@ -56,7 +51,7 @@ def __str__(self): parser.add_argument( "--n_subset_max", - help="The number of text chunks to embed when using auto selection method", + help="Number of text chunks to embed when using auto selection method. Default: 300", required=False, type=int, default=300, @@ -64,7 +59,7 @@ def __str__(self): parser.add_argument( "--k", - help="The maximum number of documents to select from each centroid when using auto selection method", + help="Maximum number of documents to select from each centroid when using auto selection method. Default: 15", required=False, type=int, default=15, @@ -72,7 +67,7 @@ def __str__(self): parser.add_argument( "--limit", - help="The limit of files to load when doing random or top selection", + help="Number of documents to load when doing random or top selection. Default: 15", type=int, required=False, default=15, @@ -80,7 +75,7 @@ def __str__(self): parser.add_argument( "--max-tokens", - help="Max token count for prompt generation", + help=f"Max token count for prompt generation. Default: {MAX_TOKEN_COUNT}", type=int, required=False, default=MAX_TOKEN_COUNT, @@ -88,7 +83,7 @@ def __str__(self): parser.add_argument( "--min-examples-required", - help="The minimum number of examples required in entity extraction prompt", + help="Minimum number of examples required in the entity extraction prompt. Default: 2", type=int, required=False, default=2, @@ -96,7 +91,7 @@ def __str__(self): parser.add_argument( "--chunk-size", - help="Max token count for prompt generation", + help=f"Max token count for prompt generation. Default: {MIN_CHUNK_SIZE}", type=int, required=False, default=MIN_CHUNK_SIZE, @@ -120,7 +115,7 @@ def __str__(self): parser.add_argument( "--output", - help="Folder to save the generated prompts to", + help="Directory to save generated prompts to. Default: 'prompts'", type=str, required=False, default="prompts", @@ -132,17 +127,18 @@ def __str__(self): loop.run_until_complete( prompt_tune( - args.root, - args.domain, - str(args.method), - args.limit, - args.max_tokens, - args.chunk_size, - args.language, - args.no_entity_types, - args.output, - args.n_subset_max, - args.k, - args.min_examples_required, + config=args.config, + root=args.root, + domain=args.domain, + selection_method=args.selection_method, + limit=args.limit, + max_tokens=args.max_tokens, + chunk_size=args.chunk_size, + language=args.language, + skip_entity_types=args.no_entity_types, + output=args.output, + n_subset_max=args.n_subset_max, + k=args.k, + min_examples_required=args.min_examples_required, ) ) diff --git a/graphrag/prompt_tune/api.py b/graphrag/prompt_tune/api.py new file mode 100644 index 0000000000..4bbcb5d7dc --- /dev/null +++ b/graphrag/prompt_tune/api.py @@ -0,0 +1,173 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +""" +Auto Templating API. + +This API provides access to the auto templating feature of graphrag, allowing external applications +to hook into graphrag and generate prompts from private data. + +WARNING: This API is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +from datashaper import NoopVerbCallbacks +from pydantic import PositiveInt, validate_call + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.llm import load_llm +from graphrag.index.progress import PrintProgressReporter + +from .cli import DocSelectionType +from .generator import ( + MAX_TOKEN_COUNT, + create_community_summarization_prompt, + create_entity_extraction_prompt, + create_entity_summarization_prompt, + detect_language, + generate_community_report_rating, + generate_community_reporter_role, + generate_domain, + generate_entity_relationship_examples, + generate_entity_types, + generate_persona, +) +from .loader import ( + MIN_CHUNK_SIZE, + load_docs_in_chunks, +) + + +@validate_call +async def generate_indexing_prompts( + config: GraphRagConfig, + root: str, + chunk_size: PositiveInt = MIN_CHUNK_SIZE, + limit: PositiveInt = 15, + selection_method: DocSelectionType = DocSelectionType.RANDOM, + domain: str | None = None, + language: str | None = None, + max_tokens: int = MAX_TOKEN_COUNT, + skip_entity_types: bool = False, + min_examples_required: PositiveInt = 2, + n_subset_max: PositiveInt = 300, + k: PositiveInt = 15, +) -> tuple[str, str, str]: + """Generate indexing prompts. + + Parameters + ---------- + - config: The GraphRag configuration. + - output_path: The path to store the prompts. + - chunk_size: The chunk token size to use for input text units. + - limit: The limit of chunks to load. + - selection_method: The chunk selection method. + - domain: The domain to map the input documents to. + - language: The language to use for the prompts. + - max_tokens: The maximum number of tokens to use on entity extraction prompts + - skip_entity_types: Skip generating entity types. + - min_examples_required: The minimum number of examples required for entity extraction prompts. + - n_subset_max: The number of text chunks to embed when using auto selection method. + - k: The number of documents to select when using auto selection method. + + Returns + ------- + tuple[str, str, str]: entity extraction prompt, entity summarization prompt, community summarization prompt + """ + reporter = PrintProgressReporter("") + + # Retrieve documents + doc_list = await load_docs_in_chunks( + root=root, + config=config, + limit=limit, + select_method=selection_method, + reporter=reporter, + chunk_size=chunk_size, + n_subset_max=n_subset_max, + k=k, + ) + + # Create LLM from config + llm = load_llm( + "prompt_tuning", + config.llm.type, + NoopVerbCallbacks(), + None, + config.llm.model_dump(), + ) + + if not domain: + reporter.info("Generating domain...") + domain = await generate_domain(llm, doc_list) + reporter.info(f"Generated domain: {domain}") + + if not language: + reporter.info("Detecting language...") + language = await detect_language(llm, doc_list) + + reporter.info("Generating persona...") + persona = await generate_persona(llm, domain) + + reporter.info("Generating community report ranking description...") + community_report_ranking = await generate_community_report_rating( + llm, domain=domain, persona=persona, docs=doc_list + ) + + entity_types = None + if not skip_entity_types: + reporter.info("Generating entity types...") + entity_types = await generate_entity_types( + llm, + domain=domain, + persona=persona, + docs=doc_list, + json_mode=config.llm.model_supports_json or False, + ) + + reporter.info("Generating entity relationship examples...") + examples = await generate_entity_relationship_examples( + llm, + persona=persona, + entity_types=entity_types, + docs=doc_list, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine + ) + + reporter.info("Generating entity extraction prompt...") + entity_extraction_prompt = create_entity_extraction_prompt( + entity_types=entity_types, + docs=doc_list, + examples=examples, + language=language, + json_mode=False, # config.llm.model_supports_json should be used, but these prompts are used in non-json by the index engine + encoding_model=config.encoding_model, + max_token_count=max_tokens, + min_examples_required=min_examples_required, + ) + + reporter.info("Generating entity summarization prompt...") + entity_summarization_prompt = create_entity_summarization_prompt( + persona=persona, + language=language, + ) + + reporter.info("Generating community reporter role...") + community_reporter_role = await generate_community_reporter_role( + llm, domain=domain, persona=persona, docs=doc_list + ) + + reporter.info("Generating community summarization prompt...") + community_summarization_prompt = create_community_summarization_prompt( + persona=persona, + role=community_reporter_role, + report_rating_description=community_report_ranking, + language=language, + ) + + return ( + entity_extraction_prompt, + entity_summarization_prompt, + community_summarization_prompt, + ) diff --git a/graphrag/prompt_tune/cli.py b/graphrag/prompt_tune/cli.py index 5979a4a6ee..eb8ff6f49f 100644 --- a/graphrag/prompt_tune/cli.py +++ b/graphrag/prompt_tune/cli.py @@ -5,37 +5,25 @@ from pathlib import Path -from datashaper import NoopVerbCallbacks - -from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.llm import load_llm from graphrag.index.progress import PrintProgressReporter -from graphrag.index.progress.types import ProgressReporter -from graphrag.llm.types.llm_types import CompletionLLM -from graphrag.prompt_tune.generator import ( - MAX_TOKEN_COUNT, - create_community_summarization_prompt, - create_entity_extraction_prompt, - create_entity_summarization_prompt, - detect_language, - generate_community_report_rating, - generate_community_reporter_role, - generate_domain, - generate_entity_relationship_examples, - generate_entity_types, - generate_persona, -) +from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT from graphrag.prompt_tune.loader import ( MIN_CHUNK_SIZE, - load_docs_in_chunks, read_config_parameters, ) +from . import api +from .generator.community_report_summarization import COMMUNITY_SUMMARIZATION_FILENAME +from .generator.entity_extraction_prompt import ENTITY_EXTRACTION_FILENAME +from .generator.entity_summarization_prompt import ENTITY_SUMMARIZATION_FILENAME +from .types import DocSelectionType + async def prompt_tune( + config: str, root: str, domain: str, - select: str = "random", + selection_method: DocSelectionType = DocSelectionType.RANDOM, limit: int = 15, max_tokens: int = MAX_TOKEN_COUNT, chunk_size: int = MIN_CHUNK_SIZE, @@ -50,223 +38,51 @@ async def prompt_tune( Parameters ---------- + - config: The configuration file. - root: The root directory. - domain: The domain to map the input documents to. - - select: The chunk selection method. + - selection_method: The chunk selection method. - limit: The limit of chunks to load. - max_tokens: The maximum number of tokens to use on entity extraction prompts. - chunk_size: The chunk token size to use. + - language: The language to use for the prompts. - skip_entity_types: Skip generating entity types. - output: The output folder to store the prompts. - n_subset_max: The number of text chunks to embed when using auto selection method. - k: The number of documents to select when using auto selection method. + - min_examples_required: The minimum number of examples required for entity extraction prompts. """ reporter = PrintProgressReporter("") - config = read_config_parameters(root, reporter) - - await prompt_tune_with_config( - root, - config, - domain, - select, - limit, - max_tokens, - chunk_size, - language, - skip_entity_types, - output, - reporter, - n_subset_max, - k, - min_examples_required, - ) - - -async def prompt_tune_with_config( - root: str, - config: GraphRagConfig, - domain: str, - select: str = "random", - limit: int = 15, - max_tokens: int = MAX_TOKEN_COUNT, - chunk_size: int = MIN_CHUNK_SIZE, - language: str | None = None, - skip_entity_types: bool = False, - output: str = "prompts", - reporter: ProgressReporter | None = None, - n_subset_max: int = 300, - k: int = 15, - min_examples_required: int = 2, -): - """Prompt tune the model with a configuration. + graph_config = read_config_parameters(root, reporter, config) - Parameters - ---------- - - root: The root directory. - - config: The GraphRag configuration. - - domain: The domain to map the input documents to. - - select: The chunk selection method. - - limit: The limit of chunks to load. - - max_tokens: The maximum number of tokens to use on entity extraction prompts. - - chunk_size: The chunk token size to use for input text units. - - skip_entity_types: Skip generating entity types. - - output: The output folder to store the prompts. - - reporter: The progress reporter. - - n_subset_max: The number of text chunks to embed when using auto selection method. - - k: The number of documents to select when using auto selection method. - - Returns - ------- - - None - """ - if not reporter: - reporter = PrintProgressReporter("") - - output_path = Path(config.root_dir) / output - - doc_list = await load_docs_in_chunks( + prompts = await api.generate_indexing_prompts( + config=graph_config, root=root, - config=config, - limit=limit, - select_method=select, - reporter=reporter, chunk_size=chunk_size, + limit=limit, + selection_method=selection_method, + domain=domain, + language=language, + max_tokens=max_tokens, + skip_entity_types=skip_entity_types, + min_examples_required=min_examples_required, n_subset_max=n_subset_max, k=k, ) - # Create LLM from config - llm = load_llm( - "prompt_tuning", - config.llm.type, - NoopVerbCallbacks(), - None, - config.llm.model_dump(), - ) - - await generate_indexing_prompts( - llm, - config, - doc_list, - output_path, - reporter, - domain, - language, - max_tokens, - skip_entity_types, - min_examples_required, - ) - - -async def generate_indexing_prompts( - llm: CompletionLLM, - config: GraphRagConfig, - doc_list: list[str], - output_path: Path, - reporter: ProgressReporter, - domain: str | None = None, - language: str | None = None, - max_tokens: int = MAX_TOKEN_COUNT, - skip_entity_types: bool = False, - min_examples_required: int = 2, -): - """Generate indexing prompts. - - Parameters - ---------- - - llm: The LLM model to use. - - config: The GraphRag configuration. - - doc_list: The list of documents to use. - - output_path: The path to store the prompts. - - reporter: The progress reporter. - - domain: The domain to map the input documents to. - - max_tokens: The maximum number of tokens to use on entity extraction prompts - - skip_entity_types: Skip generating entity types. - - min_examples_required: The minimum number of examples required for entity extraction prompts. - """ - if not domain: - reporter.info("Generating domain...") - domain = await generate_domain(llm, doc_list) - reporter.info(f"Generated domain: {domain}") - - if not language: - reporter.info("Detecting language...") - language = await detect_language(llm, doc_list) - reporter.info(f"Detected language: {language}") - - reporter.info("Generating persona...") - persona = await generate_persona(llm, domain) - reporter.info(f"Generated persona: {persona}") - - reporter.info("Generating community report ranking description...") - community_report_ranking = await generate_community_report_rating( - llm, domain=domain, persona=persona, docs=doc_list - ) - reporter.info( - f"Generated community report ranking description: {community_report_ranking}" - ) - - entity_types = None - if not skip_entity_types: - reporter.info("Generating entity types") - entity_types = await generate_entity_types( - llm, - domain=domain, - persona=persona, - docs=doc_list, - json_mode=config.llm.model_supports_json or False, + output_path = Path(output) + if output_path: + reporter.info(f"Writing prompts to {output_path}") + output_path.mkdir(parents=True, exist_ok=True) + entity_extraction_prompt_path = output_path / ENTITY_EXTRACTION_FILENAME + entity_summarization_prompt_path = output_path / ENTITY_SUMMARIZATION_FILENAME + community_summarization_prompt_path = ( + output_path / COMMUNITY_SUMMARIZATION_FILENAME ) - reporter.info(f"Generated entity types: {entity_types}") - - reporter.info("Generating entity relationship examples...") - examples = await generate_entity_relationship_examples( - llm, - persona=persona, - entity_types=entity_types, - docs=doc_list, - language=language, - json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine - ) - reporter.info("Done generating entity relationship examples") - - reporter.info("Generating entity extraction prompt...") - create_entity_extraction_prompt( - entity_types=entity_types, - docs=doc_list, - examples=examples, - language=language, - json_mode=False, # config.llm.model_supports_json should be used, but this prompts are used in non-json by the index engine - output_path=output_path, - encoding_model=config.encoding_model, - max_token_count=max_tokens, - min_examples_required=min_examples_required, - ) - reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}") - - reporter.info("Generating entity summarization prompt...") - create_entity_summarization_prompt( - persona=persona, - language=language, - output_path=output_path, - ) - reporter.info( - f"Generated entity summarization prompt, stored in folder {output_path}" - ) - - reporter.info("Generating community reporter role...") - community_reporter_role = await generate_community_reporter_role( - llm, domain=domain, persona=persona, docs=doc_list - ) - reporter.info(f"Generated community reporter role: {community_reporter_role}") - - reporter.info("Generating community summarization prompt...") - create_community_summarization_prompt( - persona=persona, - role=community_reporter_role, - report_rating_description=community_report_ranking, - language=language, - output_path=output_path, - ) - reporter.info( - f"Generated community summarization prompt, stored in folder {output_path}" - ) + # Write files to output path + with entity_extraction_prompt_path.open("wb") as file: + file.write(prompts[0].encode(encoding="utf-8", errors="strict")) + with entity_summarization_prompt_path.open("wb") as file: + file.write(prompts[1].encode(encoding="utf-8", errors="strict")) + with community_summarization_prompt_path.open("wb") as file: + file.write(prompts[2].encode(encoding="utf-8", errors="strict")) diff --git a/graphrag/prompt_tune/generator/entity_extraction_prompt.py b/graphrag/prompt_tune/generator/entity_extraction_prompt.py index 3b17dbab5d..b2192c0705 100644 --- a/graphrag/prompt_tune/generator/entity_extraction_prompt.py +++ b/graphrag/prompt_tune/generator/entity_extraction_prompt.py @@ -41,7 +41,7 @@ def create_entity_extraction_prompt( - encoding_model (str): The name of the model to use for token counting - max_token_count (int): The maximum number of tokens to use for the prompt - json_mode (bool): Whether to use JSON mode for the prompt. Default is False - - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + - output_path (Path | None): The path to write the prompt to. Default is None. - min_examples_required (int): The minimum number of examples required. Default is 2. Returns @@ -58,8 +58,8 @@ def create_entity_extraction_prompt( tokens_left = ( max_token_count - - num_tokens_from_string(prompt, model=encoding_model) - - num_tokens_from_string(entity_types, model=encoding_model) + - num_tokens_from_string(prompt, encoding_name=encoding_model) + - num_tokens_from_string(entity_types, encoding_name=encoding_model) if entity_types else 0 ) @@ -79,7 +79,9 @@ def create_entity_extraction_prompt( ) ) - example_tokens = num_tokens_from_string(example_formatted, model=encoding_model) + example_tokens = num_tokens_from_string( + example_formatted, encoding_name=encoding_model + ) # Ensure at least three examples are included if i >= min_examples_required and example_tokens > tokens_left: diff --git a/graphrag/prompt_tune/generator/entity_summarization_prompt.py b/graphrag/prompt_tune/generator/entity_summarization_prompt.py index 4ae5af77ec..736df830d6 100644 --- a/graphrag/prompt_tune/generator/entity_summarization_prompt.py +++ b/graphrag/prompt_tune/generator/entity_summarization_prompt.py @@ -15,13 +15,14 @@ def create_entity_summarization_prompt( language: str, output_path: Path | None = None, ) -> str: - """Create a prompt for entity summarization. If output_path is provided, write the prompt to a file. + """ + Create a prompt for entity summarization. Parameters ---------- - persona (str): The persona to use for the entity summarization prompt - language (str): The language to use for the entity summarization prompt - - output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None. + - output_path (Path | None): The path to write the prompt to. Default is None. """ prompt = ENTITY_SUMMARIZATION_PROMPT.format(persona=persona, language=language) diff --git a/graphrag/prompt_tune/loader/config.py b/graphrag/prompt_tune/loader/config.py index 8994604f92..350feacd79 100644 --- a/graphrag/prompt_tune/loader/config.py +++ b/graphrag/prompt_tune/loader/config.py @@ -9,20 +9,38 @@ from graphrag.index.progress.types import ProgressReporter -def read_config_parameters(root: str, reporter: ProgressReporter): +def read_config_parameters( + root: str, reporter: ProgressReporter, config: str | None = None +): """Read the configuration parameters from the settings file or environment variables. Parameters ---------- - root: The root directory where the parameters are. - reporter: The progress reporter. + - config: The path to the settings file. """ _root = Path(root) - settings_yaml = _root / "settings.yaml" + settings_yaml = ( + Path(config) + if config and Path(config).suffix in [".yaml", ".yml"] + else _root / "settings.yaml" + ) if not settings_yaml.exists(): settings_yaml = _root / "settings.yml" - settings_json = _root / "settings.json" + if settings_yaml.exists(): + reporter.info(f"Reading settings from {settings_yaml}") + with settings_yaml.open("rb") as file: + import yaml + + data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict")) + return create_graphrag_config(data, root) + settings_json = ( + Path(config) + if config and Path(config).suffix == ".json" + else _root / "settings.json" + ) if settings_yaml.exists(): reporter.info(f"Reading settings from {settings_yaml}") with settings_yaml.open("rb") as file: diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index 86c4a76040..0679990541 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -16,6 +16,7 @@ from graphrag.index.progress.types import ProgressReporter from graphrag.index.verbs import chunk from graphrag.llm.types.llm_types import EmbeddingLLM +from graphrag.prompt_tune.types import DocSelectionType MIN_CHUNK_OVERLAP = 0 MIN_CHUNK_SIZE = 200 @@ -50,7 +51,7 @@ def _sample_chunks_from_embeddings( async def load_docs_in_chunks( root: str, config: GraphRagConfig, - select_method: str, + select_method: DocSelectionType, limit: int, reporter: ProgressReporter, chunk_size: int = MIN_CHUNK_SIZE, @@ -85,11 +86,11 @@ async def load_docs_in_chunks( if limit <= 0 or limit > len(chunks_df): limit = len(chunks_df) - if select_method == "top": + if select_method == DocSelectionType.TOP: chunks_df = chunks_df[:limit] - elif select_method == "random": + elif select_method == DocSelectionType.RANDOM: chunks_df = chunks_df.sample(n=limit) - elif select_method == "auto": + elif select_method == DocSelectionType.AUTO: if k is None or k <= 0: msg = "k must be an integer > 0" raise ValueError(msg) diff --git a/graphrag/prompt_tune/template/entity_summarization.py b/graphrag/prompt_tune/template/entity_summarization.py index 60294a291b..7710ba0bdb 100644 --- a/graphrag/prompt_tune/template/entity_summarization.py +++ b/graphrag/prompt_tune/template/entity_summarization.py @@ -9,7 +9,7 @@ Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. Please concatenate all of these into a single, concise description in {language}. Make sure to include information collected from all the descriptions. If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. -Make sure it is written in third person, and include the entity names so we the have full context. +Make sure it is written in third person, and include the entity names so we have the full context. Enrich it as much as you can with relevant information from the nearby text, this is very important. diff --git a/graphrag/prompt_tune/types.py b/graphrag/prompt_tune/types.py new file mode 100644 index 0000000000..1207d18767 --- /dev/null +++ b/graphrag/prompt_tune/types.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Types for prompt tuning.""" + +from enum import Enum + + +class DocSelectionType(Enum): + """The type of document selection to use.""" + + ALL = "all" + RANDOM = "random" + TOP = "top" + AUTO = "auto" + + def __str__(self): + """Return the string representation of the enum value.""" + return self.value diff --git a/graphrag/query/__main__.py b/graphrag/query/__main__.py index edf678fa44..19ad00d5c9 100644 --- a/graphrag/query/__main__.py +++ b/graphrag/query/__main__.py @@ -23,7 +23,10 @@ def __str__(self): if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + prog="python -m graphrag.query", + description="The graphrag query engine", + ) parser.add_argument( "--config", @@ -49,7 +52,7 @@ def __str__(self): parser.add_argument( "--method", - help="The method to run, one of: local or global", + help="The method to run", required=True, type=SearchType, choices=list(SearchType), @@ -57,14 +60,14 @@ def __str__(self): parser.add_argument( "--community_level", - help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities", + help="Community level in the Leiden community hierarchy from which we will load the community reports higher value means we use reports on smaller communities. Default: 2", type=int, default=2, ) parser.add_argument( "--response_type", - help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report", + help="Free form text describing the response type and format, can be anything, e.g. Multiple Paragraphs, Single Paragraph, Single Sentence, List of 3-7 Points, Single Page, Multi-Page Report. Default: Multiple Paragraphs", type=str, default="Multiple Paragraphs", ) diff --git a/graphrag/query/api.py b/graphrag/query/api.py new file mode 100644 index 0000000000..8f6f82470a --- /dev/null +++ b/graphrag/query/api.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +""" +Query Engine API. + +This API provides access to the query engine of graphrag, allowing external applications +to hook into graphrag and run queries over a knowledge graph generated by graphrag. + +WARNING: This API is under development and may undergo changes in future releases. +Backwards compatibility is not guaranteed at this time. +""" + +from typing import Any + +import pandas as pd +from pydantic import validate_call + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.index.progress.types import PrintProgressReporter +from graphrag.model.entity import Entity +from graphrag.vector_stores.lancedb import LanceDBVectorStore +from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType + +from .factories import get_global_search_engine, get_local_search_engine +from .indexer_adapters import ( + read_indexer_covariates, + read_indexer_entities, + read_indexer_relationships, + read_indexer_reports, + read_indexer_text_units, +) +from .input.loaders.dfs import store_entity_semantic_embeddings + +reporter = PrintProgressReporter("") + + +def __get_embedding_description_store( + entities: list[Entity], + vector_store_type: str = VectorStoreType.LanceDB, + config_args: dict | None = None, +): + """Get the embedding description store.""" + if not config_args: + config_args = {} + + collection_name = config_args.get( + "query_collection_name", "entity_description_embeddings" + ) + config_args.update({"collection_name": collection_name}) + description_embedding_store = VectorStoreFactory.get_vector_store( + vector_store_type=vector_store_type, kwargs=config_args + ) + + description_embedding_store.connect(**config_args) + + if config_args.get("overwrite", True): + # this step assumes the embeddings were originally stored in a file rather + # than a vector database + + # dump embeddings from the entities list to the description_embedding_store + store_entity_semantic_embeddings( + entities=entities, vectorstore=description_embedding_store + ) + else: + # load description embeddings to an in-memory lancedb vectorstore + # and connect to a remote db, specify url and port values. + description_embedding_store = LanceDBVectorStore( + collection_name=collection_name + ) + description_embedding_store.connect( + db_uri=config_args.get("db_uri", "./lancedb") + ) + + # load data from an existing table + description_embedding_store.document_collection = ( + description_embedding_store.db_connection.open_table( + description_embedding_store.collection_name + ) + ) + + return description_embedding_store + + +@validate_call(config={"arbitrary_types_allowed": True}) +async def global_search( + config: GraphRagConfig, + nodes: pd.DataFrame, + entities: pd.DataFrame, + community_reports: pd.DataFrame, + community_level: int, + response_type: str, + query: str, +) -> str | dict[str, Any] | list[dict[str, Any]]: + """Perform a global search. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet) + - entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet) + - community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet) + - community_level (int): The community level to search at. + - response_type (str): The type of response to return. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + reports = read_indexer_reports(community_reports, nodes, community_level) + _entities = read_indexer_entities(nodes, entities, community_level) + search_engine = get_global_search_engine( + config, + reports=reports, + entities=_entities, + response_type=response_type, + ) + result = await search_engine.asearch(query=query) + reporter.success(f"Global Search Response: {result.response}") + return result.response + + +@validate_call(config={"arbitrary_types_allowed": True}) +async def local_search( + config: GraphRagConfig, + nodes: pd.DataFrame, + entities: pd.DataFrame, + community_reports: pd.DataFrame, + text_units: pd.DataFrame, + relationships: pd.DataFrame, + covariates: pd.DataFrame | None, + community_level: int, + response_type: str, + query: str, +) -> str | dict[str, Any] | list[dict[str, Any]]: + """Perform a local search. + + Parameters + ---------- + - config (GraphRagConfig): A graphrag configuration (from settings.yaml) + - nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet) + - entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet) + - community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet) + - text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet) + - relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet) + - covariates (pd.DataFrame): A DataFrame containing the final covariates (from create_final_covariates.parquet) + - community_level (int): The community level to search at. + - response_type (str): The response type to return. + - query (str): The user query to search for. + + Returns + ------- + TODO: Document the search response type and format. + + Raises + ------ + TODO: Document any exceptions to expect. + """ + vector_store_args = ( + config.embeddings.vector_store if config.embeddings.vector_store else {} + ) + + reporter.info(f"Vector Store Args: {vector_store_args}") + vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) + + _entities = read_indexer_entities(nodes, entities, community_level) + description_embedding_store = __get_embedding_description_store( + entities=_entities, + vector_store_type=vector_store_type, + config_args=vector_store_args, + ) + _covariates = read_indexer_covariates(covariates) if covariates is not None else [] + + search_engine = get_local_search_engine( + config=config, + reports=read_indexer_reports(community_reports, nodes, community_level), + text_units=read_indexer_text_units(text_units), + entities=_entities, + relationships=read_indexer_relationships(relationships), + covariates={"claims": _covariates}, + description_embedding_store=description_embedding_store, + response_type=response_type, + ) + + result = await search_engine.asearch(query=query) + reporter.success(f"Local Search Response: {result.response}") + return result.response diff --git a/graphrag/query/cli.py b/graphrag/query/cli.py index 81efbb550b..16cfe0c97a 100644 --- a/graphrag/query/cli.py +++ b/graphrag/query/cli.py @@ -3,6 +3,7 @@ """Command line interface for the query module.""" +import asyncio import os from pathlib import Path from typing import cast @@ -14,72 +15,12 @@ create_graphrag_config, ) from graphrag.index.progress import PrintProgressReporter -from graphrag.model.entity import Entity -from graphrag.query.input.loaders.dfs import ( - store_entity_semantic_embeddings, -) -from graphrag.vector_stores import VectorStoreFactory, VectorStoreType -from graphrag.vector_stores.lancedb import LanceDBVectorStore -from .factories import get_global_search_engine, get_local_search_engine -from .indexer_adapters import ( - read_indexer_covariates, - read_indexer_entities, - read_indexer_relationships, - read_indexer_reports, - read_indexer_text_units, -) +from . import api reporter = PrintProgressReporter("") -def __get_embedding_description_store( - entities: list[Entity], - vector_store_type: str = VectorStoreType.LanceDB, - config_args: dict | None = None, -): - """Get the embedding description store.""" - if not config_args: - config_args = {} - - collection_name = config_args.get( - "query_collection_name", "entity_description_embeddings" - ) - config_args.update({"collection_name": collection_name}) - description_embedding_store = VectorStoreFactory.get_vector_store( - vector_store_type=vector_store_type, kwargs=config_args - ) - - description_embedding_store.connect(**config_args) - - if config_args.get("overwrite", True): - # this step assumps the embeddings where originally stored in a file rather - # than a vector database - - # dump embeddings from the entities list to the description_embedding_store - store_entity_semantic_embeddings( - entities=entities, vectorstore=description_embedding_store - ) - else: - # load description embeddings to an in-memory lancedb vectorstore - # to connect to a remote db, specify url and port values. - description_embedding_store = LanceDBVectorStore( - collection_name=collection_name - ) - description_embedding_store.connect( - db_uri=config_args.get("db_uri", "./lancedb") - ) - - # load data from an existing table - description_embedding_store.document_collection = ( - description_embedding_store.db_connection.open_table( - description_embedding_store.collection_name - ) - ) - - return description_embedding_store - - def run_global_search( config_dir: str | None, data_dir: str | None, @@ -88,7 +29,10 @@ def run_global_search( response_type: str, query: str, ): - """Run a global search with the given query.""" + """Perform a global search with a given query. + + Loads index files required for global search and calls the Query API. + """ data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) @@ -104,22 +48,18 @@ def run_global_search( data_path / "create_final_community_reports.parquet" ) - reports = read_indexer_reports( - final_community_reports, final_nodes, community_level - ) - entities = read_indexer_entities(final_nodes, final_entities, community_level) - search_engine = get_global_search_engine( - config, - reports=reports, - entities=entities, - response_type=response_type, + return asyncio.run( + api.global_search( + config=config, + nodes=final_nodes, + entities=final_entities, + community_reports=final_community_reports, + community_level=community_level, + response_type=response_type, + query=query, + ) ) - result = search_engine.search(query=query) - - reporter.success(f"Global Search Response: {result.response}") - return result.response - def run_local_search( config_dir: str | None, @@ -129,7 +69,10 @@ def run_local_search( response_type: str, query: str, ): - """Run a local search with the given query.""" + """Perform a local search with a given query. + + Loads index files required for local search and calls the Query API. + """ data_dir, root_dir, config = _configure_paths_and_settings( data_dir, root_dir, config_dir ) @@ -151,42 +94,22 @@ def run_local_search( else None ) - vector_store_args = ( - config.embeddings.vector_store if config.embeddings.vector_store else {} - ) - - reporter.info(f"Vector Store Args: {vector_store_args}") - vector_store_type = vector_store_args.get("type", VectorStoreType.LanceDB) - - entities = read_indexer_entities(final_nodes, final_entities, community_level) - description_embedding_store = __get_embedding_description_store( - entities=entities, - vector_store_type=vector_store_type, - config_args=vector_store_args, - ) - covariates = ( - read_indexer_covariates(final_covariates) - if final_covariates is not None - else [] - ) - - search_engine = get_local_search_engine( - config, - reports=read_indexer_reports( - final_community_reports, final_nodes, community_level - ), - text_units=read_indexer_text_units(final_text_units), - entities=entities, - relationships=read_indexer_relationships(final_relationships), - covariates={"claims": covariates}, - description_embedding_store=description_embedding_store, - response_type=response_type, + # call the Query API + return asyncio.run( + api.local_search( + config=config, + nodes=final_nodes, + entities=final_entities, + community_reports=final_community_reports, + text_units=final_text_units, + relationships=final_relationships, + covariates=final_covariates, + community_level=community_level, + response_type=response_type, + query=query, + ) ) - result = search_engine.search(query=query) - reporter.success(f"Local Search Response: {result.response}") - return result.response - def _configure_paths_and_settings( data_dir: str | None, diff --git a/graphrag/query/context_builder/community_context.py b/graphrag/query/context_builder/community_context.py index 398f8ac422..d344e2c06e 100644 --- a/graphrag/query/context_builder/community_context.py +++ b/graphrag/query/context_builder/community_context.py @@ -15,6 +15,10 @@ log = logging.getLogger(__name__) +NO_COMMUNITY_RECORDS_WARNING: str = ( + "Warning: No community records added when building community context." +) + def build_community_context( community_reports: list[CommunityReport], @@ -128,9 +132,9 @@ def _cut_batch() -> None: record_df = _convert_report_context_to_df( context_records=batch_records, header=header, - weight_column=community_weight_name - if entities and include_community_weight - else None, + weight_column=( + community_weight_name if entities and include_community_weight else None + ), rank_column=community_rank_name if include_community_rank else None, ) if len(record_df) == 0: @@ -163,9 +167,7 @@ def _cut_batch() -> None: _cut_batch() if len(all_context_records) == 0: - log.warning( - "Warning: No community records added when building community context." - ) + log.warning(NO_COMMUNITY_RECORDS_WARNING) return ([], {}) return all_context_text, { diff --git a/pyproject.toml b/pyproject.toml index 2c8d27db54..24ad11c9c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "graphrag" # Maintainers: do not change the version here manually, use ./scripts/release.sh -version = "0.2.2" +version = "0.3.0" description = "" authors = [ "Alonso Guevara Fernández ", diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index b5118c6a50..c5aff3d977 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -16,6 +16,9 @@ import pytest from graphrag.index.storage.blob_pipeline_storage import BlobPipelineStorage +from graphrag.query.context_builder.community_context import ( + NO_COMMUNITY_RECORDS_WARNING, +) log = logging.getLogger(__name__) @@ -25,6 +28,8 @@ # cspell:disable-next-line well-known-key WELL_KNOWN_AZURITE_CONNECTION_STRING = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1" +KNOWN_WARNINGS = [NO_COMMUNITY_RECORDS_WARNING] + def _load_fixtures(): """Load all fixtures from the tests/data folder.""" @@ -294,6 +299,8 @@ def test_fixture( result.stderr if "No existing dataset at" not in result.stderr else "" ) - assert stderror == "", f"Query failed with error: {stderror}" + assert ( + stderror == "" or stderror.replace("\n", "") in KNOWN_WARNINGS + ), f"Query failed with error: {stderror}" assert result.stdout is not None, "Query returned no output" assert len(result.stdout) > 0, "Query returned empty output" diff --git a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py index 9b66478876..4d6e36c306 100644 --- a/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py +++ b/tests/unit/indexing/graph/extractors/community_reports/test_sort_context.py @@ -1,204 +1,213 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License import math +import platform from graphrag.index.graph.extractors.community_reports import sort_context +from graphrag.query.llm.text_utils import num_tokens nan = math.nan - -def test_sort_context(): - context: list[dict] = [ - { +context: list[dict] = [ + { + "title": "ALI BABA", + "degree": 1, + "node_details": { + "human_readable_id": 26, "title": "ALI BABA", + "description": "A character from Scrooge's reading, representing a memory of his childhood imagination", "degree": 1, - "node_details": { - "human_readable_id": 26, - "title": "ALI BABA", - "description": "A character from Scrooge's reading, representing a memory of his childhood imagination", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 28, - "source": "SCROOGE", - "target": "ALI BABA", - "description": "Scrooge recalls Ali Baba as a fond memory from his childhood readings", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 28, + "source": "SCROOGE", + "target": "ALI BABA", + "description": "Scrooge recalls Ali Baba as a fond memory from his childhood readings", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "BELLE", + "degree": 1, + "node_details": { + "human_readable_id": 31, "title": "BELLE", + "description": "A woman from Scrooge's past, reflecting on how Scrooge's pursuit of wealth changed him and led to the end of their relationship", "degree": 1, - "node_details": { - "human_readable_id": 31, - "title": "BELLE", - "description": "A woman from Scrooge's past, reflecting on how Scrooge's pursuit of wealth changed him and led to the end of their relationship", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 32, - "source": "SCROOGE", - "target": "BELLE", - "description": "Belle and Scrooge were once engaged, but their relationship ended due to Scrooge's growing obsession with wealth", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 32, + "source": "SCROOGE", + "target": "BELLE", + "description": "Belle and Scrooge were once engaged, but their relationship ended due to Scrooge's growing obsession with wealth", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "CHRISTMAS", + "degree": 1, + "node_details": { + "human_readable_id": 17, "title": "CHRISTMAS", + "description": "A festive season that highlights the contrast between abundance and want, joy and misery in the story", "degree": 1, - "node_details": { - "human_readable_id": 17, - "title": "CHRISTMAS", - "description": "A festive season that highlights the contrast between abundance and want, joy and misery in the story", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 23, - "source": "SCROOGE", - "target": "CHRISTMAS", - "description": "Scrooge's disdain for Christmas is a central theme, highlighting his miserliness and lack of compassion", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 23, + "source": "SCROOGE", + "target": "CHRISTMAS", + "description": "Scrooge's disdain for Christmas is a central theme, highlighting his miserliness and lack of compassion", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "CHRISTMAS DAY", + "degree": 1, + "node_details": { + "human_readable_id": 57, "title": "CHRISTMAS DAY", + "description": "The day Scrooge realizes he hasn't missed the opportunity to celebrate and spread joy", "degree": 1, - "node_details": { - "human_readable_id": 57, - "title": "CHRISTMAS DAY", - "description": "The day Scrooge realizes he hasn't missed the opportunity to celebrate and spread joy", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 46, - "source": "SCROOGE", - "target": "CHRISTMAS DAY", - "description": "Scrooge wakes up on Christmas Day with a changed heart, ready to celebrate and spread happiness", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 46, + "source": "SCROOGE", + "target": "CHRISTMAS DAY", + "description": "Scrooge wakes up on Christmas Day with a changed heart, ready to celebrate and spread happiness", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "DUTCH MERCHANT", + "degree": 1, + "node_details": { + "human_readable_id": 19, "title": "DUTCH MERCHANT", + "description": "A historical figure mentioned as having built the fireplace in Scrooge's home, adorned with tiles illustrating the Scriptures", "degree": 1, - "node_details": { - "human_readable_id": 19, - "title": "DUTCH MERCHANT", - "description": "A historical figure mentioned as having built the fireplace in Scrooge's home, adorned with tiles illustrating the Scriptures", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 25, - "source": "SCROOGE", - "target": "DUTCH MERCHANT", - "description": "Scrooge's fireplace, built by the Dutch Merchant, serves as a focal point in his room where he encounters Marley's Ghost", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 25, + "source": "SCROOGE", + "target": "DUTCH MERCHANT", + "description": "Scrooge's fireplace, built by the Dutch Merchant, serves as a focal point in his room where he encounters Marley's Ghost", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "FAN", + "degree": 1, + "node_details": { + "human_readable_id": 27, "title": "FAN", + "description": "Scrooge's sister, who comes to bring him home from school for Christmas, showing a loving family relationship", "degree": 1, - "node_details": { - "human_readable_id": 27, - "title": "FAN", - "description": "Scrooge's sister, who comes to bring him home from school for Christmas, showing a loving family relationship", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 29, - "source": "SCROOGE", - "target": "FAN", - "description": "Fan is Scrooge's sister, who shows love and care by bringing him home for Christmas", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 29, + "source": "SCROOGE", + "target": "FAN", + "description": "Fan is Scrooge's sister, who shows love and care by bringing him home for Christmas", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "FRED", + "degree": 1, + "node_details": { + "human_readable_id": 58, "title": "FRED", + "description": "Scrooge's nephew, who invites Scrooge to Christmas dinner, symbolizing family reconciliation", "degree": 1, - "node_details": { - "human_readable_id": 58, - "title": "FRED", - "description": "Scrooge's nephew, who invites Scrooge to Christmas dinner, symbolizing family reconciliation", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 47, - "source": "SCROOGE", - "target": "FRED", - "description": "Scrooge accepts Fred's invitation to Christmas dinner, marking a significant step in repairing their relationship", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 47, + "source": "SCROOGE", + "target": "FRED", + "description": "Scrooge accepts Fred's invitation to Christmas dinner, marking a significant step in repairing their relationship", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "GENTLEMAN", + "degree": 1, + "node_details": { + "human_readable_id": 15, "title": "GENTLEMAN", + "description": "Represents charitable efforts to provide for the poor during the Christmas season", "degree": 1, - "node_details": { - "human_readable_id": 15, - "title": "GENTLEMAN", - "description": "Represents charitable efforts to provide for the poor during the Christmas season", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 21, - "source": "SCROOGE", - "target": "GENTLEMAN", - "description": "The gentleman approaches Scrooge to solicit donations for the poor, which Scrooge rebuffs", - "rank": 32, - }, - ], - "claim_details": [nan], }, - { + "edge_details": [ + nan, + { + "human_readable_id": 21, + "source": "SCROOGE", + "target": "GENTLEMAN", + "description": "The gentleman approaches Scrooge to solicit donations for the poor, which Scrooge rebuffs", + "rank": 32, + }, + ], + "claim_details": [nan], + }, + { + "title": "GHOST", + "degree": 1, + "node_details": { + "human_readable_id": 25, "title": "GHOST", + "description": "The Ghost is a spectral entity that plays a crucial role in guiding Scrooge through an introspective journey in Charles Dickens' classic tale. This spirit, likely one of the Christmas spirits, takes Scrooge on a transformative voyage through his past memories, the realities of his present, and the potential outcomes of his future. The purpose of this journey is to make Scrooge reflect deeply on his life, encouraging a profound understanding of the joy and meaning of Christmas. By showing Scrooge scenes from his life, including the potential fate of Tiny Tim, the Ghost rebukes Scrooge for his lack of compassion, ultimately aiming to instill in him a sense of responsibility and empathy towards others. Through this experience, the Ghost seeks to enlighten Scrooge, urging him to change his ways for the better.", "degree": 1, - "node_details": { - "human_readable_id": 25, - "title": "GHOST", - "description": "The Ghost is a spectral entity that plays a crucial role in guiding Scrooge through an introspective journey in Charles Dickens' classic tale. This spirit, likely one of the Christmas spirits, takes Scrooge on a transformative voyage through his past memories, the realities of his present, and the potential outcomes of his future. The purpose of this journey is to make Scrooge reflect deeply on his life, encouraging a profound understanding of the joy and meaning of Christmas. By showing Scrooge scenes from his life, including the potential fate of Tiny Tim, the Ghost rebukes Scrooge for his lack of compassion, ultimately aiming to instill in him a sense of responsibility and empathy towards others. Through this experience, the Ghost seeks to enlighten Scrooge, urging him to change his ways for the better.", - "degree": 1, - }, - "edge_details": [ - nan, - { - "human_readable_id": 27, - "source": "SCROOGE", - "target": "GHOST", - "description": "The Ghost is taking Scrooge on a transformative journey by showing him scenes from his past, aiming to make him reflect on his life choices and their consequences. This spectral guide is not only focusing on Scrooge's personal history but also emphasizing the importance of Christmas and the need for a change in perspective. Through these vivid reenactments, the Ghost highlights the error of Scrooge's ways and the significant impact his actions have on others, including Tiny Tim. This experience is designed to enlighten Scrooge, encouraging him to reconsider his approach to life and the people around him.", - "rank": 32, - }, - ], - "claim_details": [nan], }, - ] + "edge_details": [ + nan, + { + "human_readable_id": 27, + "source": "SCROOGE", + "target": "GHOST", + "description": "The Ghost is taking Scrooge on a transformative journey by showing him scenes from his past, aiming to make him reflect on his life choices and their consequences. This spectral guide is not only focusing on Scrooge's personal history but also emphasizing the importance of Christmas and the need for a change in perspective. Through these vivid reenactments, the Ghost highlights the error of Scrooge's ways and the significant impact his actions have on others, including Tiny Tim. This experience is designed to enlighten Scrooge, encouraging him to reconsider his approach to life and the people around him.", + "rank": 32, + }, + ], + "claim_details": [nan], + }, +] + +def test_sort_context(): ctx = sort_context(context) + assert num_tokens(ctx) == 827 if platform.system() == "Windows" else 826 + assert ctx is not None + + +def test_sort_context_max_tokens(): + ctx = sort_context(context, max_tokens=800) assert ctx is not None + assert num_tokens(ctx) <= 800