diff --git a/libs/community/langchain_community/chat_models/perplexity.py b/libs/community/langchain_community/chat_models/perplexity.py index f5a58def0e9d2..d2decb38df93b 100644 --- a/libs/community/langchain_community/chat_models/perplexity.py +++ b/libs/community/langchain_community/chat_models/perplexity.py @@ -223,15 +223,21 @@ def _stream( stream_resp = self.client.chat.completions.create( messages=message_dicts, stream=True, **params ) + first_chunk = True for chunk in stream_resp: if not isinstance(chunk, dict): chunk = chunk.dict() if len(chunk["choices"]) == 0: continue choice = chunk["choices"][0] + citations = chunk.get("citations", []) + chunk = self._convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) + if first_chunk: + chunk.additional_kwargs |= {"citations": citations} + first_chunk = False finish_reason = choice.get("finish_reason") generation_info = ( dict(finish_reason=finish_reason) if finish_reason is not None else None diff --git a/libs/community/tests/unit_tests/chat_models/test_perplexity.py b/libs/community/tests/unit_tests/chat_models/test_perplexity.py index 024274f7e84bc..d4991de422a36 100644 --- a/libs/community/tests/unit_tests/chat_models/test_perplexity.py +++ b/libs/community/tests/unit_tests/chat_models/test_perplexity.py @@ -1,8 +1,12 @@ """Test Perplexity Chat API wrapper.""" import os +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock import pytest +from langchain_core.messages import AIMessageChunk, BaseMessageChunk +from pytest_mock import MockerFixture from langchain_community.chat_models import ChatPerplexity @@ -40,3 +44,58 @@ def test_perplexity_initialization() -> None: ]: assert model.request_timeout == 1 assert model.pplx_api_key == "test" + + +@pytest.mark.requires("openai") +def test_perplexity_stream_includes_citations(mocker: MockerFixture) -> None: + """Test that the stream method includes citations in the additional_kwargs.""" + llm = ChatPerplexity( + model="test", + timeout=30, + verbose=True, + ) + mock_chunk_0 = { + "choices": [ + { + "delta": { + "content": "Hello ", + }, + "finish_reason": None, + } + ], + "citations": ["example.com", "example2.com"], + } + mock_chunk_1 = { + "choices": [ + { + "delta": { + "content": "Perplexity", + }, + "finish_reason": None, + } + ], + "citations": ["example.com", "example2.com"], + } + mock_chunks: List[Dict[str, Any]] = [mock_chunk_0, mock_chunk_1] + mock_stream = MagicMock() + mock_stream.__iter__.return_value = mock_chunks + patcher = mocker.patch.object( + llm.client.chat.completions, "create", return_value=mock_stream + ) + stream = llm.stream("Hello langchain") + full: Optional[BaseMessageChunk] = None + for i, chunk in enumerate(stream): + full = chunk if full is None else full + chunk + assert chunk.content == mock_chunks[i]["choices"][0]["delta"]["content"] + if i == 0: + assert chunk.additional_kwargs["citations"] == [ + "example.com", + "example2.com", + ] + else: + assert "citations" not in chunk.additional_kwargs + assert isinstance(full, AIMessageChunk) + assert full.content == "Hello Perplexity" + assert full.additional_kwargs == {"citations": ["example.com", "example2.com"]} + + patcher.assert_called_once()