Skip to content

Commit

Permalink
Feat/backend bridge (#1529)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Mar 26, 2024
1 parent 28233d9 commit 40a26bd
Show file tree
Hide file tree
Showing 36 changed files with 483 additions and 220 deletions.
4 changes: 2 additions & 2 deletions crates/burn-autodiff/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
grads::Gradients,
graph::backward::backward,
tensor::AutodiffTensor,
AutodiffBridge,
};
use burn_tensor::backend::{AutodiffBackend, Backend};
use core::marker::PhantomData;
Expand All @@ -20,8 +21,7 @@ pub struct Autodiff<B, C = NoCheckpointing> {
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
type Device = B::Device;

type FullPrecisionElem = B::FullPrecisionElem;
type FullPrecisionBackend = Autodiff<B::FullPrecisionBackend>;
type FullPrecisionBridge = AutodiffBridge<B::FullPrecisionBridge>;

type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
type FloatElem = B::FloatElem;
Expand Down
155 changes: 155 additions & 0 deletions crates/burn-autodiff/src/bridge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use std::marker::PhantomData;

use burn_tensor::{
backend::{Backend, BackendBridge},
ops::FloatTensor,
};

use crate::{
checkpoint::{
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,
strategy::CheckpointStrategy,
},
grads::Gradients,
ops::{unary_different_backend, Backward, Ops},
Autodiff, NodeID,
};

/// Enable autodiff on a [backend bridge](BackendBridge).
#[derive(Debug)]
pub struct AutodiffBridge<Bridge> {
_p: PhantomData<Bridge>,
}

impl<B, C, Bridge> BackendBridge<Autodiff<B, C>> for AutodiffBridge<Bridge>
where
B: Backend,
C: CheckpointStrategy,
Bridge: BackendBridge<B> + 'static,
{
type Target = Autodiff<Bridge::Target, C>;

fn into_target<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Autodiff<B, C>, D>,
_device: Option<burn_tensor::Device<Self::Target>>,
) -> burn_tensor::ops::FloatTensor<Self::Target, D> {
#[derive(Debug)]
struct IntoTarget<B: Backend, Bridge: BackendBridge<B>> {
_backend: PhantomData<B>,
_bridge: PhantomData<Bridge>,
}

#[derive(new, Debug)]
struct RetroIntoTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
tensor_id: NodeID,
_backend: PhantomData<B>,
_bridge: PhantomData<Bridge>,
}

impl<B, Bridge, const D: usize> RetroForward for RetroIntoTarget<B, Bridge, D>
where
B: Backend,
Bridge: BackendBridge<B> + 'static,
{
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor: FloatTensor<B, D> = states.get_state(&self.tensor_id);
let out = Bridge::into_target(tensor, Default::default());
states.save(out_node, out)
}
}

impl<B, Bridge, const D: usize> Backward<Bridge::Target, D, 1> for IntoTarget<B, Bridge>
where
B: Backend,
Bridge: BackendBridge<B> + 'static,
{
type State = ();

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary_different_backend::<B, Bridge::Target, D, D, _>(
ops.parents,
ops.node,
grads,
|grad| Bridge::from_target(grad, None),
);
}
}

IntoTarget::<B, Bridge> {
_backend: PhantomData,
_bridge: PhantomData,
}
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
.parents([&tensor])
.stateless(Bridge::into_target(tensor.primitive, None))
}

fn from_target<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self::Target, D>,
_device: Option<burn_tensor::Device<Autodiff<B, C>>>,
) -> burn_tensor::ops::FloatTensor<Autodiff<B, C>, D> {
#[derive(Debug)]
struct FromTarget<B: Backend, Bridge: BackendBridge<B>> {
_backend: PhantomData<B>,
_bridge: PhantomData<Bridge>,
}

#[derive(new, Debug)]
struct RetroFromTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
tensor_id: NodeID,
_backend: PhantomData<B>,
_bridge: PhantomData<Bridge>,
}

impl<B, Bridge, const D: usize> RetroForward for RetroFromTarget<B, Bridge, D>
where
B: Backend,
Bridge: BackendBridge<B> + 'static,
{
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor: FloatTensor<Bridge::Target, D> = states.get_state(&self.tensor_id);
let out = Bridge::from_target(tensor, None);
states.save(out_node, out)
}
}

impl<B, Bridge, const D: usize> Backward<B, D, 1> for FromTarget<B, Bridge>
where
B: Backend,
Bridge: BackendBridge<B> + 'static,
{
type State = ();

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary_different_backend::<Bridge::Target, B, D, D, _>(
ops.parents,
ops.node,
grads,
|grad| Bridge::into_target(grad, None),
);
}
}

FromTarget::<B, Bridge> {
_backend: PhantomData,
_bridge: PhantomData,
}
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
.parents([&tensor])
.stateless(Bridge::from_target(tensor.primitive, None))
}
}
3 changes: 3 additions & 0 deletions crates/burn-autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ pub(crate) mod tensor;
pub(crate) mod utils;

mod backend;
mod bridge;

pub use backend::*;
pub use bridge::*;

