From 9228d9156de05dc32c08fdf7a6fa689c1357647d Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Mon, 13 Jan 2025 17:04:53 -0500 Subject: [PATCH 01/17] Merge reduce (#2673) --- Cargo.lock | 480 +++++++++--------- Cargo.toml | 4 +- crates/burn-jit/Cargo.toml | 2 +- crates/burn-jit/src/kernel/reduce/base.rs | 142 +++--- crates/burn-jit/src/kernel/reduce/mod.rs | 7 - .../src/kernel/reduce/naive/argmax.rs | 36 -- .../src/kernel/reduce/naive/argmin.rs | 36 -- .../burn-jit/src/kernel/reduce/naive/base.rs | 25 - .../src/kernel/reduce/naive/kernel.rs | 71 --- .../src/kernel/reduce/naive/mean_dim.rs | 27 - .../burn-jit/src/kernel/reduce/naive/mod.rs | 7 - .../src/kernel/reduce/naive/prod_dim.rs | 26 - .../src/kernel/reduce/naive/sum_dim.rs | 26 - crates/burn-jit/src/kernel/reduce/prod.rs | 15 - .../src/kernel/reduce/shared/argmax.rs | 63 --- .../src/kernel/reduce/shared/argmin.rs | 64 --- .../burn-jit/src/kernel/reduce/shared/base.rs | 33 -- .../src/kernel/reduce/shared/kernel.rs | 117 ----- .../src/kernel/reduce/shared/mean_dim.rs | 44 -- .../burn-jit/src/kernel/reduce/shared/mod.rs | 7 - .../src/kernel/reduce/shared/prod_dim.rs | 43 -- .../src/kernel/reduce/shared/sum_dim.rs | 43 -- .../src/kernel/reduce/subcube/argmax.rs | 54 -- .../src/kernel/reduce/subcube/argmin.rs | 54 -- .../src/kernel/reduce/subcube/base.rs | 15 - .../src/kernel/reduce/subcube/kernel.rs | 134 ----- .../src/kernel/reduce/subcube/mean_dim.rs | 45 -- .../burn-jit/src/kernel/reduce/subcube/mod.rs | 7 - .../src/kernel/reduce/subcube/prod_dim.rs | 44 -- .../src/kernel/reduce/subcube/sum_dim.rs | 44 -- crates/burn-jit/src/kernel/reduce/sum.rs | 15 - crates/burn-jit/src/kernel/reduce/tune.rs | 222 ++++++++ .../burn-jit/src/kernel/reduce/tune/base.rs | 94 ---- crates/burn-jit/src/kernel/reduce/tune/key.rs | 39 -- crates/burn-jit/src/kernel/reduce/tune/mod.rs | 7 - crates/burn-jit/src/ops/float_ops.rs | 14 +- crates/burn-jit/src/ops/int_ops.rs | 16 +- crates/burn-jit/src/tests/mod.rs | 3 + crates/burn-jit/src/tests/reduce.rs | 128 +++++ crates/burn-jit/src/tune_key.rs | 4 +- crates/burn-tensor/src/tensor/shape.rs | 7 + 41 files changed, 708 insertions(+), 1556 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/argmax.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/argmin.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/kernel.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/mod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/prod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/argmax.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/argmin.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/kernel.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/mod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/argmax.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/argmin.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/kernel.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/mod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/sum.rs create mode 100644 crates/burn-jit/src/kernel/reduce/tune.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/tune/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/tune/key.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/tune/mod.rs create mode 100644 crates/burn-jit/src/tests/reduce.rs diff --git a/Cargo.lock b/Cargo.lock index b3c03aa56d..feff4ed96a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,19 +139,20 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "arbitrary" @@ -188,7 +189,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -269,18 +270,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -510,9 +511,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" dependencies = [ "serde", ] @@ -594,9 +595,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.1" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "786a307d683a5bf92e6fd5fd69a7eb613751668d1d8d67d802846dfe367c62c8" +checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" dependencies = [ "memchr", "serde", @@ -753,7 +754,7 @@ dependencies = [ "derive-new 0.7.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -804,7 +805,7 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.95", + "syn 2.0.96", "thiserror 2.0.11", "tracing-core", "tracing-subscriber", @@ -987,13 +988,13 @@ dependencies = [ [[package]] name = "bytemuck_derive" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1043,9 +1044,9 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1e306c8a4276ba57ce9fac76d823cc8c8a7fca14bf222ac20ad8b12c4273152" +checksum = "855dfedff437d2681d68e1f34ae559d88b0dd84aa5a6b63f2c8e75ebdd875bbf" dependencies = [ "accelerate-src", "byteorder", @@ -1073,18 +1074,18 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbd8ea6588f3c6286ea89a52dad3365f0536fd0b71e729fa998cc2347f1df3b6" +checksum = "53343628fa470b7075c28c589b98735b4220b464e37ddbb8e117040e199f4787" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc6621c7e2202f4f129bcc3185c2c6d4fa2fc6b8f3f2b07eaf7c06042910c83" +checksum = "50fa64274a009a5d95c542b10bf3a4ea809bd394654c6ae99233bcc35b3a33ef" dependencies = [ "metal 0.27.0", "once_cell", @@ -1118,9 +1119,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.4" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "jobserver", "libc", @@ -1260,7 +1261,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1333,9 +1334,9 @@ dependencies = [ [[package]] name = "compact_str" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" dependencies = [ "castaway", "cfg-if", @@ -1522,7 +1523,7 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "crossterm_winapi", "mio", "parking_lot 0.12.3", @@ -1581,12 +1582,13 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-core", "cubecl-cuda", "cubecl-hip", "cubecl-linalg", + "cubecl-reduce", "cubecl-runtime", "cubecl-wgpu", "half", @@ -1595,7 +1597,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1612,7 +1614,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1631,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1645,7 +1647,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1661,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1677,9 +1679,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.0" +version = "6.3.1000" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9974218b3ff1f1e7b2f11ce254fd90b3ebcc2af6b4d084f7f6a0c351fb16112c" +checksum = "d4d987c1720eab39c72c515377a8001f683a4c4d99232a29fc0de389d9a8ce4f" dependencies = [ "libc", ] @@ -1687,7 +1689,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-core", @@ -1699,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-common", "darling", @@ -1708,13 +1710,13 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-common", "cubecl-core", @@ -1727,10 +1729,20 @@ dependencies = [ "type-map", ] +[[package]] +name = "cubecl-reduce" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +dependencies = [ + "cubecl-core", + "cubecl-runtime", + "num-traits", +] + [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "async-channel", "async-lock", @@ -1751,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-common", "cubecl-core", @@ -1765,7 +1777,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "ash", "async-channel", @@ -1882,7 +1894,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1893,7 +1905,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1939,7 +1951,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1950,7 +1962,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1961,7 +1973,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1982,7 +1994,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1992,7 +2004,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2003,7 +2015,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2023,7 +2035,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "unicode-xid", ] @@ -2079,7 +2091,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2162,7 +2174,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2174,14 +2186,14 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] name = "env_filter" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -2189,9 +2201,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" dependencies = [ "anstream", "anstyle", @@ -2236,9 +2248,9 @@ checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2272,9 +2284,9 @@ dependencies = [ [[package]] name = "fake" -version = "3.0.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "661cb0601b5f4050d1e65452c5b0ea555c0b3e88fb5ed7855906adc6c42523ef" +checksum = "aef603df4ba9adbca6a332db7da6f614f21eafefbaf8e087844e452fdec152d0" dependencies = [ "deunicode", "rand", @@ -2373,9 +2385,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "foreign-types" @@ -2404,7 +2416,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2488,9 +2500,9 @@ checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cef40d21ae2c515b51041df9ed313ed21e572df340ea58a922a0aefe7e8891a1" +checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" dependencies = [ "fastrand", "futures-core", @@ -2507,7 +2519,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2727,9 +2739,9 @@ dependencies = [ [[package]] name = "gix-fs" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34740384d8d763975858fa2c176b68652a6fcc09f616e24e3ce967b0d370e4d8" +checksum = "3b3d4fac505a621f97e5ce2c69fdc425742af00c0920363ca4074f0eb48b1db9" dependencies = [ "fastrand", "gix-features", @@ -2791,9 +2803,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "globset" @@ -2814,7 +2826,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "ignore", "walkdir", ] @@ -2833,9 +2845,9 @@ dependencies = [ [[package]] name = "glutin_wgl_sys" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e1951bbd9434a81aa496fe59ccc2235af3820d27b85f9314e279609211e2c" +checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e" dependencies = [ "gl_generator", ] @@ -2846,7 +2858,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "gpu-alloc-types", ] @@ -2856,7 +2868,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -2873,13 +2885,13 @@ dependencies = [ [[package]] name = "gpu-descriptor" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c08c1f623a8d0b722b8b99f821eb0ba672a1618f0d3b16ddbee1cedd2dd8557" +checksum = "dcf29e94d6d243368b7a56caa16bc213e4f9f8ed38c4d9557069527b5d5281ca" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "gpu-descriptor-types", - "hashbrown 0.14.5", + "hashbrown 0.15.2", ] [[package]] @@ -2888,7 +2900,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -3133,9 +3145,9 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.3" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", "http", @@ -3322,7 +3334,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -3412,9 +3424,9 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f" +checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f" dependencies = [ "byteorder-lite", "quick-error", @@ -3467,16 +3479,15 @@ dependencies = [ [[package]] name = "instability" -version = "0.3.3" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e" +checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d" dependencies = [ "darling", "indoc", - "pretty_assertions", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -3496,7 +3507,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -3573,9 +3584,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ "once_cell", "wasm-bindgen", @@ -3648,7 +3659,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "libc", "redox_syscall 0.5.8", ] @@ -3666,9 +3677,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" @@ -3724,9 +3735,9 @@ dependencies = [ [[package]] name = "lz4" -version = "1.28.0" +version = "1.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" dependencies = [ "lz4-sys", ] @@ -3870,7 +3881,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -3885,7 +3896,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -3975,7 +3986,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4008,7 +4019,7 @@ checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" dependencies = [ "arrayvec", "bit-set", - "bitflags 2.6.0", + "bitflags 2.7.0", "cfg_aliases 0.1.1", "codespan-reporting", "hexf-parse", @@ -4199,7 +4210,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4271,7 +4282,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4295,7 +4306,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c9bff0aa1d48904a1385ea2a8b97576fbdcbc9a3cfccd0d31fe978e1c4038c5" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "libloading", "nvml-wrapper-sys", "static_assertions", @@ -4344,7 +4355,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "libc", "objc2", @@ -4360,7 +4371,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "objc2", "objc2-foundation", @@ -4390,7 +4401,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "libc", "objc2", @@ -4402,7 +4413,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "objc2", "objc2-foundation", @@ -4414,7 +4425,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "objc2", "objc2-foundation", @@ -4432,9 +4443,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] @@ -4536,9 +4547,9 @@ dependencies = [ [[package]] name = "openblas-build" -version = "0.10.10" +version = "0.10.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ca8f8c64eb5b43f5538059ccbc71391420bba14d987d7e8ab99ed62ed33e26b" +checksum = "b8140c0c1afaf88d2d30c48abad86b3bdd2334d691e08f7325a960d784240647" dependencies = [ "anyhow", "cc", @@ -4567,7 +4578,7 @@ version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "cfg-if", "foreign-types 0.3.2", "libc", @@ -4584,7 +4595,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4613,9 +4624,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "os_info" -version = "3.9.0" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ca711d8b83edbb00b44d504503cd247c9c0bd8b0fa2694f2a1a3d8165379ce" +checksum = "6e6520c8cc998c5741ee68ec1dc369fc47e5f0ea5320018ecf2a1ccd6328f48b" dependencies = [ "log", "serde", @@ -4748,18 +4759,18 @@ dependencies = [ [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -4767,9 +4778,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", "rand", @@ -4777,18 +4788,18 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ "siphasher", ] [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -4813,9 +4824,9 @@ dependencies = [ [[package]] name = "png" -version = "0.17.15" +version = "0.17.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67582bd5b65bdff614270e2ea89a1cf15bef71245cc1e5f7ea126977144211d" +checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526" dependencies = [ "bitflags 1.3.2", "crc32fast", @@ -4915,7 +4926,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "bytemuck", "chrono", "chrono-tz", @@ -4964,7 +4975,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "hashbrown 0.15.2", "num-traits", "once_cell", @@ -5056,7 +5067,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "memchr", "once_cell", "polars-arrow", @@ -5201,7 +5212,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "bytemuck", "bytes", "chrono", @@ -5403,12 +5414,12 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.25" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5445,7 +5456,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5568,7 +5579,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5581,7 +5592,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5670,9 +5681,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" +checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" dependencies = [ "cfg_aliases 0.2.1", "libc", @@ -5765,7 +5776,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "cassowary", "compact_str", "crossterm", @@ -5846,7 +5857,7 @@ version = "11.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -5915,7 +5926,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5933,7 +5944,7 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -5964,7 +5975,7 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6157,7 +6168,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.95", + "syn 2.0.96", "unicode-ident", ] @@ -6167,7 +6178,7 @@ version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "fallible-iterator", "fallible-streaming-iterator", "hashlink", @@ -6214,11 +6225,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "errno", "libc", "linux-raw-sys", @@ -6227,9 +6238,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.20" +version = "0.23.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" +checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" dependencies = [ "log", "once_cell", @@ -6262,7 +6273,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework 3.1.0", + "security-framework 3.2.0", ] [[package]] @@ -6296,9 +6307,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" @@ -6356,9 +6367,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b13f8ea6177672c49d12ed964cca44836f59621981b04a3e26b87e675181de" +checksum = "28e1c91382686d21b5ac7959341fcb9780fa7c03773646995a87c950fa7be640" dependencies = [ "sdd", ] @@ -6399,7 +6410,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "core-foundation 0.9.4", "core-foundation-sys", "libc", @@ -6408,11 +6419,11 @@ dependencies = [ [[package]] name = "security-framework" -version = "3.1.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "core-foundation 0.10.0", "core-foundation-sys", "libc", @@ -6421,9 +6432,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -6478,7 +6489,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6556,7 +6567,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6685,9 +6696,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "slab" @@ -6767,7 +6778,7 @@ version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -6871,7 +6882,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6893,9 +6904,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.95" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -6919,7 +6930,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6928,7 +6939,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "byteorder", "enum-as-inner", "libc", @@ -6970,7 +6981,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7054,12 +7065,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -7142,7 +7154,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7153,7 +7165,7 @@ checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7231,9 +7243,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -7278,9 +7290,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.42.0" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -7294,13 +7306,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7452,7 +7464,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7487,7 +7499,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7753,9 +7765,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" dependencies = [ "getrandom", "rand", @@ -7835,34 +7847,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.49" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", @@ -7873,9 +7886,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7883,22 +7896,25 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-logger" @@ -7941,9 +7957,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -8007,7 +8023,7 @@ checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" dependencies = [ "arrayvec", "bit-vec", - "bitflags 2.6.0", + "bitflags 2.7.0", "cfg_aliases 0.1.1", "document-features", "indexmap", @@ -8034,7 +8050,7 @@ dependencies = [ "arrayvec", "ash", "bit-set", - "bitflags 2.6.0", + "bitflags 2.7.0", "block", "bytemuck", "cfg_aliases 0.1.1", @@ -8075,7 +8091,7 @@ version = "23.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "js-sys", "web-sys", ] @@ -8185,7 +8201,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8196,7 +8212,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8207,7 +8223,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8218,7 +8234,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8410,9 +8426,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" dependencies = [ "memchr", ] @@ -8426,7 +8442,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8466,9 +8482,9 @@ checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" [[package]] name = "xattr" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" dependencies = [ "libc", "linux-raw-sys", @@ -8477,9 +8493,9 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8b391c9a790b496184c29f7f93b9ed5b16abb306c05415b68bcc16e4d06432" +checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" [[package]] name = "xtask" @@ -8493,9 +8509,9 @@ dependencies = [ [[package]] name = "xxhash-rust" -version = "0.8.12" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" [[package]] name = "yansi" @@ -8523,7 +8539,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "synstructure", ] @@ -8545,7 +8561,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8565,7 +8581,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "synstructure", ] @@ -8586,7 +8602,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8608,7 +8624,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a741216cfe..5dfebaf2b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index b6836a1b82..2bd0ba6f7e 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -37,7 +37,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = "cubecl", "repr", ] } -cubecl = { workspace = true, features = ["linalg"] } +cubecl = { workspace = true, features = ["linalg", "reduce"] } bytemuck = { workspace = true } derive-new = { workspace = true } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 57cdf13b1e..9ab1f5d2b6 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,83 +1,101 @@ -use cubecl::prelude::Numeric; - -#[cfg(feature = "autotune")] -use crate::kernel::reduce::reduce_dim_autotune; use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use super::{ - naive::{base::ReduceDimNaiveFamily, kernel::reduce_dim_naive}, - shared::{base::ReduceDimShared, kernel::reduce_dim_shared}, - subcube::{base::ReduceDimSubcube, kernel::reduce_dim_subcube}, -}; +use super::autotune_reduce; + +pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; -#[allow(dead_code)] -pub(crate) trait ReduceDimAlgorithm: - core::fmt::Debug + ReduceDimNaiveFamily + ReduceDimShared + ReduceDimSubcube -{ +/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). +/// +/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. +/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. +/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. +/// +/// If there is no error, the output is a tensor with decreasing strides +/// where the shape of reduced dim is set to 1 but all shape are similar to the input. +pub fn reduce( + mut input: JitTensor, + strategy: ReduceStrategy, +) -> Result, cubecl::reduce::ReduceError> { + input.shape = input.shape.flatten(); + input.strides = vec![1]; + reduce_dim::(input, 0, strategy) } -/// Creates an empty output tensor with reduce output shape -pub fn init_reduce_output( - input: &JitTensor, - reduce_dim: usize, -) -> JitTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[reduce_dim] = 1; +/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). +/// +/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. +/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. +/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. +/// +/// If there is no error, the output is a tensor with decreasing strides +/// where the shape of reduced dim is set to 1 but all shape are similar to the input. +pub fn reduce_dim( + input: JitTensor, + dim: usize, + strategy: ReduceStrategy, +) -> Result, cubecl::reduce::ReduceError> { + let client = input.client.clone(); + let output = init_reduce_output::(&input, dim).ok_or( + cubecl::reduce::ReduceError::InvalidAxis { + axis: dim, + rank: input.shape.num_dims(), + }, + )?; + let result = match strategy { + ReduceStrategy::Unspecified => cubecl::reduce::reduce::( + &client, + input.as_handle_ref(), + output.as_handle_ref(), + dim, + None, + ), + ReduceStrategy::Specific(strategy) => cubecl::reduce::reduce::( + &client, + input.as_handle_ref(), + output.as_handle_ref(), + dim, + Some(strategy), + ), + #[cfg(feature = "autotune")] + ReduceStrategy::Autotune => { + autotune_reduce::(&client, input, output.clone(), dim) + } + }; + result.map(|_| output) +} - empty_device::(input.client.clone(), input.device.clone(), shape_out) +/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input` +/// or return `None` if `axis` is out-of-bound. +pub fn init_reduce_output( + input: &JitTensor, + dim: usize, +) -> Option> { + (dim < input.shape.num_dims()).then(|| { + let mut shape_out = input.shape.clone(); + shape_out.dims[dim] = 1; + empty_device::(input.client.clone(), input.device.clone(), shape_out) + }) } +/// Select a strategy to perform a reduction. #[derive(Copy, Clone, Debug)] -#[allow(missing_docs)] pub enum ReduceStrategy { - /// Naive - Naive, - /// Use shared memory as an accumulator - SharedMemory, - /// Use subcube functions - Subcube, + /// Use a best-effort strategy based on the hardware capacity. + /// This differs from Autotune as it doesn't try and compare many strategies to select the best. + Unspecified, + /// Fix the exact strategy for the reduction. + Specific(cubecl::reduce::ReduceStrategy), + /// Use autotune to find the best strategy given the hardware and the inputs. #[cfg(feature = "autotune")] Autotune, } impl Default for ReduceStrategy { fn default() -> Self { - // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] - return ReduceStrategy::Autotune; + return Self::Autotune; #[cfg(not(feature = "autotune"))] - ReduceStrategy::Naive + return Self::Unspecified; } } - -macro_rules! reduce_operation { - ($name:ident, $ops:ident) => { - #[derive(Debug)] - pub(crate) struct $ops; - - impl ReduceDimAlgorithm for $ops {} - - /// Executes the reduce operation with the given strategy. - pub fn $name( - tensor: JitTensor, - dim: usize, - strategy: ReduceStrategy, - ) -> Result, String> { - match strategy { - ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), - #[cfg(feature = "autotune")] - ReduceStrategy::Autotune => Ok(reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim)), - } - } - }; -} - -// Autotunable reduce operation variants -reduce_operation!(sum_dim, SumDim); -reduce_operation!(mean_dim, MeanDim); -reduce_operation!(prod_dim, ProdDim); -reduce_operation!(argmin, Argmin); -reduce_operation!(argmax, Argmax); diff --git a/crates/burn-jit/src/kernel/reduce/mod.rs b/crates/burn-jit/src/kernel/reduce/mod.rs index 2401f9467e..8ff38a9da7 100644 --- a/crates/burn-jit/src/kernel/reduce/mod.rs +++ b/crates/burn-jit/src/kernel/reduce/mod.rs @@ -1,12 +1,5 @@ mod base; -mod naive; -mod prod; -mod shared; -mod subcube; -mod sum; mod tune; pub use base::*; -pub use prod::*; -pub use sum::*; pub use tune::*; diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs deleted file mode 100644 index d577d3decf..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs +++ /dev/null @@ -1,36 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmax; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for Argmax { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for Argmax { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (EI::min_value(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (max, index) = accumulator; - if current_value > *max { - *max = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs deleted file mode 100644 index 2302a2b205..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs +++ /dev/null @@ -1,36 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmin; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for Argmin { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for Argmin { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (EI::max_value(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (min, index) = accumulator; - if current_value < *min { - *min = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/base.rs b/crates/burn-jit/src/kernel/reduce/naive/base.rs deleted file mode 100644 index 7512103ebb..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/base.rs +++ /dev/null @@ -1,25 +0,0 @@ -use cubecl::prelude::*; - -pub trait ReduceDimNaiveFamily: Send + Sync + 'static { - type Reduce: ReduceDimNaive; -} - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimNaive: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - - /// Initialization for naive algorithm - fn initialize_naive() -> Self::Accumulator; - - /// Inner loop for naive algorithm - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32); - - /// Assignation for naive algorithm - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - shape_reduce_dim: u32, - ); -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs deleted file mode 100644 index c862e7070d..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::{ - element::JitElement, kernel::reduce::init_reduce_output, tensor::JitTensor, JitRuntime, -}; -use cubecl::calculate_cube_count_elemwise; -use cubecl::prelude::*; - -use super::base::ReduceDimNaive; -use super::base::ReduceDimNaiveFamily; - -#[cube(launch_unchecked)] -pub(crate) fn naive_reduce_dim_kernel( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - naive_reduce::, EI, EO>(input, output, dim) -} - -#[cube] -fn naive_reduce, EI: Numeric, EO: Numeric>( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - if ABSOLUTE_POS >= output.len() { - return; - } - - let mut offset_input = 0; - - for i in 0..input.rank() { - let mut offset_local = ABSOLUTE_POS / output.stride(i); - offset_local %= output.shape(i); - if i != dim { - offset_input += offset_local * input.stride(i); - } - } - - let mut accumulator = RD::initialize_naive(); - - for i in 0..input.shape(dim) { - let index = i * input.stride(dim) + offset_input; - RD::inner_loop_naive(&mut accumulator, input[index], i); - } - - RD::assign_naive::(output, accumulator, input.shape(dim)); -} - -/// Executes the naive kernel for reduce dim -pub fn reduce_dim_naive( - input: JitTensor, - dim: usize, -) -> Result, String> { - let output = init_reduce_output::(&input, dim); - - let cube_dim = CubeDim::default(); - let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); - - unsafe { - naive_reduce_dim_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - ScalarArg::new(dim as u32), - ); - } - - Ok(output) -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs deleted file mode 100644 index 774c9b251c..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs +++ /dev/null @@ -1,27 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::MeanDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for MeanDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for MeanDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(0) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, shape_reduce_dim: u32) { - let mean = accumulator / EI::cast_from(shape_reduce_dim); - output[ABSOLUTE_POS] = EO::cast_from(mean); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mod.rs b/crates/burn-jit/src/kernel/reduce/naive/mod.rs deleted file mode 100644 index b11ee5e2da..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs deleted file mode 100644 index 1ea52a149c..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::ProdDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for ProdDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for ProdDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(1) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator *= current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs deleted file mode 100644 index 7168e07ff3..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::SumDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for SumDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for SumDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(0) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/prod.rs b/crates/burn-jit/src/kernel/reduce/prod.rs deleted file mode 100644 index 1b156157fa..0000000000 --- a/crates/burn-jit/src/kernel/reduce/prod.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{prod_dim, ReduceStrategy}; - -/// Multiply all elements in the input buffer. -pub fn prod( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - prod_dim::(input, 0, strategy).unwrap() -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs deleted file mode 100644 index 43c03c09ce..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::kernel::reduce::Argmax; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::min_value(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value > values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs deleted file mode 100644 index 0e47693c5a..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ /dev/null @@ -1,64 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmin; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::max_value(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value < values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/base.rs b/crates/burn-jit/src/kernel/reduce/shared/base.rs deleted file mode 100644 index 256123fe1b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/base.rs +++ /dev/null @@ -1,33 +0,0 @@ -use cubecl::prelude::*; - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimShared: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - type Value: CubeType; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> Self::Accumulator; - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut Self::Accumulator, - write_position: u32, - value: Self::Value, - ); - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> Self::Value; - - /// How to read from shared memory - fn read_from_shared(shared_memory: &Self::Accumulator, read_position: u32) -> Self::Value; - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &Self::Accumulator, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ); -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs deleted file mode 100644 index 1c15e4523f..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ /dev/null @@ -1,117 +0,0 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::init_reduce_output, tensor::JitTensor, JitElement, JitRuntime}; - -use super::base::ReduceDimShared; - -#[cube(launch)] -pub fn reduce_dim_shared_kernel< - RD: ReduceDimShared, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] smem_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let mut shared_memory = RD::initialize_shared(smem_size, UNIT_POS); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } else { - if nth < shape_reduce_dim_input { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } - } - } - - sync_units(); - - let mut n_threads = CUBE_DIM; - - while n_threads > 1 { - n_threads /= 2; - - if UNIT_POS < n_threads { - let read_pos = n_threads + UNIT_POS; - let read_value = RD::read_from_shared(&shared_memory, read_pos); - RD::write_to_shared(&mut shared_memory, UNIT_POS, read_value); - } - - sync_units(); - } - - if UNIT_POS == 0 { - RD::assign_shared( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_shared< - RD: ReduceDimShared, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> Result, String> { - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim::default(); - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_shared_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - cube_dim.num_elems(), - elems_per_thread, - divisible_shape, - ); - - Ok(output) -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs deleted file mode 100644 index eef8f5f478..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::kernel::reduce::MeanDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ) { - let mean = shared_memory[0] / EIn::cast_from(shape_reduce_dim); - output[write_position] = EOut::cast_from(mean); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/mod.rs b/crates/burn-jit/src/kernel/reduce/shared/mod.rs deleted file mode 100644 index b11ee5e2da..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs deleted file mode 100644 index 594f2fec11..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::kernel::reduce::ProdDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(1); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] *= value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs deleted file mode 100644 index 476dd554a4..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::kernel::reduce::SumDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs b/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs deleted file mode 100644 index c8e567e816..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::Argmax; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::min_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Max::max(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let max = plane_max(val); - - if max == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs b/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs deleted file mode 100644 index b7950ebfe2..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::Argmin; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::max_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Min::min(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let min = plane_min(val); - - if min == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/base.rs b/crates/burn-jit/src/kernel/reduce/subcube/base.rs deleted file mode 100644 index f20e538914..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/base.rs +++ /dev/null @@ -1,15 +0,0 @@ -use cubecl::prelude::*; - -#[cube] -pub trait ReduceDimSubcube: Send + Sync + 'static { - type Accumulator: CubeType; - type Value: CubeType; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator; - fn init_value() -> Self::Value; - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value; - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value; - fn update_value(current: &mut Self::Value, new: Self::Value); - fn reduce_subcube(acc: &mut Self::Accumulator, pos: u32, value: Self::Value); - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_len: u32); -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs deleted file mode 100644 index 26f65f5d68..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ /dev/null @@ -1,134 +0,0 @@ -use cubecl::{prelude::*, CubeCount, CubeDim, Feature}; - -use crate::{ - kernel::reduce::{init_reduce_output, shared::kernel::reduce_dim_shared, ReduceDimAlgorithm}, - tensor::JitTensor, - JitElement, JitRuntime, -}; - -use super::base::ReduceDimSubcube; - -#[cube(launch)] -pub fn reduce_dim_subcube_kernel< - RD: ReduceDimSubcube, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] subcube_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let should_unroll = elems_per_thread <= 8; - - let warp_id = plane_broadcast(UNIT_POS / PLANE_DIM, 0); - - let mut shared_memory = RD::init_shared(subcube_size); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - let mut value = RD::init_value(); - - #[unroll(should_unroll)] - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - let current_pos = nth * stride_reduce_dim_input + index_offset; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } else { - if nth < shape_reduce_dim_input { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } - } - } - - RD::reduce_subcube(&mut shared_memory, warp_id, value); - - sync_units(); - - if UNIT_POS >= PLANE_DIM { - return; - } - - let value = RD::read_from_shared(&shared_memory, UNIT_POS); - RD::reduce_subcube(&mut shared_memory, 0, value); - - if UNIT_POS == 0 { - RD::store( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_subcube< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> Result, String> { - let topology = input.client.properties().hardware_properties(); - - if !input.client.properties().feature_enabled(Feature::Plane) - || topology.plane_size_min != topology.plane_size_max - { - return reduce_dim_shared::(input, dim); - } - - let subcube_size = topology.plane_size_min; - - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim { - x: subcube_size, - y: subcube_size, - z: 1, - }; - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_subcube_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - subcube_size, - elems_per_thread, - divisible_shape, - ); - - Ok(output) -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs deleted file mode 100644 index fb8c0b41d6..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs +++ /dev/null @@ -1,45 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::MeanDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::cast_from(0u32) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_length: u32) { - let denom = EIn::cast_from(dim_length); - out[pos] = EOut::cast_from(acc[0] / denom); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/mod.rs b/crates/burn-jit/src/kernel/reduce/subcube/mod.rs deleted file mode 100644 index 183c1e2daf..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod argmax; -pub mod argmin; -pub mod base; -pub mod kernel; -pub mod mean_dim; -pub mod prod_dim; -pub mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs deleted file mode 100644 index cccec95167..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::ProdDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::from_int(1) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current *= new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let prod = plane_prod(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = prod; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs deleted file mode 100644 index 1059432eb2..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::SumDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::cast_from(0u32) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/sum.rs b/crates/burn-jit/src/kernel/reduce/sum.rs deleted file mode 100644 index d3c9416dc1..0000000000 --- a/crates/burn-jit/src/kernel/reduce/sum.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{sum_dim, ReduceStrategy}; - -/// Sum all elements in the input buffer. -pub fn sum( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - sum_dim::(input, 0, strategy).unwrap() -} diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs new file mode 100644 index 0000000000..6816196a37 --- /dev/null +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -0,0 +1,222 @@ +#![allow(missing_docs)] + +use burn_tensor::ElementConversion; +use cubecl::{ + client::ComputeClient, + tune, + tune::{local_tuner, tune_with, LocalTuner}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, + JitAutotuneKey, JitElement, JitRuntime, JitTuneId, +}; + +/// Executes autotune on reduce operations. +pub fn autotune_reduce< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, +>( + client: &ComputeClient, + input: JitTensor, + output: JitTensor, + dim: usize, +) -> Result<(), cubecl::reduce::ReduceError> { + static TUNER: LocalTuner = local_tuner!(); + + TUNER.execute( + &JitTuneId::new::(&input.device), + client, + Box::new(ReduceOps::::new(input, output, dim)), + ); + + Ok(()) +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of redue versions +pub struct ReduceAutotuneKey { + dtype: burn_tensor::DType, + #[autotune(anchor)] + reduce_axis_shape: usize, + #[autotune(anchor)] + reduce_axis_stride: usize, + #[autotune(anchor)] + outer_axes_product: usize, // The product of the shapes of all axes with greater strides. +} + +impl ReduceAutotuneKey { + pub(crate) fn generate(input: &JitTensor, axis: usize) -> Self { + let rank = input.shape.num_dims(); + + if axis > rank { + panic!("axis {axis} is out-of-bound for a rank of {rank}"); + } + + let dtype = input.dtype; + let reduce_axis_shape = input.shape.dims[axis]; + let reduce_axis_stride = input.strides[axis]; + + let outer_axes_product = input + .strides + .iter() + .zip(input.shape.dims.iter()) + .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape)) + .product(); + + Self { + dtype, + reduce_axis_shape, + reduce_axis_stride, + outer_axes_product, + } + } +} + +pub(crate) fn create_key( + input: &JitTensor, + _output: &JitTensor, + dim: &usize, +) -> JitAutotuneKey { + JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim)) +} + +pub use reduce_ops::*; +mod reduce_ops { + #![allow(missing_docs)] + + use super::*; + + #[tune( + operations(reduce, reduce_shared, reduce_plane, reduce_shared_plane), + create_key = create_key::, + should_run = should_run +)] + fn reduce_ops( + key: JitAutotuneKey, + input: JitTensor, + output: JitTensor, + dim: usize, + ) { + let random_bounds: (In, In) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); + let input = random_like_uniform(input, random_bounds.0, random_bounds.1); + + let output = empty_device::( + output.client.clone(), + output.device.clone(), + output.shape.clone(), + ); + + tune_with!(input, output, dim) + } + + fn should_run( + op: &ReduceOps, + _key: &JitAutotuneKey, + index: usize, + ) -> bool { + match index { + // if strategy uses planes + 2 | 3 => { + let properties = op.input.client.properties(); + properties.feature_enabled(cubecl::Feature::Plane) + && properties + .hardware_properties() + .defined_plane_size() + .is_some() + } + _ => true, + } + } + + fn reduce( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: false, + use_planes: false, + }), + ) + .map_err(|e| format!("{e}")) + } + + fn reduce_shared< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: true, + use_planes: false, + }), + ) + .map_err(|e| format!("{e}")) + } + + fn reduce_plane< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: false, + use_planes: true, + }), + ) + .map_err(|e| format!("{e}")) + } + + fn reduce_shared_plane< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: true, + use_planes: true, + }), + ) + .map_err(|e| format!("{e}")) + } +} diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs deleted file mode 100644 index f52bfd7ca0..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ /dev/null @@ -1,94 +0,0 @@ -use burn_tensor::{Element, ElementConversion}; -use cubecl::tune::{local_tuner, tune_with, LocalTuner}; -use cubecl::{tune, Feature}; - -use crate::{ - element::JitElement, - kernel::{ - prng::random_like_uniform, - reduce::{ - naive::kernel::reduce_dim_naive, shared::kernel::reduce_dim_shared, - subcube::kernel::reduce_dim_subcube, ReduceDimAlgorithm, - }, - }, - tensor::JitTensor, - tune_key::JitAutotuneKey, - JitRuntime, JitTuneId, -}; - -use super::create_key; - -/// Set of reduce_dim implementations available for autotune -/// Autotune key is given by concatenating the closest upper power of 2 of -/// dim to reduce, and product of others -#[tune( - operations(reduce_dim_naive, reduce_dim_shared, reduce_dim_subcube), - create_key = create_key::, - should_run = should_run -)] -pub fn reduce_dim_operations< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - key: JitAutotuneKey, - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let random_bounds: (EI, EI) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(input, random_bounds.0, random_bounds.1); - - tune_with!(input, reduce_dim) -} - -/// Executes autotune on reduce_dim operation -pub(crate) fn reduce_dim_autotune< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let client = input.client.clone(); - - let id = JitTuneId::new::(&input.device); - - let operation_set = Box::new(ReduceDimOperations::::new(input, reduce_dim)); - - static TUNER: LocalTuner = local_tuner!(); - - TUNER.execute(&id, &client, operation_set) -} - -fn should_run< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - op: &ReduceDimOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let JitAutotuneKey::ReduceDim(key) = key else { - unreachable!() - }; - - match index { - // Naive - 0 => key.reduce_dim_length <= 8192, - // Shared - 1 => key.reduce_dim_length >= 16, - // Subcube - 2 => { - let props = op.input.client.properties(); - let hardware = props.hardware_properties(); - props.feature_enabled(Feature::Plane) - && hardware.plane_size_min == hardware.plane_size_max - } - _ => true, - } -} diff --git a/crates/burn-jit/src/kernel/reduce/tune/key.rs b/crates/burn-jit/src/kernel/reduce/tune/key.rs deleted file mode 100644 index 3634022bc7..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/key.rs +++ /dev/null @@ -1,39 +0,0 @@ -use cubecl::AutotuneKey; -use serde::{Deserialize, Serialize}; - -use burn_tensor::DType; - -use crate::{tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime}; - -/// Autotune key representative of reduce versions -#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] -pub struct ReduceAutotuneKey { - #[autotune(anchor)] - pub(crate) reduce_dim_length: usize, - #[autotune(anchor)] - pub(crate) reduce_dim_stride: usize, - #[autotune(anchor)] - pub(crate) others_product: usize, - dtype: DType, -} - -pub(crate) fn create_key( - input: &JitTensor, - reduce_dim: &usize, -) -> JitAutotuneKey { - let dims = &input.shape.dims; - let reduce_dim = *reduce_dim; - - let mut others_product = 1; - for (d, len) in dims.iter().enumerate() { - if d != reduce_dim { - others_product *= len - } - } - JitAutotuneKey::ReduceDim(ReduceAutotuneKey::new( - dims[reduce_dim], - input.strides[reduce_dim], - others_product, - EI::dtype(), - )) -} diff --git a/crates/burn-jit/src/kernel/reduce/tune/mod.rs b/crates/burn-jit/src/kernel/reduce/tune/mod.rs deleted file mode 100644 index aee5569b6b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(feature = "autotune")] -mod base; -mod key; - -#[cfg(feature = "autotune")] -pub(crate) use base::*; -pub use key::*; diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index c59b9df83c..d32de97436 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -355,7 +355,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() ) } @@ -363,7 +363,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -371,7 +371,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::mean_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -379,7 +379,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() ) } @@ -387,7 +387,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -467,7 +467,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmax::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -475,7 +475,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmin::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index ed99258826..5702a90849 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,5 @@ use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; +use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -193,31 +193,31 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::reduce::sum::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::sum_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_prod(tensor: IntTensor) -> IntTensor { - kernel::reduce::prod::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::prod_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::mean_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_clamp( diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index 378eb035ed..a79ac3c437 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -17,6 +17,7 @@ mod max_pool2d; mod max_pool2d_backward; mod normal; mod quantization; +mod reduce; mod repeat_dim; mod scatter; mod select; @@ -78,6 +79,8 @@ macro_rules! testgen_all { burn_jit::testgen_clamp!(); burn_jit::testgen_unary!(); + burn_jit::testgen_reduce!(); + burn_jit::testgen_quantization!(); } } diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs new file mode 100644 index 0000000000..8e533361e9 --- /dev/null +++ b/crates/burn-jit/src/tests/reduce.rs @@ -0,0 +1,128 @@ +#[burn_tensor_testgen::testgen(reduce)] +mod reduce { + use super::*; + use burn_jit::kernel::reduce::{ + reduce, reduce_dim, ArgMax, ArgMin, Mean, Prod, ReduceStrategy, Sum, + }; + use burn_tensor::{ + backend::Backend, ops::IntTensorOps, Distribution, Int, Shape, Tensor, TensorData, + TensorPrimitive, + }; + + const RANK: usize = 4; + const SHAPE: [usize; RANK] = [2, 4, 8, 16]; + + #[test] + fn reduction_argmax_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .argmax(dim) + .into_data() + .assert_eq(&tensor_ref.clone().argmax(dim).into_data(), false); + } + } + + #[test] + fn reduction_argmin_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .argmin(dim) + .into_data() + .assert_eq(&tensor_ref.clone().argmin(dim).into_data(), false); + } + } + + #[test] + fn reduction_mean_dim_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .mean_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().mean_dim(dim).into_data(), 1e-6); + } + } + + #[test] + fn reduction_mean_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .mean() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().mean().into_data(), 1e-6); + } + + #[test] + fn reduction_prod_dim_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .prod_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().prod_dim(dim).into_data(), 1e-6); + } + } + + #[test] + fn reduction_prod_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .prod() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().prod().into_data(), 1e-6); + } + + #[test] + fn reduction_sum_dim_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .sum_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().sum_dim(dim).into_data(), 1e-6); + } + } + + #[test] + fn reduction_sum_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .sum() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().sum().into_data(), 1e-6); + } +} diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index 0a7ae855b9..cb29e2fe0c 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -13,7 +13,7 @@ pub enum JitAutotuneKey { /// Key for matmul operation Matmul(MatmulAutotuneKey), /// Key for reduce dim operations - ReduceDim(ReduceAutotuneKey), + Reduce(ReduceAutotuneKey), /// Key for convolution operations Conv2d(Conv2dAutotuneKey), /// Key for transpose convolution operations @@ -24,7 +24,7 @@ impl Display for JitAutotuneKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), - JitAutotuneKey::ReduceDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } diff --git a/crates/burn-tensor/src/tensor/shape.rs b/crates/burn-tensor/src/tensor/shape.rs index 8ad54ba4d9..29eebd549e 100644 --- a/crates/burn-tensor/src/tensor/shape.rs +++ b/crates/burn-tensor/src/tensor/shape.rs @@ -33,6 +33,13 @@ impl Shape { dims[..D].copy_from_slice(&self.dims[..D]); dims } + + /// Change the shape to one dimensional with the same number of elements. + pub fn flatten(&self) -> Self { + Self { + dims: [self.dims.iter().product()].into(), + } + } } impl From<[usize; D]> for Shape { From 1902c90d067faaa29b20b97917d05b5f912d4748 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 13 Jan 2025 18:34:47 -0500 Subject: [PATCH 02/17] Update cubecl (#2693) --- Cargo.lock | 26 +++++++++++++------------- Cargo.toml | 4 ++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index feff4ed96a..02c5f0fda5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1582,7 +1582,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1597,7 +1597,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,7 +1614,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1633,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1647,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1689,7 +1689,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-common", "darling", @@ -1716,7 +1716,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-common", "cubecl-core", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,7 +1742,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "async-channel", "async-lock", @@ -1763,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-common", "cubecl-core", @@ -1777,7 +1777,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 5dfebaf2b0..508868e381 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } From ddc43398037b50c94fc2665c54ccf447c3eb179c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 09:05:56 -0500 Subject: [PATCH 03/17] Add dropout prob check (#2695) * Add dropout prob check * Add test --- crates/burn-core/src/nn/dropout.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index d03e95c1f3..79fc12ecbf 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -30,6 +30,12 @@ pub struct Dropout { impl DropoutConfig { /// Initialize a new [dropout](Dropout) module. pub fn init(&self) -> Dropout { + if self.prob < 0.0 || self.prob > 1.0 { + panic!( + "Dropout probability should be between 0 and 1, but got {}", + self.prob + ); + } Dropout { prob: self.prob } } } @@ -108,4 +114,11 @@ mod tests { assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}"); } + + #[test] + #[should_panic = "Dropout probability should be between 0 and 1,"] + fn dropout_prob_invalid() { + let config = DropoutConfig::new(-10.); + let _layer = config.init(); + } } From d30f71c53343879d5d429fccc2fec74c026d298d Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 14 Jan 2025 13:36:13 -0500 Subject: [PATCH 04/17] Fix reduce autotune key no anchor (#2696) --- Cargo.lock | 26 +++++++++++------------ Cargo.toml | 4 ++-- crates/burn-jit/src/kernel/reduce/tune.rs | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 02c5f0fda5..3c8a8b414f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1582,7 +1582,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1597,7 +1597,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,7 +1614,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1633,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1647,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1689,7 +1689,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-common", "darling", @@ -1716,7 +1716,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-common", "cubecl-core", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,7 +1742,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "async-channel", "async-lock", @@ -1763,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-common", "cubecl-core", @@ -1777,7 +1777,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 508868e381..bf14a247d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index 6816196a37..c5659cc1cc 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -68,12 +68,12 @@ impl ReduceAutotuneKey { .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape)) .product(); - Self { + Self::new( dtype, reduce_axis_shape, reduce_axis_stride, outer_axes_product, - } + ) } } From cdcff034f57ce8d17108ffb8bbaa6b42ed840f88 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 14:52:05 -0500 Subject: [PATCH 05/17] Set cubecl version (#2697) --- Cargo.lock | 39 ++++++++++++++++++++++++++------------- Cargo.toml | 8 ++++---- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c8a8b414f..09f9e5dc14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1582,7 +1582,8 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aecf090429a4172d94c819e2977f440d7f5846c09f31d36937de309f986c878e" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1597,7 +1598,8 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10239ee4800968f367fbc4828250d38acf5d14fa53e8d0370d5f474387591322" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,7 +1616,8 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d249976814abe45ee5d04bdfd5e2359558b409affdc03914625bea778dab5ade" dependencies = [ "bytemuck", "cubecl-common", @@ -1633,7 +1636,8 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8463629d0bdf4d09d47150bce35132236c1a597f65eba213b45073406048a596" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1651,8 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12c0b49113ba986e984538cf54c3d7390c0af934a80f083b6c99cad737d22c59" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1668,8 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976e150315f9d7d6bb84c51cb13c19221ea5d185bb6d61347a3c392dd29720de" dependencies = [ "bytemuck", "cubecl-common", @@ -1689,7 +1695,8 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "640c379e225fecb1336f963affd3b8f1ff66b9320a972dfe92d8158dca8b6382" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1708,8 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05d95f3be436814f909a3ac97209159f63076d3d2b254914bc02db2ac7faefb" dependencies = [ "cubecl-common", "darling", @@ -1716,7 +1724,8 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42c0593efee028e010a1a7e8646a8a405f6a653fe194bc8c5b46189245ecaeec" dependencies = [ "cubecl-common", "cubecl-core", @@ -1732,7 +1741,8 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0912890b52cc6f9636e0070320ff93dec27af15d57453789081b9a8bdb49786d" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,7 +1752,8 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e84f4ae5a096e4d0c410db01d18b673d6efcd6eea1724d1a001ab60484df87" dependencies = [ "async-channel", "async-lock", @@ -1763,7 +1774,8 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5d88e7d35a58a40991e42c4492739d4b89b6046ac75126cb5f10b190032012c" dependencies = [ "cubecl-common", "cubecl-core", @@ -1777,7 +1789,8 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cf8105d01ef4cd103d4e31bee9ae583fabc807253234923fb08218b28db7d15" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index bf14a247d7..e79690fd28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,14 +153,14 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### -# cubecl = { version = "0.3.0", default-features = false } -# cubecl-common = { version = "0.3.0", default-features = false } +cubecl = { version = "0.4.0", default-features = false } +cubecl-common = { version = "0.4.0", default-features = false } ### For xtask crate ### tracel-xtask = { version = "=1.1.8" } From 59a2e3bc395cfaca694dff170727dc1632e6c078 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:14:22 -0500 Subject: [PATCH 06/17] Add burn-remote to publish workflow --- .github/workflows/publish.yml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8b8916c631..46446ced01 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,6 +6,31 @@ on: - "v*" jobs: + publish-burn-router: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-router + needs: + - publish-burn-common + - publish-burn-tensor + # dev dependencies + - publish-burn-autodiff + - publish-burn-ndarray + - publish-burn-wgpu + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + + publish-burn-remote: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-derive + needs: + - publish-burn-common + - publish-burn-tensor + - publish-burn-router + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-derive: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: @@ -162,6 +187,7 @@ jobs: - publish-burn-tch - publish-burn-ndarray - publish-burn-candle + - publish-burn-remote with: crate: burn-core secrets: From 1b91e32e86d44c33ec4a87d468954a45dcb3e554 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:25:24 -0500 Subject: [PATCH 07/17] Fix licenses --- NOTICES.md | 43 +++++++++++++++++++++++++++++++++++++++++++ deny.toml | 3 +++ 2 files changed, 46 insertions(+) diff --git a/NOTICES.md b/NOTICES.md index c41a90d952..0f559e27a4 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -601,3 +601,46 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +## ICU + +UNICODE LICENSE V3 + +COPYRIGHT AND PERMISSION NOTICE + +Copyright © 2016-2024 Unicode, Inc. + +NOTICE TO USER: Carefully read the following legal agreement. BY +DOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR +SOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE +TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT +DOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of data files and any associated documentation (the "Data Files") or +software and any associated documentation (the "Software") to deal in the +Data Files or Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, and/or sell +copies of the Data Files or Software, and to permit persons to whom the +Data Files or Software are furnished to do so, provided that either (a) +this copyright and permission notice appear with all copies of the Data +Files or Software, or (b) this copyright and permission notice appear in +associated Documentation. + +THE DATA FILES AND SOFTWARE ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY +KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF +THIRD PARTY RIGHTS. + +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE +BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, +OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, +ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA +FILES OR SOFTWARE. + +Except as contained in this notice, the name of a copyright holder shall +not be used in advertising or otherwise to promote the sale, use or other +dealings in these Data Files or Software without prior written +authorization of the copyright holder. diff --git a/deny.toml b/deny.toml index a9a4506064..e8a251eb1c 100644 --- a/deny.toml +++ b/deny.toml @@ -75,12 +75,14 @@ allow = [ "Apache-2.0 WITH LLVM-exception", "Apache-2.0", "BSD-3-Clause", + "BSD-2-Clause", "CC0-1.0", "ISC", "MIT", "MPL-2.0", "OpenSSL", "Unicode-DFS-2016", + "Unicode-3.0", "Unlicense", "Zlib", ] @@ -90,4 +92,5 @@ exceptions = [ # Each entry is the crate and version constraint, and its specific allow # list #{ allow = ["license_name"], name = "crate", version = "*" }, + { allow = ["BSL-1.0"], name = "clipboard-win", version = "*" }, # in NOTICES.md ] From 3a6a456d2bad0501cd4de213208d82f746e57745 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:35:53 -0500 Subject: [PATCH 08/17] Fix typo --- .github/workflows/publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 46446ced01..de956c243e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,7 +23,7 @@ jobs: publish-burn-remote: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: - crate: burn-derive + crate: burn-remote needs: - publish-burn-common - publish-burn-tensor From dd628ec91c8dafa7f5767d85f822d46dec8f4707 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:49:32 -0500 Subject: [PATCH 09/17] Remove self dep on burn-remote --- Cargo.lock | 1 - crates/burn-remote/Cargo.toml | 4 ---- deny.toml | 2 +- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 09f9e5dc14..92be93458c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -874,7 +874,6 @@ dependencies = [ "async-channel", "axum", "burn-common", - "burn-remote", "burn-router", "burn-tensor", "derive-new 0.7.0", diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index fa6034c681..de772861c9 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -43,10 +43,6 @@ axum = { version = "0.8.1", features = ["ws"], optional = true } tracing-core = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } -[dev-dependencies] -# We activate the features client and server during dev. -burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] } - [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] diff --git a/deny.toml b/deny.toml index e8a251eb1c..ac64c923fe 100644 --- a/deny.toml +++ b/deny.toml @@ -76,6 +76,7 @@ allow = [ "Apache-2.0", "BSD-3-Clause", "BSD-2-Clause", + "BSL-1.0", # in NOTICES.md "CC0-1.0", "ISC", "MIT", @@ -92,5 +93,4 @@ exceptions = [ # Each entry is the crate and version constraint, and its specific allow # list #{ allow = ["license_name"], name = "crate", version = "*" }, - { allow = ["BSL-1.0"], name = "clipboard-win", version = "*" }, # in NOTICES.md ] From 93cafc41b509e618ca46b680b3015659b97a0de3 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 18:43:58 -0500 Subject: [PATCH 10/17] Bump next version of Burn to 0.17.0 (#2698) --- Cargo.lock | 88 ++++++++++---------- Cargo.toml | 2 +- backend-comparison/Cargo.toml | 2 +- burn-book/src/advanced/no-std.md | 2 +- burn-book/src/basic-workflow/README.md | 2 +- burn-book/src/basic-workflow/model.md | 2 +- burn-book/src/import/onnx-model.md | 2 +- crates/burn-autodiff/Cargo.toml | 8 +- crates/burn-candle/Cargo.toml | 8 +- crates/burn-core/Cargo.toml | 32 +++---- crates/burn-cuda/Cargo.toml | 8 +- crates/burn-dataset/Cargo.toml | 2 +- crates/burn-fusion/Cargo.toml | 4 +- crates/burn-hip/Cargo.toml | 8 +- crates/burn-import/Cargo.toml | 6 +- crates/burn-jit/Cargo.toml | 12 +-- crates/burn-ndarray/Cargo.toml | 10 +-- crates/burn-no-std-tests/Cargo.toml | 4 +- crates/burn-remote/Cargo.toml | 6 +- crates/burn-router/Cargo.toml | 12 +-- crates/burn-tch/Cargo.toml | 6 +- crates/burn-tensor/Cargo.toml | 4 +- crates/burn-train/Cargo.toml | 6 +- crates/burn-wgpu/Cargo.toml | 8 +- crates/burn/Cargo.toml | 4 +- examples/image-classification-web/Cargo.toml | 4 +- examples/pytorch-import/Cargo.toml | 2 +- examples/pytorch-import/model/Cargo.toml | 2 +- examples/raspberry-pi-pico/Cargo.lock | 34 ++++---- examples/server/Cargo.toml | 2 +- xtask/Cargo.toml | 2 +- 31 files changed, 147 insertions(+), 147 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 92be93458c..c34fb9cd03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -399,7 +399,7 @@ dependencies = [ [[package]] name = "backend-comparison" -version = "0.16.0" +version = "0.17.0" dependencies = [ "arboard", "burn", @@ -617,7 +617,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "burn-train", @@ -625,7 +625,7 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -637,7 +637,7 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-tch", @@ -649,7 +649,7 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.16.0" +version = "0.17.0" dependencies = [ "cubecl-common", "dashmap", @@ -664,7 +664,7 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.16.0" +version = "0.17.0" dependencies = [ "ahash", "bincode", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "burn-cuda" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -717,7 +717,7 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "csv", @@ -749,7 +749,7 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.16.0" +version = "0.17.0" dependencies = [ "derive-new 0.7.0", "proc-macro2", @@ -759,7 +759,7 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -773,7 +773,7 @@ dependencies = [ [[package]] name = "burn-hip" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -788,7 +788,7 @@ dependencies = [ [[package]] name = "burn-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-ndarray", @@ -814,7 +814,7 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -840,7 +840,7 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.16.0" +version = "0.17.0" dependencies = [ "atomic_float", "blas-src", @@ -860,7 +860,7 @@ dependencies = [ [[package]] name = "burn-no-std-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-ndarray", @@ -869,7 +869,7 @@ dependencies = [ [[package]] name = "burn-remote" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-channel", "axum", @@ -890,7 +890,7 @@ dependencies = [ [[package]] name = "burn-router" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -904,7 +904,7 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-tensor", @@ -917,7 +917,7 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bincode", "burn-common", @@ -938,7 +938,7 @@ dependencies = [ [[package]] name = "burn-tensor-testgen" -version = "0.16.0" +version = "0.17.0" dependencies = [ "proc-macro2", "quote", @@ -946,7 +946,7 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-channel", "burn-core", @@ -966,7 +966,7 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -1819,7 +1819,7 @@ dependencies = [ [[package]] name = "custom-csv-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "csv", @@ -1829,7 +1829,7 @@ dependencies = [ [[package]] name = "custom-cubecl-kernel" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-jit", @@ -1842,7 +1842,7 @@ dependencies = [ [[package]] name = "custom-image-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "flate2", @@ -1851,7 +1851,7 @@ dependencies = [ [[package]] name = "custom-renderer" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1863,7 +1863,7 @@ dependencies = [ [[package]] name = "custom-training-loop" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1875,7 +1875,7 @@ dependencies = [ [[package]] name = "custom-wgpu-kernel" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -2917,7 +2917,7 @@ dependencies = [ [[package]] name = "guide" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -3417,7 +3417,7 @@ dependencies = [ [[package]] name = "image-classification-web" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-candle", @@ -3953,7 +3953,7 @@ dependencies = [ [[package]] name = "mnist" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -3962,7 +3962,7 @@ dependencies = [ [[package]] name = "mnist-inference-web" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "console_error_panic_hook", @@ -3974,7 +3974,7 @@ dependencies = [ [[package]] name = "model" -version = "0.5.0" +version = "0.6.0" dependencies = [ "burn", "burn-import", @@ -4046,7 +4046,7 @@ dependencies = [ [[package]] name = "named-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "serde", @@ -4522,7 +4522,7 @@ dependencies = [ [[package]] name = "onnx-inference" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -4531,7 +4531,7 @@ dependencies = [ [[package]] name = "onnx-ir" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bytemuck", "half", @@ -4548,7 +4548,7 @@ dependencies = [ [[package]] name = "onnx-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -5609,7 +5609,7 @@ dependencies = [ [[package]] name = "pytorch-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -5618,7 +5618,7 @@ dependencies = [ [[package]] name = "pytorch-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-autodiff", @@ -6584,7 +6584,7 @@ dependencies = [ [[package]] name = "server" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "cfg-if", @@ -6697,7 +6697,7 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "simple-regression" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -7100,7 +7100,7 @@ dependencies = [ [[package]] name = "text-classification" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "derive-new 0.7.0", @@ -7110,7 +7110,7 @@ dependencies = [ [[package]] name = "text-generation" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "derive-new 0.7.0", @@ -8511,7 +8511,7 @@ checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" [[package]] name = "xtask" -version = "1.1.0" +version = "1.2.0" dependencies = [ "log", "rstest", diff --git a/Cargo.toml b/Cargo.toml index e79690fd28..870902a9ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ exclude = [ edition = "2021" license = "MIT OR Apache-2.0" readme = "README.md" -version = "0.16.0" +version = "0.17.0" [workspace.dependencies] atomic_float = "1" diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 9e5054c6af..265dbeaaf0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -33,7 +33,7 @@ wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] [dependencies] arboard = { workspace = true } burn = { path = "../crates/burn", default-features = false } -burn-common = { path = "../crates/burn-common", version = "0.16.0" } +burn-common = { path = "../crates/burn-common", version = "0.17.0" } clap = { workspace = true } colored = { workspace = true } diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md index 7689d25354..5f5621cc51 100644 --- a/burn-book/src/advanced/no-std.md +++ b/burn-book/src/advanced/no-std.md @@ -23,7 +23,7 @@ Some other dependencies have to be added ```toml [dependencies] embedded-alloc = "0.5.1" # Only if there is no default allocator for your chip -burn = { version = "0.16", default-features = false, features = ["ndarray"] } # Backend must be ndarray +burn = { version = "0.17", default-features = false, features = ["ndarray"] } # Backend must be ndarray [build-dependencies] burn-import = { version = "0.14" } # Used to auto generate the rust code to import the model diff --git a/burn-book/src/basic-workflow/README.md b/burn-book/src/basic-workflow/README.md index 5b32591a58..8515d73d2c 100644 --- a/burn-book/src/basic-workflow/README.md +++ b/burn-book/src/basic-workflow/README.md @@ -14,7 +14,7 @@ automatically add the missing imports as you add the code snippets to your code. Be sure to checkout the git branch corresponding to the version of Burn you are using to follow this guide. -The current version of Burn is `0.16` and the corresponding branch to checkout is `main`. +The current version of Burn is `0.17` and the corresponding branch to checkout is `main`. The code for this demo can be executed from Burn's base directory using the command: diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index adce46b297..ac4b16dbce 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -20,7 +20,7 @@ version = "0.1.0" edition = "2021" [dependencies] -burn = { version = "~0.16", features = ["train", "wgpu", "vision"] } +burn = { version = "~0.17", features = ["train", "wgpu", "vision"] } ``` Our goal will be to create a basic convolutional neural network used for image classification. We diff --git a/burn-book/src/import/onnx-model.md b/burn-book/src/import/onnx-model.md index 05b9d5de81..9b3b7917fd 100644 --- a/burn-book/src/import/onnx-model.md +++ b/burn-book/src/import/onnx-model.md @@ -74,7 +74,7 @@ First, add the `burn-import` crate to your `Cargo.toml`: ```toml [build-dependencies] -burn-import = "~0.16" +burn-import = "~0.17" ``` Then, in your `build.rs` file: diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index 2144d46885..5e221f887f 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -18,16 +18,16 @@ std = [] async = [] # Require std [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0" } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } derive-new = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-candle/Cargo.toml b/crates/burn-candle/Cargo.toml index 62af31d5fb..65fbf416ca 100644 --- a/crates/burn-candle/Cargo.toml +++ b/crates/burn-candle/Cargo.toml @@ -21,17 +21,17 @@ accelerate = ["candle-core/accelerate"] [dependencies] derive-new = { workspace = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } half = { workspace = true } candle-core = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tch = { path = "../burn-tch", version = "0.16.0", default-features = false, features = [ +burn-tch = { path = "../burn-tch", version = "0.17.0", default-features = false, features = [ ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e63af0fba5..b968e28a68 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -129,21 +129,21 @@ test-wgpu-spirv = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-dataset = { path = "../burn-dataset", version = "0.16.0", optional = true, default-features = false } -burn-derive = { path = "../burn-derive", version = "0.16.0" } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-dataset = { path = "../burn-dataset", version = "0.17.0", optional = true, default-features = false } +burn-derive = { path = "../burn-derive", version = "0.17.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } # Backends -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true } -burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true } -burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false } -burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false } -burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } -burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true } -burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } +burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false } +burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true, default-features = false } +burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true } +burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false } data-encoding = { workspace = true } uuid = { workspace = true } @@ -173,13 +173,13 @@ thiserror = { workspace = true, optional = true } portable-atomic-util = { workspace = true } [dev-dependencies] -burn-dataset = { path = "../burn-dataset", version = "0.16.0", features = [ +burn-dataset = { path = "../burn-dataset", version = "0.17.0", features = [ "fake", ] } tempfile = { workspace = true } -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0" } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index c366386b0e..1a92e695b2 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -19,9 +19,9 @@ fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] [dependencies] -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = [ "cubecl-cuda", ] } cubecl = { workspace = true, features = ["cuda"] } @@ -34,7 +34,7 @@ log = { workspace = true } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } paste = { workspace = true } diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml index 0237765973..c7ddbebc41 100644 --- a/crates/burn-dataset/Cargo.toml +++ b/crates/burn-dataset/Cargo.toml @@ -30,7 +30,7 @@ __sqlite-shared = [ dataframe = ["dep:polars"] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0", optional = true, features = [ +burn-common = { path = "../burn-common", version = "0.17.0", optional = true, features = [ "network", ] } csv = { workspace = true } diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml index eb4296097b..1f2f785940 100644 --- a/crates/burn-fusion/Cargo.toml +++ b/crates/burn-fusion/Cargo.toml @@ -17,8 +17,8 @@ std = ["serde/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0" } -burn-common = { path = "../burn-common", version = "0.16.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } +burn-common = { path = "../burn-common", version = "0.17.0" } hashbrown = { workspace = true } derive-new = {workspace = true } spin = { workspace = true } diff --git a/crates/burn-hip/Cargo.toml b/crates/burn-hip/Cargo.toml index d5f0bb70f5..206f56e8fe 100644 --- a/crates/burn-hip/Cargo.toml +++ b/crates/burn-hip/Cargo.toml @@ -20,9 +20,9 @@ std = ["burn-jit/std", "cubecl/std"] [dependencies] cubecl = { workspace = true, features = ["hip"] } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = ["cubecl-hip"] } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = ["cubecl-hip"] } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } half = { workspace = true } bytemuck = { workspace = true } @@ -31,7 +31,7 @@ log = { workspace = true } derive-new = { workspace = true } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } paste = { workspace = true } diff --git a/crates/burn-import/Cargo.toml b/crates/burn-import/Cargo.toml index 14a8bffa32..ee7c1c559e 100644 --- a/crates/burn-import/Cargo.toml +++ b/crates/burn-import/Cargo.toml @@ -20,9 +20,9 @@ onnx = [] pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"] [dependencies] -burn = { path = "../burn", version = "0.16.0", default-features = false, features = ["std"]} -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } -onnx-ir = { path = "../onnx-ir", version = "0.16.0" } +burn = { path = "../burn", version = "0.17.0", default-features = false, features = ["std"]} +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } +onnx-ir = { path = "../onnx-ir", version = "0.17.0" } candle-core = { workspace = true } derive-new = { workspace = true } half = { workspace = true } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 2bd0ba6f7e..214b21eef3 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -31,9 +31,9 @@ std = ["cubecl/std", "burn-tensor/std"] template = [] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0" } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-common = { path = "../burn-common", version = "0.17.0" } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl", "repr", ] } @@ -54,12 +54,12 @@ futures-lite = { workspace = true, features = ["std"] } serde = { workspace = true } text_placeholder = { workspace = true, features = ["struct_context"] } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } hashbrown = { workspace = true } # When exporting tests -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, optional = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true } paste = { workspace = true, optional = true } serial_test = { workspace = true, optional = true } diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 89253cd7e8..167cf88c1a 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -43,9 +43,9 @@ blas-openblas-system = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true } -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"] } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"] } atomic_float = { workspace = true } blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible @@ -62,10 +62,10 @@ spin = { workspace = true } # usi portable-atomic-util = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-no-std-tests/Cargo.toml b/crates/burn-no-std-tests/Cargo.toml index e15ce56d15..77c7524f6f 100644 --- a/crates/burn-no-std-tests/Cargo.toml +++ b/crates/burn-no-std-tests/Cargo.toml @@ -14,7 +14,7 @@ version.workspace = true # ** Please make sure all dependencies support no_std ** -burn = { path = "../burn", version = "0.16.0", default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } +burn = { path = "../burn", version = "0.17.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } serde = { workspace = true } diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index de772861c9..9ebd0c8568 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -19,9 +19,9 @@ server = ["axum", "tracing-core", "tracing-subscriber"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]} -burn-common = { path = "../burn-common", version = "0.16.0", default-features = true} -burn-router = { path = "../burn-router", version = "0.16.0", default-features = true} +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = true, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.17.0", default-features = true} +burn-router = { path = "../burn-router", version = "0.17.0", default-features = true} # Basic dependencies derive-new = {workspace = true } diff --git a/crates/burn-router/Cargo.toml b/crates/burn-router/Cargo.toml index f6df54e59f..6f21d63640 100644 --- a/crates/burn-router/Cargo.toml +++ b/crates/burn-router/Cargo.toml @@ -17,22 +17,22 @@ std = ["burn-tensor/std", "burn-common/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]} -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false} +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false} hashbrown = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false } [package.metadata.docs.rs] diff --git a/crates/burn-tch/Cargo.toml b/crates/burn-tch/Cargo.toml index 44702c21ba..69b0240c34 100644 --- a/crates/burn-tch/Cargo.toml +++ b/crates/burn-tch/Cargo.toml @@ -16,7 +16,7 @@ default = [] doc = ["tch/doc-only"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } half = { workspace = true, features = ["std"] } libc = { workspace = true } @@ -25,10 +25,10 @@ tch = { workspace = true, features = ["download-libtorch"] } log = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 55e14e174c..318912b2f7 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -30,8 +30,8 @@ std = [ ] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } cubecl = { workspace = true, optional = true, default-features = true } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 35707f5052..b922a1a59e 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -18,7 +18,7 @@ metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] [dependencies] -burn-core = { path = "../burn-core", version = "0.16.0", features = [ +burn-core = { path = "../burn-core", version = "0.17.0", features = [ "dataset", "std", ], default-features = false } @@ -40,11 +40,11 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } async-channel = { workspace = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } rstest.workspace = true [dev-dependencies] -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 055b53ae2f..c2e034ada5 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -24,15 +24,15 @@ template = ["burn-jit/template", "cubecl/template"] [dependencies] cubecl = { workspace = true, features = ["wgpu"] } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl-wgpu", ] } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } half = { workspace = true } diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 7f6af14fbb..0e7ff51e88 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -74,5 +74,5 @@ record-item-custom-serde = ["burn-core/record-item-custom-serde"] # ** Please make sure all dependencies support no_std when std is disabled ** -burn-core = { path = "../burn-core", version = "0.16.0", default-features = false } -burn-train = { path = "../burn-train", version = "0.16.0", optional = true, default-features = false } +burn-core = { path = "../burn-core", version = "0.17.0", default-features = false } +burn-train = { path = "../burn-train", version = "0.17.0", optional = true, default-features = false } diff --git a/examples/image-classification-web/Cargo.toml b/examples/image-classification-web/Cargo.toml index 44591bdad1..9429b24d25 100644 --- a/examples/image-classification-web/Cargo.toml +++ b/examples/image-classification-web/Cargo.toml @@ -14,10 +14,10 @@ default = [] half_precision = [] [dependencies] -burn = { path = "../../crates/burn", version = "0.16.0", default-features = false, features = [ +burn = { path = "../../crates/burn", version = "0.17.0", default-features = false, features = [ "ndarray", "wgpu", ] } -burn-candle = { path = "../../crates/burn-candle", version = "0.16.0", default-features = false } +burn-candle = { path = "../../crates/burn-candle", version = "0.17.0", default-features = false } log = { workspace = true } serde = { workspace = true } diff --git a/examples/pytorch-import/Cargo.toml b/examples/pytorch-import/Cargo.toml index a7b3305689..dd2b56e92d 100644 --- a/examples/pytorch-import/Cargo.toml +++ b/examples/pytorch-import/Cargo.toml @@ -4,7 +4,7 @@ edition = "2021" license = "MIT OR Apache-2.0" name = "pytorch-import" publish = false -version = "0.16.0" +version = "0.17.0" [dependencies] burn = { path = "../../crates/burn", features = [ diff --git a/examples/pytorch-import/model/Cargo.toml b/examples/pytorch-import/model/Cargo.toml index 894ac7e48f..f2678bfcbc 100644 --- a/examples/pytorch-import/model/Cargo.toml +++ b/examples/pytorch-import/model/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "model" -version = "0.5.0" +version = "0.6.0" edition = "2021" [dependencies] diff --git a/examples/raspberry-pi-pico/Cargo.lock b/examples/raspberry-pi-pico/Cargo.lock index 2cbc8fb721..a2f5e866d3 100644 --- a/examples/raspberry-pi-pico/Cargo.lock +++ b/examples/raspberry-pi-pico/Cargo.lock @@ -286,7 +286,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "burn-train", @@ -294,7 +294,7 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -305,7 +305,7 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-tensor", "candle-core", @@ -315,7 +315,7 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.16.0" +version = "0.17.0" dependencies = [ "cubecl-common", "data-encoding", @@ -326,7 +326,7 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bincode", "burn-autodiff", @@ -357,7 +357,7 @@ dependencies = [ [[package]] name = "burn-cuda" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -371,7 +371,7 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "csv", "derive-new", @@ -395,7 +395,7 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.16.0" +version = "0.17.0" dependencies = [ "derive-new", "proc-macro2", @@ -405,7 +405,7 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -418,7 +418,7 @@ dependencies = [ [[package]] name = "burn-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "candle-core", @@ -441,7 +441,7 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-fusion", @@ -461,7 +461,7 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -478,7 +478,7 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-tensor", "half", @@ -489,7 +489,7 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "bytemuck", @@ -507,7 +507,7 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "crossterm", @@ -525,7 +525,7 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -2959,7 +2959,7 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "onnx-ir" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bytemuck", "half", diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 5d06497e08..bb4824fba9 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -15,4 +15,4 @@ ndarray = ["burn/ndarray"] [dependencies] cfg-if = { workspace = true } -burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] } +burn = { path = "../../crates/burn", version = "0.17.0", features = ["server"] } diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index ce796eb7b1..63ac5e4c70 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xtask" -version = "1.1.0" +version = "1.2.0" edition = "2021" license = "MIT OR Apache-2.0" From ad81344821bbdca2fde1b5debda4d83d96370876 Mon Sep 17 00:00:00 2001 From: tiruka <33803972+tiruka@users.noreply.github.com> Date: Thu, 16 Jan 2025 01:44:50 +0900 Subject: [PATCH 11/17] Feature add new one hot function meeting multi-dimensions (ranks) (#2613) * add one hot with axis and values function * update one hot multidimentional function * implementing on numeric.rs * update one hot method in numeric * update one hot function to deal with additional dims add one hot test * added tests for one hot * modify function name modify format add tests * modify to respond to difference between Tensor type and values type * fix clippy point out and doc test * do refactoring modify comments * update burn book to publish one hot plus method * modify one_hot_plus to one_hot_fill and args names * modify one_hot function in int impl and float impl modify one_hot tests * modify numeric to clear logic * modify miscs due to validation, linnter and formatter * modify documents for tensor api * modify codes to follow review comments * modify codes to follow reviews * modify tests to follow reviews comments * Improve check message --------- Co-authored-by: Guillaume Lagrange --- burn-book/src/building-blocks/tensor.md | 4 +- crates/burn-tensor/src/tensor/api/check.rs | 34 +++--- crates/burn-tensor/src/tensor/api/float.rs | 34 +----- crates/burn-tensor/src/tensor/api/int.rs | 30 ----- crates/burn-tensor/src/tensor/api/numeric.rs | 97 ++++++++++++++++ crates/burn-tensor/src/tests/ops/one_hot.rs | 112 +++++++++++++------ 6 files changed, 193 insertions(+), 118 deletions(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index fb429ffd0f..8a7c01bbc9 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -228,6 +228,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | | `tensor.ones_like()` | `torch.ones_like(tensor)` | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | +| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A | | `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | @@ -258,7 +260,6 @@ Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | --------------------------------------------- | ---------------------------------- | -| `Tensor::one_hot(index, num_classes, device)` | N/A | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | | `tensor.cos()` | `tensor.cos()` | @@ -296,7 +297,6 @@ Those operations are only available for `Int` tensors. | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | -| `tensor.one_hot(num_classes)` | N/A | ### Bool Operations diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d4ab13faf4..8a6fb2ad78 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, Int, Shape, Tensor}; +use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -447,22 +447,8 @@ impl TensorCheck { check } - pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { - let mut check = Self::Ok; - if index >= num_classes { - check = check.register( - "One Hot", - TensorError::new(format!( - "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", - )), - ); - } - - check - } - - pub(crate) fn one_hot_tensor( - index_tensor: Tensor, + pub(crate) fn one_hot_tensor>( + index_tensor: Tensor, num_classes: usize, ) -> Self { let mut check = Self::Ok; @@ -487,6 +473,20 @@ impl TensorCheck { check } + pub(crate) fn one_hot_tensor_rank() -> Self { + let mut check = Self::Ok; + if D + 1 != D2 { + check = check.register( + "One Hot", + TensorError::new( + "The one-hot tensor rank must correspond to the rank of the tensor + 1", + ) + .details(format!("Expected D2={}, got {D2}", D + 1)), + ); + } + check + } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a6f59f6e88..b50d0d0596 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,11 +1,8 @@ -use alloc::vec::Vec; -use core::convert::TryInto; - use crate::check::TensorCheck; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, Shape, TensorData}; +use crate::tensor::{Distribution, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; @@ -174,35 +171,6 @@ where ))) } - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let device = Default::default(); - /// let one_hot = Tensor::::one_hot(2, 10, &device); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { - check!(TensorCheck::one_hot_index(index, num_classes)); - - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape, device); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) - } - /// Applies the matrix multiplication operation. /// /// `C = AB` diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 08bdab0fe7..e882a107c7 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,5 +1,3 @@ -use crate::check; -use crate::check::TensorCheck; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -29,34 +27,6 @@ where pub fn arange_step(range: Range, step: usize, device: &B::Device) -> Self { Tensor::new(B::int_arange_step(range, step, device)) } - - /// Create a one hot tensor from an index tensor. - /// - /// # Arguments - /// - /// * `num_classes` - The number of classes to use in encoding. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); - /// let one_hot = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - let [num_samples] = self.dims(); - let indices = self.unsqueeze_dim(1); - let values = indices.ones_like(); - Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values) - } } impl Tensor diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 59dc44b7e6..b82175c3fe 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2034,6 +2034,103 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, 1.0, 0.0, -1) + } + + /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. + /// + /// # Arguments + /// + /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `on_value`: The value to assign for active positions (corresponding to indices). + /// * `off_value`: The value to assign for inactive positions. + /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. + /// + /// # Returns + /// + /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Float}; + /// fn example>>() { + /// let device = B::Device::default(); + /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); + /// // One-hot encoding + /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); + /// println!("{tensor}"); + /// // [[[5.0, 0.0, 0.0], + /// // [0.0, 0.0, 5.0]], + /// // [[0.0, 5.0, 0.0], + /// // [0.0, 0.0, 5.0]]] + /// } + /// ``` + pub fn one_hot_fill( + self, + num_classes: usize, + on_value: f32, + off_value: f32, + axis: i64, + ) -> Tensor { + check!(TensorCheck::one_hot_tensor_rank::()); + // Initialize shape from the current tensor dimensions and prepare for modification + let mut shape = self.shape().dims::().to_vec(); + let device = self.device(); + let rank = self.dims().len(); + + // Adjust negative axis to a positive index + let axis = if axis < 0 { + axis + rank as i64 + 1 + } else { + axis + }; + + // Ensure axis is within valid range + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + // Convert the input tensor to integer indices + let indices: Tensor = + Tensor::from_data(self.to_data().convert::(), &device); + // Insert the new dimension for the one-hot representation + shape.insert(axis as usize, num_classes); + // Adjust indices to valid range and handle invalid indices + let adjusted_indices = indices + .clone() + .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices + // Unsqueeze the indices tensor along the specified axis + let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); + + // Initialize the output tensor with the off_value + let output = Tensor::full(shape.clone(), off_value, &device); + + // Prepare scatter tensor for on_value and off_value adjustments + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + + // Scatter on_value at the appropriate indices to create the one-hot representation + output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) + } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 310399119f..24e8f24b38 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -1,74 +1,114 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { use super::*; - use burn_tensor::{Int, TensorData}; + use burn_tensor::{ + as_type, + backend::Backend, + tests::{Float as _, Int as _}, + Float, Int, Numeric, Shape, Tensor, TensorData, + }; #[test] fn float_should_support_one_hot() { - let device = Default::default(); - - let tensor = TestTensor::<1>::one_hot(0, 5, &device); - let expected = TensorData::from([1., 0., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(1, 5, &device); - let expected = TensorData::from([0., 1., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(4, 5, &device); - let expected = TensorData::from([0., 0., 0., 0., 1.]); - tensor.into_data().assert_eq(&expected, false); + let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]); + one_hot_tensor.into_data().assert_eq(&expected, false); + } - let tensor = TestTensor::<1>::one_hot(1, 2, &device); - let expected = TensorData::from([0., 1.]); - tensor.into_data().assert_eq(&expected, false); + #[test] + fn float_should_support_one_hot_index() { + let tensor = TestTensor::<1>::from([2.0]); + let one_hot_tensor: Tensor = tensor.one_hot::<2>(10); + let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]); + one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(1, 1, &device); + let tensor = TestTensor::<1>::from([5.0]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn float_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(0, 0, &device); + let tensor = TestTensor::<1>::from([0.0]); + let result: Tensor = tensor.one_hot(0); } #[test] fn int_should_support_one_hot() { - let device = Default::default(); - - let index_tensor = TestTensorInt::<1>::arange(0..5, &device); - let one_hot_tensor = index_tensor.one_hot(5); - let expected = TestTensorInt::eye(5, &device).into_data(); + let tensor = TestTensorInt::<1>::from([0, 1, 4]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..6, &device); - let one_hot_tensor = index_tensor.one_hot(5); + let tensor = TestTensorInt::<1>::from([5]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(0); + let tensor = TestTensorInt::<1>::from([2]); + let result: Tensor = tensor.one_hot(0); + } + + #[test] + fn one_hot_fill_with_positive_axis_and_indices() { + let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); + let expected = TensorData::from(as_type!(IntType: [ + [[1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3]], + [[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + + #[test] + fn one_hot_fill_with_negative_axis_and_indices() { + let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + let expected = TensorData::from(as_type!(FloatType: [ + [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], + [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); + + one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] + fn one_hot_fill_with_negative_indices() { + let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); + let expected = TensorData::from(as_type!(FloatType: [ + [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + #[should_panic] - fn int_one_hot_should_panic_when_number_of_classes_is_1() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(1); + #[test] + fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(2, 5.0, 0.0, 3); } } From f630b3bc7d2d7fac0b972ea001c33daa7c32dd22 Mon Sep 17 00:00:00 2001 From: jiawen wang Date: Thu, 16 Jan 2025 00:45:20 +0800 Subject: [PATCH 12/17] Wasserstein Generative Adversarial Network (#2660) * Add files via upload Wasserstein Generative Adversarial Network * Delete examples/wgan/readme * Create README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update cli.rs * Update cli.rs * Update model.rs * Update training.rs * Update main.rs * Update model.rs * Update training.rs * Update training.rs * Update main.rs * Update training.rs * Update model.rs * Update training.rs * Update cli.rs * Update cli.rs * Update generating.rs * Update lib.rs * Update model.rs * Update training.rs * Update main.rs * Update generating.rs * Update model.rs * Update training.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update dataset.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update training.rs * Restructure as workspace example * Add support for single range slice (fixes clippy) * Update example usage + list --------- Co-authored-by: Guillaume Lagrange --- Cargo.lock | 8 + README.md | 2 + burn-book/src/examples.md | 1 + crates/burn-tensor/src/tensor/api/base.rs | 8 + crates/burn-tensor/src/tests/ops/slice.rs | 11 ++ examples/wgan/Cargo.toml | 18 ++ examples/wgan/README.md | 40 ++++ examples/wgan/examples/wgan-generate.rs | 95 ++++++++++ examples/wgan/examples/wgan-mnist.rs | 107 +++++++++++ examples/wgan/src/dataset.rs | 49 +++++ examples/wgan/src/infer.rs | 41 +++++ examples/wgan/src/lib.rs | 4 + examples/wgan/src/model.rs | 157 ++++++++++++++++ examples/wgan/src/training.rs | 211 ++++++++++++++++++++++ 14 files changed, 752 insertions(+) create mode 100644 examples/wgan/Cargo.toml create mode 100644 examples/wgan/README.md create mode 100644 examples/wgan/examples/wgan-generate.rs create mode 100644 examples/wgan/examples/wgan-mnist.rs create mode 100644 examples/wgan/src/dataset.rs create mode 100644 examples/wgan/src/infer.rs create mode 100644 examples/wgan/src/lib.rs create mode 100644 examples/wgan/src/model.rs create mode 100644 examples/wgan/src/training.rs diff --git a/Cargo.lock b/Cargo.lock index c34fb9cd03..1af3585919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8002,6 +8002,14 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "wgan" +version = "0.1.0" +dependencies = [ + "burn", + "image", +] + [[package]] name = "wgpu" version = "23.0.1" diff --git a/README.md b/README.md index a0780dcc16..d0ccbcf411 100644 --- a/README.md +++ b/README.md @@ -567,6 +567,8 @@ Additional examples: sample. - [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the DbPedia dataset. +- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits + based on MNIST. For more practical insights, you can clone the repository and run any of them directly on your computer! diff --git a/burn-book/src/examples.md b/burn-book/src/examples.md index c9703a4389..2b083b6fbe 100644 --- a/burn-book/src/examples.md +++ b/burn-book/src/examples.md @@ -85,6 +85,7 @@ The following additional examples are currently available if you want to check t | [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. | | [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. | | [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. | +| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. | For more information on each example, see their respective `README.md` file. Be sure to check out the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index fabf321d96..4bbc522f49 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -805,6 +805,7 @@ where /// # Arguments /// /// * `ranges` - A type implementing the `RangesArg` trait, which can be: + /// - A single `core::ops::Range` (slice the first dimension) /// - An array of `core::ops::Range` /// - An array of `Option<(i64, i64)>` /// - An array of `(i64, i64)` tuples @@ -2988,6 +2989,13 @@ impl RangesArg for [(i64, i64); D2] { } } +impl RangesArg<1> for core::ops::Range { + fn into_ranges(self, shape: Shape) -> [core::ops::Range; 1] { + let (start, end) = Self::clamp_range(self.start, self.end, shape.dims[0]); + [(start..end)] + } +} + /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. diff --git a/crates/burn-tensor/src/tests/ops/slice.rs b/crates/burn-tensor/src/tests/ops/slice.rs index 61725a506a..1be5b76315 100644 --- a/crates/burn-tensor/src/tests/ops/slice.rs +++ b/crates/burn-tensor/src/tests/ops/slice.rs @@ -47,6 +47,17 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_slice_range_first_dim() { + let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.slice(0..1); + let expected = TensorData::from([[0.0, 1.0, 2.0]]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_partial_sliceing_3d() { let tensor = TestTensor::<3>::from_floats( diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml new file mode 100644 index 0000000000..48d5680f51 --- /dev/null +++ b/examples/wgan/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "wgan" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda-jit = ["burn/cuda-jit"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train", "vision"] } +image = { workspace = true } diff --git a/examples/wgan/README.md b/examples/wgan/README.md new file mode 100644 index 0000000000..d7252ba520 --- /dev/null +++ b/examples/wgan/README.md @@ -0,0 +1,40 @@ +# Wasserstein Generative Adversarial Network + +A burn implementation of examplar WGAN model to generate MNIST digits inspired by +[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html). +Please note that better performance maybe gained by adopting a convolution layer in +[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch). + +## Usage + + +## Training + +```sh +# Cuda backend +cargo run --example wgan-mnist --release --features cuda-jit + +# Wgpu backend +cargo run --example wgan-mnist --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example wgan-mnist --release --features tch-gpu + +# Tch CPU backend +cargo run --example wgan-mnist --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example wgan-mnist --release --features ndarray # f32 - single thread +cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas +cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib +``` + + +### Generating + +To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. + +```sh +cargo run --example wgan-generate --release --features cuda-jit +``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs new file mode 100644 index 0000000000..fa66623ca3 --- /dev/null +++ b/examples/wgan/examples/wgan-generate.rs @@ -0,0 +1,95 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + wgan::infer::generate::("/tmp/wgan-mnist", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::launch; + use burn::backend::{Autodiff, CudaJit}; + + pub fn run() { + launch::>(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); +} diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs new file mode 100644 index 0000000000..d964b07844 --- /dev/null +++ b/examples/wgan/examples/wgan-mnist.rs @@ -0,0 +1,107 @@ +use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend}; + +use wgan::{model::ModelConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + ModelConfig::new(), + RmsPropConfig::new() + .with_alpha(0.99) + .with_momentum(0.0) + .with_epsilon(0.00000008) + .with_weight_decay(None) + .with_centered(false), + ); + + wgan::training::train::("/tmp/wgan-mnist", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::launch; + use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); +} diff --git a/examples/wgan/src/dataset.rs b/examples/wgan/src/dataset.rs new file mode 100644 index 0000000000..46848d4ffb --- /dev/null +++ b/examples/wgan/src/dataset.rs @@ -0,0 +1,49 @@ +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, +}; + +#[derive(Clone, Debug)] +pub struct MnistBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct MnistBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl MnistBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for MnistBatcher { + fn batch(&self, items: Vec) -> MnistBatch { + let images = items + .iter() + .map(|item| TensorData::from(item.image)) + .map(|data| Tensor::::from_data(data.convert::(), &self.device)) + .map(|tensor| tensor.reshape([1, 28, 28])) + // Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example + .map(|tensor| ((tensor / 255) - 0.5) / 0.5) + .collect(); + + let targets = items + .iter() + .map(|item| { + Tensor::::from_data( + TensorData::from([(item.label as i64).elem::()]), + &self.device, + ) + }) + .collect(); + + let images = Tensor::stack(images, 0); + let targets = Tensor::cat(targets, 0); + + MnistBatch { images, targets } + } +} diff --git a/examples/wgan/src/infer.rs b/examples/wgan/src/infer.rs new file mode 100644 index 0000000000..25ca984feb --- /dev/null +++ b/examples/wgan/src/infer.rs @@ -0,0 +1,41 @@ +use crate::training::{save_image, TrainingConfig}; +use burn::{ + prelude::*, + record::{CompactRecorder, Recorder}, + tensor::Distribution, +}; + +pub fn generate(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/generator").into(), &device) + .expect("Trained model should exist; run train first"); + let (mut generator, _) = config.model.init::(&device); + generator = generator.load_record(record); + + // Get a batch of noise + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + let fake_images = generator.forward(noise); // [batch_size, channesl*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + save_image::(fake_images, 5, format!("{artifact_dir}/fake_image.png")).unwrap(); +} diff --git a/examples/wgan/src/lib.rs b/examples/wgan/src/lib.rs new file mode 100644 index 0000000000..021f62278a --- /dev/null +++ b/examples/wgan/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod infer; +pub mod model; +pub mod training; diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs new file mode 100644 index 0000000000..ddb84ff6d3 --- /dev/null +++ b/examples/wgan/src/model.rs @@ -0,0 +1,157 @@ +use burn::{ + module::{Module, ModuleMapper, ParamId}, + nn::BatchNorm, + prelude::*, + tensor::backend::AutodiffBackend, +}; + +/// Layer block of generator model +#[derive(Module, Debug)] +pub struct LayerBlock { + fc: nn::Linear, + bn: nn::BatchNorm, + leakyrelu: nn::LeakyRelu, +} + +impl LayerBlock { + pub fn new(input: usize, output: usize, device: &B::Device) -> Self { + let fc = nn::LinearConfig::new(input, output) + .with_bias(true) + .init(device); + let bn: BatchNorm = nn::BatchNormConfig::new(output) + .with_epsilon(0.8) + .init(device); + let leakyrelu = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + + Self { fc, bn, leakyrelu } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.fc.forward(input); // output: [Batch, x] + let output = self.bn.forward(output); // output: [Batch, x] + + self.leakyrelu.forward(output) // output: [Batch, x] + } +} + +/// Generator model +#[derive(Module, Debug)] +pub struct Generator { + layer1: LayerBlock, + layer2: LayerBlock, + layer3: LayerBlock, + layer4: LayerBlock, + fc: nn::Linear, + tanh: nn::Tanh, +} + +impl Generator { + /// Applies the forward pass on the input tensor by specified order + pub fn forward(&self, noise: Tensor) -> Tensor { + let output = self.layer1.forward(noise); + let output = self.layer2.forward(output); + let output = self.layer3.forward(output); + let output = self.layer4.forward(output); + let output = self.fc.forward(output); + + self.tanh.forward(output) // [batch_size, channels*height*width] + } +} + +/// Discriminator model +#[derive(Module, Debug)] +pub struct Discriminator { + fc1: nn::Linear, + leakyrelu1: nn::LeakyRelu, + fc2: nn::Linear, + leakyrelu2: nn::LeakyRelu, + fc3: nn::Linear, +} + +impl Discriminator { + /// Applies the forward pass on the input tensor by specified order. + /// The input image shape is [batch, channels, height, width] + pub fn forward(&self, images: Tensor) -> Tensor { + // Full connection for each batch + let output = images.flatten(1, 3); // output: [batch, channels*height*width] + let output = self.fc1.forward(output); // output: [batch, 512] + let output = self.leakyrelu1.forward(output); // output: [batch, 512] + let output = self.fc2.forward(output); // output: [batch, 256] + let output = self.leakyrelu2.forward(output); // output: [batch, 256] + + self.fc3.forward(output) // output: [batch, 1] + } +} + +// Use model config to construct a generative and adverserial model +#[derive(Config, Debug)] +pub struct ModelConfig { + /// Dimensionality of the latent space + #[config(default = 100)] + pub latent_dim: usize, + #[config(default = 28)] + pub image_size: usize, + #[config(default = 1)] + pub channels: usize, +} + +impl ModelConfig { + /// "init" is used to create other objects, while "new" is usally used to create itself. + pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { + // Construct the initialized generator + let layer1 = LayerBlock::new(self.latent_dim, 128, device); + let layer2 = LayerBlock::new(128, 256, device); + let layer3 = LayerBlock::new(256, 512, device); + let layer4 = LayerBlock::new(512, 1024, device); + let fc = nn::LinearConfig::new(1024, self.channels * self.image_size * self.image_size) + .with_bias(true) + .init(device); + + let generator = Generator { + layer1, + layer2, + layer3, + layer4, + fc, + tanh: nn::Tanh::new(), + }; + + // Construct the initialized discriminator + let fc1 = nn::LinearConfig::new(self.channels * self.image_size * self.image_size, 512) + .init(device); + let leakyrelu1 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc2 = nn::LinearConfig::new(512, 256).init(device); + let leakyrelu2 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc3 = nn::LinearConfig::new(256, 1).init(device); + + let discriminator = Discriminator { + fc1, + leakyrelu1, + fc2, + leakyrelu2, + fc3, + }; + + (generator, discriminator) + } +} + +/// Clip module mapper to clip all module parameters between a range of values +#[derive(Module, Clone, Debug)] +pub struct Clip { + pub min: f32, + pub max: f32, +} + +impl ModuleMapper for Clip { + fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + let is_require_grad = tensor.is_require_grad(); + + let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); + + if is_require_grad { + tensor = tensor.require_grad(); + } + tensor + } +} diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs new file mode 100644 index 0000000000..db1f594b46 --- /dev/null +++ b/examples/wgan/src/training.rs @@ -0,0 +1,211 @@ +use crate::dataset::MnistBatcher; +use crate::model::{Clip, ModelConfig}; +use burn::optim::{GradientsParams, Optimizer, RmsPropConfig}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, + prelude::*, + record::CompactRecorder, + tensor::{backend::AutodiffBackend, Distribution}, +}; +use image::{buffer::ConvertBuffer, error::ImageResult, Rgb32FImage, RgbImage}; +use std::path::Path; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: RmsPropConfig, + + #[config(default = 200)] + pub num_epochs: usize, + #[config(default = 512)] + pub batch_size: usize, + #[config(default = 8)] + pub num_workers: usize, + #[config(default = 5)] + pub seed: u64, + #[config(default = 5e-5)] + pub lr: f64, + + /// Number of training steps for discriminator before generator is trained per iteration + #[config(default = 5)] + pub num_critic: usize, + /// Lower and upper clip value for disc. weights + #[config(default = 0.01)] + pub clip_value: f32, + /// Save a sample of images every `sample_interval` epochs + #[config(default = 10)] + pub sample_interval: usize, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +/// Save the generated images +// The images format is [B, H, W, C] +pub fn save_image>( + images: Tensor, + nrow: u32, + path: Q, +) -> ImageResult<()> { + let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32; + + let width = images.dims()[2] as u32; + let height = images.dims()[1] as u32; + + // Supports both 1 and 3 channels image + let channels = match images.dims()[3] { + 1 => 3, + 3 => 1, + _ => panic!("Wrong channels number"), + }; + + let mut imgbuf = RgbImage::new(nrow * width, ncol * height); + // Write images into a nrow*ncol grid layout + for row in 0..nrow { + for col in 0..ncol { + let image: Tensor = images + .clone() + .slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize) + .squeeze(0); + // The Rgb32 should be in range 0.0-1.0 + let image = image.into_data().iter::().collect::>(); + // Supports both 1 and 3 channels image + let image = image + .into_iter() + .flat_map(|n| std::iter::repeat(n).take(channels)) + .collect(); + + let image = Rgb32FImage::from_vec(width, height, image).unwrap(); + let image: RgbImage = image.convert(); + for (x, y, pixel) in image.enumerate_pixels() { + imgbuf.put_pixel(row * width + x, col * height + y, *pixel); + } + } + } + imgbuf.save(path) +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Create the Clip module mapper + let mut clip = Clip { + min: -config.clip_value, + max: config.clip_value, + }; + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(config.seed); + + // Create the model and optimizer + let (mut generator, mut discriminator) = config.model.init::(&device); + let mut optimizer_g = config.optimizer.init(); + let mut optimizer_d = config.optimizer.init(); + + // Create the dataset batcher + let batcher_train = MnistBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::train()); + + // Iterate over our training for X epochs + for epoch in 0..config.num_epochs { + // Implement our training loop + for (iteration, batch) in dataloader_train.iter().enumerate() { + // Generate a batch of fake images from noise (standarded normal distribution) + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + // datach: do not update gerenator, only discriminator is updated + let fake_images = generator.forward(noise.clone()).detach(); // [batch_size, channels*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss + let loss_d = -discriminator.forward(batch.images).mean() + + discriminator.forward(fake_images.clone()).mean(); + + // Gradients for the current backward pass + let grads = loss_d.backward(); + // Gradients linked to each parameter of the discriminator + let grads = GradientsParams::from_grads(grads, &discriminator); + // Update the discriminator using the optimizer + discriminator = optimizer_d.step(config.lr, discriminator, grads); + // Clip parameters (weights) of discriminator + discriminator = discriminator.map(&mut clip); + + // Train the generator every num_critic iterations + if iteration % config.num_critic == 0 { + // Generate a batch of images again without detaching + let critic_fake_images = generator.forward(noise.clone()); + let critic_fake_images = critic_fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss. Minimize it to make the fake images as truth + let loss_g = -discriminator.forward(critic_fake_images).mean(); + + let grads = loss_g.backward(); + let grads = GradientsParams::from_grads(grads, &generator); + generator = optimizer_g.step(config.lr, generator, grads); + + // Print the progression + let batch_num = (dataloader_train.num_items() as f32 / config.batch_size as f32) + .ceil() as usize; + println!( + "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]", + epoch + 1, + config.num_epochs, + iteration, + batch_num, + loss_d.into_scalar(), + loss_g.into_scalar() + ); + } + // If at save interval => save the first 25 generated images + if epoch % config.sample_interval == 0 && iteration == 0 { + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() + - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5/255.0 to the images, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + let path = format!("{artifact_dir}/image-{}.png", epoch); + save_image::(fake_images, 5, path).unwrap(); + } + } + } + + // Save the trained models + generator + .save_file(format!("{artifact_dir}/generator"), &CompactRecorder::new()) + .expect("Generator should be saved successfully"); + discriminator + .save_file( + format!("{artifact_dir}/discriminator"), + &CompactRecorder::new(), + ) + .expect("Discriminator should be saved successfully"); +} From 93f8bad67198a0d9cf576c412b39c1864ed040a0 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 09:30:32 -0500 Subject: [PATCH 13/17] Remove deprecated Data and DataSerialize (#2703) --- README.md | 35 +- crates/burn-core/Cargo.toml | 3 - crates/burn-core/src/record/primitive.rs | 21 +- crates/burn-core/src/record/tensor.rs | 50 +-- crates/burn-tensor/src/lib.rs | 4 +- crates/burn-tensor/src/tensor/data.rs | 400 +---------------------- crates/burn/Cargo.toml | 1 - 7 files changed, 37 insertions(+), 477 deletions(-) diff --git a/README.md b/README.md index d0ccbcf411..951b2a9f24 100644 --- a/README.md +++ b/README.md @@ -621,19 +621,20 @@ leads to more reliable, bug-free solutions built faster (after some practice
> **Deprecation Note**
Since `0.14.0`, the internal structure for tensor data has changed. The -> previous `Data` struct is being deprecated in favor of the new `TensorData` struct, which allows -> for more flexibility by storing the underlying data as bytes and keeping the data type as a field. -> If you are using `Data` in your code, make sure to switch to `TensorData`. +> previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new +> `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and +> keeping the data type as a field. If you are using `Data` in your code, make sure to switch to +> `TensorData`.
@@ -642,8 +643,9 @@ Loading Model Records From Previous Versions ⚠️
-In the event that you are trying to load a model record saved in a previous version, make sure to -enable the `record-backward-compat` feature flag. +In the event that you are trying to load a model record saved in a version older than `0.14.0`, make +sure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat` +feature flag. ``` features = [..., "record-backward-compat"] @@ -652,13 +654,14 @@ features = [..., "record-backward-compat"] Otherwise, the record won't be deserialized correctly and you will get an error message. This error will also point you to the backward compatible feature flag. -The backward compatibility is maintained for deserialization when loading records. Therefore, as -soon as you have saved the record again it will be saved according to the new structure and you -won't need the backward compatible feature flag anymore. +The backward compatibility was maintained for deserialization when loading records. Therefore, as +soon as you have saved the record again it will be saved according to the new structure and you can +upgrade back to the current version Please note that binary formats are not backward compatible. Thus, you will need to load your record in a previous version and save it in any of the other self-describing record format (e.g., using the -`NamedMpkFileRecorder`) before using the new version with the `record-backward-compat` feature flag. +`NamedMpkFileRecorder`) before using a compatible version (as described) with the +`record-backward-compat` feature flag.
diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index b968e28a68..e895cc4572 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -113,9 +113,6 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -# Backwards compatibility with previous serialized data format. -record-backward-compat = [] - test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. diff --git a/crates/burn-core/src/record/primitive.rs b/crates/burn-core/src/record/primitive.rs index 9dd921e824..2f9fa3e83c 100644 --- a/crates/burn-core/src/record/primitive.rs +++ b/crates/burn-core/src/record/primitive.rs @@ -5,9 +5,7 @@ use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde}; use super::{PrecisionSettings, Record}; use crate::module::{Param, ParamId}; -#[allow(deprecated)] -use burn_tensor::DataSerialize; -use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor}; +use burn_tensor::{backend::Backend, Bool, Int, Tensor}; use hashbrown::HashMap; use serde::{ @@ -143,23 +141,6 @@ where } } -#[allow(deprecated)] -impl Record for DataSerialize -where - E: Element, - B: Backend, -{ - type Item = DataSerialize; - - fn into_item(self) -> Self::Item { - self.convert() - } - - fn from_item(item: Self::Item, _device: &B::Device) -> Self { - item.convert() - } -} - /// (De)serialize parameters into a clean format. #[derive(new, Debug, Clone, Serialize, Deserialize)] pub struct ParamSerde { diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs index ab6f448b7e..a07453bcba 100644 --- a/crates/burn-core/src/record/tensor.rs +++ b/crates/burn-core/src/record/tensor.rs @@ -4,20 +4,7 @@ use super::{PrecisionSettings, Record}; use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData}; use serde::{Deserialize, Serialize}; -#[cfg(not(feature = "record-backward-compat"))] use alloc::format; -#[cfg(feature = "record-backward-compat")] -use burn_tensor::DataSerialize; - -/// Versioned serde data deserialization to maintain backward compatibility between formats. -#[cfg(feature = "record-backward-compat")] -#[allow(deprecated)] -#[derive(Serialize, Deserialize)] -#[serde(untagged)] -enum TensorDataSerde { - V1(DataSerialize), - V2(TensorData), -} /// Deserialize the value into [`TensorData`]. fn deserialize_data<'de, E, De>(deserializer: De) -> Result @@ -25,31 +12,18 @@ where E: Element + Deserialize<'de>, De: serde::Deserializer<'de>, { - #[cfg(feature = "record-backward-compat")] - { - let data = match TensorDataSerde::::deserialize(deserializer)? { - TensorDataSerde::V1(data) => data.into_tensor_data(), - // NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16 - TensorDataSerde::V2(data) => data.convert::(), - }; - Ok(data) - } - - #[cfg(not(feature = "record-backward-compat"))] - { - let data = TensorData::deserialize(deserializer).map_err(|e| { - serde::de::Error::custom(format!( - "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag. Once you have saved the record in the new format, you can disable the feature flag.\n", - e - )) - })?; - let data = if let DType::QFloat(_) = data.dtype { - data // do not convert quantized tensors - } else { - data.convert::() - }; - Ok(data) - } + let data = TensorData::deserialize(deserializer).map_err(|e| { + serde::de::Error::custom(format!( + "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n", + e + )) + })?; + let data = if let DType::QFloat(_) = data.dtype { + data // do not convert quantized tensors + } else { + data.convert::() + }; + Ok(data) } /// This struct implements serde to lazily serialize and deserialize a float tensor diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index d3cb280e90..0376da57a2 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -1,8 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -// Allow deprecated `Data` and `DataSerialize` -#![allow(deprecated)] //! This library provides multiple tensor implementations hidden behind an easy to use API //! that supports reverse mode automatic differentiation. @@ -59,6 +57,8 @@ mod cube_wgpu { use crate::backend::{DeviceId, DeviceOps}; use cubecl::wgpu::WgpuDevice; + // Allow deprecated `WgpuDevice::BestAvailable` + #[allow(deprecated)] impl DeviceOps for WgpuDevice { fn id(&self) -> DeviceId { match self { diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 5fa6f765fc..bd144e397f 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -1,7 +1,4 @@ -use core::{ - any::{Any, TypeId}, - f32, -}; +use core::f32; use alloc::boxed::Box; use alloc::format; @@ -14,7 +11,7 @@ use crate::{ quantization::{ Quantization, QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes, }, - tensor::{bytes::Bytes, Shape}, + tensor::bytes::Bytes, DType, Distribution, Element, ElementConversion, }; @@ -777,396 +774,6 @@ impl core::fmt::Display for TensorData { } } -/// Data structure for serializing and deserializing tensor data. -#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)] -#[deprecated( - since = "0.14.0", - note = "the internal data format has changed, please use `TensorData` instead" -)] -pub struct DataSerialize { - /// The values of the tensor. - pub value: Vec, - /// The shape of the tensor. - pub shape: Vec, -} - -/// Data structure for tensors. -#[derive(new, Debug, Clone, PartialEq, Eq)] -#[deprecated( - since = "0.14.0", - note = "the internal data format has changed, please use `TensorData` instead" -)] -pub struct Data { - /// The values of the tensor. - pub value: Vec, - - /// The shape of the tensor. - pub shape: Shape, -} - -#[allow(deprecated)] -impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - - Data { - value, - shape: self.shape, - } - } - - /// Asserts each value is within a given range. - /// - /// # Arguments - /// - /// * `range` - The range. - /// - /// # Panics - /// - /// If any value is not within the half-open range bounded inclusively below - /// and exclusively above (`start..end`). - pub fn assert_within_range(&self, range: core::ops::Range) { - let start = range.start.elem::(); - let end = range.end.elem::(); - - for elem in self.value.iter() { - let elem = elem.elem::(); - if elem < start || elem >= end { - panic!("Element ({elem:?}) is not within range {range:?}"); - } - } - } -} - -#[allow(deprecated)] -impl DataSerialize { - /// Converts the data to a different element type. - pub fn convert(self) -> DataSerialize { - if TypeId::of::() == TypeId::of::() { - let cast: Box = Box::new(self); - let cast: Box> = cast.downcast().unwrap(); - return *cast; - } - - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - - DataSerialize { - value, - shape: self.shape, - } - } - - /// Converts the data to the new [TensorData] format. - pub fn into_tensor_data(self) -> TensorData { - TensorData::new(self.value, self.shape) - } -} - -#[allow(deprecated)] -impl Data { - /// Populates the data with random values. - pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(E::random(distribution, rng)); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with zeros. - pub fn zeros>(shape: S) -> Data { - let shape = shape.into(); - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(0.elem()); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with ones. - pub fn ones(shape: Shape) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(1.elem()); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with the given value - pub fn full(shape: Shape, fill_value: E) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(fill_value) - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data { - /// Serializes the data. - /// - /// # Returns - /// - /// The serialized data. - pub fn serialize(&self) -> DataSerialize { - DataSerialize { - value: self.value.clone(), - shape: self.shape.dims.to_vec(), - } - } -} - -#[allow(deprecated)] -impl + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data { - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `precision` - The precision of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq(&self, other: &Self, precision: usize) { - let tolerance = 0.1.pow(precision as f64); - - self.assert_approx_eq_diff(other, tolerance) - } - - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `tolerance` - The tolerance of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { - let mut message = String::new(); - if self.shape != other.shape { - message += format!( - "\n => Shape is different: {:?} != {:?}", - self.shape.dims, other.shape.dims - ) - .as_str(); - } - - let iter = self.value.clone().into_iter().zip(other.value.clone()); - - let mut num_diff = 0; - let max_num_diff = 5; - - for (i, (a, b)) in iter.enumerate() { - let a: f64 = a.into(); - let b: f64 = b.into(); - - //if they are both nan, then they are equally nan - let both_nan = a.is_nan() && b.is_nan(); - //this works for both infinities - let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.)); - - if both_nan || both_inf { - continue; - } - - let err = (a - b).abs(); - - if E::dtype().is_float() { - if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ - {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } - } else if err > tolerance || err.is_nan() { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ - {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } - } - - if num_diff >= max_num_diff { - message += format!("\n{} more errors...", num_diff - 5).as_str(); - } - - if !message.is_empty() { - panic!("Tensors are not approx eq:{}", message); - } - } -} - -#[allow(deprecated)] -impl Data { - /// Converts the usize data to a different element type. - pub fn from_usize(self) -> Data { - let value: Vec = self - .value - .into_iter() - .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) - .collect(); - - Data { - value, - shape: self.shape, - } - } -} - -#[allow(deprecated)] -impl From<&DataSerialize> for Data { - fn from(data: &DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value.clone(), Shape::new(dims)) - } -} - -#[allow(deprecated)] -impl From> for Data { - fn from(data: DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value, Shape::new(dims)) - } -} - -#[allow(deprecated)] -impl From<[E; A]> for Data { - fn from(elems: [E; A]) -> Self { - let mut data = Vec::with_capacity(2 * A); - for elem in elems.into_iter() { - data.push(elem); - } - - Data::new(data, Shape::new([A])) - } -} - -#[allow(deprecated)] -impl From<&[E]> for Data { - fn from(elems: &[E]) -> Self { - let mut data = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - data.push(*elem); - } - - Data::new(data, Shape::new([elems.len()])) - } -} - -#[allow(deprecated)] -impl From<[[E; B]; A]> for Data { - fn from(elems: [[E; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B); - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - data.push(elem); - } - } - - Data::new(data, Shape::new([A, B])) - } -} - -#[allow(deprecated)] -impl - From<[[[E; C]; B]; A]> for Data -{ - fn from(elems: [[[E; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - data.push(elem); - } - } - } - - Data::new(data, Shape::new([A, B, C])) - } -} - -#[allow(deprecated)] -impl< - E: core::fmt::Debug + Copy, - const A: usize, - const B: usize, - const C: usize, - const D: usize, - > From<[[[[E; D]; C]; B]; A]> for Data -{ - fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C * D); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - for elem in elem.into_iter().take(D) { - data.push(elem); - } - } - } - } - - Data::new(data, Shape::new([A, B, C, D])) - } -} - -#[allow(deprecated)] -impl core::fmt::Display for Data { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{:?}", &self.value).as_str()) - } -} - fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> { let epsilon_deviations = tolerance / f32::EPSILON as f64; let epsilon = match ty { @@ -1192,9 +799,8 @@ fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<( } #[cfg(test)] -#[allow(deprecated)] mod tests { - use crate::quantization::AffineQuantization; + use crate::{quantization::AffineQuantization, Shape}; use super::*; use alloc::vec; diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 0e7ff51e88..d54233f993 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -67,7 +67,6 @@ network = ["burn-core/network"] experimental-named-tensor = ["burn-core/experimental-named-tensor"] # Records -record-backward-compat = ["burn-core/record-backward-compat"] record-item-custom-serde = ["burn-core/record-item-custom-serde"] [dependencies] From 9d9ea8b7013313ceb992d9eb4ef9d3e30c804851 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 10:07:31 -0500 Subject: [PATCH 14/17] Add hardsigmoid formula and fix WGAN doc + default lr (#2706) --- crates/burn-tensor/src/tensor/activation/base.rs | 2 ++ examples/wgan/src/model.rs | 2 +- examples/wgan/src/training.rs | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/burn-tensor/src/tensor/activation/base.rs b/crates/burn-tensor/src/tensor/activation/base.rs index cc5990d375..15fcc7ab50 100644 --- a/crates/burn-tensor/src/tensor/activation/base.rs +++ b/crates/burn-tensor/src/tensor/activation/base.rs @@ -144,6 +144,8 @@ pub fn sigmoid(tensor: Tensor) -> Tensor } /// Applies the hard sigmoid function +/// +/// `hard_sigmoid(x) = max(0, min(1, alpha * x + beta))` pub fn hard_sigmoid( tensor: Tensor, alpha: f64, diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs index ddb84ff6d3..b9615f5270 100644 --- a/examples/wgan/src/model.rs +++ b/examples/wgan/src/model.rs @@ -96,7 +96,7 @@ pub struct ModelConfig { } impl ModelConfig { - /// "init" is used to create other objects, while "new" is usally used to create itself. + /// Initialize the generator and discriminator models based on the config. pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { // Construct the initialized generator let layer1 = LayerBlock::new(self.latent_dim, 128, device); diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs index db1f594b46..25fbef21c1 100644 --- a/examples/wgan/src/training.rs +++ b/examples/wgan/src/training.rs @@ -23,7 +23,7 @@ pub struct TrainingConfig { pub num_workers: usize, #[config(default = 5)] pub seed: u64, - #[config(default = 5e-5)] + #[config(default = 3e-4)] pub lr: f64, /// Number of training steps for discriminator before generator is trained per iteration From 9daf0486ec71c2d11d23bffd7d5d8ebdcfd6de37 Mon Sep 17 00:00:00 2001 From: Nathan Whitehead Date: Thu, 16 Jan 2025 08:08:07 -0700 Subject: [PATCH 15/17] Fix GRU (#2704) * Fix GRU to match pytorch (#2701). Update GRU implementation of new gate to match pytorch implementation. This can change numerical output in some cases. Add GRU unit test with sequence length > 1. Fix GRU input state dimensions and hidden state handling. This is an API change since the dimensions of the optional hidden state input are being corrected to the right sizes. Just updating to the correct dimensions seems like the best thing since the previous implementation was incorrect, not just different than pytorch. * Add GruConfig option reset_after to allow both reset behaviors. * Fix clippy and keep previous test --------- Co-authored-by: Guillaume Lagrange --- crates/burn-core/src/nn/rnn/gru.rs | 198 ++++++++++++++++++++--------- 1 file changed, 141 insertions(+), 57 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs index c66ad631b6..e2f8b2425e 100644 --- a/crates/burn-core/src/nn/rnn/gru.rs +++ b/crates/burn-core/src/nn/rnn/gru.rs @@ -20,6 +20,21 @@ pub struct GruConfig { pub d_hidden: usize, /// If a bias should be applied during the Gru transformation. pub bias: bool, + /// If reset gate should be applied after weight multiplication. + /// + /// This configuration option controls how the reset gate is applied to the hidden state. + /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for + /// Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by + /// the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU). + /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine + /// Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication. + /// + /// The differing implementations can give slightly different numerical results and have different efficiencies. For more + /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs). + /// + /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`). + #[config(default = "true")] + pub reset_after: bool, /// Gru initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, @@ -41,6 +56,8 @@ pub struct Gru { pub new_gate: GateController, /// The size of the hidden state. pub d_hidden: usize, + /// If reset gate should be applied after weight multiplication. + pub reset_after: bool, } impl ModuleDisplay for Gru { @@ -58,6 +75,7 @@ impl ModuleDisplay for Gru { .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) + .add("reset_after", &self.reset_after) .optional() } } @@ -94,86 +112,92 @@ impl GruConfig { reset_gate, new_gate, d_hidden: self.d_hidden, + reset_after: self.reset_after, } } } impl Gru { /// Applies the forward pass on the input tensor. This GRU implementation - /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. + /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`. /// - /// # Shapes + /// # Parameters /// - batched_input: `[batch_size, sequence_length, input_size]`. - /// - state: An optional tensor representing an initial cell state with the same dimensions - /// as batched_input. If none is provided, one will be generated. - /// - output: `[batch_size, sequence_length, hidden_size]`. + /// - state: An optional tensor representing an initial cell state with dimensions + /// `[batch_size, hidden_size]`. If none is provided, an empty state will be used. + /// + /// # Returns + /// - output: `[batch_size, sequence_length, hidden_size]` pub fn forward( &self, batched_input: Tensor, - state: Option>, + state: Option>, ) -> Tensor { + let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); - let mut hidden_state = match state { + let mut batched_hidden_state = + Tensor::empty([batch_size, seq_length, self.d_hidden], &device); + + let mut hidden_t = match state { Some(state) => state, - None => Tensor::zeros( - [batch_size, seq_length, self.d_hidden], - &batched_input.device(), - ), + None => Tensor::zeros([batch_size, self.d_hidden], &device), }; - for (t, (input_t, hidden_t)) in batched_input - .iter_dim(1) - .zip(hidden_state.clone().iter_dim(1)) - .enumerate() - { + for (t, input_t) in batched_input.iter_dim(1).enumerate() { let input_t = input_t.squeeze(1); - let hidden_t = hidden_t.squeeze(1); // u(pdate)g(ate) tensors - let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); + let biased_ug_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.update_gate); let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) // r(eset)g(ate) tensors - let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); + let biased_rg_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.reset_gate); let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) - let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate // n(ew)g(ate) tensor - let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); + let biased_ng_input_sum = if self.reset_after { + self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate) + } else { + let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate + self.gate_product(&input_t, &reset_t, None, &self.new_gate) + }; let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) // calculate linear interpolation between previous hidden state and candidate state: // g(t) * (1 - z(t)) + z(t) * hidden_t - let state_vector = candidate_state + hidden_t = candidate_state .clone() .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1) + update_values.clone().mul(hidden_t); - let current_shape = state_vector.shape().dims; - let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; - let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); - hidden_state = hidden_state.slice_assign( + let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1); + + batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], - reshaped_state_vector, + unsqueezed_hidden_state, ); } - hidden_state + batched_hidden_state } /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. + /// bias, if any, and optionally applies reset to hidden state. /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where: /// Wx = weight matrix for the connection to input vector X /// Wh = weight matrix for the connection to hidden state H /// X = input vector /// H = hidden state /// b = bias terms + /// r = reset state fn gate_product( &self, input: &Tensor, hidden: &Tensor, + reset: Option<&Tensor>, gate: &GateController, ) -> Tensor { let input_product = input.clone().matmul(gate.input_transform.weight.val()); @@ -190,13 +214,29 @@ impl Gru { .as_ref() .map(|bias_param| bias_param.val()); - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { + match (input_bias, hidden_bias, reset) { + (Some(input_bias), Some(hidden_bias), Some(r)) => { + input_product + + input_bias.unsqueeze() + + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (Some(input_bias), Some(hidden_bias), None) => { input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, + (Some(input_bias), None, Some(r)) => { + input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product) + } + (Some(input_bias), None, None) => { + input_product + input_bias.unsqueeze() + hidden_product + } + (None, Some(hidden_bias), Some(r)) => { + input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (None, Some(hidden_bias), None) => { + input_product + hidden_product + hidden_bias.unsqueeze() + } + (None, None, Some(r)) => input_product + r.clone().mul(hidden_product), + (None, None, None) => input_product + hidden_product, } } } @@ -207,29 +247,16 @@ mod tests { use crate::tensor::{Distribution, TensorData}; use crate::{module::Param, nn::LinearRecord, TestBackend}; - /// Test forward pass with simple input vector. - /// - /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 - /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 - /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 - /// - /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 - #[test] - fn tests_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = GruConfig::new(1, 1, false); - let device = Default::default(); - let mut gru = config.init::(&device); - - fn create_gate_controller( + fn init_gru(reset_after: bool, device: &B::Device) -> Gru { + fn create_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, - device: &::Device, - ) -> GateController { + device: &B::Device, + ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), @@ -248,6 +275,9 @@ mod tests { ) } + let config = GruConfig::new(1, 1, false).with_reset_after(reset_after); + let mut gru = config.init::(device); + gru.update_gate = create_gate_controller( 0.5, 0.0, @@ -255,7 +285,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.reset_gate = create_gate_controller( 0.6, @@ -264,7 +294,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.new_gate = create_gate_controller( 0.7, @@ -273,18 +303,72 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); + gru + } + + /// Test forward pass with simple input vector. + /// + /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 + /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 + /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 + /// + /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 + #[test] + fn tests_forward_single_input_single_feature() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(false, &device); let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); + let expected = TensorData::from([[0.034]]); + // Reset gate applied to hidden state before the matrix multiplication + let state = gru.forward(input.clone(), None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state after the matrix multiplication + gru.reset_after = true; // override forward behavior + let state = gru.forward(input, None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn tests_forward_seq_len_3() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(true, &device); + + let input = + Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device); + let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]); + + let result = gru.forward(input.clone(), None); + let output = result + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state before the matrix multiplication + gru.reset_after = false; // override forward behavior let state = gru.forward(input, None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze::<2>(0); - let expected = TensorData::from([[0.034]]); output.to_data().assert_approx_eq(&expected, 3); } @@ -308,7 +392,7 @@ mod tests { assert_eq!( alloc::format!("{}", layer), - "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}" + "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}" ); } } From 05925f187fea94fd7cf6ed6bb087a5e8fbb3eea0 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 12:54:04 -0500 Subject: [PATCH 16/17] Clean up train system metrics (#2707) --- crates/burn-train/Cargo.toml | 6 ++-- crates/burn-train/src/metric/mod.rs | 49 +++++++++++------------------ crates/burn/Cargo.toml | 2 +- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index b922a1a59e..8c024c88f8 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -12,9 +12,9 @@ documentation = "https://docs.rs/burn-train" version.workspace = true [features] -default = ["metrics", "tui"] +default = ["sys-metrics", "tui"] doc = ["default"] -metrics = ["nvml-wrapper", "sysinfo", "systemstat"] +sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] [dependencies] @@ -28,7 +28,7 @@ tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } tracing-core = { workspace = true } -# Metrics +# System Metrics nvml-wrapper = { workspace = true, optional = true } sysinfo = { workspace = true, optional = true } systemstat = { workspace = true, optional = true } diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 191099a383..ac8211e884 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -3,65 +3,54 @@ pub mod state; /// Module responsible to save and exposes data collected during training. pub mod store; +// System metrics +#[cfg(feature = "sys-metrics")] +mod cpu_temp; +#[cfg(feature = "sys-metrics")] +mod cpu_use; +#[cfg(feature = "sys-metrics")] +mod cuda; +#[cfg(feature = "sys-metrics")] +mod memory_use; +#[cfg(feature = "sys-metrics")] +pub use cpu_temp::*; +#[cfg(feature = "sys-metrics")] +pub use cpu_use::*; +#[cfg(feature = "sys-metrics")] +pub use cuda::*; +#[cfg(feature = "sys-metrics")] +pub use memory_use::*; + +// Training metrics mod acc; mod auroc; mod base; -#[cfg(feature = "metrics")] mod confusion_stats; -#[cfg(feature = "metrics")] -mod cpu_temp; -#[cfg(feature = "metrics")] -mod cpu_use; -#[cfg(feature = "metrics")] -mod cuda; -#[cfg(feature = "metrics")] mod fbetascore; mod hamming; -#[cfg(feature = "metrics")] mod iteration; mod learning_rate; mod loss; -#[cfg(feature = "metrics")] -mod memory_use; -#[cfg(feature = "metrics")] mod precision; -#[cfg(feature = "metrics")] mod recall; -#[cfg(feature = "metrics")] mod top_k_acc; pub use acc::*; pub use auroc::*; pub use base::*; -#[cfg(feature = "metrics")] pub use confusion_stats::ConfusionStatsInput; -#[cfg(feature = "metrics")] -pub use cpu_temp::*; -#[cfg(feature = "metrics")] -pub use cpu_use::*; -#[cfg(feature = "metrics")] -pub use cuda::*; -#[cfg(feature = "metrics")] pub use fbetascore::*; pub use hamming::*; -#[cfg(feature = "metrics")] pub use iteration::*; pub use learning_rate::*; pub use loss::*; -#[cfg(feature = "metrics")] -pub use memory_use::*; -#[cfg(feature = "metrics")] pub use precision::*; -#[cfg(feature = "metrics")] pub use recall::*; -#[cfg(feature = "metrics")] pub use top_k_acc::*; -#[cfg(feature = "metrics")] pub(crate) mod classification; pub(crate) mod processor; -#[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; // Expose `ItemLazy` so it can be implemented for custom types pub use processor::ItemLazy; diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index d54233f993..cd13682a4b 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -24,7 +24,7 @@ train = ["burn-train", "autodiff", "dataset"] tui = ["burn-train?/tui"] ## Includes system info metrics (CPU/GPU usage, etc) -metrics = ["burn-train?/metrics"] +metrics = ["burn-train?/sys-metrics"] # Datasets dataset = ["burn-core/dataset"] From 6750fd689059e628ed763e389c6bd5d8350ce77f Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:33:28 -0600 Subject: [PATCH 17/17] Code generation bug fix for ONNX import (#2708) --- crates/burn-import/src/burn/codegen.rs | 5 +++-- crates/burn-import/src/burn/node/resize.rs | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 798636c323..7f511dafd4 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -5,8 +5,9 @@ use burn::nn::PaddingConfig1d; use burn::nn::PaddingConfig2d; use burn::nn::PaddingConfig3d; -fn convert_primitive(primitive: T) -> TokenStream { - let value = primitive.to_string(); +fn convert_primitive(primitive: T) -> TokenStream { + let value = format!("{:?}", primitive); + value.parse().unwrap() } diff --git a/crates/burn-import/src/burn/node/resize.rs b/crates/burn-import/src/burn/node/resize.rs index 59afcfb607..606f3ef38d 100644 --- a/crates/burn-import/src/burn/node/resize.rs +++ b/crates/burn-import/src/burn/node/resize.rs @@ -228,7 +228,7 @@ mod tests { TensorType::new_float("tensor1", 3), TensorType::new_float("tensor2", 3), "cubic".to_string(), - vec![], + vec![2.0], vec![20], )); @@ -253,7 +253,7 @@ mod tests { pub fn new(device: &B::Device) -> Self { let resize = Interpolate1dConfig::new() .with_output_size(Some(20)) - .with_scale_factor(None) + .with_scale_factor(Some(2.0)) .with_mode(InterpolateMode::Cubic) .init(); Self {