-
Notifications
You must be signed in to change notification settings - Fork 478
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
28233d9
commit 40a26bd
Showing
36 changed files
with
483 additions
and
220 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
use std::marker::PhantomData; | ||
|
||
use burn_tensor::{ | ||
backend::{Backend, BackendBridge}, | ||
ops::FloatTensor, | ||
}; | ||
|
||
use crate::{ | ||
checkpoint::{ | ||
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates, | ||
strategy::CheckpointStrategy, | ||
}, | ||
grads::Gradients, | ||
ops::{unary_different_backend, Backward, Ops}, | ||
Autodiff, NodeID, | ||
}; | ||
|
||
/// Enable autodiff on a [backend bridge](BackendBridge). | ||
#[derive(Debug)] | ||
pub struct AutodiffBridge<Bridge> { | ||
_p: PhantomData<Bridge>, | ||
} | ||
|
||
impl<B, C, Bridge> BackendBridge<Autodiff<B, C>> for AutodiffBridge<Bridge> | ||
where | ||
B: Backend, | ||
C: CheckpointStrategy, | ||
Bridge: BackendBridge<B> + 'static, | ||
{ | ||
type Target = Autodiff<Bridge::Target, C>; | ||
|
||
fn into_target<const D: usize>( | ||
tensor: burn_tensor::ops::FloatTensor<Autodiff<B, C>, D>, | ||
_device: Option<burn_tensor::Device<Self::Target>>, | ||
) -> burn_tensor::ops::FloatTensor<Self::Target, D> { | ||
#[derive(Debug)] | ||
struct IntoTarget<B: Backend, Bridge: BackendBridge<B>> { | ||
_backend: PhantomData<B>, | ||
_bridge: PhantomData<Bridge>, | ||
} | ||
|
||
#[derive(new, Debug)] | ||
struct RetroIntoTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> { | ||
tensor_id: NodeID, | ||
_backend: PhantomData<B>, | ||
_bridge: PhantomData<Bridge>, | ||
} | ||
|
||
impl<B, Bridge, const D: usize> RetroForward for RetroIntoTarget<B, Bridge, D> | ||
where | ||
B: Backend, | ||
Bridge: BackendBridge<B> + 'static, | ||
{ | ||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { | ||
let tensor: FloatTensor<B, D> = states.get_state(&self.tensor_id); | ||
let out = Bridge::into_target(tensor, Default::default()); | ||
states.save(out_node, out) | ||
} | ||
} | ||
|
||
impl<B, Bridge, const D: usize> Backward<Bridge::Target, D, 1> for IntoTarget<B, Bridge> | ||
where | ||
B: Backend, | ||
Bridge: BackendBridge<B> + 'static, | ||
{ | ||
type State = (); | ||
|
||
fn backward( | ||
self, | ||
ops: Ops<Self::State, 1>, | ||
grads: &mut Gradients, | ||
_checkpointer: &mut Checkpointer, | ||
) { | ||
unary_different_backend::<B, Bridge::Target, D, D, _>( | ||
ops.parents, | ||
ops.node, | ||
grads, | ||
|grad| Bridge::from_target(grad, None), | ||
); | ||
} | ||
} | ||
|
||
IntoTarget::<B, Bridge> { | ||
_backend: PhantomData, | ||
_bridge: PhantomData, | ||
} | ||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()]) | ||
.memory_bound() | ||
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id.clone())) | ||
.parents([&tensor]) | ||
.stateless(Bridge::into_target(tensor.primitive, None)) | ||
} | ||
|
||
fn from_target<const D: usize>( | ||
tensor: burn_tensor::ops::FloatTensor<Self::Target, D>, | ||
_device: Option<burn_tensor::Device<Autodiff<B, C>>>, | ||
) -> burn_tensor::ops::FloatTensor<Autodiff<B, C>, D> { | ||
#[derive(Debug)] | ||
struct FromTarget<B: Backend, Bridge: BackendBridge<B>> { | ||
_backend: PhantomData<B>, | ||
_bridge: PhantomData<Bridge>, | ||
} | ||
|
||
#[derive(new, Debug)] | ||
struct RetroFromTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> { | ||
tensor_id: NodeID, | ||
_backend: PhantomData<B>, | ||
_bridge: PhantomData<Bridge>, | ||
} | ||
|
||
impl<B, Bridge, const D: usize> RetroForward for RetroFromTarget<B, Bridge, D> | ||
where | ||
B: Backend, | ||
Bridge: BackendBridge<B> + 'static, | ||
{ | ||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { | ||
let tensor: FloatTensor<Bridge::Target, D> = states.get_state(&self.tensor_id); | ||
let out = Bridge::from_target(tensor, None); | ||
states.save(out_node, out) | ||
} | ||
} | ||
|
||
impl<B, Bridge, const D: usize> Backward<B, D, 1> for FromTarget<B, Bridge> | ||
where | ||
B: Backend, | ||
Bridge: BackendBridge<B> + 'static, | ||
{ | ||
type State = (); | ||
|
||
fn backward( | ||
self, | ||
ops: Ops<Self::State, 1>, | ||
grads: &mut Gradients, | ||
_checkpointer: &mut Checkpointer, | ||
) { | ||
unary_different_backend::<Bridge::Target, B, D, D, _>( | ||
ops.parents, | ||
ops.node, | ||
grads, | ||
|grad| Bridge::into_target(grad, None), | ||
); | ||
} | ||
} | ||
|
||
FromTarget::<B, Bridge> { | ||
_backend: PhantomData, | ||
_bridge: PhantomData, | ||
} | ||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()]) | ||
.memory_bound() | ||
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id.clone())) | ||
.parents([&tensor]) | ||
.stateless(Bridge::from_target(tensor.primitive, None)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#[burn_tensor_testgen::testgen(bridge)] | ||
mod tests { | ||
use super::*; | ||
use burn_tensor::{backend::Backend, module::embedding, Data, Distribution, Int, Tensor}; | ||
|
||
#[test] | ||
fn test_full_precision() { | ||
let device = Default::default(); | ||
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device) | ||
.require_grad(); | ||
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device) | ||
.require_grad(); | ||
|
||
let x3 = x1.clone().into_full_precision(); | ||
let x4 = x2.clone().into_full_precision(); | ||
|
||
let x5 = x3.matmul(x4); | ||
let x6 = Tensor::<TestAutodiffBackend, 2>::from_full_precision(x5); | ||
let x7 = x6 * x1.clone() / x2.clone(); | ||
|
||
let mut grads = x7.backward(); | ||
|
||
let x1_grad = x1.grad(&mut grads); | ||
let x2_grad = x2.grad(&mut grads); | ||
|
||
assert!(x1_grad.is_some()); | ||
assert!(x2_grad.is_some()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.