#[cfg(feature = "export_tests")]
mod tests;
105 changes: 2 additions & 103 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
},
grads::Gradients,
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind},
ops::{binary, broadcast_shape, unary, Backward, Ops, OpsKind},
retro_binary, retro_unary, retro_unary_scalar,
tensor::AutodiffTensor,
utils::duplicate,
Expand All @@ -16,7 +16,7 @@ use crate::{

use burn_tensor::{
backend::Backend,
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
Data, Device, ElementConversion, Reader, Shape, Tensor,
};

Expand Down Expand Up @@ -1621,107 +1621,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_to_full_precision<const D: usize>(
tensor: &FloatTensor<Self, D>,
) -> FloatTensor<FullPrecisionBackend<Self>, D> {
#[derive(Debug)]
struct ToFullPrecision<B: Backend> {
phantom: PhantomData<B>,
}

#[derive(new, Debug)]
struct RetroToFullPrecision<B: Backend, const D: usize> {
tensor_id: NodeID,
_backend: PhantomData<B>,
}

impl<B: Backend, const D: usize> RetroForward for RetroToFullPrecision<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
let out = B::float_to_full_precision(&tensor);
states.save(out_node, out)
}
}

impl<B: Backend, const D: usize> Backward<B::FullPrecisionBackend, D, 1> for ToFullPrecision<B> {
type State = ();

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary_different_backend::<B, B::FullPrecisionBackend, D, D, _>(
ops.parents,
ops.node,
grads,
|grad| B::float_from_full_precision(grad),
);
}
}

let ops = ToFullPrecision::<B> {
phantom: PhantomData,
};
ops.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroToFullPrecision::<B, D>::new(tensor.node.id.clone()))
.parents([tensor])
.stateless(B::float_to_full_precision(&tensor.primitive))
}

fn float_from_full_precision<const D: usize>(
tensor: FloatTensor<FullPrecisionBackend<Self>, D>,
) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct FromFullPrecision<B: Backend> {
phantom: PhantomData<B>,
}

#[derive(new, Debug)]
struct RetroFromFullPrecision<B: Backend, const D: usize> {
tensor_id: NodeID,
_backend: PhantomData<B>,
}

impl<B: Backend, const D: usize> RetroForward for RetroFromFullPrecision<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor = states.get_state::<<<B as Backend>::FullPrecisionBackend as Backend>::FloatTensorPrimitive<D>>(&self.tensor_id);
let out = B::float_from_full_precision(tensor);
states.save(out_node, out)
}
}

impl<B: Backend, const D: usize> Backward<B, D, 1> for FromFullPrecision<B::FullPrecisionBackend> {
type State = ();

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
unary_different_backend::<B::FullPrecisionBackend, B, D, D, _>(
ops.parents,
ops.node,
grads,
|grad| B::float_to_full_precision(&grad),
);
}
}

let ops = FromFullPrecision::<B::FullPrecisionBackend> {
phantom: PhantomData,
};

ops.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroFromFullPrecision::<B, D>::new(tensor.node.id.clone()))
.parents([&tensor])
.stateless(B::float_from_full_precision(tensor.primitive))
}

fn float_argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<B, D> {
B::float_argmax(tensor.primitive, dim)
}
Expand Down
29 changes: 29 additions & 0 deletions crates/burn-autodiff/src/tests/bridge.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#[burn_tensor_testgen::testgen(bridge)]
mod tests {
use super::*;
use burn_tensor::{backend::Backend, module::embedding, Data, Distribution, Int, Tensor};

#[test]
fn test_full_precision() {
let device = Default::default();
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
.require_grad();

let x3 = x1.clone().into_full_precision();
let x4 = x2.clone().into_full_precision();

let x5 = x3.matmul(x4);
let x6 = Tensor::<TestAutodiffBackend, 2>::from_full_precision(x5);
let x7 = x6 * x1.clone() / x2.clone();

let mut grads = x7.backward();

let x1_grad = x1.grad(&mut grads);
let x2_grad = x2.grad(&mut grads);

assert!(x1_grad.is_some());
assert!(x2_grad.is_some());
}
}
2 changes: 2 additions & 0 deletions crates/burn-autodiff/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod aggregation;
mod avgpool1d;
mod avgpool2d;
mod backward;
mod bridge;
mod broadcast;
mod cat;
mod checkpoint;
Expand Down Expand Up @@ -64,6 +65,7 @@ macro_rules! testgen_all {
// Behavior
burn_autodiff::testgen_ad_broadcast!();
burn_autodiff::testgen_gradients!();
burn_autodiff::testgen_bridge!();
burn_autodiff::testgen_checkpoint!();

// Activation
Expand Down
5 changes: 2 additions & 3 deletions crates/burn-candle/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use candle_core::DeviceLocation;

use crate::{
element::{CandleElement, FloatCandleElement, IntCandleElement},
CandleTensor,
CandleTensor, PrecisionBridge,
};

/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
Expand Down Expand Up @@ -69,8 +69,7 @@ impl Default for CandleDevice {
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
type Device = CandleDevice;

type FullPrecisionBackend = Candle<Self::FullPrecisionElem, Self::IntElem>;
type FullPrecisionElem = f32;
type FullPrecisionBridge = PrecisionBridge<f32>;

type FloatTensorPrimitive<const D: usize> = CandleTensor<Self::FloatElem, D>;
type FloatElem = F;
Expand Down
Loading

0 comments on commit 40a26bd

Please sign in to comment.