Skip to content

Commit

Permalink
Add hardsigmoid formula and fix WGAN doc + default lr (#2706)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 16, 2025
1 parent 93f8bad commit 9d9ea8b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 2 additions & 0 deletions crates/burn-tensor/src/tensor/activation/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ pub fn sigmoid<const D: usize, B: Backend>(tensor: Tensor<B, D>) -> Tensor<B, D>
}

/// Applies the hard sigmoid function
///
/// `hard_sigmoid(x) = max(0, min(1, alpha * x + beta))`
pub fn hard_sigmoid<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
alpha: f64,
Expand Down
2 changes: 1 addition & 1 deletion examples/wgan/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pub struct ModelConfig {
}

impl ModelConfig {
/// "init" is used to create other objects, while "new" is usally used to create itself.
/// Initialize the generator and discriminator models based on the config.
pub fn init<B: Backend>(&self, device: &B::Device) -> (Generator<B>, Discriminator<B>) {
// Construct the initialized generator
let layer1 = LayerBlock::new(self.latent_dim, 128, device);
Expand Down
2 changes: 1 addition & 1 deletion examples/wgan/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct TrainingConfig {
pub num_workers: usize,
#[config(default = 5)]
pub seed: u64,
#[config(default = 5e-5)]
#[config(default = 3e-4)]
pub lr: f64,

/// Number of training steps for discriminator before generator is trained per iteration
Expand Down

0 comments on commit 9d9ea8b

Please sign in to comment.