Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Aug 13, 2024
1 parent 8bf9baf commit e87f32a
Showing 1 changed file with 58 additions and 64 deletions.
122 changes: 58 additions & 64 deletions libs/langchain-mistralai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ import {
// Message as MistralAIMessage,
// TokenUsage as MistralAITokenUsage,
} from "@mistralai/mistralai";
import {
ChatCompletionRequest as MistralAIChatCompletionRequest,
ToolChoice as MistralAIToolChoice,
Messages as MistralAIMessage,
} from "@mistralai/mistralai/models/components/chatcompletionrequest.js";
import { Tool as MistralAITool } from "@mistralai/mistralai/models/components/tool.js";
import { ToolCall as MistralAIToolCall } from "@mistralai/mistralai/models/components/toolcall.js";
import { ChatCompletionStreamRequest as MistralChatCompletionStreamRequest } from "@mistralai/mistralai/models/components/chatcompletionstreamrequest.js";
import { UsageInfo as MistralAITokenUsage } from "@mistralai/mistralai/models/components/usageinfo.js";
import { CompletionEvent as MistralAIChatCompletionEvent } from "@mistralai/mistralai/models/components/completionevent.js";
import { ChatCompletionResponse as MistralChatCompletionResponse } from "@mistralai/mistralai/models/components/chatcompletionresponse.js";
import {
MessageType,
type BaseMessage,
Expand Down Expand Up @@ -74,26 +85,7 @@ interface TokenUsage {
totalTokens?: number;
}

export type MistralAIToolChoice = "auto" | "any" | "none";

export type MistralChatCompletionParams = Parameters<MistralClient["chat"]["complete"]>[0];
export type MistralAIToolCall = NonNullable<MistralChatCompletionParams["tools"]>[number]

export type MistralChatCompletionResponse = Awaited<ReturnType<MistralClient["chat"]["complete"]>>;
export type MistralChatCompletionResponseChunk = Awaited<ReturnType<MistralClient["chat"]["stream"]>>;
export type MistralChatCompletionChoice = NonNullable<MistralChatCompletionResponse["choices"]>[number];

export type MistralAITool = NonNullable<MistralChatCompletionParams["tools"]>[number];
export type MistralAIFunction = MistralAITool["function"];

export type MistralAIMessage = NonNullable<MistralChatCompletionChoice["message"]>;

type MistralAIToolInput = { type: string; function: MistralAIFunction };

type ChatMistralAIToolType =
| MistralAIToolInput
| MistralAITool
| BindToolsInput;
type ChatMistralAIToolType = MistralAIToolCall | MistralAITool | BindToolsInput;

export interface ChatMistralAICallOptions
extends Omit<BaseLanguageModelCallOptions, "stop"> {
Expand Down Expand Up @@ -257,10 +249,13 @@ function convertMessagesToMistralMessages(
}

function mistralAIResponseToChatMessage(
choice: MistralChatCompletionResponse["choices"][0],
choice: NonNullable<MistralChatCompletionResponse["choices"]>[0],
usage?: MistralAITokenUsage
): BaseMessage {
const { message } = choice;
if (message === undefined) {
throw new Error("No message found in response");
}
// MistralAI SDK does not include tool_calls in the non
// streaming return type, so we need to extract it like this
// to satisfy typescript.
Expand Down Expand Up @@ -288,19 +283,12 @@ function mistralAIResponseToChatMessage(
content: message.content ?? "",
tool_calls: toolCalls,
invalid_tool_calls: invalidToolCalls,
additional_kwargs: {
tool_calls: rawToolCalls.length
? rawToolCalls.map((toolCall) => ({
...toolCall,
type: "function",
}))
: undefined,
},
additional_kwargs: {},
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand All @@ -324,9 +312,9 @@ function _convertDeltaToMessageChunk(
content: "",
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand All @@ -338,7 +326,7 @@ function _convertDeltaToMessageChunk(
// need to insert it here.
const rawToolCallChunksWithIndex = delta.tool_calls?.length
? delta.tool_calls?.map(
(toolCall, index): OpenAIToolCall => ({
(toolCall, index): MistralAIToolCall & { index: number } => ({
...toolCall,
index,
id: toolCall.id ?? uuidv4().replace(/-/g, ""),
Expand All @@ -355,13 +343,15 @@ function _convertDeltaToMessageChunk(
let additional_kwargs;
const toolCallChunks: ToolCallChunk[] = [];
if (rawToolCallChunksWithIndex !== undefined) {
additional_kwargs = {
tool_calls: rawToolCallChunksWithIndex,
};
for (const rawToolCallChunk of rawToolCallChunksWithIndex) {
const rawArgs = rawToolCallChunk.function?.arguments;
const args =
rawArgs === undefined || typeof rawArgs === "string"
? rawArgs
: JSON.stringify(rawArgs);
toolCallChunks.push({
name: rawToolCallChunk.function?.name,
args: rawToolCallChunk.function?.arguments,
args,
id: rawToolCallChunk.id,
index: rawToolCallChunk.index,
type: "tool_call_chunk",
Expand All @@ -380,9 +370,9 @@ function _convertDeltaToMessageChunk(
additional_kwargs,
usage_metadata: usage
? {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
input_tokens: usage.promptTokens,
output_tokens: usage.completionTokens,
total_tokens: usage.totalTokens,
}
: undefined,
});
Expand Down Expand Up @@ -832,22 +822,21 @@ export class ChatMistralAI<
*/
invocationParams(
options?: this["ParsedCallOptions"]
): Omit<MistralChatCompletionParams, "messages"> {
): Omit<MistralAIChatCompletionRequest, "messages"> {
const { response_format, tools, tool_choice } = options ?? {};
const mistralAITools: Array<MistralAITool> | undefined = tools?.length
? _convertToolToMistralTool(tools)
: undefined;
const params: Omit<MistralChatCompletionParams, "messages"> = {
const params: Omit<MistralAIChatCompletionRequest, "messages"> = {
model: this.model,
tools: mistralAITools,
temperature: this.temperature,
maxTokens: this.maxTokens,
topP: this.topP,
randomSeed: this.seed,
safeMode: this.safeMode,
safePrompt: this.safePrompt,
toolChoice: tool_choice,
responseFormat: response_format as ResponseFormat,
responseFormat: response_format,
};
return params;
}
Expand All @@ -868,37 +857,45 @@ export class ChatMistralAI<
* @returns {Promise<MistralAIChatCompletionResult | AsyncGenerator<MistralAIChatCompletionResult>>} The response from the MistralAI API.
*/
async completionWithRetry(
input: ChatRequest,
input: MistralChatCompletionStreamRequest,
streaming: true
): Promise<AsyncGenerator<MistralChatCompletionResponseChunk>>;
): Promise<AsyncIterable<MistralAIChatCompletionEvent>>;

async completionWithRetry(
input: ChatRequest,
input: MistralAIChatCompletionRequest,
streaming: false
): Promise<MistralChatCompletionResponse>;

async completionWithRetry(
input: ChatRequest,
input: MistralAIChatCompletionRequest | MistralChatCompletionStreamRequest,
streaming: boolean
): Promise<
MistralChatCompletionResponse | AsyncGenerator<MistralChatCompletionResponseChunk>
MistralChatCompletionResponse | AsyncIterable<MistralAIChatCompletionEvent>
> {
const client = new MistralClient({ apiKey: this.apiKey, serverURL: this.endpoint });
const client = new MistralClient({
apiKey: this.apiKey,
serverURL: this.endpoint,
});

return this.caller.call(async () => {
try {
let res:
| MistralChatCompletionResponse
| AsyncGenerator<MistralChatCompletionResponseChunk>;
| AsyncIterable<MistralAIChatCompletionEvent>;
if (streaming) {
res = client.chat.stream(input);
res = await client.chat.stream(input);
} else {
res = await client.chat.complete(input);
}
return res;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (e: any) {
if (e.message?.includes("status: 400")) {
console.log(e, e.status, e.code, e.statusCode, e.message);
if (
e.message?.includes("status: 400") ||
e.message?.toLowerCase().includes("status 400") ||
e.message?.includes("validation failed")
) {
e.status = 400;
}
throw e;
Expand Down Expand Up @@ -947,11 +944,8 @@ export class ChatMistralAI<
// Not streaming, so we can just call the API once.
const response = await this.completionWithRetry(input, false);

const {
completion_tokens: completionTokens,
prompt_tokens: promptTokens,
total_tokens: totalTokens,
} = response?.usage ?? {};
const { completionTokens, promptTokens, totalTokens } =
response?.usage ?? {};

if (completionTokens) {
tokenUsage.completionTokens =
Expand Down Expand Up @@ -979,8 +973,8 @@ export class ChatMistralAI<
text,
message: mistralAIResponseToChatMessage(part, response?.usage),
};
if (part.finish_reason) {
generation.generationInfo = { finish_reason: part.finish_reason };
if (part.finishReason) {
generation.generationInfo = { finishReason: part.finishReason };
}
generations.push(generation);
}
Expand All @@ -1003,7 +997,7 @@ export class ChatMistralAI<
};

const streamIterable = await this.completionWithRetry(input, true);
for await (const data of streamIterable) {
for await (const { data } of streamIterable) {
if (options.signal?.aborted) {
throw new Error("AbortError");
}
Expand Down

0 comments on commit e87f32a

Please sign in to comment.