diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs index c66ad631b6..e2f8b2425e 100644 --- a/crates/burn-core/src/nn/rnn/gru.rs +++ b/crates/burn-core/src/nn/rnn/gru.rs @@ -20,6 +20,21 @@ pub struct GruConfig { pub d_hidden: usize, /// If a bias should be applied during the Gru transformation. pub bias: bool, + /// If reset gate should be applied after weight multiplication. + /// + /// This configuration option controls how the reset gate is applied to the hidden state. + /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for + /// Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by + /// the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU). + /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine + /// Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication. + /// + /// The differing implementations can give slightly different numerical results and have different efficiencies. For more + /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs). + /// + /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`). + #[config(default = "true")] + pub reset_after: bool, /// Gru initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, @@ -41,6 +56,8 @@ pub struct Gru { pub new_gate: GateController, /// The size of the hidden state. pub d_hidden: usize, + /// If reset gate should be applied after weight multiplication. + pub reset_after: bool, } impl ModuleDisplay for Gru { @@ -58,6 +75,7 @@ impl ModuleDisplay for Gru { .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) + .add("reset_after", &self.reset_after) .optional() } } @@ -94,86 +112,92 @@ impl GruConfig { reset_gate, new_gate, d_hidden: self.d_hidden, + reset_after: self.reset_after, } } } impl Gru { /// Applies the forward pass on the input tensor. This GRU implementation - /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. + /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`. /// - /// # Shapes + /// # Parameters /// - batched_input: `[batch_size, sequence_length, input_size]`. - /// - state: An optional tensor representing an initial cell state with the same dimensions - /// as batched_input. If none is provided, one will be generated. - /// - output: `[batch_size, sequence_length, hidden_size]`. + /// - state: An optional tensor representing an initial cell state with dimensions + /// `[batch_size, hidden_size]`. If none is provided, an empty state will be used. + /// + /// # Returns + /// - output: `[batch_size, sequence_length, hidden_size]` pub fn forward( &self, batched_input: Tensor, - state: Option>, + state: Option>, ) -> Tensor { + let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); - let mut hidden_state = match state { + let mut batched_hidden_state = + Tensor::empty([batch_size, seq_length, self.d_hidden], &device); + + let mut hidden_t = match state { Some(state) => state, - None => Tensor::zeros( - [batch_size, seq_length, self.d_hidden], - &batched_input.device(), - ), + None => Tensor::zeros([batch_size, self.d_hidden], &device), }; - for (t, (input_t, hidden_t)) in batched_input - .iter_dim(1) - .zip(hidden_state.clone().iter_dim(1)) - .enumerate() - { + for (t, input_t) in batched_input.iter_dim(1).enumerate() { let input_t = input_t.squeeze(1); - let hidden_t = hidden_t.squeeze(1); // u(pdate)g(ate) tensors - let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); + let biased_ug_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.update_gate); let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) // r(eset)g(ate) tensors - let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); + let biased_rg_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.reset_gate); let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) - let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate // n(ew)g(ate) tensor - let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); + let biased_ng_input_sum = if self.reset_after { + self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate) + } else { + let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate + self.gate_product(&input_t, &reset_t, None, &self.new_gate) + }; let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) // calculate linear interpolation between previous hidden state and candidate state: // g(t) * (1 - z(t)) + z(t) * hidden_t - let state_vector = candidate_state + hidden_t = candidate_state .clone() .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1) + update_values.clone().mul(hidden_t); - let current_shape = state_vector.shape().dims; - let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; - let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); - hidden_state = hidden_state.slice_assign( + let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1); + + batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], - reshaped_state_vector, + unsqueezed_hidden_state, ); } - hidden_state + batched_hidden_state } /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. + /// bias, if any, and optionally applies reset to hidden state. /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where: /// Wx = weight matrix for the connection to input vector X /// Wh = weight matrix for the connection to hidden state H /// X = input vector /// H = hidden state /// b = bias terms + /// r = reset state fn gate_product( &self, input: &Tensor, hidden: &Tensor, + reset: Option<&Tensor>, gate: &GateController, ) -> Tensor { let input_product = input.clone().matmul(gate.input_transform.weight.val()); @@ -190,13 +214,29 @@ impl Gru { .as_ref() .map(|bias_param| bias_param.val()); - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { + match (input_bias, hidden_bias, reset) { + (Some(input_bias), Some(hidden_bias), Some(r)) => { + input_product + + input_bias.unsqueeze() + + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (Some(input_bias), Some(hidden_bias), None) => { input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, + (Some(input_bias), None, Some(r)) => { + input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product) + } + (Some(input_bias), None, None) => { + input_product + input_bias.unsqueeze() + hidden_product + } + (None, Some(hidden_bias), Some(r)) => { + input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (None, Some(hidden_bias), None) => { + input_product + hidden_product + hidden_bias.unsqueeze() + } + (None, None, Some(r)) => input_product + r.clone().mul(hidden_product), + (None, None, None) => input_product + hidden_product, } } } @@ -207,29 +247,16 @@ mod tests { use crate::tensor::{Distribution, TensorData}; use crate::{module::Param, nn::LinearRecord, TestBackend}; - /// Test forward pass with simple input vector. - /// - /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 - /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 - /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 - /// - /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 - #[test] - fn tests_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = GruConfig::new(1, 1, false); - let device = Default::default(); - let mut gru = config.init::(&device); - - fn create_gate_controller( + fn init_gru(reset_after: bool, device: &B::Device) -> Gru { + fn create_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, - device: &::Device, - ) -> GateController { + device: &B::Device, + ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), @@ -248,6 +275,9 @@ mod tests { ) } + let config = GruConfig::new(1, 1, false).with_reset_after(reset_after); + let mut gru = config.init::(device); + gru.update_gate = create_gate_controller( 0.5, 0.0, @@ -255,7 +285,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.reset_gate = create_gate_controller( 0.6, @@ -264,7 +294,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.new_gate = create_gate_controller( 0.7, @@ -273,18 +303,72 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); + gru + } + + /// Test forward pass with simple input vector. + /// + /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 + /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 + /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 + /// + /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 + #[test] + fn tests_forward_single_input_single_feature() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(false, &device); let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); + let expected = TensorData::from([[0.034]]); + // Reset gate applied to hidden state before the matrix multiplication + let state = gru.forward(input.clone(), None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state after the matrix multiplication + gru.reset_after = true; // override forward behavior + let state = gru.forward(input, None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn tests_forward_seq_len_3() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(true, &device); + + let input = + Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device); + let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]); + + let result = gru.forward(input.clone(), None); + let output = result + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state before the matrix multiplication + gru.reset_after = false; // override forward behavior let state = gru.forward(input, None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze::<2>(0); - let expected = TensorData::from([[0.034]]); output.to_data().assert_approx_eq(&expected, 3); } @@ -308,7 +392,7 @@ mod tests { assert_eq!( alloc::format!("{}", layer), - "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}" + "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}" ); } }