From 40a26bd2ea0000845fae167a9a7cb28bd3a38d26 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 26 Mar 2024 19:24:45 -0400 Subject: [PATCH] Feat/backend bridge (#1529) --- crates/burn-autodiff/src/backend.rs | 4 +- crates/burn-autodiff/src/bridge.rs | 155 ++++++++++++++++++ crates/burn-autodiff/src/lib.rs | 3 + crates/burn-autodiff/src/ops/tensor.rs | 105 +----------- crates/burn-autodiff/src/tests/bridge.rs | 29 ++++ crates/burn-autodiff/src/tests/mod.rs | 2 + crates/burn-candle/src/backend.rs | 5 +- crates/burn-candle/src/bridge.rs | 35 ++++ crates/burn-candle/src/lib.rs | 3 + crates/burn-candle/src/ops/tensor.rs | 12 -- crates/burn-fusion/src/backend.rs | 6 +- crates/burn-fusion/src/bridge.rs | 25 +++ crates/burn-fusion/src/lib.rs | 2 + crates/burn-fusion/src/ops/float.rs | 14 +- crates/burn-jit/src/backend.rs | 5 +- crates/burn-jit/src/bridge.rs | 62 +++++++ crates/burn-jit/src/lib.rs | 5 +- crates/burn-jit/src/ops/float_ops.rs | 21 +-- crates/burn-jit/src/runtime.rs | 2 +- crates/burn-ndarray/src/backend.rs | 5 +- crates/burn-ndarray/src/bridge.rs | 35 ++++ crates/burn-ndarray/src/lib.rs | 2 + crates/burn-ndarray/src/ops/tensor.rs | 16 -- crates/burn-tch/src/backend.rs | 5 +- crates/burn-tch/src/bridge.rs | 49 ++++++ crates/burn-tch/src/lib.rs | 2 + crates/burn-tch/src/ops/tensor.rs | 14 -- .../burn-tensor/src/tensor/activation/base.rs | 2 +- crates/burn-tensor/src/tensor/api/float.rs | 7 +- crates/burn-tensor/src/tensor/backend/base.rs | 9 +- .../burn-tensor/src/tensor/backend/bridge.rs | 21 +++ crates/burn-tensor/src/tensor/backend/mod.rs | 3 + .../burn-tensor/src/tensor/ops/activation.rs | 19 ++- crates/burn-tensor/src/tensor/ops/alias.rs | 5 +- crates/burn-tensor/src/tensor/ops/tensor.rs | 13 +- crates/burn-wgpu/src/runtime.rs | 1 + 36 files changed, 483 insertions(+), 220 deletions(-) create mode 100644 crates/burn-autodiff/src/bridge.rs create mode 100644 crates/burn-autodiff/src/tests/bridge.rs create mode 100644 crates/burn-candle/src/bridge.rs create mode 100644 crates/burn-fusion/src/bridge.rs create mode 100644 crates/burn-jit/src/bridge.rs create mode 100644 crates/burn-ndarray/src/bridge.rs create mode 100644 crates/burn-tch/src/bridge.rs create mode 100644 crates/burn-tensor/src/tensor/backend/bridge.rs diff --git a/crates/burn-autodiff/src/backend.rs b/crates/burn-autodiff/src/backend.rs index 8c3e87b205..8bc5b57d66 100644 --- a/crates/burn-autodiff/src/backend.rs +++ b/crates/burn-autodiff/src/backend.rs @@ -3,6 +3,7 @@ use crate::{ grads::Gradients, graph::backward::backward, tensor::AutodiffTensor, + AutodiffBridge, }; use burn_tensor::backend::{AutodiffBackend, Backend}; use core::marker::PhantomData; @@ -20,8 +21,7 @@ pub struct Autodiff { impl Backend for Autodiff { type Device = B::Device; - type FullPrecisionElem = B::FullPrecisionElem; - type FullPrecisionBackend = Autodiff; + type FullPrecisionBridge = AutodiffBridge; type FloatTensorPrimitive = AutodiffTensor; type FloatElem = B::FloatElem; diff --git a/crates/burn-autodiff/src/bridge.rs b/crates/burn-autodiff/src/bridge.rs new file mode 100644 index 0000000000..3edf450db6 --- /dev/null +++ b/crates/burn-autodiff/src/bridge.rs @@ -0,0 +1,155 @@ +use std::marker::PhantomData; + +use burn_tensor::{ + backend::{Backend, BackendBridge}, + ops::FloatTensor, +}; + +use crate::{ + checkpoint::{ + base::Checkpointer, retro_forward::RetroForward, state::BackwardStates, + strategy::CheckpointStrategy, + }, + grads::Gradients, + ops::{unary_different_backend, Backward, Ops}, + Autodiff, NodeID, +}; + +/// Enable autodiff on a [backend bridge](BackendBridge). +#[derive(Debug)] +pub struct AutodiffBridge { + _p: PhantomData, +} + +impl BackendBridge> for AutodiffBridge +where + B: Backend, + C: CheckpointStrategy, + Bridge: BackendBridge + 'static, +{ + type Target = Autodiff; + + fn into_target( + tensor: burn_tensor::ops::FloatTensor, D>, + _device: Option>, + ) -> burn_tensor::ops::FloatTensor { + #[derive(Debug)] + struct IntoTarget> { + _backend: PhantomData, + _bridge: PhantomData, + } + + #[derive(new, Debug)] + struct RetroIntoTarget, const D: usize> { + tensor_id: NodeID, + _backend: PhantomData, + _bridge: PhantomData, + } + + impl RetroForward for RetroIntoTarget + where + B: Backend, + Bridge: BackendBridge + 'static, + { + fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { + let tensor: FloatTensor = states.get_state(&self.tensor_id); + let out = Bridge::into_target(tensor, Default::default()); + states.save(out_node, out) + } + } + + impl Backward for IntoTarget + where + B: Backend, + Bridge: BackendBridge + 'static, + { + type State = (); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + unary_different_backend::( + ops.parents, + ops.node, + grads, + |grad| Bridge::from_target(grad, None), + ); + } + } + + IntoTarget:: { + _backend: PhantomData, + _bridge: PhantomData, + } + .prepare::([tensor.node.clone()], [tensor.graph.clone()]) + .memory_bound() + .retro_forward(RetroIntoTarget::::new(tensor.node.id.clone())) + .parents([&tensor]) + .stateless(Bridge::into_target(tensor.primitive, None)) + } + + fn from_target( + tensor: burn_tensor::ops::FloatTensor, + _device: Option>>, + ) -> burn_tensor::ops::FloatTensor, D> { + #[derive(Debug)] + struct FromTarget> { + _backend: PhantomData, + _bridge: PhantomData, + } + + #[derive(new, Debug)] + struct RetroFromTarget, const D: usize> { + tensor_id: NodeID, + _backend: PhantomData, + _bridge: PhantomData, + } + + impl RetroForward for RetroFromTarget + where + B: Backend, + Bridge: BackendBridge + 'static, + { + fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { + let tensor: FloatTensor = states.get_state(&self.tensor_id); + let out = Bridge::from_target(tensor, None); + states.save(out_node, out) + } + } + + impl Backward for FromTarget + where + B: Backend, + Bridge: BackendBridge + 'static, + { + type State = (); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + unary_different_backend::( + ops.parents, + ops.node, + grads, + |grad| Bridge::into_target(grad, None), + ); + } + } + + FromTarget:: { + _backend: PhantomData, + _bridge: PhantomData, + } + .prepare::([tensor.node.clone()], [tensor.graph.clone()]) + .memory_bound() + .retro_forward(RetroFromTarget::::new(tensor.node.id.clone())) + .parents([&tensor]) + .stateless(Bridge::from_target(tensor.primitive, None)) + } +} diff --git a/crates/burn-autodiff/src/lib.rs b/crates/burn-autodiff/src/lib.rs index 22d84590ad..7c3c256605 100644 --- a/crates/burn-autodiff/src/lib.rs +++ b/crates/burn-autodiff/src/lib.rs @@ -26,7 +26,10 @@ pub(crate) mod tensor; pub(crate) mod utils; mod backend; +mod bridge; + pub use backend::*; +pub use bridge::*; #[cfg(feature = "export_tests")] mod tests; diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 0da3975731..58adc86acf 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -7,7 +7,7 @@ use crate::{ }, grads::Gradients, graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step}, - ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind}, + ops::{binary, broadcast_shape, unary, Backward, Ops, OpsKind}, retro_binary, retro_unary, retro_unary_scalar, tensor::AutodiffTensor, utils::duplicate, @@ -16,7 +16,7 @@ use crate::{ use burn_tensor::{ backend::Backend, - ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor}, + ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, Data, Device, ElementConversion, Reader, Shape, Tensor, }; @@ -1621,107 +1621,6 @@ impl FloatTensorOps for Autodiff } } - fn float_to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - #[derive(Debug)] - struct ToFullPrecision { - phantom: PhantomData, - } - - #[derive(new, Debug)] - struct RetroToFullPrecision { - tensor_id: NodeID, - _backend: PhantomData, - } - - impl RetroForward for RetroToFullPrecision { - fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { - let tensor = states.get_state::>(&self.tensor_id); - let out = B::float_to_full_precision(&tensor); - states.save(out_node, out) - } - } - - impl Backward for ToFullPrecision { - type State = (); - - fn backward( - self, - ops: Ops, - grads: &mut Gradients, - _checkpointer: &mut Checkpointer, - ) { - unary_different_backend::( - ops.parents, - ops.node, - grads, - |grad| B::float_from_full_precision(grad), - ); - } - } - - let ops = ToFullPrecision:: { - phantom: PhantomData, - }; - ops.prepare::([tensor.node.clone()], [tensor.graph.clone()]) - .memory_bound() - .retro_forward(RetroToFullPrecision::::new(tensor.node.id.clone())) - .parents([tensor]) - .stateless(B::float_to_full_precision(&tensor.primitive)) - } - - fn float_from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - #[derive(Debug)] - struct FromFullPrecision { - phantom: PhantomData, - } - - #[derive(new, Debug)] - struct RetroFromFullPrecision { - tensor_id: NodeID, - _backend: PhantomData, - } - - impl RetroForward for RetroFromFullPrecision { - fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { - let tensor = states.get_state::<<::FullPrecisionBackend as Backend>::FloatTensorPrimitive>(&self.tensor_id); - let out = B::float_from_full_precision(tensor); - states.save(out_node, out) - } - } - - impl Backward for FromFullPrecision { - type State = (); - - fn backward( - self, - ops: Ops, - grads: &mut Gradients, - _checkpointer: &mut Checkpointer, - ) { - unary_different_backend::( - ops.parents, - ops.node, - grads, - |grad| B::float_to_full_precision(&grad), - ); - } - } - - let ops = FromFullPrecision:: { - phantom: PhantomData, - }; - - ops.prepare::([tensor.node.clone()], [tensor.graph.clone()]) - .memory_bound() - .retro_forward(RetroFromFullPrecision::::new(tensor.node.id.clone())) - .parents([&tensor]) - .stateless(B::float_from_full_precision(tensor.primitive)) - } - fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { B::float_argmax(tensor.primitive, dim) } diff --git a/crates/burn-autodiff/src/tests/bridge.rs b/crates/burn-autodiff/src/tests/bridge.rs new file mode 100644 index 0000000000..2a2a639d52 --- /dev/null +++ b/crates/burn-autodiff/src/tests/bridge.rs @@ -0,0 +1,29 @@ +#[burn_tensor_testgen::testgen(bridge)] +mod tests { + use super::*; + use burn_tensor::{backend::Backend, module::embedding, Data, Distribution, Int, Tensor}; + + #[test] + fn test_full_precision() { + let device = Default::default(); + let x1 = Tensor::::random([32, 32], Distribution::Default, &device) + .require_grad(); + let x2 = Tensor::::random([32, 32], Distribution::Default, &device) + .require_grad(); + + let x3 = x1.clone().into_full_precision(); + let x4 = x2.clone().into_full_precision(); + + let x5 = x3.matmul(x4); + let x6 = Tensor::::from_full_precision(x5); + let x7 = x6 * x1.clone() / x2.clone(); + + let mut grads = x7.backward(); + + let x1_grad = x1.grad(&mut grads); + let x2_grad = x2.grad(&mut grads); + + assert!(x1_grad.is_some()); + assert!(x2_grad.is_some()); + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 9043b43307..e3f459eab3 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -8,6 +8,7 @@ mod aggregation; mod avgpool1d; mod avgpool2d; mod backward; +mod bridge; mod broadcast; mod cat; mod checkpoint; @@ -64,6 +65,7 @@ macro_rules! testgen_all { // Behavior burn_autodiff::testgen_ad_broadcast!(); burn_autodiff::testgen_gradients!(); + burn_autodiff::testgen_bridge!(); burn_autodiff::testgen_checkpoint!(); // Activation diff --git a/crates/burn-candle/src/backend.rs b/crates/burn-candle/src/backend.rs index 12de1f925d..4b0b12d7b6 100644 --- a/crates/burn-candle/src/backend.rs +++ b/crates/burn-candle/src/backend.rs @@ -5,7 +5,7 @@ use candle_core::DeviceLocation; use crate::{ element::{CandleElement, FloatCandleElement, IntCandleElement}, - CandleTensor, + CandleTensor, PrecisionBridge, }; /// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations. @@ -69,8 +69,7 @@ impl Default for CandleDevice { impl Backend for Candle { type Device = CandleDevice; - type FullPrecisionBackend = Candle; - type FullPrecisionElem = f32; + type FullPrecisionBridge = PrecisionBridge; type FloatTensorPrimitive = CandleTensor; type FloatElem = F; diff --git a/crates/burn-candle/src/bridge.rs b/crates/burn-candle/src/bridge.rs new file mode 100644 index 0000000000..d46e90e6e8 --- /dev/null +++ b/crates/burn-candle/src/bridge.rs @@ -0,0 +1,35 @@ +use crate::{ + element::{FloatCandleElement, IntCandleElement}, + Candle, CandleTensor, +}; +use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device}; +use std::marker::PhantomData; + +/// Handle precision conversion for the candle backend. +#[derive(Debug)] +pub struct PrecisionBridge { + _e: PhantomData, +} + +impl BackendBridge> for PrecisionBridge +where + TElem: FloatCandleElement, + OElem: FloatCandleElement, + IntElem: IntCandleElement, +{ + type Target = Candle; + + fn into_target( + tensor: FloatTensor, D>, + device: Option>, + ) -> FloatTensor { + CandleTensor::new(tensor.tensor.to_dtype(TElem::DTYPE).unwrap()) + } + + fn from_target( + tensor: FloatTensor, + device: Option>>, + ) -> FloatTensor, D> { + CandleTensor::new(tensor.tensor.to_dtype(OElem::DTYPE).unwrap()) + } +} diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index a6c6d1c5e0..60f576652f 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -7,10 +7,13 @@ extern crate derive_new; mod backend; +mod bridge; mod element; mod ops; mod tensor; + pub use backend::*; +pub use bridge::*; pub use tensor::*; #[cfg(test)] diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index ca98b85c3a..5c5c5f351e 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -368,18 +368,6 @@ impl FloatTensorOps for Candle CandleTensor::new(tensor.tensor.mean_keepdim(dim).unwrap()) } - fn float_to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - CandleTensor::new(tensor.tensor.to_dtype(candle_core::DType::F32).unwrap()) - } - - fn float_from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - CandleTensor::new(tensor.tensor.to_dtype(F::DTYPE).unwrap()) - } - fn float_exp(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.exp().unwrap()) } diff --git a/crates/burn-fusion/src/backend.rs b/crates/burn-fusion/src/backend.rs index da0954b92f..79c3680472 100644 --- a/crates/burn-fusion/src/backend.rs +++ b/crates/burn-fusion/src/backend.rs @@ -1,7 +1,7 @@ use crate::{ client::FusionClient, stream::{Context, OperationDescription}, - FusionClientLocator, FusionTensor, + FusionClientLocator, FusionTensor, PrecisionBridge, }; use burn_tensor::{backend::Backend, Device, Shape}; use serde::{de::DeserializeOwned, Serialize}; @@ -22,9 +22,7 @@ pub struct Fusion { impl Backend for Fusion { type Device = B::Device; - // TODO: Find a better way to handle full precision. - type FullPrecisionBackend = Self; - type FullPrecisionElem = B::FloatElem; + type FullPrecisionBridge = PrecisionBridge; type FloatTensorPrimitive = FusionTensor; diff --git a/crates/burn-fusion/src/bridge.rs b/crates/burn-fusion/src/bridge.rs new file mode 100644 index 0000000000..375fd4fb52 --- /dev/null +++ b/crates/burn-fusion/src/bridge.rs @@ -0,0 +1,25 @@ +use burn_tensor::backend::BackendBridge; + +use crate::{Fusion, FusionBackend}; + +#[derive(Debug)] +/// Fusion bridge. +pub struct PrecisionBridge; + +impl BackendBridge> for PrecisionBridge { + type Target = Fusion; + + fn into_target( + tensor: burn_tensor::ops::FloatTensor, D>, + _device: Option>, + ) -> burn_tensor::ops::FloatTensor { + tensor + } + + fn from_target( + tensor: burn_tensor::ops::FloatTensor, + _device: Option>>, + ) -> burn_tensor::ops::FloatTensor, D> { + tensor + } +} diff --git a/crates/burn-fusion/src/lib.rs b/crates/burn-fusion/src/lib.rs index 67550661dc..217e434f73 100644 --- a/crates/burn-fusion/src/lib.rs +++ b/crates/burn-fusion/src/lib.rs @@ -14,6 +14,7 @@ pub mod client; pub mod stream; mod backend; +mod bridge; mod fusion; mod handle; mod ops; @@ -23,6 +24,7 @@ mod tensor; pub(crate) use server::*; pub use backend::*; +pub use bridge::*; pub use fusion::*; pub use handle::*; pub use tensor::*; diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 1a130f6ecd..3b5cecff10 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -18,7 +18,7 @@ use crate::{ unary_float_ops, Fusion, FusionBackend, TensorDescription, }; use burn_tensor::{ - ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor}, + ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor}, Data, Device, Distribution, ElementConversion, Reader, Shape, }; use std::ops::Range; @@ -1283,18 +1283,6 @@ impl FloatTensorOps for Fusion { out } - fn float_to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - tensor.clone() - } - - fn float_from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - tensor - } - fn float_exp(lhs: FloatTensor) -> FloatTensor { unary_float_ops!(ExpOps, B::float_exp); diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index f0e090aaa1..e0784358fd 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,4 +1,4 @@ -use crate::{codegen::Compiler, tensor::JitTensor, Runtime}; +use crate::{codegen::Compiler, tensor::JitTensor, PrecisionBridge, Runtime}; use burn_tensor::backend::Backend; use rand::{rngs::StdRng, SeedableRng}; use std::{marker::PhantomData, sync::Mutex}; @@ -13,9 +13,8 @@ pub struct JitBackend { impl Backend for JitBackend { type Device = R::Device; - type FullPrecisionBackend = JitBackend; - type FullPrecisionElem = f32; + type FullPrecisionBridge = PrecisionBridge; type FloatElem = ::Float; type IntElem = ::Int; diff --git a/crates/burn-jit/src/bridge.rs b/crates/burn-jit/src/bridge.rs new file mode 100644 index 0000000000..4e1cf7e897 --- /dev/null +++ b/crates/burn-jit/src/bridge.rs @@ -0,0 +1,62 @@ +use crate::{kernel, ops::to_device, tensor::JitTensor, JitBackend, Runtime}; +use burn_tensor::{ + backend::BackendBridge, + ops::{FloatElem, FloatTensor}, +}; +use core::marker::PhantomData; + +/// Handle precision conversion for the jit backend. +#[derive(Debug)] +pub struct PrecisionBridge { + _runtime: PhantomData, +} + +impl BackendBridge> for PrecisionBridge +where + ROrigin: Runtime, + RTarget: + Runtime, +{ + type Target = JitBackend; + + fn into_target( + tensor: FloatTensor, D>, + device: Option>, + ) -> FloatTensor { + let tensor = kernel::cast::< + ROrigin, + FloatElem>, + FloatElem>, + D, + >(tensor); + + // The line below does the backend type cast. + let tensor = JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); + + if let Some(device) = &device { + to_device(tensor, device) + } else { + tensor + } + } + + fn from_target( + tensor: FloatTensor, + device: Option>>, + ) -> FloatTensor, D> { + let tensor = kernel::cast::< + RTarget, + FloatElem>, + FloatElem>, + D, + >(tensor); + // The line below does the backend type cast. + let tensor = JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle); + + if let Some(device) = &device { + to_device(tensor, device) + } else { + tensor + } + } +} diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 17de76fba4..929bbc4d9c 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -25,8 +25,11 @@ pub use codegen::dialect::gpu; pub use element::{FloatElement, IntElement, JitElement}; mod backend; -pub use backend::*; +mod bridge; mod runtime; + +pub use backend::*; +pub use bridge::*; pub use runtime::*; #[cfg(any(feature = "fusion", test))] diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index cfb68379af..d76f28998c 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -3,12 +3,9 @@ use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryO use crate::kernel::matmul::{matmul, MatmulStrategy}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{self, reduce}; -use crate::tensor::JitTensor; use crate::Runtime; use crate::{unary, JitBackend}; -use burn_tensor::ops::{ - BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, -}; +use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; use burn_tensor::{ops::FloatTensorOps, Data, Distribution, Shape}; use burn_tensor::{ElementConversion, Reader}; use std::ops::Range; @@ -321,22 +318,6 @@ impl FloatTensorOps for JitBackend { reduce::prod_dim(tensor, dim, Default::default()) } - fn float_to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D> { - let tensor = kernel::cast::, f32, D>(tensor.clone()); - // The line bellow does the backend type cast. - JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle) - } - - fn float_from_full_precision( - tensor: FloatTensor, D>, - ) -> FloatTensor { - let tensor = kernel::cast::, D>(tensor); - // The line bellow does the backend type cast. - JitTensor::new(tensor.client, tensor.device, tensor.shape, tensor.handle) - } - fn float_exp(tensor: FloatTensor) -> FloatTensor { unary!( operation: |scope: &mut Scope, elem: Elem| Operator::Exp(UnaryOperator { diff --git a/crates/burn-jit/src/runtime.rs b/crates/burn-jit/src/runtime.rs index 8e6a6a2be0..bb2da61bf9 100644 --- a/crates/burn-jit/src/runtime.rs +++ b/crates/burn-jit/src/runtime.rs @@ -5,7 +5,7 @@ use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::Compu pub type RuntimeInt = <::Compiler as Compiler>::Int; /// Runtime for the [just-in-time backend](crate::JitBackend). -pub trait Runtime: Send + Sync + 'static { +pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// The compiler used to compile the inner representation into tokens. type Compiler: Compiler; /// The compute server used to run kernels and perform autotuning. diff --git a/crates/burn-ndarray/src/backend.rs b/crates/burn-ndarray/src/backend.rs index 0bfdc766e0..d9c5e493c3 100644 --- a/crates/burn-ndarray/src/backend.rs +++ b/crates/burn-ndarray/src/backend.rs @@ -1,5 +1,5 @@ -use crate::element::FloatNdArrayElement; use crate::NdArrayTensor; +use crate::{element::FloatNdArrayElement, PrecisionBridge}; use alloc::string::String; use burn_common::stub::Mutex; use burn_tensor::backend::Backend; @@ -32,8 +32,7 @@ pub struct NdArray { impl Backend for NdArray { type Device = NdArrayDevice; - type FullPrecisionElem = f32; - type FullPrecisionBackend = NdArray; + type FullPrecisionBridge = PrecisionBridge; type FloatTensorPrimitive = NdArrayTensor; type FloatElem = E; diff --git a/crates/burn-ndarray/src/bridge.rs b/crates/burn-ndarray/src/bridge.rs new file mode 100644 index 0000000000..27a01b140d --- /dev/null +++ b/crates/burn-ndarray/src/bridge.rs @@ -0,0 +1,35 @@ +use crate::{FloatNdArrayElement, NdArray, NdArrayDevice, NdArrayTensor}; +use burn_tensor::{backend::BackendBridge, ops::FloatTensor}; +use core::marker::PhantomData; + +/// Handle precision conversion for the ndarray backend. +#[derive(Debug)] +pub struct PrecisionBridge { + _e: PhantomData, +} + +impl BackendBridge> for PrecisionBridge +where + TElem: FloatNdArrayElement, + OElem: FloatNdArrayElement, +{ + type Target = NdArray; + + fn into_target( + tensor: FloatTensor, D>, + _device: Option, + ) -> FloatTensor { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + + NdArrayTensor::new(array) + } + + fn from_target( + tensor: FloatTensor, + _device: Option, + ) -> FloatTensor, D> { + let array = tensor.array.mapv(|a| a.elem()).into_shared(); + + NdArrayTensor::new(array) + } +} diff --git a/crates/burn-ndarray/src/lib.rs b/crates/burn-ndarray/src/lib.rs index c31781dd65..01ed8abdb3 100644 --- a/crates/burn-ndarray/src/lib.rs +++ b/crates/burn-ndarray/src/lib.rs @@ -14,6 +14,7 @@ extern crate derive_new; extern crate blas_src; mod backend; +mod bridge; mod element; mod ops; mod parallel; @@ -21,6 +22,7 @@ mod sharing; mod tensor; pub use backend::*; +pub use bridge::*; pub use element::FloatNdArrayElement; pub(crate) use sharing::*; pub use tensor::*; diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index ee3dec3e0e..f733d2b5d7 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -332,22 +332,6 @@ impl FloatTensorOps for NdArray { NdArrayMathOps::sum_dim(tensor, dim) } - fn float_to_full_precision( - tensor: &NdArrayTensor, - ) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - - NdArrayTensor::new(array) - } - - fn float_from_full_precision( - tensor: NdArrayTensor, - ) -> NdArrayTensor { - let array = tensor.array.mapv(|a| a.elem()).into_shared(); - - NdArrayTensor::new(array) - } - fn float_argmax( tensor: NdArrayTensor, dim: usize, diff --git a/crates/burn-tch/src/backend.rs b/crates/burn-tch/src/backend.rs index fdcec97a8e..94e183a0a7 100644 --- a/crates/burn-tch/src/backend.rs +++ b/crates/burn-tch/src/backend.rs @@ -1,3 +1,5 @@ +use crate::PrecisionBridge; + use super::element::TchElement; use super::TchTensor; use burn_tensor::backend::Backend; @@ -77,8 +79,7 @@ pub struct LibTorch { impl Backend for LibTorch { type Device = LibTorchDevice; - type FullPrecisionElem = f32; - type FullPrecisionBackend = LibTorch; + type FullPrecisionBridge = PrecisionBridge; type FloatTensorPrimitive = TchTensor; type FloatElem = E; diff --git a/crates/burn-tch/src/bridge.rs b/crates/burn-tch/src/bridge.rs new file mode 100644 index 0000000000..8f58bf517f --- /dev/null +++ b/crates/burn-tch/src/bridge.rs @@ -0,0 +1,49 @@ +use crate::{ops::TchOps, LibTorch, TchElement, TchTensor}; +use burn_tensor::{backend::BackendBridge, ops::FloatTensor, Device}; +use std::marker::PhantomData; + +/// Handle precision conversion for the candle backend. +#[derive(Debug)] +pub struct PrecisionBridge { + _e: PhantomData, +} + +impl BackendBridge> for PrecisionBridge +where + TElem: TchElement, + OElem: TchElement, +{ + type Target = LibTorch; + + fn into_target( + tensor: FloatTensor, D>, + device: Option>, + ) -> FloatTensor { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.to_kind(TElem::KIND); + + let tensor = TchTensor::from_existing(tensor, storage); + + if let Some(device) = &device { + TchOps::to_device(tensor, device) + } else { + tensor + } + } + + fn from_target( + tensor: FloatTensor, + device: Option>>, + ) -> FloatTensor, D> { + let storage = tensor.storage.clone(); + let tensor = tensor.tensor.to_kind(OElem::KIND); + + let tensor = TchTensor::from_existing(tensor, storage); + + if let Some(device) = &device { + TchOps::to_device(tensor, device) + } else { + tensor + } + } +} diff --git a/crates/burn-tch/src/lib.rs b/crates/burn-tch/src/lib.rs index 25d416757b..c858edd2bc 100644 --- a/crates/burn-tch/src/lib.rs +++ b/crates/burn-tch/src/lib.rs @@ -4,11 +4,13 @@ //! Burn Tch Backend mod backend; +mod bridge; mod element; mod ops; mod tensor; pub use backend::*; +pub use bridge::*; pub use element::*; pub use tensor::*; diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index 09974a8310..28c1f782d6 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -334,20 +334,6 @@ impl FloatTensorOps for LibTorch { TchOps::prod_dim(tensor, dim) } - fn float_to_full_precision(tensor: &TchTensor) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.to_kind(tch::Kind::Float); - - TchTensor::from_existing(tensor, storage) - } - - fn float_from_full_precision(tensor: TchTensor) -> TchTensor { - let storage = tensor.storage.clone(); - let tensor = tensor.tensor.to_kind(E::KIND); - - TchTensor::from_existing(tensor, storage) - } - fn float_argmax(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::argmax(tensor, dim) } diff --git a/crates/burn-tensor/src/tensor/activation/base.rs b/crates/burn-tensor/src/tensor/activation/base.rs index 7abfc57c92..ce2205b09a 100644 --- a/crates/burn-tensor/src/tensor/activation/base.rs +++ b/crates/burn-tensor/src/tensor/activation/base.rs @@ -114,7 +114,7 @@ pub fn sigmoid(tensor: Tensor) -> Tensor pub fn log_sigmoid(tensor: Tensor) -> Tensor { match B::FloatElem::precision() { Precision::Half => { - let tensor_full = tensor.to_full_precision(); + let tensor_full = tensor.into_full_precision(); let tensor_tmp = tensor_full.neg().exp().add_scalar(1.0_f32).log().neg(); Tensor::from_full_precision(tensor_tmp) } diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 7e3d0440b2..cc23db05ec 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -3,6 +3,7 @@ use core::convert::TryInto; use crate::check; use crate::check::TensorCheck; +use crate::ops::FullPrecisionBackend; use crate::tensor::backend::Backend; use crate::tensor::stats; use crate::tensor::{Data, Distribution, Shape}; @@ -206,12 +207,12 @@ where } /// Returns a tensor with full precision based on the selected backend. - pub fn to_full_precision(&self) -> Tensor { - Tensor::new(B::float_to_full_precision(&self.primitive)) + pub fn into_full_precision(self) -> Tensor, D> { + Tensor::new(B::float_into_full_precision(self.primitive)) } /// Returns a tensor on the selected backend from a full precision tensor. - pub fn from_full_precision(tensor: Tensor) -> Self { + pub fn from_full_precision(tensor: Tensor, D>) -> Self { Self::new(B::float_from_full_precision(tensor.primitive)) } diff --git a/crates/burn-tensor/src/tensor/backend/base.rs b/crates/burn-tensor/src/tensor/backend/base.rs index d3c0b38f32..bf01ac8862 100644 --- a/crates/burn-tensor/src/tensor/backend/base.rs +++ b/crates/burn-tensor/src/tensor/backend/base.rs @@ -3,6 +3,8 @@ use alloc::string::String; use crate::ops::*; use crate::tensor::Element; +use super::BackendBridge; + /// This trait defines all types and functions needed for a backend to be used with burn. /// /// ## Design @@ -66,10 +68,8 @@ pub trait Backend: /// Device type. type Device: Clone + Default + PartialEq + core::fmt::Debug + Send + Sync; - /// Pointer to another backend that have a full precision float element type - type FullPrecisionBackend: Backend; - /// Full precision float element type. - type FullPrecisionElem: Element; + /// A bridge that can cast tensors to full precision. + type FullPrecisionBridge: BackendBridge + 'static; /// Tensor primitive to be used for all float operations. type FloatTensorPrimitive: Clone + Send + Sync + 'static + core::fmt::Debug; @@ -106,7 +106,6 @@ pub trait AutodiffBackend: Backend { Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem, - FullPrecisionElem = Self::FullPrecisionElem, >; /// Gradients type. diff --git a/crates/burn-tensor/src/tensor/backend/bridge.rs b/crates/burn-tensor/src/tensor/backend/bridge.rs new file mode 100644 index 0000000000..61fccf76d5 --- /dev/null +++ b/crates/burn-tensor/src/tensor/backend/bridge.rs @@ -0,0 +1,21 @@ +use crate::{ops::FloatTensor, Device}; + +use super::Backend; + +/// Allows tensors to be transferred between backends efficiently. +pub trait BackendBridge: Send + Sync + core::fmt::Debug { + /// The target backend + type Target: Backend; + + /// Transfer the tensor to the target backend. + fn into_target( + tensor: FloatTensor, + device: Option>, + ) -> FloatTensor; + + /// Transfer the tensor from the target backend. + fn from_target( + tensor: FloatTensor, + device: Option>, + ) -> FloatTensor; +} diff --git a/crates/burn-tensor/src/tensor/backend/mod.rs b/crates/burn-tensor/src/tensor/backend/mod.rs index 303a2b5269..64780085c1 100644 --- a/crates/burn-tensor/src/tensor/backend/mod.rs +++ b/crates/burn-tensor/src/tensor/backend/mod.rs @@ -1,5 +1,8 @@ mod base; +mod bridge; + pub use base::*; +pub use bridge::*; // Not needed for now, useful for different tensor memory layout // pub mod conversion; diff --git a/crates/burn-tensor/src/tensor/ops/activation.rs b/crates/burn-tensor/src/tensor/ops/activation.rs index 5c731bbac6..b4926385bb 100644 --- a/crates/burn-tensor/src/tensor/ops/activation.rs +++ b/crates/burn-tensor/src/tensor/ops/activation.rs @@ -2,7 +2,7 @@ use crate::tensor::ops::tensor::FloatTensorOps; use crate::{backend::Backend, ElementConversion}; use core::f64::consts::SQRT_2; -use super::FloatTensor; +use super::{FloatTensor, FullPrecisionBackend}; /// Activation function operations. /// @@ -126,13 +126,16 @@ pub trait ActivationOps { /// /// The output tensor. fn sigmoid(tensor: FloatTensor) -> FloatTensor { - let tensor_full = B::float_to_full_precision(&tensor); - let tensor_tmp = B::FullPrecisionBackend::float_exp(B::FullPrecisionBackend::float_neg( - B::FullPrecisionBackend::float_log(B::FullPrecisionBackend::float_add_scalar( - B::FullPrecisionBackend::float_exp(B::FullPrecisionBackend::float_neg(tensor_full)), - 1.0.elem(), - )), - )); + let tensor_full = B::float_into_full_precision(tensor); + let tensor_tmp = + FullPrecisionBackend::::float_exp(FullPrecisionBackend::::float_neg( + FullPrecisionBackend::::float_log(FullPrecisionBackend::::float_add_scalar( + FullPrecisionBackend::::float_exp(FullPrecisionBackend::::float_neg( + tensor_full, + )), + 1.0.elem(), + )), + )); B::float_from_full_precision(tensor_tmp) } diff --git a/crates/burn-tensor/src/tensor/ops/alias.rs b/crates/burn-tensor/src/tensor/ops/alias.rs index daa6567837..c3e74fbb57 100644 --- a/crates/burn-tensor/src/tensor/ops/alias.rs +++ b/crates/burn-tensor/src/tensor/ops/alias.rs @@ -1,4 +1,4 @@ -use crate::backend::Backend; +use crate::backend::{Backend, BackendBridge}; // We provide some type aliases to improve the readability of using associated types without // having to use the disambiguation syntax. @@ -11,7 +11,8 @@ pub type FloatElem = ::FloatElem; /// Integer element type used by backend. pub type IntElem = ::IntElem; /// Full precision float element type used by the backend. -pub type FullPrecisionBackend = ::FullPrecisionBackend; +pub type FullPrecisionBackend = + <::FullPrecisionBridge as BackendBridge>::Target; /// Float tensor primitive type used by the backend. pub type FloatTensor = ::FloatTensorPrimitive; diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 22acb5dfd9..3e8a200ad9 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -1,5 +1,6 @@ use super::cat::cat_with_slice_assign; use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor}; +use crate::backend::BackendBridge; use crate::Tensor; use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float}; use crate::{tensor::api::chunk, tensor::api::narrow}; @@ -909,9 +910,11 @@ pub trait FloatTensorOps { /// # Returns /// /// A tensor with the same values as `tensor` but with full precision. - fn float_to_full_precision( - tensor: &FloatTensor, - ) -> FloatTensor, D>; + fn float_into_full_precision( + tensor: FloatTensor, + ) -> FloatTensor, D> { + >::into_target(tensor, None) + } /// Converts a tensor from full precision. /// @@ -924,7 +927,9 @@ pub trait FloatTensorOps { /// A tensor with the same values as `tensor` but with the precision of the backend. fn float_from_full_precision( tensor: FloatTensor, D>, - ) -> FloatTensor; + ) -> FloatTensor { + >::from_target(tensor, None) + } /// Returns a new tensor with exponential values. /// diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index 2c5cd3dd5d..f1a7418534 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -20,6 +20,7 @@ use wgpu::{AdapterInfo, DeviceDescriptor}; /// /// The [graphics api](GraphicsApi), the [float element](FloatElement) and the /// [int element](IntElement) types are passed as generic. +#[derive(Debug)] pub struct WgpuRuntime { _g: PhantomData, _f: PhantomData,