From f630b3bc7d2d7fac0b972ea001c33daa7c32dd22 Mon Sep 17 00:00:00 2001 From: jiawen wang Date: Thu, 16 Jan 2025 00:45:20 +0800 Subject: [PATCH] Wasserstein Generative Adversarial Network (#2660) * Add files via upload Wasserstein Generative Adversarial Network * Delete examples/wgan/readme * Create README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update cli.rs * Update cli.rs * Update model.rs * Update training.rs * Update main.rs * Update model.rs * Update training.rs * Update training.rs * Update main.rs * Update training.rs * Update model.rs * Update training.rs * Update cli.rs * Update cli.rs * Update generating.rs * Update lib.rs * Update model.rs * Update training.rs * Update main.rs * Update generating.rs * Update model.rs * Update training.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update dataset.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update training.rs * Restructure as workspace example * Add support for single range slice (fixes clippy) * Update example usage + list --------- Co-authored-by: Guillaume Lagrange --- Cargo.lock | 8 + README.md | 2 + burn-book/src/examples.md | 1 + crates/burn-tensor/src/tensor/api/base.rs | 8 + crates/burn-tensor/src/tests/ops/slice.rs | 11 ++ examples/wgan/Cargo.toml | 18 ++ examples/wgan/README.md | 40 ++++ examples/wgan/examples/wgan-generate.rs | 95 ++++++++++ examples/wgan/examples/wgan-mnist.rs | 107 +++++++++++ examples/wgan/src/dataset.rs | 49 +++++ examples/wgan/src/infer.rs | 41 +++++ examples/wgan/src/lib.rs | 4 + examples/wgan/src/model.rs | 157 ++++++++++++++++ examples/wgan/src/training.rs | 211 ++++++++++++++++++++++ 14 files changed, 752 insertions(+) create mode 100644 examples/wgan/Cargo.toml create mode 100644 examples/wgan/README.md create mode 100644 examples/wgan/examples/wgan-generate.rs create mode 100644 examples/wgan/examples/wgan-mnist.rs create mode 100644 examples/wgan/src/dataset.rs create mode 100644 examples/wgan/src/infer.rs create mode 100644 examples/wgan/src/lib.rs create mode 100644 examples/wgan/src/model.rs create mode 100644 examples/wgan/src/training.rs diff --git a/Cargo.lock b/Cargo.lock index c34fb9cd03..1af3585919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8002,6 +8002,14 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "wgan" +version = "0.1.0" +dependencies = [ + "burn", + "image", +] + [[package]] name = "wgpu" version = "23.0.1" diff --git a/README.md b/README.md index a0780dcc16..d0ccbcf411 100644 --- a/README.md +++ b/README.md @@ -567,6 +567,8 @@ Additional examples: sample. - [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the DbPedia dataset. +- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits + based on MNIST. For more practical insights, you can clone the repository and run any of them directly on your computer! diff --git a/burn-book/src/examples.md b/burn-book/src/examples.md index c9703a4389..2b083b6fbe 100644 --- a/burn-book/src/examples.md +++ b/burn-book/src/examples.md @@ -85,6 +85,7 @@ The following additional examples are currently available if you want to check t | [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. | | [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. | | [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. | +| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. | For more information on each example, see their respective `README.md` file. Be sure to check out the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index fabf321d96..4bbc522f49 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -805,6 +805,7 @@ where /// # Arguments /// /// * `ranges` - A type implementing the `RangesArg` trait, which can be: + /// - A single `core::ops::Range` (slice the first dimension) /// - An array of `core::ops::Range` /// - An array of `Option<(i64, i64)>` /// - An array of `(i64, i64)` tuples @@ -2988,6 +2989,13 @@ impl RangesArg for [(i64, i64); D2] { } } +impl RangesArg<1> for core::ops::Range { + fn into_ranges(self, shape: Shape) -> [core::ops::Range; 1] { + let (start, end) = Self::clamp_range(self.start, self.end, shape.dims[0]); + [(start..end)] + } +} + /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. diff --git a/crates/burn-tensor/src/tests/ops/slice.rs b/crates/burn-tensor/src/tests/ops/slice.rs index 61725a506a..1be5b76315 100644 --- a/crates/burn-tensor/src/tests/ops/slice.rs +++ b/crates/burn-tensor/src/tests/ops/slice.rs @@ -47,6 +47,17 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_slice_range_first_dim() { + let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.slice(0..1); + let expected = TensorData::from([[0.0, 1.0, 2.0]]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_partial_sliceing_3d() { let tensor = TestTensor::<3>::from_floats( diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml new file mode 100644 index 0000000000..48d5680f51 --- /dev/null +++ b/examples/wgan/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "wgan" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda-jit = ["burn/cuda-jit"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train", "vision"] } +image = { workspace = true } diff --git a/examples/wgan/README.md b/examples/wgan/README.md new file mode 100644 index 0000000000..d7252ba520 --- /dev/null +++ b/examples/wgan/README.md @@ -0,0 +1,40 @@ +# Wasserstein Generative Adversarial Network + +A burn implementation of examplar WGAN model to generate MNIST digits inspired by +[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html). +Please note that better performance maybe gained by adopting a convolution layer in +[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch). + +## Usage + + +## Training + +```sh +# Cuda backend +cargo run --example wgan-mnist --release --features cuda-jit + +# Wgpu backend +cargo run --example wgan-mnist --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example wgan-mnist --release --features tch-gpu + +# Tch CPU backend +cargo run --example wgan-mnist --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example wgan-mnist --release --features ndarray # f32 - single thread +cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas +cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib +``` + + +### Generating + +To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. + +```sh +cargo run --example wgan-generate --release --features cuda-jit +``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs new file mode 100644 index 0000000000..fa66623ca3 --- /dev/null +++ b/examples/wgan/examples/wgan-generate.rs @@ -0,0 +1,95 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + wgan::infer::generate::("/tmp/wgan-mnist", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::launch; + use burn::backend::{Autodiff, CudaJit}; + + pub fn run() { + launch::>(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); +} diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs new file mode 100644 index 0000000000..d964b07844 --- /dev/null +++ b/examples/wgan/examples/wgan-mnist.rs @@ -0,0 +1,107 @@ +use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend}; + +use wgan::{model::ModelConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + ModelConfig::new(), + RmsPropConfig::new() + .with_alpha(0.99) + .with_momentum(0.0) + .with_epsilon(0.00000008) + .with_weight_decay(None) + .with_centered(false), + ); + + wgan::training::train::("/tmp/wgan-mnist", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::launch; + use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); +} diff --git a/examples/wgan/src/dataset.rs b/examples/wgan/src/dataset.rs new file mode 100644 index 0000000000..46848d4ffb --- /dev/null +++ b/examples/wgan/src/dataset.rs @@ -0,0 +1,49 @@ +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, +}; + +#[derive(Clone, Debug)] +pub struct MnistBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct MnistBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl MnistBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for MnistBatcher { + fn batch(&self, items: Vec) -> MnistBatch { + let images = items + .iter() + .map(|item| TensorData::from(item.image)) + .map(|data| Tensor::::from_data(data.convert::(), &self.device)) + .map(|tensor| tensor.reshape([1, 28, 28])) + // Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example + .map(|tensor| ((tensor / 255) - 0.5) / 0.5) + .collect(); + + let targets = items + .iter() + .map(|item| { + Tensor::::from_data( + TensorData::from([(item.label as i64).elem::()]), + &self.device, + ) + }) + .collect(); + + let images = Tensor::stack(images, 0); + let targets = Tensor::cat(targets, 0); + + MnistBatch { images, targets } + } +} diff --git a/examples/wgan/src/infer.rs b/examples/wgan/src/infer.rs new file mode 100644 index 0000000000..25ca984feb --- /dev/null +++ b/examples/wgan/src/infer.rs @@ -0,0 +1,41 @@ +use crate::training::{save_image, TrainingConfig}; +use burn::{ + prelude::*, + record::{CompactRecorder, Recorder}, + tensor::Distribution, +}; + +pub fn generate(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/generator").into(), &device) + .expect("Trained model should exist; run train first"); + let (mut generator, _) = config.model.init::(&device); + generator = generator.load_record(record); + + // Get a batch of noise + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + let fake_images = generator.forward(noise); // [batch_size, channesl*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + save_image::(fake_images, 5, format!("{artifact_dir}/fake_image.png")).unwrap(); +} diff --git a/examples/wgan/src/lib.rs b/examples/wgan/src/lib.rs new file mode 100644 index 0000000000..021f62278a --- /dev/null +++ b/examples/wgan/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod infer; +pub mod model; +pub mod training; diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs new file mode 100644 index 0000000000..ddb84ff6d3 --- /dev/null +++ b/examples/wgan/src/model.rs @@ -0,0 +1,157 @@ +use burn::{ + module::{Module, ModuleMapper, ParamId}, + nn::BatchNorm, + prelude::*, + tensor::backend::AutodiffBackend, +}; + +/// Layer block of generator model +#[derive(Module, Debug)] +pub struct LayerBlock { + fc: nn::Linear, + bn: nn::BatchNorm, + leakyrelu: nn::LeakyRelu, +} + +impl LayerBlock { + pub fn new(input: usize, output: usize, device: &B::Device) -> Self { + let fc = nn::LinearConfig::new(input, output) + .with_bias(true) + .init(device); + let bn: BatchNorm = nn::BatchNormConfig::new(output) + .with_epsilon(0.8) + .init(device); + let leakyrelu = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + + Self { fc, bn, leakyrelu } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.fc.forward(input); // output: [Batch, x] + let output = self.bn.forward(output); // output: [Batch, x] + + self.leakyrelu.forward(output) // output: [Batch, x] + } +} + +/// Generator model +#[derive(Module, Debug)] +pub struct Generator { + layer1: LayerBlock, + layer2: LayerBlock, + layer3: LayerBlock, + layer4: LayerBlock, + fc: nn::Linear, + tanh: nn::Tanh, +} + +impl Generator { + /// Applies the forward pass on the input tensor by specified order + pub fn forward(&self, noise: Tensor) -> Tensor { + let output = self.layer1.forward(noise); + let output = self.layer2.forward(output); + let output = self.layer3.forward(output); + let output = self.layer4.forward(output); + let output = self.fc.forward(output); + + self.tanh.forward(output) // [batch_size, channels*height*width] + } +} + +/// Discriminator model +#[derive(Module, Debug)] +pub struct Discriminator { + fc1: nn::Linear, + leakyrelu1: nn::LeakyRelu, + fc2: nn::Linear, + leakyrelu2: nn::LeakyRelu, + fc3: nn::Linear, +} + +impl Discriminator { + /// Applies the forward pass on the input tensor by specified order. + /// The input image shape is [batch, channels, height, width] + pub fn forward(&self, images: Tensor) -> Tensor { + // Full connection for each batch + let output = images.flatten(1, 3); // output: [batch, channels*height*width] + let output = self.fc1.forward(output); // output: [batch, 512] + let output = self.leakyrelu1.forward(output); // output: [batch, 512] + let output = self.fc2.forward(output); // output: [batch, 256] + let output = self.leakyrelu2.forward(output); // output: [batch, 256] + + self.fc3.forward(output) // output: [batch, 1] + } +} + +// Use model config to construct a generative and adverserial model +#[derive(Config, Debug)] +pub struct ModelConfig { + /// Dimensionality of the latent space + #[config(default = 100)] + pub latent_dim: usize, + #[config(default = 28)] + pub image_size: usize, + #[config(default = 1)] + pub channels: usize, +} + +impl ModelConfig { + /// "init" is used to create other objects, while "new" is usally used to create itself. + pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { + // Construct the initialized generator + let layer1 = LayerBlock::new(self.latent_dim, 128, device); + let layer2 = LayerBlock::new(128, 256, device); + let layer3 = LayerBlock::new(256, 512, device); + let layer4 = LayerBlock::new(512, 1024, device); + let fc = nn::LinearConfig::new(1024, self.channels * self.image_size * self.image_size) + .with_bias(true) + .init(device); + + let generator = Generator { + layer1, + layer2, + layer3, + layer4, + fc, + tanh: nn::Tanh::new(), + }; + + // Construct the initialized discriminator + let fc1 = nn::LinearConfig::new(self.channels * self.image_size * self.image_size, 512) + .init(device); + let leakyrelu1 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc2 = nn::LinearConfig::new(512, 256).init(device); + let leakyrelu2 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc3 = nn::LinearConfig::new(256, 1).init(device); + + let discriminator = Discriminator { + fc1, + leakyrelu1, + fc2, + leakyrelu2, + fc3, + }; + + (generator, discriminator) + } +} + +/// Clip module mapper to clip all module parameters between a range of values +#[derive(Module, Clone, Debug)] +pub struct Clip { + pub min: f32, + pub max: f32, +} + +impl ModuleMapper for Clip { + fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + let is_require_grad = tensor.is_require_grad(); + + let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); + + if is_require_grad { + tensor = tensor.require_grad(); + } + tensor + } +} diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs new file mode 100644 index 0000000000..db1f594b46 --- /dev/null +++ b/examples/wgan/src/training.rs @@ -0,0 +1,211 @@ +use crate::dataset::MnistBatcher; +use crate::model::{Clip, ModelConfig}; +use burn::optim::{GradientsParams, Optimizer, RmsPropConfig}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, + prelude::*, + record::CompactRecorder, + tensor::{backend::AutodiffBackend, Distribution}, +}; +use image::{buffer::ConvertBuffer, error::ImageResult, Rgb32FImage, RgbImage}; +use std::path::Path; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: RmsPropConfig, + + #[config(default = 200)] + pub num_epochs: usize, + #[config(default = 512)] + pub batch_size: usize, + #[config(default = 8)] + pub num_workers: usize, + #[config(default = 5)] + pub seed: u64, + #[config(default = 5e-5)] + pub lr: f64, + + /// Number of training steps for discriminator before generator is trained per iteration + #[config(default = 5)] + pub num_critic: usize, + /// Lower and upper clip value for disc. weights + #[config(default = 0.01)] + pub clip_value: f32, + /// Save a sample of images every `sample_interval` epochs + #[config(default = 10)] + pub sample_interval: usize, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +/// Save the generated images +// The images format is [B, H, W, C] +pub fn save_image>( + images: Tensor, + nrow: u32, + path: Q, +) -> ImageResult<()> { + let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32; + + let width = images.dims()[2] as u32; + let height = images.dims()[1] as u32; + + // Supports both 1 and 3 channels image + let channels = match images.dims()[3] { + 1 => 3, + 3 => 1, + _ => panic!("Wrong channels number"), + }; + + let mut imgbuf = RgbImage::new(nrow * width, ncol * height); + // Write images into a nrow*ncol grid layout + for row in 0..nrow { + for col in 0..ncol { + let image: Tensor = images + .clone() + .slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize) + .squeeze(0); + // The Rgb32 should be in range 0.0-1.0 + let image = image.into_data().iter::().collect::>(); + // Supports both 1 and 3 channels image + let image = image + .into_iter() + .flat_map(|n| std::iter::repeat(n).take(channels)) + .collect(); + + let image = Rgb32FImage::from_vec(width, height, image).unwrap(); + let image: RgbImage = image.convert(); + for (x, y, pixel) in image.enumerate_pixels() { + imgbuf.put_pixel(row * width + x, col * height + y, *pixel); + } + } + } + imgbuf.save(path) +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Create the Clip module mapper + let mut clip = Clip { + min: -config.clip_value, + max: config.clip_value, + }; + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(config.seed); + + // Create the model and optimizer + let (mut generator, mut discriminator) = config.model.init::(&device); + let mut optimizer_g = config.optimizer.init(); + let mut optimizer_d = config.optimizer.init(); + + // Create the dataset batcher + let batcher_train = MnistBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::train()); + + // Iterate over our training for X epochs + for epoch in 0..config.num_epochs { + // Implement our training loop + for (iteration, batch) in dataloader_train.iter().enumerate() { + // Generate a batch of fake images from noise (standarded normal distribution) + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + // datach: do not update gerenator, only discriminator is updated + let fake_images = generator.forward(noise.clone()).detach(); // [batch_size, channels*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss + let loss_d = -discriminator.forward(batch.images).mean() + + discriminator.forward(fake_images.clone()).mean(); + + // Gradients for the current backward pass + let grads = loss_d.backward(); + // Gradients linked to each parameter of the discriminator + let grads = GradientsParams::from_grads(grads, &discriminator); + // Update the discriminator using the optimizer + discriminator = optimizer_d.step(config.lr, discriminator, grads); + // Clip parameters (weights) of discriminator + discriminator = discriminator.map(&mut clip); + + // Train the generator every num_critic iterations + if iteration % config.num_critic == 0 { + // Generate a batch of images again without detaching + let critic_fake_images = generator.forward(noise.clone()); + let critic_fake_images = critic_fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss. Minimize it to make the fake images as truth + let loss_g = -discriminator.forward(critic_fake_images).mean(); + + let grads = loss_g.backward(); + let grads = GradientsParams::from_grads(grads, &generator); + generator = optimizer_g.step(config.lr, generator, grads); + + // Print the progression + let batch_num = (dataloader_train.num_items() as f32 / config.batch_size as f32) + .ceil() as usize; + println!( + "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]", + epoch + 1, + config.num_epochs, + iteration, + batch_num, + loss_d.into_scalar(), + loss_g.into_scalar() + ); + } + // If at save interval => save the first 25 generated images + if epoch % config.sample_interval == 0 && iteration == 0 { + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() + - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5/255.0 to the images, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + let path = format!("{artifact_dir}/image-{}.png", epoch); + save_image::(fake_images, 5, path).unwrap(); + } + } + } + + // Save the trained models + generator + .save_file(format!("{artifact_dir}/generator"), &CompactRecorder::new()) + .expect("Generator should be saved successfully"); + discriminator + .save_file( + format!("{artifact_dir}/discriminator"), + &CompactRecorder::new(), + ) + .expect("Discriminator should be saved successfully"); +}