diff --git a/libs/community/langchain_community/document_loaders/needle.py b/libs/community/langchain_community/document_loaders/needle.py index 4bad8f67685ba..94023826bb8b3 100644 --- a/libs/community/langchain_community/document_loaders/needle.py +++ b/libs/community/langchain_community/document_loaders/needle.py @@ -93,6 +93,7 @@ def add_files(self, files: dict) -> None: ValueError: If the collection is not properly initialized. """ self._get_collection() + assert self.client is not None, "NeedleClient must be initialized." files_to_add = [] for name, url in files.items(): @@ -113,6 +114,7 @@ def _fetch_documents(self) -> List[Document]: ValueError: If the collection is not properly initialized. """ self._get_collection() + assert self.client is not None, "NeedleClient must be initialized." files = self.client.collections.files.list(self.collection_id) docs = [] @@ -129,6 +131,7 @@ def _fetch_documents(self) -> List[Document]: docs.append(doc) return docs + def load(self) -> List[Document]: """ Loads all documents from the Needle collection. diff --git a/libs/community/langchain_community/retrievers/needle.py b/libs/community/langchain_community/retrievers/needle.py index 190305b88f876..6c31cad4ab3de 100644 --- a/libs/community/langchain_community/retrievers/needle.py +++ b/libs/community/langchain_community/retrievers/needle.py @@ -66,6 +66,9 @@ def _search_collection(self, query: str) -> List[Document]: List[Document]: A list of documents matching the search query. """ self._initialize_client() + if self.client is None: + raise ValueError("Provide a valid API key.") + results = self.client.collections.search( collection_id=self.collection_id, text=query ) diff --git a/libs/community/tests/unit_tests/document_loaders/test_needle.py b/libs/community/tests/unit_tests/document_loaders/test_needle.py index e6bd56d0f1f0b..ac3888d6fd00e 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_needle.py +++ b/libs/community/tests/unit_tests/document_loaders/test_needle.py @@ -36,9 +36,9 @@ def list(self, collection_id: str): ] -# Need to pass real API key and collection ID to test this function, otherwise fails -@pytest.mark.usefixtures("socket_enabled") +@pytest.mark.requires("needle-python") def test_add_and_fetch_files(mocker: MockerFixture): + """Test adding and fetching files using the NeedleLoader with a mock NeedleClient.""" # Mock the NeedleClient to use the mock implementation mocker.patch("needle.v1.NeedleClient", new=MockNeedleClient) diff --git a/libs/community/tests/unit_tests/retrievers/test_needle.py b/libs/community/tests/unit_tests/retrievers/test_needle.py index 1fefd66a67402..2d505f1f5ee24 100644 --- a/libs/community/tests/unit_tests/retrievers/test_needle.py +++ b/libs/community/tests/unit_tests/retrievers/test_needle.py @@ -24,6 +24,7 @@ def search(self, collection_id: str, text: str): ] +@pytest.mark.requires("needle-python") def test_needle_retriever_initialization() -> None: """Test that the NeedleRetriever is initialized correctly.""" retriever = NeedleRetriever( @@ -35,7 +36,7 @@ def test_needle_retriever_initialization() -> None: assert retriever.collection_id == "mock_collection_id" -@pytest.mark.usefixtures("socket_enabled") +@pytest.mark.requires("needle-python") def test_get_relevant_documents(mocker: MockerFixture) -> None: """Test that the retriever correctly fetches documents.""" # Patch NeedleClient with the mock