diff --git a/docs/docs/integrations/chat/symblai_nebula.ipynb b/docs/docs/integrations/chat/symblai_nebula.ipynb new file mode 100644 index 0000000000000..beb16374cda32 --- /dev/null +++ b/docs/docs/integrations/chat/symblai_nebula.ipynb @@ -0,0 +1,297 @@ +{ + "cells": [ + { + "cell_type": "raw", + "id": "53fbf15f", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "---\n", + "sidebar_label: Nebula (Symbl.ai)\n", + "---" + ] + }, + { + "cell_type": "markdown", + "id": "bf733a38-db84-4363-89e2-de6735c37230", + "metadata": {}, + "source": [ + "# Nebula (Symbl.ai)\n", + "\n", + "## Overview\n", + "This notebook covers how to get started with [Nebula](https://docs.symbl.ai/docs/nebula-llm) - Symbl.ai's chat model.\n", + "\n", + "### Integration details\n", + "Head to the [API reference](https://docs.symbl.ai/reference/nebula-chat) for detailed documentation.\n", + "\n", + "### Model features: TODO" + ] + }, + { + "cell_type": "markdown", + "id": "3607d67e-e56c-4102-bbba-df2edc0e109e", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "### Credentials\n", + "To get started, request a [Nebula API key](https://platform.symbl.ai/#/login) and set the `NEBULA_API_KEY` environment variable:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2108b517-1e8d-473d-92fa-4f930e8072a7", + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "os.environ[\"NEBULA_API_KEY\"] = getpass.getpass()" + ] + }, + { + "cell_type": "markdown", + "id": "68b44357", + "metadata": {}, + "source": [ + "### Installation\n", + "The integration is set up in the `langchain-community` package." + ] + }, + { + "cell_type": "markdown", + "id": "4c26754b-b3c9-4d93-8f36-43049bd943bf", + "metadata": {}, + "source": [ + "## Instantiation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fdd26e7", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.symblai_nebula import ChatNebula\n", + "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "70cf04e8-423a-4ff6-8b09-f11fb711c817", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "chat = ChatNebula(max_tokens=1024, temperature=0.5)" + ] + }, + { + "cell_type": "markdown", + "id": "2a915547", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8199ef8f-eb8b-4253-9ea0-6c24a013ca4c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " SystemMessage(\n", + " content=\"You are a helpful assistant that answers general knowledge questions.\"\n", + " ),\n", + " HumanMessage(content=\"What is the capital of France?\"),\n", + "]\n", + "chat.invoke(messages)" + ] + }, + { + "cell_type": "markdown", + "id": "9723913f", + "metadata": {}, + "source": [ + "### Async" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c5fac0e9-05a4-4fc1-a3b3-e5bbb24b971b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await chat.ainvoke(messages)" + ] + }, + { + "cell_type": "markdown", + "id": "e0a1d3b4", + "metadata": {}, + "source": [ + "### Streaming" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "025be980-e50d-4a68-93dc-c9c7b500ce34", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " The capital of France is Paris." + ] + } + ], + "source": [ + "for chunk in chat.stream(messages):\n", + " print(chunk.content, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "id": "9f91b7c7", + "metadata": {}, + "source": [ + "### Batch" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "054dc648", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[AIMessage(content=[{'role': 'human', 'text': 'What is the capital of France?'}, {'role': 'assistant', 'text': 'The capital of France is Paris.'}])]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat.batch([messages])" + ] + }, + { + "cell_type": "markdown", + "id": "e59a5519", + "metadata": {}, + "source": [ + "## Chaining" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6455f67b", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate.from_template(\"Tell me a joke about {topic}\")\n", + "chain = prompt | chat" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "deb1e2a1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=[{'role': 'human', 'text': 'Tell me a joke about cows'}, {'role': 'assistant', 'text': \"Sure, here's a joke about cows:\\n\\nWhy did the cow cross the road?\\n\\nTo get to the udder side!\"}])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke({\"topic\": \"cows\"})" + ] + }, + { + "cell_type": "markdown", + "id": "bb9d4755", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "Check out the [API reference](https://python.langchain.com/v0.2/api_reference/community/chat_models/langchain_community.chat_models.symblai_nebula.ChatNebula.html) for more detail." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 6e0344f6c122d..9658c707e6782 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -153,6 +153,7 @@ from langchain_community.chat_models.sparkllm import ( ChatSparkLLM, ) + from langchain_community.chat_models.symblai_nebula import ChatNebula from langchain_community.chat_models.tongyi import ( ChatTongyi, ) @@ -201,6 +202,7 @@ "ChatMLflowAIGateway", "ChatMaritalk", "ChatMlflow", + "ChatNebula", "ChatOCIGenAI", "ChatOllama", "ChatOpenAI", @@ -257,6 +259,7 @@ "ChatMLX": "langchain_community.chat_models.mlx", "ChatMaritalk": "langchain_community.chat_models.maritalk", "ChatMlflow": "langchain_community.chat_models.mlflow", + "ChatNebula": "langchain_community.chat_models.symblai_nebula", "ChatOctoAI": "langchain_community.chat_models.octoai", "ChatOCIGenAI": "langchain_community.chat_models.oci_generative_ai", "ChatOllama": "langchain_community.chat_models.ollama", diff --git a/libs/community/langchain_community/chat_models/symblai_nebula.py b/libs/community/langchain_community/chat_models/symblai_nebula.py new file mode 100644 index 0000000000000..e3638bada8f24 --- /dev/null +++ b/libs/community/langchain_community/chat_models/symblai_nebula.py @@ -0,0 +1,271 @@ +import json +import os +from json import JSONDecodeError +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +import requests +from aiohttp import ClientSession +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models.chat_models import ( + BaseChatModel, + agenerate_from_stream, + generate_from_stream, +) +from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Field, SecretStr +from langchain_core.utils import convert_to_secret_str + + +def _convert_role(role: str) -> str: + map = {"ai": "assistant", "human": "human", "chat": "human"} + if role in map: + return map[role] + else: + raise ValueError(f"Unknown role type: {role}") + + +def _format_nebula_messages(messages: List[BaseMessage]) -> Dict[str, Any]: + system = "" + formatted_messages = [] + for message in messages[:-1]: + if message.type == "system": + if isinstance(message.content, str): + system = message.content + else: + raise ValueError("System prompt must be a string") + else: + formatted_messages.append( + { + "role": _convert_role(message.type), + "text": message.content, + } + ) + + text = messages[-1].content + formatted_messages.append({"role": "human", "text": text}) + return {"system_prompt": system, "messages": formatted_messages} + + +class ChatNebula(BaseChatModel): + """`Nebula` chat large language model - https://docs.symbl.ai/docs/nebula-llm + + API Reference: https://docs.symbl.ai/reference/nebula-chat + + To use, set the environment variable ``NEBULA_API_KEY``, + or pass it as a named parameter to the constructor. + To request an API key, visit https://platform.symbl.ai/#/login + Example: + .. code-block:: python + + from langchain_community.chat_models import ChatNebula + from langchain_core.messages import SystemMessage, HumanMessage + + chat = ChatNebula(max_new_tokens=1024, temperature=0.5) + + messages = [ + SystemMessage( + content="You are a helpful assistant." + ), + HumanMessage( + "Answer the following question. How can I help save the world." + ), + ] + chat.invoke(messages) + """ + + max_new_tokens: int = 1024 + """Denotes the number of tokens to predict per generation.""" + + temperature: Optional[float] = 0 + """A non-negative float that tunes the degree of randomness in generation.""" + + streaming: bool = False + + nebula_api_url: str = "https://api-nebula.symbl.ai" + + nebula_api_key: Optional[SecretStr] = Field(None, description="Nebula API Token") + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + + def __init__(self, **kwargs: Any) -> None: + if "nebula_api_key" in kwargs: + api_key = convert_to_secret_str(kwargs.pop("nebula_api_key")) + elif "NEBULA_API_KEY" in os.environ: + api_key = convert_to_secret_str(os.environ["NEBULA_API_KEY"]) + else: + api_key = None + super().__init__(nebula_api_key=api_key, **kwargs) # type: ignore[call-arg] + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "nebula-chat" + + @property + def _api_key(self) -> str: + if self.nebula_api_key: + return self.nebula_api_key.get_secret_value() + return "" + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Call out to Nebula's chat endpoint.""" + url = f"{self.nebula_api_url}/v1/model/chat/streaming" + headers = { + "ApiKey": self._api_key, + "Content-Type": "application/json", + } + formatted_data = _format_nebula_messages(messages=messages) + payload: Dict[str, Any] = { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + **formatted_data, + **kwargs, + } + + payload = {k: v for k, v in payload.items() if v is not None} + json_payload = json.dumps(payload) + + response = requests.request( + "POST", url, headers=headers, data=json_payload, stream=True + ) + response.raise_for_status() + + for chunk_response in response.iter_lines(): + chunk_decoded = chunk_response.decode()[6:] + try: + chunk = json.loads(chunk_decoded) + except JSONDecodeError: + continue + token = chunk["delta"] + cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token)) + if run_manager: + run_manager.on_llm_new_token(token, chunk=cg_chunk) + yield cg_chunk + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + url = f"{self.nebula_api_url}/v1/model/chat/streaming" + headers = {"ApiKey": self._api_key, "Content-Type": "application/json"} + formatted_data = _format_nebula_messages(messages=messages) + payload: Dict[str, Any] = { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + **formatted_data, + **kwargs, + } + + payload = {k: v for k, v in payload.items() if v is not None} + json_payload = json.dumps(payload) + + async with ClientSession() as session: + async with session.post( # type: ignore[call-arg] + url, data=json_payload, headers=headers, stream=True + ) as response: + response.raise_for_status() + async for chunk_response in response.content: + chunk_decoded = chunk_response.decode()[6:] + try: + chunk = json.loads(chunk_decoded) + except JSONDecodeError: + continue + token = chunk["delta"] + cg_chunk = ChatGenerationChunk( + message=AIMessageChunk(content=token) + ) + if run_manager: + await run_manager.on_llm_new_token(token, chunk=cg_chunk) + yield cg_chunk + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + url = f"{self.nebula_api_url}/v1/model/chat" + headers = {"ApiKey": self._api_key, "Content-Type": "application/json"} + formatted_data = _format_nebula_messages(messages=messages) + payload: Dict[str, Any] = { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + **formatted_data, + **kwargs, + } + + payload = {k: v for k, v in payload.items() if v is not None} + json_payload = json.dumps(payload) + + response = requests.request("POST", url, headers=headers, data=json_payload) + response.raise_for_status() + data = response.json() + + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content=data["messages"]))], + llm_output=data, + ) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + url = f"{self.nebula_api_url}/v1/model/chat" + headers = {"ApiKey": self._api_key, "Content-Type": "application/json"} + formatted_data = _format_nebula_messages(messages=messages) + payload: Dict[str, Any] = { + "max_new_tokens": self.max_new_tokens, + "temperature": self.temperature, + **formatted_data, + **kwargs, + } + + payload = {k: v for k, v in payload.items() if v is not None} + json_payload = json.dumps(payload) + + async with ClientSession() as session: + async with session.post( + url, data=json_payload, headers=headers + ) as response: + response.raise_for_status() + data = await response.json() + + return ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content=data["messages"])) + ], + llm_output=data, + ) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index 24e0b6b707caf..9e97444484ba5 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -28,6 +28,7 @@ "ChatMlflow", "ChatMLflowAIGateway", "ChatMLX", + "ChatNebula", "ChatOCIGenAI", "ChatOllama", "ChatOpenAI",