From 4ca37fedf8eb8dcbe5d2a76918211cbbafd7c0e8 Mon Sep 17 00:00:00 2001 From: TheSongg <145535169+TheSongg@users.noreply.github.com> Date: Fri, 17 Jan 2025 10:58:56 +0800 Subject: [PATCH] [langchain_community.llms.xinference]: Rewrite _stream() method and support stream() method in xinference.py --- .../langchain_community/llms/xinference.py | 91 ++++++++++++++++++- 1 file changed, 90 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/llms/xinference.py b/libs/community/langchain_community/llms/xinference.py index ada9f8b1f1084..06bc1c9b30e62 100644 --- a/libs/community/langchain_community/llms/xinference.py +++ b/libs/community/langchain_community/llms/xinference.py @@ -1,7 +1,20 @@ -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union +from __future__ import annotations + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterator, + List, + Mapping, + Optional, + Union, +) from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk if TYPE_CHECKING: from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle @@ -73,6 +86,26 @@ class Xinference(LLM): generate_config={"max_tokens": 1024, "stream": True}, ) + Example: + + .. code-block:: python + + from langchain_community.llms import Xinference + from langchain.prompts import PromptTemplate + + llm = Xinference( + server_url="http://0.0.0.0:9997", + model_uid={model_uid}, # replace model_uid with the model UID return from launching the model + stream=True + ) + prompt = PromptTemplate( + input=['country'], + template="Q: where can we visit in the capital of {country}? A:" + ) + chain = prompt | llm + chain.stream(input={'country': 'France'}) + + To view all the supported builtin models, run: .. code-block:: bash @@ -216,3 +249,59 @@ def _stream_generate( token=token, verbose=self.verbose, log_probs=log_probs ) yield token + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + generate_config = kwargs.get("generate_config", {}) + generate_config = {**self.model_kwargs, **generate_config} + if stop: + generate_config["stop"] = stop + for stream_resp in self._create_generate_stream(prompt, generate_config): + if stream_resp: + chunk = self._stream_response_to_generation_chunk(stream_resp) + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + ) + yield chunk + + def _create_generate_stream( + self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None + ) -> Iterator[str]: + if self.client is None: + raise ValueError("Client is not initialized!") + model = self.client.get_model(self.model_uid) + yield from model.generate(prompt=prompt, generate_config=generate_config) + + @staticmethod + def _stream_response_to_generation_chunk( + stream_response: str, + ) -> GenerationChunk: + """Convert a stream response to a generation chunk.""" + token = "" + if isinstance(stream_response, dict): + choices = stream_response.get("choices", []) + if choices: + choice = choices[0] + if isinstance(choice, dict): + token = choice.get("text", "") + + return GenerationChunk( + text=token, + generation_info=dict( + finish_reason=choice.get("finish_reason", None), + logprobs=choice.get("logprobs", None), + ), + ) + else: + raise TypeError("choice type error!") + else: + return GenerationChunk(text=token) + else: + raise TypeError("stream_response type error!")