From 9d9ea8b7013313ceb992d9eb4ef9d3e30c804851 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 10:07:31 -0500 Subject: [PATCH] Add hardsigmoid formula and fix WGAN doc + default lr (#2706) --- crates/burn-tensor/src/tensor/activation/base.rs | 2 ++ examples/wgan/src/model.rs | 2 +- examples/wgan/src/training.rs | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/burn-tensor/src/tensor/activation/base.rs b/crates/burn-tensor/src/tensor/activation/base.rs index cc5990d375..15fcc7ab50 100644 --- a/crates/burn-tensor/src/tensor/activation/base.rs +++ b/crates/burn-tensor/src/tensor/activation/base.rs @@ -144,6 +144,8 @@ pub fn sigmoid(tensor: Tensor) -> Tensor } /// Applies the hard sigmoid function +/// +/// `hard_sigmoid(x) = max(0, min(1, alpha * x + beta))` pub fn hard_sigmoid( tensor: Tensor, alpha: f64, diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs index ddb84ff6d3..b9615f5270 100644 --- a/examples/wgan/src/model.rs +++ b/examples/wgan/src/model.rs @@ -96,7 +96,7 @@ pub struct ModelConfig { } impl ModelConfig { - /// "init" is used to create other objects, while "new" is usally used to create itself. + /// Initialize the generator and discriminator models based on the config. pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { // Construct the initialized generator let layer1 = LayerBlock::new(self.latent_dim, 128, device); diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs index db1f594b46..25fbef21c1 100644 --- a/examples/wgan/src/training.rs +++ b/examples/wgan/src/training.rs @@ -23,7 +23,7 @@ pub struct TrainingConfig { pub num_workers: usize, #[config(default = 5)] pub seed: u64, - #[config(default = 5e-5)] + #[config(default = 3e-4)] pub lr: f64, /// Number of training steps for discriminator before generator is trained per iteration