diff --git a/serde_avro_fast/src/de/deserializer/mod.rs b/serde_avro_fast/src/de/deserializer/mod.rs index 15fa47f..f7a83c3 100644 --- a/serde_avro_fast/src/de/deserializer/mod.rs +++ b/serde_avro_fast/src/de/deserializer/mod.rs @@ -38,11 +38,11 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R> SchemaNode::String => read_length_delimited(self.state, StringVisitor(visitor)), SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess { elements_schema: elements_schema.as_ref(), - block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?), + block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?), }), SchemaNode::Map(elements_schema) => visitor.visit_map(MapMapAccess { elements_schema: elements_schema.as_ref(), - block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?), + block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?), }), SchemaNode::Union(ref union) => Self { schema_node: read_union_discriminant(self.state, union)?, @@ -283,7 +283,7 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R> match *self.schema_node { SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess { elements_schema: elements_schema.as_ref(), - block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?), + block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?), }), SchemaNode::Duration => visitor.visit_seq(DurationMapAndSeqAccess { duration_buf: &self.state.read_const_size_buf::<12>()?, @@ -300,7 +300,7 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R> match *self.schema_node { SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess { elements_schema: elements_schema.as_ref(), - block_reader: BlockReader::new(self.state, self.allowed_depth.dec()?), + block_reader: BlockReader::new(self.state, false, self.allowed_depth.dec()?), }), SchemaNode::Duration if len == 3 => visitor.visit_seq(DurationMapAndSeqAccess { duration_buf: &self.state.read_const_size_buf::<12>()?, @@ -417,18 +417,20 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for DatumDeserializer<'_, '_, R> where V: Visitor<'de>, { - // The main thing we can skip here for performance is utf8 decoding of strings. - // However we still need to drive the deserializer mostly normally to properly - // advance the reader. - - // TODO skip more efficiently using blocks size hints - // https://stackoverflow.com/a/42247224/3799609 - - // Ideally this would also specialize if we have Seek on our generic reader but - // we don't have specialization + // We can skip here for performance: + // - utf8 decoding of strings + // - block reads when serialized data provides serialized block size in bytes match *self.schema_node { SchemaNode::String => read_length_delimited(self.state, BytesVisitor(visitor)), + SchemaNode::Array(elements_schema) => visitor.visit_seq(ArraySeqAccess { + elements_schema: elements_schema.as_ref(), + block_reader: BlockReader::new(self.state, true, self.allowed_depth.dec()?), + }), + SchemaNode::Map(elements_schema) => visitor.visit_map(MapMapAccess { + elements_schema: elements_schema.as_ref(), + block_reader: BlockReader::new(self.state, true, self.allowed_depth.dec()?), + }), _ => self.deserialize_any(visitor), } } diff --git a/serde_avro_fast/src/de/deserializer/types/blocks.rs b/serde_avro_fast/src/de/deserializer/types/blocks.rs index bf8243e..d94e613 100644 --- a/serde_avro_fast/src/de/deserializer/types/blocks.rs +++ b/serde_avro_fast/src/de/deserializer/types/blocks.rs @@ -2,26 +2,42 @@ use super::*; use std::num::NonZeroUsize; -fn read_block_len<'de, R>(state: &mut DeserializerState) -> Result, DeError> +fn read_block_len<'de, R>( + state: &mut DeserializerState, + ignored: bool, +) -> Result, DeError> where R: ReadSlice<'de>, { - let len: i64 = state.read_varint()?; - let res; - if len < 0 { - // res = -len, properly handling i64::MIN - res = u64::from_ne_bytes(len.to_ne_bytes()).wrapping_neg(); - // Drop the number of bytes in the block to properly advance the reader - // Since we don't use that value, decode as u64 instead of i64 (skip zigzag - // decoding) TODO enable fast skipping when encountering - // `deserialize_ignored_any` - let _: u64 = state.read_varint()?; - } else { - res = len as u64; + loop { + let len: i64 = state.read_varint()?; + let res; + if len < 0 { + if ignored { + // We have block length hint in the data, and we are ignoring the data, so we + // can skip the block efficiently + let block_len_in_bytes: i64 = state.read_varint()?; + let block_len_in_bytes: u64 = block_len_in_bytes.try_into().map_err(|e| { + DeError::custom(format_args!("Invalid block length in stream: {e}")) + })?; + state.skip_bytes(block_len_in_bytes)?; + continue; // Also discard next blocks if any + } else { + // res = -len, properly handling i64::MIN + res = u64::from_ne_bytes(len.to_ne_bytes()).wrapping_neg(); + // Drop the number of bytes in the block to properly advance the reader + // Since we don't use that value, decode as u64 instead of i64 (skip zigzag + // decoding) + let _: u64 = state.read_varint()?; + } + } else { + res = len as u64; + } + break res + .try_into() + .map_err(|e| DeError::custom(format_args!("Invalid array length in stream: {e}"))) + .map(NonZeroUsize::new); } - res.try_into() - .map_err(|e| DeError::custom(format_args!("Invalid array length in stream: {e}"))) - .map(NonZeroUsize::new) } pub(in super::super) struct BlockReader<'r, 's, R> { @@ -29,10 +45,14 @@ pub(in super::super) struct BlockReader<'r, 's, R> { n_read: usize, reader: &'r mut DeserializerState<'s, R>, allowed_depth: AllowedDepth, + /// Represents whether we were hinted deserialize_ignored_any. If yes, we + /// can use the block length to skip the block. + ignored: bool, } impl<'r, 's, R> BlockReader<'r, 's, R> { pub(in super::super) fn new( reader: &'r mut DeserializerState<'s, R>, + hinted_ignored: bool, allowed_depth: AllowedDepth, ) -> Self { Self { @@ -40,6 +60,7 @@ impl<'r, 's, R> BlockReader<'r, 's, R> { current_block_len: 0, n_read: 0, allowed_depth, + ignored: hinted_ignored, } } fn has_more<'de>(&mut self) -> Result @@ -48,7 +69,7 @@ impl<'r, 's, R> BlockReader<'r, 's, R> { { self.current_block_len = match self.current_block_len.checked_sub(1) { None => { - let new_len = read_block_len(self.reader)?; + let new_len = read_block_len(self.reader, self.ignored)?; match new_len { None => return Ok(false), Some(new_len) => { @@ -137,9 +158,17 @@ impl<'de, R: ReadSlice<'de>> Deserializer<'de> for StringDeserializer<'_, '_, R> read_length_delimited(self.reader, StringVisitor(visitor)) } + fn deserialize_ignored_any(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + // Skip utf-8 validation + read_length_delimited(self.reader, BytesVisitor(visitor)) + } + serde::forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string bytes byte_buf option unit unit_struct newtype_struct seq tuple - tuple_struct map struct enum identifier ignored_any + tuple_struct map struct enum identifier } } diff --git a/serde_avro_fast/src/de/read/mod.rs b/serde_avro_fast/src/de/read/mod.rs index 03cd773..3881c6d 100644 --- a/serde_avro_fast/src/de/read/mod.rs +++ b/serde_avro_fast/src/de/read/mod.rs @@ -30,6 +30,22 @@ pub trait Read: std::io::Read + Sized + private::Sealed { self.read_exact(&mut buf).map_err(DeError::io)?; Ok(buf) } + /// Skip `n_bytes` bytes from the underlying buffer + fn skip_bytes(&mut self, n_bytes: u64) -> Result<(), DeError> { + let written = std::io::copy( + &mut <&mut Self as std::io::Read>::take(self, n_bytes), + &mut std::io::sink(), + ) + .map_err(DeError::io)?; + if written == n_bytes { + Ok(()) + } else { + Err(DeError::custom(format_args!( + "Expected to skip {} bytes, but only skipped {}", + n_bytes, written + ))) + } + } } /// Abstracts reading from slices (propagating lifetime) or any other `impl @@ -74,6 +90,18 @@ impl<'de> Read for SliceRead<'de> { } } } + fn skip_bytes(&mut self, n_bytes: u64) -> Result<(), DeError> { + let n_bytes: usize = n_bytes + .try_into() + .map_err(|_| DeError::custom("Invalid number of bytes to skip"))?; + match self.slice.get(n_bytes..) { + Some(rest) => { + self.slice = rest; + Ok(()) + } + None => Err(DeError::unexpected_eof()), + } + } } impl<'de> ReadSlice<'de> for SliceRead<'de> { fn read_slice(&mut self, n: usize, visitor: V) -> Result diff --git a/serde_avro_fast/tests/deserialize_ignored.rs b/serde_avro_fast/tests/deserialize_ignored.rs new file mode 100644 index 0000000..eb1be89 --- /dev/null +++ b/serde_avro_fast/tests/deserialize_ignored.rs @@ -0,0 +1,47 @@ +use serde_avro_fast::Schema; + +const SCHEMA: &str = r#" +{ + "fields": [ + { + "type": {"type": "array", "items": "int"}, + "name": "a" + }, + { + "type": {"type": "array", "items": "int"}, + "name": "b" + }, + { + "type": {"type": "array", "items": "int"}, + "name": "cd" + } + ], + "type": "record", + "name": "test_skip" +} +"#; + +#[derive(Debug, PartialEq, Eq, serde::Deserialize)] +struct TestSkip { + a: Vec, + cd: Vec, +} + +#[test] +fn skip_block() { + let schema: Schema = SCHEMA.parse().unwrap(); + let input: &[u8] = &[1, 2, 20, 0, 1, 2, 30, 1, 4, 31, 32, 0, 4, 40, 50, 0, 0xFF]; + let expected = TestSkip { + a: vec![10], + cd: vec![20, 25], + }; + + let deserialized: TestSkip = serde_avro_fast::from_datum_slice(input, &schema).unwrap(); + assert_eq!(deserialized, expected); + + let mut reader = &input[..]; + let deserialized: TestSkip = serde_avro_fast::from_datum_reader(&mut reader, &schema).unwrap(); + assert_eq!(deserialized, expected); + // Also make sure that the reader stopped at the end of the block + assert_eq!(reader, &[0xFF]); +}