Skip to content

Commit

Permalink
Wasserstein Generative Adversarial Network (#2660)
Browse files Browse the repository at this point in the history
* Add files via upload

Wasserstein Generative Adversarial Network

* Delete examples/wgan/readme

* Create README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update cli.rs

* Update cli.rs

* Update model.rs

* Update training.rs

* Update main.rs

* Update model.rs

* Update training.rs

* Update training.rs

* Update main.rs

* Update training.rs

* Update model.rs

* Update training.rs

* Update cli.rs

* Update cli.rs

* Update generating.rs

* Update lib.rs

* Update model.rs

* Update training.rs

* Update main.rs

* Update generating.rs

* Update model.rs

* Update training.rs

* Update generating.rs

* Update model.rs

* Update training.rs

* Update training.rs

* Update dataset.rs

* Update generating.rs

* Update model.rs

* Update training.rs

* Update training.rs

* Update training.rs

* Restructure as workspace example

* Add support for single range slice (fixes clippy)

* Update example usage + list

---------

Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
wangjiawen2013 and laggui authored Jan 15, 2025
1 parent ad81344 commit f630b3b
Show file tree
Hide file tree
Showing 14 changed files with 752 additions and 0 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,8 @@ Additional examples:
sample.
- [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the
DbPedia dataset.
- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits
based on MNIST.

For more practical insights, you can clone the repository and run any of them directly on your
computer!
Expand Down
1 change: 1 addition & 0 deletions burn-book/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ The following additional examples are currently available if you want to check t
| [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. |
| [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. |
| [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. |
| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. |

For more information on each example, see their respective `README.md` file. Be sure to check out
the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date
Expand Down
8 changes: 8 additions & 0 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,7 @@ where
/// # Arguments
///
/// * `ranges` - A type implementing the `RangesArg` trait, which can be:
/// - A single `core::ops::Range<usize>` (slice the first dimension)
/// - An array of `core::ops::Range<usize>`
/// - An array of `Option<(i64, i64)>`
/// - An array of `(i64, i64)` tuples
Expand Down Expand Up @@ -2988,6 +2989,13 @@ impl<const D2: usize> RangesArg<D2> for [(i64, i64); D2] {
}
}

impl RangesArg<1> for core::ops::Range<usize> {
fn into_ranges(self, shape: Shape) -> [core::ops::Range<usize>; 1] {
let (start, end) = Self::clamp_range(self.start, self.end, shape.dims[0]);
[(start..end)]
}
}

/// Trait used for reshape arguments.
pub trait ReshapeArgs<const D2: usize> {
/// Converts to a shape.
Expand Down
11 changes: 11 additions & 0 deletions crates/burn-tensor/src/tests/ops/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@ mod tests {
output.into_data().assert_eq(&expected, false);
}

#[test]
fn should_support_slice_range_first_dim() {
let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = TestTensor::<2>::from_data(data, &Default::default());

let output = tensor.slice(0..1);
let expected = TensorData::from([[0.0, 1.0, 2.0]]);

output.into_data().assert_eq(&expected, false);
}

#[test]
fn should_support_partial_sliceing_3d() {
let tensor = TestTensor::<3>::from_floats(
Expand Down
18 changes: 18 additions & 0 deletions examples/wgan/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "wgan"
version = "0.1.0"
edition = "2021"

[features]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda-jit = ["burn/cuda-jit"]

[dependencies]
burn = { path = "../../crates/burn", features=["train", "vision"] }
image = { workspace = true }
40 changes: 40 additions & 0 deletions examples/wgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Wasserstein Generative Adversarial Network

A burn implementation of examplar WGAN model to generate MNIST digits inspired by
[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html).
Please note that better performance maybe gained by adopting a convolution layer in
[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch).

## Usage


## Training

```sh
# Cuda backend
cargo run --example wgan-mnist --release --features cuda-jit

# Wgpu backend
cargo run --example wgan-mnist --release --features wgpu

# Tch GPU backend
export TORCH_CUDA_VERSION=cu121 # Set the cuda version
cargo run --example wgan-mnist --release --features tch-gpu

# Tch CPU backend
cargo run --example wgan-mnist --release --features tch-cpu

# NdArray backend (CPU)
cargo run --example wgan-mnist --release --features ndarray # f32 - single thread
cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas
cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib
```


### Generating

To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend.

```sh
cargo run --example wgan-generate --release --features cuda-jit
```
95 changes: 95 additions & 0 deletions examples/wgan/examples/wgan-generate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
use burn::tensor::backend::Backend;

pub fn launch<B: Backend>(device: B::Device) {
wgan::infer::generate::<B>("/tmp/wgan-mnist", device);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::{wgpu::Wgpu, Autodiff};

pub fn run() {
launch::<Autodiff<Wgpu>>(Default::default());
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
use crate::launch;
use burn::backend::{Autodiff, CudaJit};

pub fn run() {
launch::<Autodiff<CudaJit>>(Default::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
}
107 changes: 107 additions & 0 deletions examples/wgan/examples/wgan-mnist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend};

use wgan::{model::ModelConfig, training::TrainingConfig};

pub fn launch<B: AutodiffBackend>(device: B::Device) {
let config = TrainingConfig::new(
ModelConfig::new(),
RmsPropConfig::new()
.with_alpha(0.99)
.with_momentum(0.0)
.with_epsilon(0.00000008)
.with_weight_decay(None)
.with_centered(false),
);

wgan::training::train::<B>("/tmp/wgan-mnist", config, device);
}

#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<NdArray>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};

use crate::launch;

pub fn run() {
launch::<Autodiff<LibTorch>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use crate::launch;
use burn::backend::{wgpu::Wgpu, Autodiff};

pub fn run() {
launch::<Autodiff<Wgpu>>(Default::default());
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
use crate::launch;
use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit};

pub fn run() {
launch::<Autodiff<CudaJit>>(CudaDevice::default());
}
}

fn main() {
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
ndarray::run();
#[cfg(feature = "tch-gpu")]
tch_gpu::run();
#[cfg(feature = "tch-cpu")]
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
}
49 changes: 49 additions & 0 deletions examples/wgan/src/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MnistItem},
prelude::*,
};

#[derive(Clone, Debug)]
pub struct MnistBatcher<B: Backend> {
device: B::Device,
}

#[derive(Clone, Debug)]
pub struct MnistBatch<B: Backend> {
pub images: Tensor<B, 4>,
pub targets: Tensor<B, 1, Int>,
}

impl<B: Backend> MnistBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}

impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {
fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {
let images = items
.iter()
.map(|item| TensorData::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert::<B::FloatElem>(), &self.device))
.map(|tensor| tensor.reshape([1, 28, 28]))
// Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example
.map(|tensor| ((tensor / 255) - 0.5) / 0.5)
.collect();

let targets = items
.iter()
.map(|item| {
Tensor::<B, 1, Int>::from_data(
TensorData::from([(item.label as i64).elem::<B::IntElem>()]),
&self.device,
)
})
.collect();

let images = Tensor::stack(images, 0);
let targets = Tensor::cat(targets, 0);

MnistBatch { images, targets }
}
}
Loading

0 comments on commit f630b3b

Please sign in to comment.