Skip to content

Commit

Permalink
code opt
Browse files Browse the repository at this point in the history
  • Loading branch information
HuiDBK committed Dec 19, 2024
1 parent 30b305d commit e3bf81a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
21 changes: 12 additions & 9 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
MilvusRetrieverConfig,
Neo4jPGRetrieverConfig,
)
from metagpt.utils.async_helper import NestAsyncio


def get_or_build_index(build_index_func):
Expand Down Expand Up @@ -127,13 +126,7 @@ def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -
return ElasticsearchRetriever(**config.model_dump())

def _create_neo4j_pg_retriever(self, config: Neo4jPGRetrieverConfig, **kwargs) -> PGRetriever:
NestAsyncio.apply_once()
graph_store = Neo4jPropertyGraphStore(**config.model_dump(exclude={"index", "similarity_top_k"}))
graph_index = PropertyGraphIndex(
nodes=self._extract_nodes(**kwargs),
property_graph_store=graph_store,
embed_model=self._extract_embed_model(**kwargs),
)
graph_index = self._build_neo4j_pg_index(config, **kwargs)
return graph_index.as_retriever(**config.model_dump())

def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
Expand All @@ -154,7 +147,6 @@ def _build_default_vector_index(self, **kwargs) -> VectorStoreIndex:

def _build_default_pg_index(self, **kwargs):
# build default PropertyGraphIndex
NestAsyncio.apply_once()
pg_index = PropertyGraphIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
Expand Down Expand Up @@ -189,6 +181,17 @@ def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> Vec

return self._build_index_from_vector_store(config, vector_store, **kwargs)

@get_or_build_index
def _build_neo4j_pg_index(self, config: Neo4jPGRetrieverConfig, **kwargs) -> PropertyGraphIndex:
graph_store = Neo4jPropertyGraphStore(**config.store_config.model_dump())
graph_index = PropertyGraphIndex(
nodes=self._extract_nodes(**kwargs),
property_graph_store=graph_store,
embed_model=self._extract_embed_model(**kwargs),
**config.model_dump(),
)
return graph_index

def _build_index_from_vector_store(
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
Expand Down
11 changes: 9 additions & 2 deletions metagpt/rag/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.schema import TextNode, TransformComponent
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator

Expand Down Expand Up @@ -102,13 +102,20 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
)


class Neo4jPGRetrieverConfig(IndexRetrieverConfig):
class Neo4jPGRetrieverStoreConfig(BaseModel):
username: str = Field(default="neo4j", description="The username for neo4j.")
password: str = Field(default="<password>", description="The password for neo4j.")
url: str = Field(default="bolt://localhost:7687", description="The neo4j server to save data.")
database: str = Field(default="neo4j", description="The database to save data.")


class Neo4jPGRetrieverConfig(IndexRetrieverConfig):
store_config: Neo4jPGRetrieverStoreConfig = Field(
default=Neo4jPGRetrieverStoreConfig(), description="Neo4jPGRetrieverStoreConfig"
)
kg_extractors: Optional[List[TransformComponent]] = Field(default=None, description="property graph extractors.")


class ElasticsearchStoreConfig(BaseModel):
index_name: str = Field(default="metagpt", description="Name of the Elasticsearch index.")
es_url: str = Field(default=None, description="Elasticsearch URL.")
Expand Down

0 comments on commit e3bf81a

Please sign in to comment.