diff --git a/metagpt/rag/factories/retriever.py b/metagpt/rag/factories/retriever.py index dd1a5726a..9582a589d 100644 --- a/metagpt/rag/factories/retriever.py +++ b/metagpt/rag/factories/retriever.py @@ -33,7 +33,6 @@ MilvusRetrieverConfig, Neo4jPGRetrieverConfig, ) -from metagpt.utils.async_helper import NestAsyncio def get_or_build_index(build_index_func): @@ -127,13 +126,7 @@ def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) - return ElasticsearchRetriever(**config.model_dump()) def _create_neo4j_pg_retriever(self, config: Neo4jPGRetrieverConfig, **kwargs) -> PGRetriever: - NestAsyncio.apply_once() - graph_store = Neo4jPropertyGraphStore(**config.model_dump(exclude={"index", "similarity_top_k"})) - graph_index = PropertyGraphIndex( - nodes=self._extract_nodes(**kwargs), - property_graph_store=graph_store, - embed_model=self._extract_embed_model(**kwargs), - ) + graph_index = self._build_neo4j_pg_index(config, **kwargs) return graph_index.as_retriever(**config.model_dump()) def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex: @@ -154,7 +147,6 @@ def _build_default_vector_index(self, **kwargs) -> VectorStoreIndex: def _build_default_pg_index(self, **kwargs): # build default PropertyGraphIndex - NestAsyncio.apply_once() pg_index = PropertyGraphIndex( nodes=self._extract_nodes(**kwargs), embed_model=self._extract_embed_model(**kwargs), @@ -189,6 +181,17 @@ def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> Vec return self._build_index_from_vector_store(config, vector_store, **kwargs) + @get_or_build_index + def _build_neo4j_pg_index(self, config: Neo4jPGRetrieverConfig, **kwargs) -> PropertyGraphIndex: + graph_store = Neo4jPropertyGraphStore(**config.store_config.model_dump()) + graph_index = PropertyGraphIndex( + nodes=self._extract_nodes(**kwargs), + property_graph_store=graph_store, + embed_model=self._extract_embed_model(**kwargs), + **config.model_dump(), + ) + return graph_index + def _build_index_from_vector_store( self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs ) -> VectorStoreIndex: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index a87ceecd5..a1f93bc02 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -6,7 +6,7 @@ from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex -from llama_index.core.schema import TextNode +from llama_index.core.schema import TextNode, TransformComponent from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator @@ -102,13 +102,20 @@ class ChromaRetrieverConfig(IndexRetrieverConfig): ) -class Neo4jPGRetrieverConfig(IndexRetrieverConfig): +class Neo4jPGRetrieverStoreConfig(BaseModel): username: str = Field(default="neo4j", description="The username for neo4j.") password: str = Field(default="", description="The password for neo4j.") url: str = Field(default="bolt://localhost:7687", description="The neo4j server to save data.") database: str = Field(default="neo4j", description="The database to save data.") +class Neo4jPGRetrieverConfig(IndexRetrieverConfig): + store_config: Neo4jPGRetrieverStoreConfig = Field( + default=Neo4jPGRetrieverStoreConfig(), description="Neo4jPGRetrieverStoreConfig" + ) + kg_extractors: Optional[List[TransformComponent]] = Field(default=None, description="property graph extractors.") + + class ElasticsearchStoreConfig(BaseModel): index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.") es_url: str = Field(default=None, description="Elasticsearch URL.")