Skip to content

Commit

Permalink
Merge pull request #54 from sharifhsn/toolchoice
Browse files Browse the repository at this point in the history
Deprecate `functions` and add `tools`
  • Loading branch information
Dongri Jin authored Dec 28, 2023
2 parents a061135 + 4afe1ff commit cab9136
Showing 1 changed file with 52 additions and 1 deletion.
53 changes: 52 additions & 1 deletion src/v1/chat_completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,29 @@ pub enum FunctionCallType {
Function { name: String },
}

#[derive(Debug, Serialize, Clone)]
pub enum ToolChoiceType {
None,
Auto,
ToolChoice { tool: Tool },
}

#[derive(Debug, Serialize, Clone)]
pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatCompletionMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
#[deprecated(
since = "2.1.5",
note = "This field is deprecated. Use `tools` instead."
)]
pub functions: Option<Vec<Function>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_function_call")]
#[deprecated(
since = "2.1.5",
note = "This field is deprecated. Use `tool_choice` instead."
)]
pub function_call: Option<FunctionCallType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
Expand All @@ -46,6 +61,11 @@ pub struct ChatCompletionRequest {
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(serialize_with = "serialize_tool_choice")]
pub tool_choice: Option<ToolChoiceType>,
}

impl ChatCompletionRequest {
Expand All @@ -67,6 +87,8 @@ impl ChatCompletionRequest {
logit_bias: None,
user: None,
seed: None,
tools: None,
tool_choice: None,
}
}
}
Expand All @@ -86,7 +108,9 @@ impl_builder_methods!(
frequency_penalty: f64,
logit_bias: HashMap<String, i32>,
user: String,
seed: i64
seed: i64,
tools: Vec<Tool>,
tool_choice: ToolChoiceType
);

#[derive(Debug, Serialize, Deserialize, Clone)]
Expand Down Expand Up @@ -225,3 +249,30 @@ where
None => serializer.serialize_none(),
}
}

fn serialize_tool_choice<S>(
value: &Option<ToolChoiceType>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(ToolChoiceType::None) => serializer.serialize_str("none"),
Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"),
Some(ToolChoiceType::ToolChoice { tool }) => {
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_entry("type", &tool.tool_type)?;
map.serialize_entry("function", &tool.function)?;
map.end()
}
None => serializer.serialize_none(),
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Tool {
#[serde(rename = "type")]
tool_type: String,
function: Function,
}

0 comments on commit cab9136

Please sign in to comment.