diff --git a/libs/community/langchain_community/chat_models/sparkllm.py b/libs/community/langchain_community/chat_models/sparkllm.py index b0e207b126218..d8b7c7dc0a8d7 100644 --- a/libs/community/langchain_community/chat_models/sparkllm.py +++ b/libs/community/langchain_community/chat_models/sparkllm.py @@ -92,34 +92,116 @@ def _convert_delta_to_message_chunk( class ChatSparkLLM(BaseChatModel): - """iFlyTek Spark large language model. + """IFlyTek Spark chat model integration. + + Setup: + To use, you should have the environment variable``IFLYTEK_SPARK_API_KEY``, + ``IFLYTEK_SPARK_API_SECRET`` and ``IFLYTEK_SPARK_APP_ID``. + + Key init args — completion params: + model: Optional[str] + Name of IFLYTEK SPARK model to use. + temperature: Optional[float] + Sampling temperature. + top_k: Optional[float] + What search sampling control to use. + streaming: Optional[bool] + Whether to stream the results or not. + + Key init args — client params: + api_key: Optional[str] + IFLYTEK SPARK API KEY. If not passed in will be read from env var IFLYTEK_SPARK_API_KEY. + api_secret: Optional[str] + IFLYTEK SPARK API SECRET. If not passed in will be read from env var IFLYTEK_SPARK_API_SECRET. + api_url: Optional[str] + Base URL for API requests. + timeout: Optional[int] + Timeout for requests. + + See full list of supported init args and their descriptions in the params section. + + Instantiate: + .. code-block:: python + + from langchain_community.chat_models import ChatSparkLLM - To use, you should pass `app_id`, `api_key`, `api_secret` - as a named parameter to the constructor OR set environment - variables ``IFLYTEK_SPARK_APP_ID``, ``IFLYTEK_SPARK_API_KEY`` and - ``IFLYTEK_SPARK_API_SECRET`` + chat = MiniMaxChat( + api_key=api_key, + api_secret=ak, + model='Spark4.0 Ultra', + # temperature=..., + # other params... + ) - Example: + Invoke: .. code-block:: python - client = ChatSparkLLM( - spark_app_id="", - spark_api_key="", - spark_api_secret="" - ) + messages = [ + ("system", "你是一名专业的翻译家,可以将用户的中文翻译为英文。"), + ("human", "我喜欢编程。"), + ] + chat.invoke(messages) - Extra infos: - 1. Get app_id, api_key, api_secret from the iFlyTek Open Platform Console: - https://console.xfyun.cn/services/bm35 - 2. By default, iFlyTek Spark LLM V3.5 is invoked. - If you need to invoke other versions, please configure the corresponding - parameters(spark_api_url and spark_llm_domain) according to the document: - https://www.xfyun.cn/doc/spark/Web.html - 3. It is necessary to ensure that the app_id used has a license for - the corresponding model version. - 4. If you encounter problems during use, try getting help at: - https://console.xfyun.cn/workorder/commit - """ + .. code-block:: python + + AIMessage( + content='I like programming.', + response_metadata={ + 'token_usage': { + 'question_tokens': 3, + 'prompt_tokens': 16, + 'completion_tokens': 4, + 'total_tokens': 20 + } + }, + id='run-af8b3531-7bf7-47f0-bfe8-9262cb2a9d47-0' + ) + + Stream: + .. code-block:: python + + for chunk in chat.stream(messages): + print(chunk) + + .. code-block:: python + + content='I' id='run-fdbb57c2-2d32-4516-b894-6c5a67605d83' + content=' like programming' id='run-fdbb57c2-2d32-4516-b894-6c5a67605d83' + content='.' id='run-fdbb57c2-2d32-4516-b894-6c5a67605d83' + + .. code-block:: python + + stream = chat.stream(messages) + full = next(stream) + for chunk in stream: + full += chunk + full + + .. code-block:: python + + AIMessageChunk( + content='I like programming.', + id='run-aca2fa82-c2e4-4835-b7e2-865ddd3c46cb' + ) + + Response metadata + .. code-block:: python + + ai_msg = chat.invoke(messages) + ai_msg.response_metadata + + .. code-block:: python + + { + 'token_usage': { + 'question_tokens': 3, + 'prompt_tokens': 16, + 'completion_tokens': 4, + 'total_tokens': 20 + } + } + + """ # noqa: E501 @classmethod def is_lc_serializable(cls) -> bool: @@ -257,7 +339,7 @@ def _stream( [_convert_message_to_dict(m) for m in messages], self.spark_user_id, self.model_kwargs, - self.streaming, + streaming=True, ) for content in self.client.subscribe(timeout=self.request_timeout): if "data" not in content: @@ -274,9 +356,10 @@ def _generate( messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - if self.streaming: + if stream or self.streaming: stream_iter = self._stream( messages=messages, stop=stop, run_manager=run_manager, **kwargs ) diff --git a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py index 848dc487bb8d9..ae94a8a3e600a 100644 --- a/libs/community/tests/integration_tests/chat_models/test_sparkllm.py +++ b/libs/community/tests/integration_tests/chat_models/test_sparkllm.py @@ -53,3 +53,10 @@ def test_chat_spark_llm_with_temperature() -> None: print(response) # noqa: T201 assert isinstance(response, AIMessage) assert isinstance(response.content, str) + + +def test_chat_spark_llm_streaming_with_stream_method() -> None: + chat = ChatSparkLLM() # type: ignore[call-arg] + for chunk in chat.stream("Hello!"): + assert isinstance(chunk, AIMessageChunk) + assert isinstance(chunk.content, str)