Skip to content

Commit

Permalink
Add BigDecimal support
Browse files Browse the repository at this point in the history
  • Loading branch information
Ten0 committed Oct 20, 2024
1 parent d2456a0 commit 98aaa86
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 38 deletions.
66 changes: 54 additions & 12 deletions serde_avro_fast/src/de/deserializer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,14 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
SchemaNode::Fixed(ref fixed) => {
self.state.read_slice(fixed.size, BytesVisitor(visitor))
}
SchemaNode::Decimal(ref decimal) => {
read_decimal(self.state, decimal, VisitorHint::Str, visitor)
SchemaNode::Decimal(ref decimal) => read_decimal(
self.state,
DecimalMode::Regular(decimal),
VisitorHint::Str,
visitor,
),
SchemaNode::BigDecimal => {
read_decimal(self.state, DecimalMode::Big, VisitorHint::Str, visitor)
}
SchemaNode::Uuid => read_length_delimited(self.state, StringVisitor(visitor)),
SchemaNode::Date => visitor.visit_i32(self.state.read_varint()?),
Expand Down Expand Up @@ -95,8 +101,14 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
DeError::custom(format_args!("Got negative enum discriminant: {e}"))
})?)
}
SchemaNode::Decimal(ref decimal) => {
read_decimal(self.state, decimal, VisitorHint::U64, visitor)
SchemaNode::Decimal(ref decimal) => read_decimal(
self.state,
DecimalMode::Regular(decimal),
VisitorHint::U64,
visitor,
),
SchemaNode::BigDecimal => {
read_decimal(self.state, DecimalMode::Big, VisitorHint::U64, visitor)
}
_ => self.deserialize_any(visitor),
}
Expand All @@ -108,8 +120,14 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
{
match *self.schema_node {
SchemaNode::Long => visitor.visit_i64(self.state.read_varint()?),
SchemaNode::Decimal(ref decimal) => {
read_decimal(self.state, decimal, VisitorHint::I64, visitor)
SchemaNode::Decimal(ref decimal) => read_decimal(
self.state,
DecimalMode::Regular(decimal),
VisitorHint::I64,
visitor,
),
SchemaNode::BigDecimal => {
read_decimal(self.state, DecimalMode::Big, VisitorHint::I64, visitor)
}
_ => self.deserialize_any(visitor),
}
Expand All @@ -120,8 +138,14 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
V: Visitor<'de>,
{
match *self.schema_node {
SchemaNode::Decimal(ref decimal) => {
read_decimal(self.state, decimal, VisitorHint::U128, visitor)
SchemaNode::Decimal(ref decimal) => read_decimal(
self.state,
DecimalMode::Regular(decimal),
VisitorHint::U128,
visitor,
),
SchemaNode::BigDecimal => {
read_decimal(self.state, DecimalMode::Big, VisitorHint::U128, visitor)
}
_ => self.deserialize_any(visitor),
}
Expand All @@ -132,8 +156,14 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
V: Visitor<'de>,
{
match *self.schema_node {
SchemaNode::Decimal(ref decimal) => {
read_decimal(self.state, decimal, VisitorHint::I128, visitor)
SchemaNode::Decimal(ref decimal) => read_decimal(
self.state,
DecimalMode::Regular(decimal),
VisitorHint::I128,
visitor,
),
SchemaNode::BigDecimal => {
read_decimal(self.state, DecimalMode::Big, VisitorHint::I128, visitor)
}
_ => self.deserialize_any(visitor),
}
Expand All @@ -147,8 +177,14 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
SchemaNode::Double => {
visitor.visit_f64(f64::from_le_bytes(self.state.read_const_size_buf()?))
}
SchemaNode::Decimal(ref decimal) => {
read_decimal(self.state, decimal, VisitorHint::F64, visitor)
SchemaNode::Decimal(ref decimal) => read_decimal(
self.state,
DecimalMode::Regular(decimal),
VisitorHint::F64,
visitor,
),
SchemaNode::BigDecimal => {
read_decimal(self.state, DecimalMode::Big, VisitorHint::F64, visitor)
}
_ => self.deserialize_any(visitor),
}
Expand Down Expand Up @@ -314,6 +350,11 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
where
V: Visitor<'de>,
{
// Depending on the schema node type, it may represent the identifier of the
// enum variant directly (e.g. String type may be used to represent the enum
// variant name) If that's the case then we need to propose it as variant name
// when deserializing, otherwise we should propose the type as variant name and
// propagate deserialization of the current node to the variant's inner value
match *self.schema_node {
SchemaNode::Union(ref union) => visitor.visit_enum(SchemaTypeNameEnumAccess {
variant_schema: read_union_discriminant(self.state, union)?,
Expand All @@ -338,6 +379,7 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R>
| SchemaNode::Map(_)
| SchemaNode::Record(_)
| SchemaNode::Decimal(_)
| SchemaNode::BigDecimal
| SchemaNode::Uuid
| SchemaNode::Date
| SchemaNode::TimeMillis
Expand Down
87 changes: 80 additions & 7 deletions serde_avro_fast/src/de/deserializer/types/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,65 @@
use super::*;

use {rust_decimal::prelude::ToPrimitive as _, std::marker::PhantomData};
use {
rust_decimal::prelude::ToPrimitive as _,
std::{io::Read, marker::PhantomData},
};

pub(in super::super) enum DecimalMode<'a> {
Big,
Regular(&'a Decimal),
}

pub(in super::super) fn read_decimal<'de, R, V>(
state: &mut DeserializerState<R>,
decimal: &Decimal,
decimal_mode: DecimalMode<'_>,
hint: VisitorHint,
visitor: V,
) -> Result<V::Value, DeError>
where
R: ReadSlice<'de>,
V: Visitor<'de>,
{
let size = match decimal.repr {
DecimalRepr::Bytes => read_len(state)?,
DecimalRepr::Fixed(ref fixed) => fixed.size,
let (size, mut reader) = match decimal_mode {
DecimalMode::Big => {
// BigDecimal are represented as bytes, and inside the bytes contain a length
// marker followed by the actual bytes, followed by another Long that represents
// the scale.

let bytes_len = state.read_varint::<i64>()?.try_into().map_err(|e| {
DeError::custom(format_args!(
"Invalid BigDecimal bytes length in stream: {e}"
))
})?;

let mut reader = (&mut state.reader).take(bytes_len);

// Read the unsized repr len
let unsized_len = integer_encoding::VarIntReader::read_varint::<i64>(&mut reader)
.map_err(DeError::io)?
.try_into()
.map_err(|e| {
DeError::custom(format_args!("Invalid BigDecimal length in bytes: {e}"))
})?;

(unsized_len, ReaderEither::Take(reader))
}
DecimalMode::Regular(Decimal {
repr: DecimalRepr::Bytes,
..
}) => (read_len(state)?, ReaderEither::Reader(&mut state.reader)),
DecimalMode::Regular(Decimal {
repr: DecimalRepr::Fixed(fixed),
..
}) => (fixed.size, ReaderEither::Reader(&mut state.reader)),
};
let mut buf = [0u8; 16];
let start = buf.len().checked_sub(size).ok_or_else(|| {
DeError::custom(format_args!(
"Decimals of size larger than 16 are not supported (got size {size})"
))
})?;
state.read_exact(&mut buf[start..]).map_err(DeError::io)?;
reader.read_exact(&mut buf[start..]).map_err(DeError::io)?;
if buf.get(start).map_or(false, |&v| v & 0x80 != 0) {
// This is a negative number in CA2 repr, we need to maintain that for the
// larger number
Expand All @@ -31,7 +68,30 @@ where
}
}
let unscaled = i128::from_be_bytes(buf);
let scale = decimal.scale;
let scale = match decimal_mode {
DecimalMode::Big => integer_encoding::VarIntReader::read_varint::<i64>(&mut reader)
.map_err(DeError::io)?
.try_into()
.map_err(|e| {
DeError::custom(format_args!("Invalid BigDecimal scale in stream: {e}"))
})?,
DecimalMode::Regular(Decimal { scale, .. }) => *scale,
};
match reader {
ReaderEither::Take(take) => {
if take.limit() > 0 {
// This would be incorrect if we don't skip the extra bytes
// in the original reader.
// Arguably we could just ignore the extra bytes, but until proven
// that this is a real use-case we'll just do the conservative thing
// and encourage people to use the appropriate number of bytes.
return Err(DeError::new(
"BigDecimal scale is not at the end of the bytes",
));
}
}
ReaderEither::Reader(_) => {}
}
if scale == 0 {
match hint {
VisitorHint::U64 => {
Expand Down Expand Up @@ -111,3 +171,16 @@ impl<'de, V: Visitor<'de>> serde::Serializer for SerializeToVisitorStr<'de, V> {
struct_variant i128 u128
}
}

enum ReaderEither<'a, R> {
Reader(&'a mut R),
Take(std::io::Take<&'a mut R>),
}
impl<R: Read> Read for ReaderEither<'_, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
ReaderEither::Reader(reader) => reader.read(buf),
ReaderEither::Take(reader) => reader.read(buf),
}
}
}
1 change: 1 addition & 0 deletions serde_avro_fast/src/de/deserializer/types/union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ impl<'de> Deserializer<'de> for SchemaTypeNameDeserializer<'_> {
repr: DecimalRepr::Bytes,
..
}) => "Decimal",
SchemaNode::BigDecimal => "BigDecimal",
SchemaNode::Uuid => "Uuid",
SchemaNode::Date => "Date",
SchemaNode::TimeMillis => "TimeMillis",
Expand Down
4 changes: 4 additions & 0 deletions serde_avro_fast/src/schema/safe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,9 @@ pub enum LogicalType {
/// tuple, or to its raw representation [as defined by the specification](https://avro.apache.org/docs/current/specification/#duration)
/// if the deserializer is hinted this way ([`serde_bytes`](https://docs.rs/serde_bytes/latest/serde_bytes/)).
Duration,
/// Logical type which represents `Decimal` values without predefined scale.
/// The underlying type is serialized and deserialized as `Schema::Bytes`
BigDecimal,
/// A logical type that is not known or not handled in any particular way
/// by this library.
///
Expand Down Expand Up @@ -549,6 +552,7 @@ impl LogicalType {
LogicalType::TimestampMillis => "timestamp-millis",
LogicalType::TimestampMicros => "timestamp-micros",
LogicalType::Duration => "duration",
LogicalType::BigDecimal => "big-decimal",
LogicalType::Unknown(unknown_logical_type) => &unknown_logical_type.logical_type_name,
}
}
Expand Down
1 change: 1 addition & 0 deletions serde_avro_fast/src/schema/safe/parsing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ impl<'a> SchemaConstructionState<'a> {
"timestamp-millis" => LogicalType::TimestampMillis,
"timestamp-micros" => LogicalType::TimestampMicros,
"duration" => LogicalType::Duration,
"big-decimal" => LogicalType::BigDecimal,
unknown => LogicalType::Unknown(UnknownLogicalType::new(unknown)),
}
}),
Expand Down
3 changes: 2 additions & 1 deletion serde_avro_fast/src/schema/safe/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ impl Serialize for SerializeSchema<'_, SchemaKey> {
| LogicalType::TimeMicros
| LogicalType::TimestampMillis
| LogicalType::TimestampMicros
| LogicalType::Duration => {}
| LogicalType::Duration
| LogicalType::BigDecimal => {}
LogicalType::Unknown(_) => {}
}
} else {
Expand Down
6 changes: 6 additions & 0 deletions serde_avro_fast/src/schema/self_referential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ pub(crate) enum SchemaNode<'a> {
Enum(Enum),
Fixed(Fixed),
Decimal(Decimal),
BigDecimal,
Uuid,
Date,
TimeMillis,
Expand Down Expand Up @@ -339,6 +340,10 @@ impl TryFrom<super::safe::SchemaMut> for Schema {
logical_type: Some(LogicalType::Duration),
type_: SafeSchemaType::Fixed(fixed),
} if fixed.size == 12 => SchemaNode::Duration,
SafeSchemaNode {
logical_type: Some(LogicalType::BigDecimal),
type_: SafeSchemaType::Bytes,
} => SchemaNode::BigDecimal,
_ => match safe_node.type_ {
SafeSchemaType::Null => SchemaNode::Null,
SafeSchemaType::Boolean => SchemaNode::Boolean,
Expand Down Expand Up @@ -527,6 +532,7 @@ impl<'a> std::fmt::Debug for SchemaNode<'a> {
}
d.finish()
}
SchemaNode::BigDecimal => f.debug_tuple("BigDecimal").finish(),
SchemaNode::Uuid => f.debug_tuple("Uuid").finish(),
SchemaNode::Date => f.debug_tuple("Date").finish(),
SchemaNode::TimeMillis => f.debug_tuple("TimeMillis").finish(),
Expand Down
8 changes: 8 additions & 0 deletions serde_avro_fast/src/schema/union_variants_per_type_lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,14 @@ impl<'a> PerTypeLookup<'a> {
register(UnionVariantLookupKey::Float8, 2);
register(UnionVariantLookupKey::Str, 20);
}
SchemaNode::BigDecimal => {
register_type_name("BigDecimal");
register(UnionVariantLookupKey::Integer, 5);
register(UnionVariantLookupKey::Integer4, 5);
register(UnionVariantLookupKey::Integer8, 5);
register(UnionVariantLookupKey::Float8, 2);
register(UnionVariantLookupKey::Str, 20);
}
SchemaNode::Uuid => {
register_type_name("Uuid");
// A user may assume that uuid::Uuid will serialize to Uuid by default,
Expand Down
Loading

0 comments on commit 98aaa86

Please sign in to comment.