diff --git a/src/v1/chat_completion.rs b/src/v1/chat_completion.rs index f66c1e6..8ae275a 100644 --- a/src/v1/chat_completion.rs +++ b/src/v1/chat_completion.rs @@ -13,14 +13,29 @@ pub enum FunctionCallType { Function { name: String }, } +#[derive(Debug, Serialize, Clone)] +pub enum ToolChoiceType { + None, + Auto, + ToolChoice { tool: Tool }, +} + #[derive(Debug, Serialize, Clone)] pub struct ChatCompletionRequest { pub model: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] + #[deprecated( + since = "2.1.5", + note = "This field is deprecated. Use `tools` instead." + )] pub functions: Option>, #[serde(skip_serializing_if = "Option::is_none")] #[serde(serialize_with = "serialize_function_call")] + #[deprecated( + since = "2.1.5", + note = "This field is deprecated. Use `tool_choice` instead." + )] pub function_call: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, @@ -46,6 +61,11 @@ pub struct ChatCompletionRequest { pub user: Option, #[serde(skip_serializing_if = "Option::is_none")] pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(serialize_with = "serialize_tool_choice")] + pub tool_choice: Option, } impl ChatCompletionRequest { @@ -67,6 +87,8 @@ impl ChatCompletionRequest { logit_bias: None, user: None, seed: None, + tools: None, + tool_choice: None, } } } @@ -86,7 +108,9 @@ impl_builder_methods!( frequency_penalty: f64, logit_bias: HashMap, user: String, - seed: i64 + seed: i64, + tools: Vec, + tool_choice: ToolChoiceType ); #[derive(Debug, Serialize, Deserialize, Clone)] @@ -225,3 +249,30 @@ where None => serializer.serialize_none(), } } + +fn serialize_tool_choice( + value: &Option, + serializer: S, +) -> Result +where + S: Serializer, +{ + match value { + Some(ToolChoiceType::None) => serializer.serialize_str("none"), + Some(ToolChoiceType::Auto) => serializer.serialize_str("auto"), + Some(ToolChoiceType::ToolChoice { tool }) => { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("type", &tool.tool_type)?; + map.serialize_entry("function", &tool.function)?; + map.end() + } + None => serializer.serialize_none(), + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct Tool { + #[serde(rename = "type")] + tool_type: String, + function: Function, +}