Skip to content

Commit

Permalink
Fix alignment issue of TensorData bytes (#2416)
Browse files Browse the repository at this point in the history
* implement memory-safe bytes that can be serialized and cloned

* change serialization to only serialize the bytes

introduce max alignment (which depends on platform anyway) and dont serialize that part
fixes Clone, Debug, and Eq impls to work on the bytes, not the pointers.

* make bytes no-std compatible

* enforce Send and Sync for Bytes

* avoid a copy during deserialization if data is already aligned

this already improves readability a bit by separating out alloc/dealloc logic
and adding a bunch of safety comments and better error messages

* revert back to using Vec as deserialization intermediate

borrowing from the deserializer will not save a copy, and is moreover
inefficient when we could take ownership of an existing byte buffer

* add serialization and conversion tests

* make Bytes tests run under miri

both changes only target miri's borrowing semantics, oprationally
the pointers are the same, but they obey different borrow-stack rules.

* let the Bytes buffer grow

* Clean the code by separation of concerns

The new Allocation struct keeps the raw allocation and its layout,
the Bytes struct wraps an Allocation and asserts that len bytes of it are initialized

* nit: change typo and improve internal naming

* use Bytes in jit ops
  • Loading branch information
WorldSEnder authored Dec 17, 2024
1 parent 8a89293 commit 28f99d1
Show file tree
Hide file tree
Showing 13 changed files with 641 additions and 80 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions crates/burn-core/src/record/serde/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::record::{PrecisionSettings, Record};
use crate::tensor::backend::Backend;

use alloc::fmt;
use burn_tensor::Bytes;
use num_traits::cast::ToPrimitive;
use regex::Regex;
use serde::Deserialize;
Expand Down Expand Up @@ -66,7 +67,11 @@ pub enum NestedValue {

/// A vector of 32-bit floating point values.
F32s(Vec<f32>),

/// An opaque vector of bytes, with alignment.
Bytes(Bytes),
}

impl NestedValue {
/// Get the nested value as a map.
pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {
Expand Down Expand Up @@ -184,9 +189,10 @@ impl NestedValue {
}

/// Get the nested value as a vector of bytes.
pub fn as_bytes(self) -> Option<Vec<u8>> {
pub fn as_bytes(self) -> Option<Bytes> {
match self {
NestedValue::U8s(u) => Some(u),
NestedValue::Bytes(u) => Some(u),
NestedValue::U8s(u) => Some(Bytes::from_elems(u)),
_ => None,
}
}
Expand Down Expand Up @@ -368,6 +374,7 @@ impl fmt::Debug for NestedValue {
NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f),
// Handle other variants as usual
NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(),
NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(),
Expand All @@ -385,6 +392,7 @@ impl fmt::Debug for NestedValue {
NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(),
}
}
}
6 changes: 5 additions & 1 deletion crates/burn-core/src/record/serde/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ impl<'de, A: BurnModuleAdapter> serde::Deserializer<'de> for Deserializer<A> {
where
V: Visitor<'de>,
{
visitor.visit_byte_buf(self.value.unwrap().as_bytes().unwrap())
let bytes = self.value.unwrap().as_bytes().unwrap();
match bytes.try_into_vec::<u8>() {
Ok(bytes) => visitor.visit_byte_buf(bytes),
Err(bytes) => visitor.visit_bytes(&bytes),
}
}

fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-core/src/record/serde/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,6 @@ mod tests {
.clone()
.as_bytes()
.expect("has bytes vec");
assert_eq!(bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());
assert_eq!(&*bytes, [1.0f32; 4].map(|f| f.to_le_bytes()).as_flattened());
}
}
2 changes: 1 addition & 1 deletion crates/burn-import/src/pytorch/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ where
// Because serializer copies individual elements of TensorData `value` into a new Vec<u8>,
// which is not necessary and inefficient.
let mut tensor_data: HashMap<String, NestedValue> = HashMap::new();
tensor_data.insert("bytes".into(), NestedValue::U8s(bytes));
tensor_data.insert("bytes".into(), NestedValue::Bytes(bytes));
tensor_data.insert("shape".into(), shape.serialize(serializer.clone())?);
tensor_data.insert("dtype".into(), dtype.serialize(serializer)?);

Expand Down
3 changes: 1 addition & 2 deletions crates/burn-jit/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ use std::{marker::PhantomData, sync::Mutex};
#[cfg(not(feature = "fusion"))]
use burn_tensor::{
ops::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor},
quantization::QuantizationScheme,
repr::{HandleKind, ReprBackend, TensorHandle},
repr::{ReprBackend, TensorHandle},
};

pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
Expand Down
8 changes: 2 additions & 6 deletions crates/burn-jit/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::Range;
use burn_tensor::{
ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
DType, Device, Shape, TensorData,
Bytes, DType, Device, Shape, TensorData,
};

use crate::{
Expand Down Expand Up @@ -82,12 +82,8 @@ where
let tensor = kernel::into_contiguous(tensor);
let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;

// TODO: this should be refactored such that the bytes type is opaque.
// With this, the logic for handling the bytes representation of quantized data
// (as well as all data manipulations) will be encapsulated in the type.
// Creating a TensorData struct directly from some bytes should probably not be possible outside of the crate.
TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape: tensor.shape.into(),
dtype: tensor.dtype,
}
Expand Down
8 changes: 4 additions & 4 deletions crates/burn-jit/src/ops/transaction.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use burn_tensor::{
ops::{TransactionOps, TransactionPrimitiveResult},
DType, TensorData,
Bytes, DType, TensorData,
};

use crate::{element::BoolElement, FloatElement, IntElement, JitBackend, JitRuntime};
Expand Down Expand Up @@ -74,23 +74,23 @@ where
Kind::Float(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_floats.push(TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
}
Kind::Int(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_ints.push(TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
}
Kind::Bool(index, shape, dtype) => {
let bytes = data.get_mut(index).unwrap().take().unwrap();
result.read_bools.push(TensorData {
bytes,
bytes: Bytes::from_bytes_vec(bytes),
shape,
dtype,
});
Expand Down
1 change: 1 addition & 0 deletions crates/burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ portable-atomic-util = { workspace = true }

[dev-dependencies]
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
bincode = { workspace = true }

[package.metadata.docs.rs]
features = ["doc"]
Expand Down
Loading

0 comments on commit 28f99d1

Please sign in to comment.