Skip to content

Commit

Permalink
[langchain_community.llms.xinference]: Rewrite _stream() method and s…
Browse files Browse the repository at this point in the history
…upport stream() method in xinference.py (#29259)

- [ ] **PR title**:[langchain_community.llms.xinference]: Rewrite
_stream() method and support stream() method in xinference.py

- [ ] **PR message**: Rewrite the _stream method so that the
chain.stream() can be used to return data streams.

       chain = prompt | llm
       chain.stream(input=user_input)


- [ ] **tests**: 
      from langchain_community.llms import Xinference
      from langchain.prompts import PromptTemplate

      llm = Xinference(
server_url="http://0.0.0.0:9997", # replace your xinference server url
model_uid={model_uid} # replace model_uid with the model UID return from
launching the model
          stream = True
       )
prompt = PromptTemplate(input=['country'], template="Q: where can we
visit in the capital of {country}? A:")
      chain = prompt | llm
      chain.stream(input={'country': 'France'})
  • Loading branch information
TheSongg authored Jan 18, 2025
1 parent d4b9404 commit 1cd4d8d
Showing 1 changed file with 90 additions and 1 deletion.
91 changes: 90 additions & 1 deletion libs/community/langchain_community/llms/xinference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Mapping, Optional, Union
from __future__ import annotations

from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Union,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

if TYPE_CHECKING:
from xinference.client import RESTfulChatModelHandle, RESTfulGenerateModelHandle
Expand Down Expand Up @@ -73,6 +86,26 @@ class Xinference(LLM):
generate_config={"max_tokens": 1024, "stream": True},
)
Example:
.. code-block:: python
from langchain_community.llms import Xinference
from langchain.prompts import PromptTemplate
llm = Xinference(
server_url="http://0.0.0.0:9997",
model_uid={model_uid}, # replace model_uid with the model UID return from launching the model
stream=True
)
prompt = PromptTemplate(
input=['country'],
template="Q: where can we visit in the capital of {country}? A:"
)
chain = prompt | llm
chain.stream(input={'country': 'France'})
To view all the supported builtin models, run:
.. code-block:: bash
Expand Down Expand Up @@ -216,3 +249,59 @@ def _stream_generate(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield token

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
generate_config = kwargs.get("generate_config", {})
generate_config = {**self.model_kwargs, **generate_config}
if stop:
generate_config["stop"] = stop
for stream_resp in self._create_generate_stream(prompt, generate_config):
if stream_resp:
chunk = self._stream_response_to_generation_chunk(stream_resp)
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
)
yield chunk

def _create_generate_stream(
self, prompt: str, generate_config: Optional[Dict[str, List[str]]] = None
) -> Iterator[str]:
if self.client is None:
raise ValueError("Client is not initialized!")
model = self.client.get_model(self.model_uid)
yield from model.generate(prompt=prompt, generate_config=generate_config)

@staticmethod
def _stream_response_to_generation_chunk(
stream_response: str,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
token = ""
if isinstance(stream_response, dict):
choices = stream_response.get("choices", [])
if choices:
choice = choices[0]
if isinstance(choice, dict):
token = choice.get("text", "")

return GenerationChunk(
text=token,
generation_info=dict(
finish_reason=choice.get("finish_reason", None),
logprobs=choice.get("logprobs", None),
),
)
else:
raise TypeError("choice type error!")
else:
return GenerationChunk(text=token)
else:
raise TypeError("stream_response type error!")

0 comments on commit 1cd4d8d

Please sign in to comment.