Skip to content

Commit

Permalink
deserialize
Browse files Browse the repository at this point in the history
  • Loading branch information
hanselke committed Jul 17, 2024
1 parent 4a11c78 commit 787ad45
Showing 1 changed file with 59 additions and 4 deletions.
63 changes: 59 additions & 4 deletions src/v1/chat_completion.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize, Serializer};
use serde::{Deserialize, Serialize, Serializer, Deserializer};
use serde_json::Value;
use std::collections::HashMap;

use serde::de::{self, MapAccess, SeqAccess, Visitor};
use crate::impl_builder_methods;
use crate::v1::common;

use std::fmt;
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub enum ToolChoiceType {
None,
Expand Down Expand Up @@ -104,7 +104,7 @@ pub enum MessageRole {
tool,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Content {
Text(String),
ImageUrl(Vec<ImageUrl>),
Expand All @@ -128,6 +128,61 @@ impl serde::Serialize for Content {
}
}

impl<'de> Deserialize<'de> for Content {
fn deserialize<D>(deserializer: D) -> Result<Content, D::Error>
where
D: Deserializer<'de>,
{
struct ContentVisitor;

impl<'de> Visitor<'de> for ContentVisitor {
type Value = Content;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a valid content type")
}

fn visit_str<E>(self, value: &str) -> Result<Content, E>
where
E: de::Error,
{
Ok(Content::Text(value.to_string()))
}

fn visit_seq<A>(self, seq: A) -> Result<Content, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let image_urls: Vec<ImageUrl> = Deserialize::deserialize(de::value::SeqAccessDeserializer::new(seq))?;
Ok(Content::ImageUrl(image_urls))
}

fn visit_map<M>(self, map: M) -> Result<Content, M::Error>
where
M: serde::de::MapAccess<'de>,
{
let image_urls: Vec<ImageUrl> = Deserialize::deserialize(de::value::MapAccessDeserializer::new(map))?;
Ok(Content::ImageUrl(image_urls))
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(String::new()))
}

fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(Content::Text(String::new()))
}
}

deserializer.deserialize_any(ContentVisitor)
}
}
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum ContentType {
Expand Down

0 comments on commit 787ad45

Please sign in to comment.