Skip to content

Commit

Permalink
communty[patch]: Native RAG Support in Prem AI langchain (#22238)
Browse files Browse the repository at this point in the history
This PR adds native RAG support in langchain premai package. The same
has been added in the docs too.
  • Loading branch information
Anindyadeep authored Jun 4, 2024
1 parent 77ad857 commit 7a19753
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 46 deletions.
63 changes: 61 additions & 2 deletions docs/docs/integrations/chat/premai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,69 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"> If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform. \n",
"> If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Native RAG Support with Prem Repositories\n",
"\n",
"Prem Repositories which allows users to upload documents (.txt, .pdf etc) and connect those repositories to the LLMs. You can think Prem repositories as native RAG, where each repository can be considered as a vector database. You can connect multiple repositories. You can learn more about repositories [here](https://docs.premai.io/get-started/repositories).\n",
"\n",
"Repositories are also supported in langchain premai. Here is how you can do it. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"what is the diameter of individual Galaxy\"\n",
"repository_ids = [\n",
" 1991,\n",
"]\n",
"repositories = dict(ids=repository_ids, similarity_threshold=0.3, limit=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we start by defining our repository with some repository ids. Make sure that the ids are valid repository ids. You can learn more about how to get the repository id [here](https://docs.premai.io/get-started/repositories). \n",
"\n",
"> Please note that the current version of ChatPremAI does not support parameters: [n](https://platform.openai.com/docs/api-reference/chat/create#chat-create-n) and [stop](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). \n",
"> Please note: Similar like `model_name` when you invoke the argument `repositories`, then you are potentially overriding the repositories connected in the launchpad. \n",
"\n",
"Now, we connect the repository with our chat object to invoke RAG based generations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"response = chat.invoke(query, max_tokens=100, repositories=repositories)\n",
"\n",
"print(response.content)\n",
"print(json.dumps(response.response_metadata, indent=4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Ideally, you do not need to connect Repository IDs here to get Retrieval Augmented Generations. You can still get the same result if you have connected the repositories in prem platform. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming\n",
"\n",
"In this section, let's see how we can stream tokens using langchain and PremAI. Here's how you do it. "
Expand Down
73 changes: 72 additions & 1 deletion docs/docs/integrations/providers/premai.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,76 @@ chat.invoke(

> If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform.
> Please note that the current version of ChatPremAI does not support parameters: [n](https://platform.openai.com/docs/api-reference/chat/create#chat-create-n) and [stop](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop).
> You can find all the optional parameters [here](https://docs.premai.io/get-started/sdk#optional-parameters). Any parameters other than [these supported parameters](https://docs.premai.io/get-started/sdk#optional-parameters) will be automatically removed before calling the model.

### Native RAG Support with Prem Repositories

Prem Repositories which allows users to upload documents (.txt, .pdf etc) and connect those repositories to the LLMs. You can think Prem repositories as native RAG, where each repository can be considered as a vector database. You can connect multiple repositories. You can learn more about repositories [here](https://docs.premai.io/get-started/repositories).

Repositories are also supported in langchain premai. Here is how you can do it.

```python

query = "what is the diameter of individual Galaxy"
repository_ids = [1991, ]
repositories = dict(
ids=repository_ids,
similarity_threshold=0.3,
limit=3
)
```

First we start by defining our repository with some repository ids. Make sure that the ids are valid repository ids. You can learn more about how to get the repository id [here](https://docs.premai.io/get-started/repositories).

> Please note: Similar like `model_name` when you invoke the argument `repositories`, then you are potentially overriding the repositories connected in the launchpad.
Now, we connect the repository with our chat object to invoke RAG based generations.

```python
response = chat.invoke(query, max_tokens=100, repositories=repositories)

print(response.content)
print(json.dumps(response.response_metadata, indent=4))
```

This is how an output looks like.

```bash
The diameters of individual galaxies range from 80,000-150,000 light-years.
{
"document_chunks": [
{
"repository_id": 1991,
"document_id": 1307,
"chunk_id": 173926,
"document_name": "Kegy 202 Chapter 2",
"similarity_score": 0.586126983165741,
"content": "n thousands\n of light-years. The diameters of individual\n galaxies range from 80,000-150,000 light\n "
},
{
"repository_id": 1991,
"document_id": 1307,
"chunk_id": 173925,
"document_name": "Kegy 202 Chapter 2",
"similarity_score": 0.4815782308578491,
"content": " for development of galaxies. A galaxy contains\n a large number of stars. Galaxies spread over\n vast distances that are measured in thousands\n "
},
{
"repository_id": 1991,
"document_id": 1307,
"chunk_id": 173916,
"document_name": "Kegy 202 Chapter 2",
"similarity_score": 0.38112708926200867,
"content": " was separated from the from each other as the balloon expands.\n solar surface. As the passing star moved away, Similarly, the distance between the galaxies is\n the material separated from the solar surface\n continued to revolve around the sun and it\n slowly condensed into planets. Sir James Jeans\n and later Sir Harold Jeffrey supported thisnot to be republishedalso found to be increasing and thereby, the\n universe is"
}
]
}
```

So, this also means that you do not need to make your own RAG pipeline when using the Prem Platform. Prem uses it's own RAG technology to deliver best in class performance for Retrieval Augmented Generations.

> Ideally, you do not need to connect Repository IDs here to get Retrieval Augmented Generations. You can still get the same result if you have connected the repositories in prem platform.
### Streaming

Expand Down Expand Up @@ -102,6 +171,8 @@ for chunk in chat.stream(

This will stream tokens one after the other.

> Please note: As of now, RAG with streaming is not supported. However we still support it with our API. You can learn more about that [here](https://docs.premai.io/get-started/chat-completion-sse).
## PremEmbeddings

In this section we are going to dicuss how we can get access to different embedding model using `PremEmbeddings` with LangChain. Lets start by importing our modules and setting our API Key.
Expand Down
84 changes: 41 additions & 43 deletions libs/community/langchain_community/chat_models/premai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import logging
import warnings
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -104,7 +105,18 @@ def _response_to_result(
text=content, message=ChatMessage(role=role, content=content)
)
)
return ChatResult(generations=generations)

if response.document_chunks is not None:
return ChatResult(
generations=generations,
llm_output={
"document_chunks": [
chunk.to_dict() for chunk in response.document_chunks
]
},
)
else:
return ChatResult(generations=generations, llm_output={"document_chunks": None})


def _convert_delta_response_to_message_chunk(
Expand All @@ -118,10 +130,6 @@ def _convert_delta_response_to_message_chunk(
role = _delta.get("role", "") # type: ignore
content = _delta.get("content", "") # type: ignore
additional_kwargs: Dict = {}

if role is None or role == "":
raise ChatPremAPIError("Role can not be None. Please check the response")

finish_reasons: Optional[str] = response.choices[0].finish_reason

if role == "user" or default_class == HumanMessageChunk:
Expand Down Expand Up @@ -185,17 +193,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
If model name is other than default model then it will override the calls
from the model deployed from launchpad."""

session_id: Optional[str] = None
"""The ID of the session to use. It helps to track the chat history."""

temperature: Optional[float] = None
"""Model temperature. Value should be >= 0 and <= 1.0"""

top_p: Optional[float] = None
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""

max_tokens: Optional[int] = None
"""The maximum number of tokens to generate"""

Expand All @@ -209,30 +209,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
Changing the system prompt would override the default system prompt.
"""

repositories: Optional[dict] = None
"""Add valid repository ids. This will be overriding existing connected
repositories (if any) and will use RAG with the connected repos.
"""

streaming: Optional[bool] = False
"""Whether to stream the responses or not."""

tools: Optional[Dict[str, Any]] = None
"""A list of tools the model may call. Currently, only functions are
supported as a tool"""

frequency_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based"""

presence_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based
on whether they appear in the text so far."""

logit_bias: Optional[dict] = None
"""JSON object that maps tokens to an associated bias value from -100 to 100."""

stop: Optional[Union[str, List[str]]] = None
"""Up to 4 sequences where the API will stop generating further tokens."""

seed: Optional[int] = None
"""This feature is in Beta. If specified, our system will make a best effort
to sample deterministically."""

client: Any

class Config:
Expand Down Expand Up @@ -268,21 +252,34 @@ def _llm_type(self) -> str:

@property
def _default_params(self) -> Dict[str, Any]:
# FIXME: n and stop is not supported, so hardcoding to current default value
return {
"model": self.model,
"system_prompt": self.system_prompt,
"top_p": self.top_p,
"temperature": self.temperature,
"logit_bias": self.logit_bias,
"max_tokens": self.max_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"seed": self.seed,
"stop": None,
"repositories": self.repositories,
}

def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
kwargs_to_ignore = [
"top_p",
"tools",
"frequency_penalty",
"presence_penalty",
"logit_bias",
"stop",
"seed",
]
keys_to_remove = []

for key in kwargs:
if key in kwargs_to_ignore:
warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")
keys_to_remove.append(key)

for key in keys_to_remove:
kwargs.pop(key)

all_kwargs = {**self._default_params, **kwargs}
for key in list(self._default_params.keys()):
if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
Expand All @@ -298,7 +295,6 @@ def _generate(
) -> ChatResult:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore

kwargs["stop"] = stop
if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt

Expand All @@ -322,7 +318,9 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
kwargs["stop"] = stop

if stop is not None:
logger.warning("stop is not supported in langchain streaming")

if "system_prompt" not in kwargs:
if system_prompt is not None and system_prompt != "":
Expand Down

0 comments on commit 7a19753

Please sign in to comment.