From c33ca147f619c3b148034d4b885c2f39b63dfc08 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 18 Dec 2024 09:51:06 -0500 Subject: [PATCH] Fused/matmul (#2622) --- .github/workflows/test.yml | 10 +- Cargo.lock | 523 +++++++-------- Cargo.toml | 6 +- backend-comparison/Cargo.toml | 7 +- backend-comparison/benches/matmul.rs | 13 +- backend-comparison/benches/matmul_fused.rs | 74 +++ backend-comparison/src/burnbenchapp/base.rs | 2 + .../src/persistence/system_info.rs | 7 +- crates/burn-core/src/nn/linear.rs | 9 +- crates/burn-fusion/src/stream/context.rs | 9 +- crates/burn-jit/src/fusion/base.rs | 40 +- .../burn-jit/src/fusion/elemwise/builder.rs | 2 +- .../src/fusion/elemwise/optimization.rs | 77 +-- crates/burn-jit/src/fusion/matmul/args.rs | 250 ++++++++ crates/burn-jit/src/fusion/matmul/builder.rs | 97 +++ crates/burn-jit/src/fusion/matmul/mod.rs | 4 + .../src/fusion/matmul/optimization.rs | 304 +++++++++ crates/burn-jit/src/fusion/matmul/spec.rs | 22 + crates/burn-jit/src/fusion/mod.rs | 1 + .../burn-jit/src/fusion/on_write/builder.rs | 16 + crates/burn-jit/src/fusion/on_write/io.rs | 603 ++++++++++++------ crates/burn-jit/src/fusion/on_write/ir.rs | 100 ++- crates/burn-jit/src/fusion/on_write/trace.rs | 186 ++++-- .../src/fusion/on_write/trace_builder.rs | 39 +- .../src/kernel/conv/conv2d/gemm/algorithm.rs | 58 +- .../src/kernel/conv/conv2d/gemm/base.rs | 35 +- .../conv/conv2d/gemm/homogeneous/base.rs | 125 ++-- .../src/kernel/conv/conv2d/gemm/launch.rs | 21 +- .../kernel/conv/conv2d/gemm/loader/bias.rs | 82 +-- .../kernel/conv/conv2d/gemm/loader/im2col.rs | 60 +- .../src/kernel/conv/conv2d/gemm/mod.rs | 4 +- .../kernel/conv/conv2d/gemm/reader/bias.rs | 18 +- .../kernel/conv/conv2d/gemm/reader/im2col.rs | 56 +- .../src/kernel/conv/conv2d/gemm/spec.rs | 33 + .../src/kernel/conv/conv2d/tune/conv2d.rs | 21 +- crates/burn-tensor/src/lib.rs | 8 +- 36 files changed, 2074 insertions(+), 848 deletions(-) create mode 100644 backend-comparison/benches/matmul_fused.rs create mode 100644 crates/burn-jit/src/fusion/matmul/args.rs create mode 100644 crates/burn-jit/src/fusion/matmul/builder.rs create mode 100644 crates/burn-jit/src/fusion/matmul/mod.rs create mode 100644 crates/burn-jit/src/fusion/matmul/optimization.rs create mode 100644 crates/burn-jit/src/fusion/matmul/spec.rs create mode 100644 crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2721233595..d3e5a2dc36 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -44,16 +44,16 @@ env: # Sourced from https://archive.mesa3d.org/. Bumping this requires # updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release. - MESA_VERSION: "23.3.1" + MESA_VERSION: "24.2.3" # Corresponds to https://github.com/gfx-rs/ci-build/releases - MESA_CI_BINARY_BUILD: "build18" + MESA_CI_BINARY_BUILD: "build19" # Sourced from https://www.nuget.org/packages/Microsoft.Direct3D.WARP - WARP_VERSION: "1.0.8" + WARP_VERSION: "1.0.13" # Sourced from https://github.com/microsoft/DirectXShaderCompiler/releases # Must also be changed in shaders.yaml - DXC_RELEASE: "v1.7.2308" - DXC_FILENAME: "dxc_2023_08_14.zip" + DXC_RELEASE: "v1.8.2407" + DXC_FILENAME: "dxc_2024_07_31_clang_cl.zip" # Mozilla Grcov GRCOV_LINK: "https://github.com/mozilla/grcov/releases/download" diff --git a/Cargo.lock b/Cargo.lock index 0fa466bb87..bc66a5f125 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -79,9 +79,9 @@ dependencies = [ [[package]] name = "allocator-api2" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "android-tzdata" @@ -149,9 +149,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.93" +version = "1.0.94" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c95c10ba0b00a02636238b814946408b1322d5ac4760326e6fb8ec956d85775" +checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" [[package]] name = "arbitrary" @@ -361,10 +361,10 @@ dependencies = [ "base64 0.22.1", "bytes", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "itoa", "matchit", @@ -396,7 +396,7 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "http-body-util", "mime", @@ -415,7 +415,6 @@ dependencies = [ "arboard", "burn", "burn-common", - "burn-wgpu", "clap 4.5.23", "colored", "cubecl", @@ -502,18 +501,18 @@ dependencies = [ [[package]] name = "bit-set" -version = "0.6.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0481a0e032742109b1133a095184ee93d88f3dc9e0d28a5d033dc77a073f44f" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" dependencies = [ "bit-vec", ] [[package]] name = "bit-vec" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c54ff287cfc0a34f38a6b832ea1bd8e448a330b3e40a50859e6488bee07f22" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" [[package]] name = "bit_field" @@ -613,9 +612,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a68f1f47cdf0ec8ee4b941b2eee2a80cb796db73118c0dd09ac63fbe405be22" +checksum = "786a307d683a5bf92e6fd5fd69a7eb613751668d1d8d67d802846dfe367c62c8" dependencies = [ "memchr", "serde", @@ -714,7 +713,7 @@ dependencies = [ "serde_json", "spin", "tempfile", - "thiserror 2.0.6", + "thiserror 2.0.8", "uuid", ] @@ -762,7 +761,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", - "thiserror 2.0.6", + "thiserror 2.0.8", ] [[package]] @@ -823,10 +822,10 @@ dependencies = [ "serde", "serde_json", "syn 2.0.90", - "thiserror 2.0.6", + "thiserror 2.0.8", "tracing-core", "tracing-subscriber", - "zip 2.2.1", + "zip 2.2.2", ] [[package]] @@ -1136,9 +1135,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.1" +version = "1.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47" +checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" dependencies = [ "jobserver", "libc", @@ -1175,9 +1174,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.38" +version = "0.4.39" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1368,43 +1367,12 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "colored" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" dependencies = [ "lazy_static", - "windows-sys 0.48.0", -] - -[[package]] -name = "com" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e17887fd17353b65b1b2ef1c526c83e26cd72e74f598a8dc1bee13a48f3d9f6" -dependencies = [ - "com_macros", -] - -[[package]] -name = "com_macros" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d375883580a668c7481ea6631fc1a8863e33cc335bf56bfad8d7e6d4b04b13a5" -dependencies = [ - "com_macros_support", - "proc-macro2", - "syn 1.0.109", -] - -[[package]] -name = "com_macros_support" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad899a1087a9296d5644792d7cb72b8e34c1bec8e7d4fbc002230169a6e8710c" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", + "windows-sys 0.59.0", ] [[package]] @@ -1445,15 +1413,15 @@ dependencies = [ [[package]] name = "console" -version = "0.15.8" +version = "0.15.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" dependencies = [ "encode_unicode", - "lazy_static", "libc", - "unicode-width 0.1.14", - "windows-sys 0.52.0", + "once_cell", + "unicode-width 0.2.0", + "windows-sys 0.59.0", ] [[package]] @@ -1563,18 +1531,18 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.13" +version = "0.5.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33480d6946193aa8033910124896ca395333cae7e2d1113d1fef6c3272217df2" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-deque" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" dependencies = [ "crossbeam-epoch", "crossbeam-utils", @@ -1591,18 +1559,18 @@ dependencies = [ [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crossterm" @@ -1669,7 +1637,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1701,7 +1669,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1718,7 +1686,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1736,7 +1704,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1750,7 +1718,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1766,7 +1734,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1782,9 +1750,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "0.0.7" +version = "6.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d4c2b571b09b04b85d669c301042a45fa22a8a043504152f38bf5ba5b69e414" +checksum = "9974218b3ff1f1e7b2f11ce254fd90b3ebcc2af6b4d084f7f6a0c351fb16112c" dependencies = [ "libc", ] @@ -1792,7 +1760,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "bytemuck", "cubecl-core", @@ -1803,7 +1771,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1818,7 +1786,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1855,7 +1823,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "async-channel", "async-lock", @@ -1876,7 +1844,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1890,7 +1858,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=6e6fb265346c6378e939573900c5d32b722569fa#6e6fb265346c6378e939573900c5d32b722569fa" +source = "git+https://github.com/tracel-ai/cubecl?rev=4372c4109ec36b368236ad895df5009f6b6f1a18#4372c4109ec36b368236ad895df5009f6b6f1a18" dependencies = [ "ash", "async-channel", @@ -1910,9 +1878,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.12.1" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" +checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" dependencies = [ "half", "libloading", @@ -1986,17 +1954,6 @@ dependencies = [ "serde", ] -[[package]] -name = "d3d12" -version = "22.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdbd1f579714e3c809ebd822c81ef148b1ceaeb3d535352afc73fd0c4c6a0017" -dependencies = [ - "bitflags 2.6.0", - "libloading", - "winapi", -] - [[package]] name = "darling" version = "0.20.10" @@ -2276,9 +2233,9 @@ checksum = "1f878075b9794c1e4ac788c95b728f26aa6366d32eeb10c7051389f898f7d067" [[package]] name = "encode_unicode" -version = "0.3.6" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" [[package]] name = "encoding_rs" @@ -2383,9 +2340,9 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2" dependencies = [ "event-listener", "pin-project-lite", @@ -2445,15 +2402,15 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" [[package]] name = "fdeflate" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07c6f4c64c1d33a3111c4466f7365ebdcc37c5bd1ea0d62aae2e3d722aacbedb" +checksum = "1e6853b52649d4ac5c0bd02320cddc5ba956bdb407c4b75a2c6b75bf51500f8c" dependencies = [ "simd-adler32", ] @@ -2893,7 +2850,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b5eccc17194ed0e67d49285e4853307e4147e95407f91c1c3e4a13ba9f4e4ce" dependencies = [ "faster-hex", - "thiserror 2.0.6", + "thiserror 2.0.8", ] [[package]] @@ -2971,9 +2928,9 @@ dependencies = [ [[package]] name = "glow" -version = "0.13.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd348e04c43b32574f2de31c8bb397d96c9fcfa1371bd4ca6d8bdc464ab121b1" +checksum = "d51fa363f025f5c111e03f13eda21162faeacb6911fe8caa0c0349f9cf0c4483" dependencies = [ "js-sys", "slotmap", @@ -3011,15 +2968,14 @@ dependencies = [ [[package]] name = "gpu-allocator" -version = "0.26.0" +version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdd4240fc91d3433d5e5b0fc5b67672d771850dc19bbee03c1381e19322803d7" +checksum = "c151a2a5ef800297b4e79efa4f4bec035c5f51d5ae587287c9b952bdf734cacd" dependencies = [ "log", "presser", "thiserror 1.0.69", - "winapi", - "windows 0.52.0", + "windows 0.58.0", ] [[package]] @@ -3063,7 +3019,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.6.0", + "indexmap 2.7.0", "slab", "tokio", "tokio-util", @@ -3081,8 +3037,8 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.1.0", - "indexmap 2.6.0", + "http 1.2.0", + "indexmap 2.7.0", "slab", "tokio", "tokio-util", @@ -3163,21 +3119,6 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "hassle-rs" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af2a7e73e1f34c48da31fb668a907f250794837e08faa144fd24f0b8b741e890" -dependencies = [ - "bitflags 2.6.0", - "com", - "libc", - "libloading", - "thiserror 1.0.69", - "widestring", - "winapi", -] - [[package]] name = "heck" version = "0.4.1" @@ -3245,11 +3186,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -3271,9 +3212,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" dependencies = [ "bytes", "fnv", @@ -3298,7 +3239,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.1.0", + "http 1.2.0", ] [[package]] @@ -3309,7 +3250,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "pin-project-lite", ] @@ -3334,9 +3275,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "0.14.31" +version = "0.14.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c08302e8fa335b151b788c775ff56e7a03ae64ff85c548ee820fecb70356e85" +checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" dependencies = [ "bytes", "futures-channel", @@ -3358,15 +3299,15 @@ dependencies = [ [[package]] name = "hyper" -version = "1.5.1" +version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" dependencies = [ "bytes", "futures-channel", "futures-util", "h2 0.4.7", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "httparse", "httpdate", @@ -3384,8 +3325,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", - "http 1.1.0", - "hyper 1.5.1", + "http 1.2.0", + "hyper 1.5.2", "hyper-util", "rustls", "rustls-native-certs 0.8.1", @@ -3402,7 +3343,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" dependencies = [ "bytes", - "hyper 0.14.31", + "hyper 0.14.32", "native-tls", "tokio", "tokio-native-tls", @@ -3416,7 +3357,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-util", "native-tls", "tokio", @@ -3433,9 +3374,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", - "hyper 1.5.1", + "hyper 1.5.2", "pin-project-lite", "socket2", "tokio", @@ -3698,9 +3639,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" +checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3843,9 +3784,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a865e038f7f6ed956f788f0d7d60c541fff74c7bd74272c5d4cf15c63743e705" +checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" dependencies = [ "once_cell", "wasm-bindgen", @@ -3898,9 +3839,9 @@ dependencies = [ [[package]] name = "libloading" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", "windows-targets 0.52.6", @@ -3920,7 +3861,7 @@ checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ "bitflags 2.6.0", "libc", - "redox_syscall 0.5.7", + "redox_syscall 0.5.8", ] [[package]] @@ -4178,9 +4119,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" dependencies = [ "adler2", "simd-adler32", @@ -4188,11 +4129,10 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ - "hermit-abi 0.3.9", "libc", "log", "wasi", @@ -4274,9 +4214,9 @@ dependencies = [ [[package]] name = "naga" -version = "22.1.0" +version = "23.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bd5a652b6faf21496f2cfd88fc49989c8db0825d1f6746b1a71a6ede24a63ad" +checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" dependencies = [ "arrayvec", "bit-set", @@ -4284,7 +4224,7 @@ dependencies = [ "cfg_aliases 0.1.1", "codespan-reporting", "hexf-parse", - "indexmap 2.6.0", + "indexmap 2.7.0", "log", "rustc-hash 1.1.0", "spirv", @@ -4723,7 +4663,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.5.1", + "hyper 1.5.2", "itertools 0.13.0", "md-5", "parking_lot 0.12.3", @@ -4817,7 +4757,7 @@ dependencies = [ "flate2", "native-tls", "tar", - "thiserror 2.0.6", + "thiserror 2.0.8", "ureq", ] @@ -4955,7 +4895,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.7", + "redox_syscall 0.5.8", "smallvec", "windows-targets 0.52.6", ] @@ -5021,7 +4961,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.6.0", + "indexmap 2.7.0", ] [[package]] @@ -5091,9 +5031,9 @@ dependencies = [ [[package]] name = "png" -version = "0.17.14" +version = "0.17.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52f9d46a34a05a6a57566bc2bfae066ef07585a6e3fa30fbbdff5936380623f0" +checksum = "b67582bd5b65bdff614270e2ea89a1cf15bef71245cc1e5f7ea126977144211d" dependencies = [ "bitflags 1.3.2", "crc32fast", @@ -5201,7 +5141,7 @@ dependencies = [ "either", "hashbrown 0.14.5", "hashbrown 0.15.2", - "indexmap 2.6.0", + "indexmap 2.7.0", "num-traits", "once_cell", "polars-arrow", @@ -5316,7 +5256,7 @@ dependencies = [ "chrono", "fallible-streaming-iterator", "hashbrown 0.15.2", - "indexmap 2.6.0", + "indexmap 2.7.0", "itoa", "num-traits", "polars-arrow", @@ -5390,7 +5330,7 @@ dependencies = [ "either", "hashbrown 0.15.2", "hex", - "indexmap 2.6.0", + "indexmap 2.7.0", "memchr", "num-traits", "polars-arrow", @@ -5527,7 +5467,7 @@ version = "0.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d88667f770291cefa2e8cd366a54f29dc6fe362e9a263914c903db411a58ac1d" dependencies = [ - "indexmap 2.6.0", + "indexmap 2.7.0", "polars-error", "polars-utils", "serde", @@ -5618,7 +5558,7 @@ dependencies = [ "bytes", "compact_str", "hashbrown 0.15.2", - "indexmap 2.6.0", + "indexmap 2.7.0", "libc", "memmap2 0.7.1", "num-traits", @@ -5784,7 +5724,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "322330e133eab455718444b4e033ebfac7c6528972c784fcde28d2cc783c6257" dependencies = [ "anyhow", - "indexmap 2.6.0", + "indexmap 2.7.0", "log", "protobuf", "protobuf-support", @@ -5942,10 +5882,10 @@ dependencies = [ "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.0.0", + "rustc-hash 2.1.0", "rustls", "socket2", - "thiserror 2.0.6", + "thiserror 2.0.8", "tokio", "tracing", ] @@ -5960,11 +5900,11 @@ dependencies = [ "getrandom", "rand", "ring", - "rustc-hash 2.0.0", + "rustc-hash 2.1.0", "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.6", + "thiserror 2.0.8", "tinyvec", "tracing", "web-time", @@ -5972,9 +5912,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d5a626c6807713b15cac82a6acaccd6043c9a5408c24baae07611fec3f243da" +checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" dependencies = [ "cfg_aliases 0.2.1", "libc", @@ -6231,9 +6171,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ "bitflags 2.6.0", ] @@ -6339,7 +6279,7 @@ dependencies = [ "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper 0.14.31", + "hyper 0.14.32", "hyper-tls 0.5.0", "ipnet", "js-sys", @@ -6378,10 +6318,10 @@ dependencies = [ "futures-core", "futures-util", "h2 0.4.7", - "http 1.1.0", + "http 1.2.0", "http-body 1.0.1", "http-body-util", - "hyper 1.5.1", + "hyper 1.5.2", "hyper-rustls", "hyper-tls 0.6.0", "hyper-util", @@ -6540,9 +6480,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.0.0" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152" +checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" [[package]] name = "rustc_version" @@ -6555,22 +6495,22 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.41" +version = "0.38.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" +checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" dependencies = [ "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] name = "rustls" -version = "0.23.19" +version = "0.23.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "934b404430bb06b3fae2cba809eb45a1ab1aecd64491213d7c3301b88393f8d1" +checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" dependencies = [ "log", "once_cell", @@ -6603,7 +6543,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework 3.0.1", + "security-framework 3.1.0", ] [[package]] @@ -6626,9 +6566,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16f1201b3c9a7ee8039bcadc17b7e605e2945b27eee7631788c1bd2b0643674b" +checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" dependencies = [ "web-time", ] @@ -6706,9 +6646,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.5" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66b202022bb57c049555430e11fc22fea12909276a80a4c3d368da36ac1d88ed" +checksum = "94b13f8ea6177672c49d12ed964cca44836f59621981b04a3e26b87e675181de" dependencies = [ "sdd", ] @@ -6739,9 +6679,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sdd" -version = "3.0.4" +version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49c1eeaf4b6a87c7479688c6d52b9f1153cedd3c489300564f932b065c6eab95" +checksum = "478f121bb72bbf63c52c93011ea1791dca40140dfe13f8336c4c5ac952c33aa9" [[package]] name = "security-framework" @@ -6758,9 +6698,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "3.0.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1415a607e92bec364ea2cf9264646dcce0f91e6d65281bd6f2819cca3bf39c8" +checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc" dependencies = [ "bitflags 2.6.0", "core-foundation 0.10.0", @@ -6771,9 +6711,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.12.1" +version = "2.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa39c7303dc58b5543c94d22c1766b0d31f2ee58306363ea622b10bbc075eaa2" +checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5" dependencies = [ "core-foundation-sys", "libc", @@ -6781,9 +6721,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.23" +version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" [[package]] name = "seq-macro" @@ -7383,9 +7323,9 @@ dependencies = [ [[package]] name = "systemstat" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a24aec24a9312c83999a28e3ef9db7e2afd5c64bf47725b758cdc1cafd5b0bd2" +checksum = "668a4db78b439df482c238f559e4ea869017f9e62ef0a059c8bfcd841a4df544" dependencies = [ "bytesize", "lazy_static", @@ -7516,11 +7456,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.6" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fec2a1820ebd077e2b90c4df007bebf344cd394098a13c563957d0afc83ea47" +checksum = "08f5383f3e0071702bf93ab5ee99b52d26936be9dedd9413067cbdcddcb6141a" dependencies = [ - "thiserror-impl 2.0.6", + "thiserror-impl 2.0.8", ] [[package]] @@ -7536,9 +7476,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.6" +version = "2.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d65750cab40f4ff1929fb1ba509e9914eb756131cef4210da8d5d700d26f6312" +checksum = "f2f357fcec90b3caef6623a099691be676d033b40a058ac95d2a6ade6fa0c943" dependencies = [ "proc-macro2", "quote", @@ -7577,9 +7517,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.36" +version = "0.3.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" dependencies = [ "deranged", "itoa", @@ -7600,9 +7540,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" dependencies = [ "num-conv", "time-core", @@ -7704,12 +7644,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" dependencies = [ "rustls", - "rustls-pki-types", "tokio", ] @@ -7727,9 +7666,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.12" +version = "0.7.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" dependencies = [ "bytes", "futures-core", @@ -7765,7 +7704,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.6.0", + "indexmap 2.7.0", "serde", "serde_spanned", "toml_datetime", @@ -7789,14 +7728,14 @@ dependencies = [ [[package]] name = "tower" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.2", "tokio", "tower-layer", "tower-service", @@ -7934,7 +7873,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.2.0", "httparse", "log", "rand", @@ -8079,9 +8018,9 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "ureq" -version = "2.10.1" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" dependencies = [ "base64 0.22.1", "flate2", @@ -8216,9 +8155,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d15e63b4482863c109d70a7b8706c1e364eb6ea449b201a76c5b89cedcec2d5c" +checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" dependencies = [ "cfg-if", "once_cell", @@ -8227,13 +8166,12 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d36ef12e3aaca16ddd3f67922bc63e48e953f126de60bd33ccc0101ef9998cd" +checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" dependencies = [ "bumpalo", "log", - "once_cell", "proc-macro2", "quote", "syn 2.0.90", @@ -8242,9 +8180,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.47" +version = "0.4.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dfaf8f50e5f293737ee323940c7d8b08a66a95a419223d9f41610ca08b0833d" +checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" dependencies = [ "cfg-if", "js-sys", @@ -8255,9 +8193,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "705440e08b42d3e4b36de7d66c944be628d579796b8090bfa3471478a2260051" +checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -8265,9 +8203,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98c9ae5a76e46f4deecd0f0255cc223cfa18dc9b261213b8aa0c7b36f61b3f1d" +checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" dependencies = [ "proc-macro2", "quote", @@ -8278,9 +8216,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.97" +version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ee99da9c5ba11bd675621338ef6fa52296b76b83305e9b6e5c77d4c286d6d49" +checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" [[package]] name = "wasm-logger" @@ -8323,9 +8261,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.74" +version = "0.3.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a98bc3c33f0fe7e59ad7cd041b89034fa82a7c2d4365ca538dda6cdaf513863c" +checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" dependencies = [ "js-sys", "wasm-bindgen", @@ -8358,9 +8296,9 @@ checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "wgpu" -version = "22.1.0" +version = "23.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1d1c4ba43f80542cf63a0a6ed3134629ae73e8ab51e4b765a67f3aa062eb433" +checksum = "80f70000db37c469ea9d67defdc13024ddf9a5f1b89cb2941b812ad7cde1735a" dependencies = [ "arrayvec", "cfg_aliases 0.1.1", @@ -8383,16 +8321,16 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "22.1.0" +version = "23.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0348c840d1051b8e86c3bcd31206080c5e71e5933dabd79be1ce732b0b2f089a" +checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" dependencies = [ "arrayvec", "bit-vec", "bitflags 2.6.0", "cfg_aliases 0.1.1", "document-features", - "indexmap 2.6.0", + "indexmap 2.7.0", "log", "naga", "once_cell", @@ -8408,9 +8346,9 @@ dependencies = [ [[package]] name = "wgpu-hal" -version = "22.0.0" +version = "23.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6bbf4b4de8b2a83c0401d9e5ae0080a2792055f25859a02bf9be97952bbed4f" +checksum = "89364b8a0b211adc7b16aeaf1bd5ad4a919c1154b44c9ce27838213ba05fd821" dependencies = [ "android_system_properties", "arrayvec", @@ -8418,15 +8356,14 @@ dependencies = [ "bit-set", "bitflags 2.6.0", "block", + "bytemuck", "cfg_aliases 0.1.1", "core-graphics-types", - "d3d12", "glow", "glutin_wgl_sys", "gpu-alloc", "gpu-allocator", "gpu-descriptor", - "hassle-rs", "js-sys", "khronos-egl", "libc", @@ -8448,14 +8385,15 @@ dependencies = [ "wasm-bindgen", "web-sys", "wgpu-types", - "winapi", + "windows 0.58.0", + "windows-core 0.58.0", ] [[package]] name = "wgpu-types" -version = "22.0.0" +version = "23.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc9d91f0e2c4b51434dfa6db77846f2793149d8e73f800fa2e41f52b8eac3c5d" +checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" dependencies = [ "bitflags 2.6.0", "js-sys", @@ -8474,12 +8412,6 @@ dependencies = [ "rustix", ] -[[package]] -name = "widestring" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311" - [[package]] name = "winapi" version = "0.3.9" @@ -8513,21 +8445,21 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows" -version = "0.52.0" +version = "0.57.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" +checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" dependencies = [ - "windows-core 0.52.0", + "windows-core 0.57.0", "windows-targets 0.52.6", ] [[package]] name = "windows" -version = "0.57.0" +version = "0.58.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12342cb4d8e3b046f3d80effd474a7a02447231330ef77d71daa6fbc40681143" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" dependencies = [ - "windows-core 0.57.0", + "windows-core 0.58.0", "windows-targets 0.52.6", ] @@ -8546,12 +8478,25 @@ version = "0.57.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2ed2439a290666cd67ecce2b0ffaad89c2a56b976b736e6ece670297897832d" dependencies = [ - "windows-implement", - "windows-interface", + "windows-implement 0.57.0", + "windows-interface 0.57.0", "windows-result 0.1.2", "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement 0.58.0", + "windows-interface 0.58.0", + "windows-result 0.2.0", + "windows-strings", + "windows-targets 0.52.6", +] + [[package]] name = "windows-implement" version = "0.57.0" @@ -8563,6 +8508,17 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "windows-interface" version = "0.57.0" @@ -8574,6 +8530,17 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.90", +] + [[package]] name = "windows-registry" version = "0.2.0" @@ -8840,9 +8807,9 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.23" +version = "0.8.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af310deaae937e48a26602b730250b4949e125f468f11e6990be3e5304ddd96f" +checksum = "ea8b391c9a790b496184c29f7f93b9ed5b16abb306c05415b68bcc16e4d06432" [[package]] name = "xtask" @@ -9004,16 +8971,16 @@ dependencies = [ "crc32fast", "crossbeam-utils", "displaydoc", - "indexmap 2.6.0", + "indexmap 2.7.0", "num_enum", "thiserror 1.0.69", ] [[package]] name = "zip" -version = "2.2.1" +version = "2.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d52293fc86ea7cf13971b3bb81eb21683636e7ae24c729cdaf1b7c4157a352" +checksum = "ae9c1ea7b3a5e1f4b922ff856a129881167511563dc219869afe3787fc0c1a45" dependencies = [ "aes", "arbitrary", @@ -9025,13 +8992,13 @@ dependencies = [ "displaydoc", "flate2", "hmac", - "indexmap 2.6.0", + "indexmap 2.7.0", "lzma-rs", "memchr", "pbkdf2 0.12.2", "rand", "sha1", - "thiserror 2.0.6", + "thiserror 2.0.8", "time", "zeroize", "zopfli", @@ -9116,9 +9083,9 @@ dependencies = [ [[package]] name = "zune-jpeg" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16099418600b4d8f028622f73ff6e3deaabdff330fb9a2a131dea781ee8b0768" +checksum = "99a5bab8d7dedf81405c4bb1f2b83ea057643d9cb28778cea9eecddeedd2e028" dependencies = [ "zune-core", ] diff --git a/Cargo.toml b/Cargo.toml index 208cfbc17f..ef66550b12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,7 +101,7 @@ ratatui = "0.29.0" # WGPU stuff text_placeholder = "0.5.1" -wgpu = "22.1.0" +wgpu = "23.0.0" # Benchmarks and Burnbench arboard = "3.4.1" @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "6e6fb265346c6378e939573900c5d32b722569fa" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4372c4109ec36b368236ad895df5009f6b6f1a18" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4372c4109ec36b368236ad895df5009f6b6f1a18" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 39ddd0c6f5..ee5f0bd8a2 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -34,7 +34,7 @@ wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] arboard = { workspace = true } burn = { path = "../crates/burn", default-features = false } burn-common = { path = "../crates/burn-common", version = "0.16.0" } -burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.16.0", optional = true } + clap = { workspace = true } colored = { workspace = true } cubecl = { workspace = true, features = ["wgpu"], default-features = true } @@ -96,6 +96,11 @@ name = "conv3d" harness = false name = "matmul" +[[bench]] +harness = false +name = "matmul-fused" +path = "benches/matmul_fused.rs" + [[bench]] harness = false name = "data" diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 0e9f3622b1..d950c71926 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -1,5 +1,5 @@ use backend_comparison::persistence::save; -use burn::tensor::{backend::Backend, Shape, Tensor}; +use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -26,8 +26,8 @@ impl Benchmark for MatmulBenchmark { } fn prepare(&self) -> Self::Args { - let lhs = Tensor::zeros(self.shape_lhs.clone(), &self.device); - let rhs = Tensor::zeros(self.shape_rhs.clone(), &self.device); + let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device); (lhs, rhs) } @@ -45,9 +45,10 @@ fn bench( token: Option<&str>, ) { let benchmarks = [ - (3, 4096, 4096, 4096), - (8, 2048, 2048, 2048), - (2, 4096, 4096, 512), + (2, 4096, 4096, 4096), + (32, 2048, 2048, 2048), + (256, 1024, 1024, 1024), + (1024, 256, 256, 256), ] .into_iter() .map(|(b, m, n, k)| { diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs new file mode 100644 index 0000000000..375be97b4e --- /dev/null +++ b/backend-comparison/benches/matmul_fused.rs @@ -0,0 +1,74 @@ +use backend_comparison::persistence::save; +use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor}; +use burn_common::benchmark::{run_benchmark, Benchmark}; +use derive_new::new; + +#[derive(new)] +struct MatmulBenchmark { + shape_lhs: Shape, + shape_rhs: Shape, + device: B::Device, +} + +impl Benchmark for MatmulBenchmark { + type Args = (Tensor, Tensor, Tensor); + + fn name(&self) -> String { + "matmul_bias_relu".into() + } + + fn shapes(&self) -> Vec> { + vec![self.shape_lhs.dims.clone(), self.shape_rhs.dims.clone()] + } + + fn execute(&self, (lhs, rhs, bias): Self::Args) { + let bias = bias.unsqueeze(); + relu(lhs.matmul(rhs) + bias); + } + + fn prepare(&self) -> Self::Args { + let lhs = Tensor::random(self.shape_lhs.clone(), Distribution::Default, &self.device); + let rhs = Tensor::random(self.shape_rhs.clone(), Distribution::Default, &self.device); + let bias = Tensor::random( + [self.shape_rhs.dims[2]], + Distribution::Default, + &self.device, + ); + + (lhs, rhs, bias) + } + + fn sync(&self) { + B::sync(&self.device) + } +} + +#[allow(dead_code)] +fn bench( + device: &B::Device, + feature_name: &str, + url: Option<&str>, + token: Option<&str>, +) { + let benchmarks = [ + (2, 4096, 4096, 4096), + (32, 2048, 2048, 2048), + (256, 1024, 1024, 1024), + (1024, 256, 256, 256), + ] + .into_iter() + .map(|(b, m, n, k)| { + let shape_lhs = [b, m, k].into(); + let shape_rhs = [b, k, n].into(); + + MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()) + }) + .map(run_benchmark) + .collect(); + + save::(benchmarks, device, feature_name, url, token).unwrap(); +} + +fn main() { + backend_comparison::bench_on_backend!(); +} diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index 0424e00c4e..83c5060a6b 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -103,6 +103,8 @@ enum BenchmarkValues { Data, #[strum(to_string = "matmul")] Matmul, + #[strum(to_string = "matmul-fused")] + MatmulFused, #[strum(to_string = "unary")] Unary, #[strum(to_string = "max-pool2d")] diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs index 171877f69f..287b629c21 100644 --- a/backend-comparison/src/persistence/system_info.rs +++ b/backend-comparison/src/persistence/system_info.rs @@ -2,7 +2,7 @@ use burn::serde::{Deserialize, Serialize}; use cubecl::wgpu::GraphicsApi; use std::collections::HashSet; use sysinfo; -use wgpu; +use wgpu::{self, Backends}; #[derive(Default, Clone, Serialize, Deserialize)] pub struct BenchmarkSystemInfo { @@ -51,7 +51,10 @@ impl BenchmarkSystemInfo { fn enumerate_gpus() -> Vec { let instance = wgpu::Instance::default(); let adapters: Vec = instance - .enumerate_adapters(cubecl::wgpu::AutoGraphicsApi::backend().into()) + .enumerate_adapters({ + let backend = cubecl::wgpu::AutoGraphicsApi::backend(); + Backends::from_bits(1 << backend as u32).unwrap() + }) .into_iter() .filter(|adapter| { let info = adapter.get_info(); diff --git a/crates/burn-core/src/nn/linear.rs b/crates/burn-core/src/nn/linear.rs index 462e0ea215..738dd87c80 100644 --- a/crates/burn-core/src/nn/linear.rs +++ b/crates/burn-core/src/nn/linear.rs @@ -75,10 +75,13 @@ impl Linear { return Self::forward::<2>(self, input.unsqueeze()).flatten(0, 1); } - let output = input.matmul(self.weight.val().unsqueeze()); + let weight = self.weight.val().unsqueeze(); + let bias = self.bias.as_ref().map(|b| b.val().unsqueeze()); - match &self.bias { - Some(bias) => output + bias.val().unsqueeze(), + let output = input.matmul(weight); + + match bias { + Some(bias) => output + bias, None => output, } } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 7dcc81cb77..d5e1ee9e38 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -12,7 +12,7 @@ use hashbrown::HashMap; #[derive(new)] pub struct Context<'a, H> { /// The tensor mapping where local tensor id points to the updated tensor description. - pub tensors: &'a HashMap, + pub tensors: &'a mut HashMap, /// Handle container to retrieve tensors based on their description. pub handles: &'a mut HandleContainer, /// F32 scalars found in the graph in the order they appeared. @@ -78,10 +78,13 @@ trait RelativeOpsScalar { } impl OperationConverter { - pub(crate) fn context<'a, H>(&'a self, handles: &'a mut HandleContainer) -> Context<'a, H> { + pub(crate) fn context<'a, H>( + &'a mut self, + handles: &'a mut HandleContainer, + ) -> Context<'a, H> { Context { handles, - tensors: &self.tensors_relative2global, + tensors: &mut self.tensors_relative2global, scalar_f32: &self.scalar_f32, scalar_f16: &self.scalar_f16, scalar_bf16: &self.scalar_bf16, diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 96c22d0898..512a23a9dc 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,6 +1,10 @@ use super::elemwise::optimization::{ElemwiseOptimization, ElemwiseOptimizationState}; -use crate::{element::BoolElement, fusion::elemwise::builder::ElementWiseBuilder}; +use super::matmul::optimization::{MatmulOptimization, MatmulOptimizationState}; +use crate::fusion::elemwise::builder::ElementWiseBuilder; +use crate::fusion::matmul::builder::MatmulBuilder; +use crate::BoolElement; use crate::{kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime}; + use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::repr::TensorHandle; use burn_tensor::DType; @@ -16,7 +20,9 @@ use serde::{Deserialize, Serialize}; /// More optimization variants should be added here. pub enum JitOptimization { /// Element wise optimization. - ElementWise2(ElemwiseOptimization), + ElementWise(ElemwiseOptimization), + /// Matrix multiplication optimization. + Matmul(MatmulOptimization), } /// Fusion optimization state type for JIT. @@ -26,6 +32,8 @@ pub enum JitOptimization { pub enum JitOptimizationState { /// Element wise state. ElementWise(ElemwiseOptimizationState), + /// Matrix multiplication optimization state. + Matmul(MatmulOptimizationState), } impl burn_fusion::Optimization> for JitOptimization @@ -35,26 +43,32 @@ where { fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitFusionHandle>) { match self { - Self::ElementWise2(op) => op.execute::(context), + Self::ElementWise(op) => op.execute::(context), + Self::Matmul(op) => op.execute::(context), } } fn len(&self) -> usize { match self { - Self::ElementWise2(op) => op.num_ops_fused(), + Self::ElementWise(op) => op.num_ops_fused(), + Self::Matmul(op) => op.num_ops_fused(), } } fn to_state(&self) -> JitOptimizationState { match self { - Self::ElementWise2(value) => JitOptimizationState::ElementWise(value.to_state()), + Self::ElementWise(value) => JitOptimizationState::ElementWise(value.to_state()), + Self::Matmul(value) => JitOptimizationState::Matmul(value.to_state()), } } fn from_state(device: &R::Device, state: JitOptimizationState) -> Self { match state { JitOptimizationState::ElementWise(state) => { - Self::ElementWise2(ElemwiseOptimization::from_state(device, state)) + Self::ElementWise(ElemwiseOptimization::from_state(device, state)) + } + JitOptimizationState::Matmul(state) => { + Self::Matmul(MatmulOptimization::from_state(device, state)) } } } @@ -111,10 +125,16 @@ impl FusionRuntime for FusionJitRuntime { fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::::new( - device.clone(), - BT::as_elem().into(), - ))] + vec![ + Box::new(ElementWiseBuilder::::new( + device.clone(), + BT::as_elem().into(), + )), + Box::new(MatmulBuilder::::new( + device.clone(), + BT::as_elem().into(), + )), + ] } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index e37196bc2a..461767e9fc 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -40,7 +40,7 @@ impl OptimizationBuilder> for ElementWiseBuild let elementwise = ElemwiseOptimization::::new(trace, client, self.device.clone(), self.len()); - JitOptimization::ElementWise2(elementwise) + JitOptimization::ElementWise(elementwise) } fn reset(&mut self) { diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index d3e8e35b50..2e33eefc20 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -1,7 +1,6 @@ use crate::{fusion::on_write::kernel::fuse_on_write, BoolElement}; use crate::{fusion::JitFusionHandle, JitRuntime}; use burn_fusion::stream::Context; -use burn_tensor::repr::TensorDescription; use cubecl::{calculate_cube_count_elemwise, client::ComputeClient, prelude::*, CubeDim}; use serde::{Deserialize, Serialize}; @@ -30,7 +29,8 @@ impl ElemwiseOptimization { /// Execute the optimization. pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { self.trace - .run::(&self.client, &self.device, context) + .run::(&self.client, &self.device, context, &ElemwiseRunner) + .unwrap(); } /// Number of element wise operations fused. @@ -57,13 +57,18 @@ impl ElemwiseOptimization { } } -impl TraceRunner for ElemwiseOptimization { +pub struct ElemwiseRunner; + +impl TraceRunner for ElemwiseRunner { + type Error = (); // No error possible + fn run<'a>( - client: &ComputeClient, + &'a self, + client: &'a ComputeClient, inputs: GlobalArgsLaunch<'a, R>, outputs: GlobalArgsLaunch<'a, R>, - config: ElemwiseConfig, - ) { + config: &'a ElemwiseConfig, + ) -> Result<(), Self::Error> { let arg = match config.ref_layout { Arg::Input(index, precision, _) => match precision { ElemwisePrecision::F32 => inputs.t_f32.values.get(index as usize), @@ -111,59 +116,17 @@ impl TraceRunner for ElemwiseOptimization { let cube_count = calculate_cube_count_elemwise(total_elem, cube_dim); unsafe { - elemwise_fuse::launch_unchecked(client, cube_count, cube_dim, inputs, outputs, config); - } - } - - fn vectorization<'a>( - handles_inputs: impl Iterator>, - inputs: impl Iterator, - outputs: impl Iterator, - ) -> u8 { - let factors = R::supported_line_sizes(); - - let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { - let rank = handle.strides.len(); - - // Last dimension strides should be 1, otherwise vecX won't be contiguous. - if handle.strides[rank - 1] != 1 { - return 1; - } - - for s in factors { - // The last dimension should be a multiple of the vector size. - if desc.shape[rank - 1] % *s as usize == 0 { - return *s; - } - } - - 1 - }; - - let vectorization_output = |desc: &TensorDescription| { - let rank = desc.shape.len(); - - for s in factors { - // The last dimension should be a multiple of the vector size. - if desc.shape[rank - 1] % *s as usize == 0 { - return *s; - } - } - - 1 + elemwise_fuse::launch_unchecked( + client, + cube_count, + cube_dim, + inputs, + outputs, + config.clone(), + ); }; - let mut output = u8::MAX; - - for (handle, tensor) in handles_inputs.zip(inputs) { - output = Ord::min(vectorization_input(handle, tensor), output); - } - - for tensor in outputs { - output = Ord::min(vectorization_output(tensor), output); - } - - output + Ok(()) } } diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs new file mode 100644 index 0000000000..aa2eb6cbf0 --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -0,0 +1,250 @@ +use cubecl::{linalg::matmul::components::global::args::MatmulArgs, prelude::*}; + +use crate::fusion::on_write::{ + io::{global_rank, global_shape, global_stride, read_input}, + ir::{Arg, ElemwiseConfig, GlobalArgs, GlobalArgsExpand, LayoutInfo}, + kernel::fuse_on_write, +}; + +#[derive(Clone)] +pub struct FusedMatmulArgs; + +#[derive(CubeLaunch)] +pub struct FusedMatmulInput { + global: GlobalArgs, + #[cube(comptime)] + config: ElemwiseConfig, + #[cube(comptime)] + lhs: Arg, + #[cube(comptime)] + rhs: Arg, + #[cube(comptime)] + out: Arg, +} + +#[cube] +impl MatmulArgs for FusedMatmulArgs { + type Output = GlobalArgs; + type Input = FusedMatmulInput; + type State = FusedMatmulState; + + fn init_state(inputs: &Self::Input, outputs: &mut Self::Output) -> Self::State { + FusedMatmulState::new(inputs, outputs, &inputs.config) + } + + fn read_lhs(state: &Self::State, coordinate: u32) -> Line { + let (pos, precision) = comptime! { + match state.lhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Lhs isn't an input"), + } + }; + + read_input( + unsafe { &(*state.inputs) }, + unsafe { &(*state.outputs) }, + pos, + coordinate, + LayoutInfo::IsRef, + precision, + &state.config, + ) + } + + fn read_rhs(state: &Self::State, coordinate: u32) -> Line { + let (pos, precision) = comptime! { + match state.rhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Lhs isn't an input"), + } + }; + + read_input( + unsafe { &(*state.inputs) }, + unsafe { &(*state.outputs) }, + pos, + coordinate, + LayoutInfo::IsRef, + precision, + &state.config, + ) + } + + fn write_out(state: &mut Self::State, coordinate: u32, value: Line) { + let mut values = Registry::>::new(); + let mut args = comptime![Sequence::::new()]; + + values.insert(state.out, value); + comptime![args.push(state.out)]; + + fuse_on_write( + unsafe { &(*state.inputs) }, + unsafe { &mut (*state.outputs) }, + coordinate, + values, + args, + &state.config, + ); + } + + fn rank_lhs(state: &Self::State) -> u32 { + let (pos, precision) = comptime! { + match state.lhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Lhs isn't an input"), + } + }; + + global_rank(unsafe { &(*state.inputs) }, pos, precision) + } + + fn rank_rhs(state: &Self::State) -> u32 { + let (pos, precision) = comptime! { + match state.rhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Rhs isn't an input"), + } + }; + + global_rank(unsafe { &(*state.inputs) }, pos, precision) + } + + fn rank_out(state: &Self::State) -> u32 { + let (pos, precision, is_input) = comptime! { + match state.config.ref_layout { + Arg::Input(pos, precision, _) => (pos, precision, true), + Arg::Output(pos, precision, _) => (pos, precision, false), + _ => panic!("Out isn't an input or output"), + } + }; + + if is_input { + global_rank(unsafe { &(*state.inputs) }, pos, precision) + } else { + global_rank(unsafe { &(*state.outputs) }, pos, precision) + } + } + + fn shape_lhs(state: &Self::State, dim: u32) -> u32 { + let (pos, precision) = comptime! { + match state.lhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Lhs isn't an input"), + } + }; + + global_shape(unsafe { &(*state.inputs) }, dim, pos, precision) + } + + fn shape_rhs(state: &Self::State, dim: u32) -> u32 { + let (pos, precision) = comptime! { + match state.rhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Rhs isn't an input"), + } + }; + + global_shape(unsafe { &(*state.inputs) }, dim, pos, precision) + } + + fn shape_out(state: &Self::State, dim: u32) -> u32 { + let (pos, precision, is_input) = comptime! { + match state.config.ref_layout { + Arg::Input(pos, precision, _) => (pos, precision, true), + Arg::Output(pos, precision, _) => (pos, precision, false), + _ => panic!("Out isn't an input or output"), + } + }; + + if is_input { + global_shape(unsafe { &(*state.inputs) }, dim, pos, precision) + } else { + global_shape(unsafe { &(*state.outputs) }, dim, pos, precision) + } + } + + fn stride_lhs(state: &Self::State, dim: u32) -> u32 { + let (pos, precision) = comptime! { + match state.lhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Lhs isn't an input"), + } + }; + + global_stride(unsafe { &(*state.inputs) }, dim, pos, precision) + } + + fn stride_rhs(state: &Self::State, dim: u32) -> u32 { + let (pos, precision) = comptime! { + match state.rhs { + Arg::Input(pos, precision, _) => (pos, precision), + _ => panic!("Rhs isn't an input"), + } + }; + + global_stride(unsafe { &(*state.inputs) }, dim, pos, precision) + } + + fn stride_out(state: &Self::State, dim: u32) -> u32 { + let (pos, precision, is_input) = comptime! { + match state.config.ref_layout { + Arg::Input(pos, precision, _) => (pos, precision, true), + Arg::Output(pos, precision, _) => (pos, precision, false), + _ => panic!("Out isn't an input or output"), + } + }; + + if is_input { + global_stride(unsafe { &(*state.inputs) }, dim, pos, precision) + } else { + global_stride(unsafe { &(*state.outputs) }, dim, pos, precision) + } + } +} + +pub struct FusedMatmulState { + inputs: *const GlobalArgs, + outputs: *mut GlobalArgs, + config: ElemwiseConfig, + lhs: Arg, + rhs: Arg, + out: Arg, +} + +#[cube] +impl FusedMatmulState { + pub fn new( + inputs: &FusedMatmulInput, + outputs: &mut GlobalArgs, + #[comptime] config: &ElemwiseConfig, + ) -> FusedMatmulState { + FusedMatmulState { + inputs: &inputs.global, + outputs, + config: comptime![config.clone()], + lhs: comptime![inputs.lhs], + rhs: comptime![inputs.rhs], + out: comptime![inputs.out], + } + } +} + +#[derive(Clone)] +pub struct FusedMatmulStateExpand { + inputs: GlobalArgsExpand, + outputs: GlobalArgsExpand, + config: ElemwiseConfig, + lhs: Arg, + rhs: Arg, + out: Arg, +} + +impl CubeType for FusedMatmulState { + type ExpandType = FusedMatmulStateExpand; +} + +impl Init for FusedMatmulStateExpand { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs new file mode 100644 index 0000000000..986332914f --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -0,0 +1,97 @@ +use burn_fusion::{OptimizationBuilder, OptimizationStatus}; +use burn_tensor::repr::{FloatOperationDescription, OperationDescription}; + +use crate::{ + fusion::{ + on_write::{builder::FuseOnWriteBuilder, ir::ElemwisePrecision}, + JitOptimization, + }, + JitRuntime, +}; + +use super::optimization::{FusedMatmul, MatmulOptimization}; + +/// Fused element wise operations that are normally memory bound. +pub(crate) struct MatmulBuilder { + builder: FuseOnWriteBuilder, + builder_fallback: FuseOnWriteBuilder, + device: R::Device, + matmul: Option, +} + +impl MatmulBuilder { + pub fn new(device: R::Device, bool_precision: ElemwisePrecision) -> Self { + let client = R::client(&device); + let props = client.properties(); + let max_bindings = props.hardware_properties().max_bindings; + + Self { + builder: FuseOnWriteBuilder::new(max_bindings, bool_precision), + builder_fallback: FuseOnWriteBuilder::new(max_bindings, bool_precision), + device, + matmul: None, + } + } +} + +impl OptimizationBuilder> for MatmulBuilder { + fn register(&mut self, operation: &OperationDescription) { + if let OptimizationStatus::Closed = self.builder.status() { + return; + } + + if self.matmul.is_none() { + if let OperationDescription::Float(_, FloatOperationDescription::Matmul(op)) = operation + { + let lhs = self.builder.input_unhandled(&op.lhs); + let rhs = self.builder.input_unhandled(&op.rhs); + let out = self.builder.output_unhandled(&op.out); + + self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone())); + } else { + self.builder.close(); + } + } else { + self.builder.register(operation); + self.builder_fallback.register(operation); + } + } + + fn build(&self) -> JitOptimization { + let client = R::client(&self.device); + let trace = self.builder.build(); + let trace_fallback = self.builder_fallback.build(); + + let matmul = MatmulOptimization::::new( + trace, + trace_fallback, + client, + self.device.clone(), + self.len(), + self.matmul.as_ref().unwrap().clone(), + ); + + JitOptimization::Matmul(matmul) + } + + fn reset(&mut self) { + self.builder.reset(); + self.builder_fallback.reset(); + self.matmul = None; + } + + fn status(&self) -> burn_fusion::OptimizationStatus { + self.builder.status() + } + + fn properties(&self) -> burn_fusion::OptimizationProperties { + let mut properties = self.builder.properties(); + properties.score += 1; + properties + } + + fn len(&self) -> usize { + // Matmul operation isn't registered in the builder + self.builder.len() + 1 + } +} diff --git a/crates/burn-jit/src/fusion/matmul/mod.rs b/crates/burn-jit/src/fusion/matmul/mod.rs new file mode 100644 index 0000000000..1afeef9c88 --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod args; +pub(crate) mod builder; +pub(crate) mod optimization; +pub(crate) mod spec; diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs new file mode 100644 index 0000000000..66c372564a --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -0,0 +1,304 @@ +use crate::fusion::elemwise::optimization::ElemwiseRunner; +use crate::fusion::on_write::ir::ElemwisePrecision; +use crate::kernel::matmul; +use crate::{fusion::JitFusionHandle, JitRuntime}; +use crate::{BoolElement, FloatElement}; + +use burn_fusion::stream::Context; +use burn_tensor::repr::{BinaryOperationDescription, TensorStatus}; +use burn_tensor::Shape; +use cubecl::linalg::matmul::components; +use cubecl::linalg::matmul::components::MatmulProblem; +use cubecl::linalg::matmul::kernels::matmul::{CmmaSelector, PlaneMmaSelector}; +use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; +use cubecl::linalg::tensor::{matrix_layout, MatrixLayout}; +use cubecl::{client::ComputeClient, prelude::*}; +use half::{bf16, f16}; +use serde::{Deserialize, Serialize}; +use std::any::TypeId; + +use crate::fusion::on_write::{ + ir::{Arg, ElemwiseConfig, GlobalArgsLaunch}, + trace::{FuseOnWriteTrace, TraceRunner}, +}; + +use super::args::FusedMatmulInputLaunch; +use super::spec::FusedMatmulSpec; + +#[derive(new)] +/// Fuse matmul operation followed by elemwise operations into a single kernel. +pub struct MatmulOptimization { + trace: FuseOnWriteTrace, + trace_fallback: FuseOnWriteTrace, + client: ComputeClient, + device: R::Device, + len: usize, + matmul: FusedMatmul, +} + +#[derive(Serialize, Deserialize, Debug)] +/// State for the [matrix optimization](MatmulOptimizationState). +pub struct MatmulOptimizationState { + trace: FuseOnWriteTrace, + trace_fallback: FuseOnWriteTrace, + matmul: FusedMatmul, + len: usize, +} + +impl MatmulOptimization { + /// Execute the optimization. + pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + if self.execute_fused::(context).is_err() { + self.execute_fallback::(context); + } + } + + /// Number of operations fused. + pub fn num_ops_fused(&self) -> usize { + self.len + } + + /// Create an optimization from its [state](MatmulOptimizationState). + pub fn from_state(device: &R::Device, state: MatmulOptimizationState) -> Self { + Self { + trace: state.trace, + trace_fallback: state.trace_fallback, + len: state.len, + client: R::client(device), + device: device.clone(), + matmul: state.matmul.clone(), + } + } + + /// Convert the optimization to its [state](MatmulOptimizationState). + pub fn to_state(&self) -> MatmulOptimizationState { + MatmulOptimizationState { + trace: self.trace.clone(), + trace_fallback: self.trace_fallback.clone(), + matmul: self.matmul.clone(), + len: self.len, + } + } + + fn execute_fused( + &mut self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace + .run::(&self.client, &self.device, context, &self.matmul) + } + + fn execute_fallback(&mut self, context: &mut Context<'_, JitFusionHandle>) { + match self.matmul.lhs.precision() { + ElemwisePrecision::F32 => self.run_fallback::(context), + ElemwisePrecision::F16 => self.run_fallback::(context), + ElemwisePrecision::BF16 => self.run_fallback::(context), + _ => panic!("Unsupported precision"), + } + } + + fn run_fallback( + &mut self, + context: &mut Context<'_, JitFusionHandle>, + ) { + let (out_tensor, out_desc) = { + let lhs = context.tensors.get(&self.matmul.op.lhs.id).unwrap().clone(); + let rhs = context.tensors.get(&self.matmul.op.rhs.id).unwrap().clone(); + let out = context.tensors.get(&self.matmul.op.out.id).unwrap().clone(); + + let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly); + let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly); + + let lhs_tensor = lhs_handle.into_tensor(Shape { + dims: lhs.shape.clone(), + }); + let rhs_tensor = rhs_handle.into_tensor(Shape { + dims: rhs.shape.clone(), + }); + let out_tensor = matmul::matmul::( + lhs_tensor, + rhs_tensor, + None, + matmul::MatmulStrategy::default(), + ); + (out_tensor, out) + }; + context + .handles + .register_handle(out_desc.id, JitFusionHandle::from(out_tensor)); + + self.trace_fallback + .run::(&self.client, &self.device, context, &ElemwiseRunner) + .unwrap(); + } +} + +#[derive(new, Clone, Serialize, Deserialize, Debug)] +pub struct FusedMatmul { + lhs: Arg, + rhs: Arg, + out: Arg, + op: BinaryOperationDescription, +} + +#[derive(Debug)] +pub enum FusedMatmulError { + LaunchError(MatmulLaunchError), + InvalidInput, +} + +impl From for FusedMatmulError { + fn from(value: MatmulLaunchError) -> Self { + Self::LaunchError(value) + } +} + +impl TraceRunner for FusedMatmul { + type Error = FusedMatmulError; + + fn run<'a>( + &'a self, + client: &'a ComputeClient, + inputs: GlobalArgsLaunch<'a, R>, + outputs: GlobalArgsLaunch<'a, R>, + config: &'a ElemwiseConfig, + ) -> Result<(), FusedMatmulError> { + match self.out.precision() { + ElemwisePrecision::F32 => self.matmul_fused::(client, inputs, outputs, config), + ElemwisePrecision::F16 => self.matmul_fused::(client, inputs, outputs, config), + ElemwisePrecision::BF16 => { + self.matmul_fused::(client, inputs, outputs, config) + } + _ => panic!("Unsupported precision"), + } + } +} + +impl FusedMatmul { + fn matmul_fused<'a, R: JitRuntime, EG: Numeric>( + &'a self, + client: &'a ComputeClient, + inputs: GlobalArgsLaunch<'a, R>, + outputs: GlobalArgsLaunch<'a, R>, + config: &'a ElemwiseConfig, + ) -> Result<(), FusedMatmulError> { + let lhs_shape = inputs.shape(&self.lhs); + let rhs_shape = inputs.shape(&self.rhs); + + let lhs_strides = inputs.strides(&self.lhs); + let rhs_strides = inputs.strides(&self.rhs); + + let check_layout = |strides| match matrix_layout(strides) { + MatrixLayout::Contiguous => (false, false), + MatrixLayout::MildlyPermuted { + transposed, + batch_swap: _, + } => (false, transposed), + MatrixLayout::HighlyPermuted => (true, false), + }; + + let (lhs_make_contiguous, lhs_transposed) = check_layout(lhs_strides); + let (rhs_make_contiguous, rhs_transposed) = check_layout(rhs_strides); + + if lhs_make_contiguous || rhs_make_contiguous { + return Err(FusedMatmulError::InvalidInput); + } + + let rank = lhs_shape.len(); + + let m = lhs_shape[rank - 2] as u32; + let k = lhs_shape[rank - 1] as u32; + let n = rhs_shape[rank - 1] as u32; + + let lhs_line_size = inputs.line_size(&self.lhs); + let rhs_line_size = inputs.line_size(&self.rhs); + let out_line_size = match config.ref_layout { + Arg::Input(..) => inputs.line_size(&config.ref_layout), + Arg::Output(..) => outputs.line_size(&config.ref_layout), + _ => panic!("Invalid ref layout"), + }; + + if out_line_size == 1 && (lhs_line_size > 1 || rhs_line_size > 1) { + return Err(FusedMatmulError::InvalidInput); + } + + let problem = MatmulProblem { + m: m as usize, + n: n as usize, + k: k as usize, + batches: ( + lhs_shape[..lhs_shape.len() - 2].to_vec(), + rhs_shape[..rhs_shape.len() - 2].to_vec(), + ), + lhs_layout: match lhs_transposed { + true => components::MatrixLayout::ColMajor, + false => components::MatrixLayout::RowMajor, + }, + rhs_layout: match rhs_transposed { + true => components::MatrixLayout::ColMajor, + false => components::MatrixLayout::RowMajor, + }, + lhs_line_size, + rhs_line_size, + out_line_size, + }; + + let plane_size = client + .properties() + .hardware_properties() + .defined_plane_size(); + + match plane_size { + Some(32) => matmul_launch_kernel::<32, R, EG>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + false, + problem, + ), + Some(64) => matmul_launch_kernel::<64, R, EG>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + false, + problem, + ), + Some(plane_dim) => Err(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::PlaneDimUnsupported { plane_dim }, + )), + None => Err(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::PlaneDimUnknown, + )), + }?; + + Ok(()) + } +} + +fn matmul_launch_kernel<'a, const PLANE_DIM: u32, R: Runtime, EG: Numeric>( + client: &ComputeClient, + input: FusedMatmulInputLaunch<'a, R>, + output: GlobalArgsLaunch<'a, R>, + disable_cmma: bool, + problem: MatmulProblem, +) -> Result<(), MatmulLaunchError> { + if disable_cmma { + PlaneMmaSelector::select_kernel::, R>( + client, input, output, problem, + ) + } else if TypeId::of::() == TypeId::of::() + || TypeId::of::() == TypeId::of::() + { + CmmaSelector::select_kernel::, R>( + client, input, output, problem, + ) + } else if TypeId::of::() == TypeId::of::() { + CmmaSelector::select_kernel::, R>( + client, input, output, problem, + ) + } else { + CmmaSelector::select_kernel::, R>( + client, input, output, problem, + ) + } +} diff --git a/crates/burn-jit/src/fusion/matmul/spec.rs b/crates/burn-jit/src/fusion/matmul/spec.rs new file mode 100644 index 0000000000..c7e5c910e8 --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/spec.rs @@ -0,0 +1,22 @@ +use super::args::FusedMatmulArgs; +use cubecl::{linalg::matmul::components::MatmulSpec, prelude::Numeric}; +use std::marker::PhantomData; + +/// Specification for a fused standard matmul. +#[derive(Clone)] +pub struct FusedMatmulSpec { + _eg: PhantomData, + _es: PhantomData, + _ea: PhantomData, +} + +impl MatmulSpec + for FusedMatmulSpec +{ + const PLANE_DIM: u32 = PLANE_DIM; + + type EG = EG; + type ES = ES; + type EA = EA; + type Args = FusedMatmulArgs; +} diff --git a/crates/burn-jit/src/fusion/mod.rs b/crates/burn-jit/src/fusion/mod.rs index f15f4263c6..4c44770b4e 100644 --- a/crates/burn-jit/src/fusion/mod.rs +++ b/crates/burn-jit/src/fusion/mod.rs @@ -1,6 +1,7 @@ mod base; pub(crate) mod elemwise; +pub(crate) mod matmul; pub(crate) mod on_write; pub use base::*; diff --git a/crates/burn-jit/src/fusion/on_write/builder.rs b/crates/burn-jit/src/fusion/on_write/builder.rs index 287b656274..bf31ef78ea 100644 --- a/crates/burn-jit/src/fusion/on_write/builder.rs +++ b/crates/burn-jit/src/fusion/on_write/builder.rs @@ -147,6 +147,22 @@ impl FuseOnWriteBuilder { } } + pub fn close(&mut self) { + self.status = OptimizationStatus::Closed; + } + + pub fn input_unhandled(&mut self, tensor: &TensorDescription) -> Arg { + self.builder.builder.input_unhandled(tensor) + } + + pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg { + if self.current_output_shape.is_empty() { + self.current_output_shape = tensor.shape.clone(); + } + + self.builder.builder.output_unhandled(tensor) + } + fn register_base(&mut self, ops: &BaseOperationDescription) -> bool { match ops { BaseOperationDescription::Equal(desc) => self diff --git a/crates/burn-jit/src/fusion/on_write/io.rs b/crates/burn-jit/src/fusion/on_write/io.rs index bd1fe3ddf3..497bc510df 100644 --- a/crates/burn-jit/src/fusion/on_write/io.rs +++ b/crates/burn-jit/src/fusion/on_write/io.rs @@ -12,210 +12,12 @@ pub fn read( #[comptime] config: &ElemwiseConfig, ) -> Line { match arg { - Arg::Input(pos, precision, layout) => match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = inputs.t_f32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::F16 => { - let tensor = inputs.t_f16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::BF16 => { - let tensor = inputs.t_bf16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U64 => { - let tensor = inputs.t_u64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U32 => { - let tensor = inputs.t_u32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U16 => { - let tensor = inputs.t_u16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U8 => { - let tensor = inputs.t_u8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I64 => { - let tensor = inputs.t_i64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I32 => { - let tensor = inputs.t_i32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I16 => { - let tensor = inputs.t_i16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I8 => { - let tensor = inputs.t_i8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, - Arg::Output(pos, precision, layout) => match comptime![precision] { - ElemwisePrecision::F32 => { - let tensor = outputs.t_f32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::F16 => { - let tensor = outputs.t_f16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::BF16 => { - let tensor = outputs.t_bf16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U64 => { - let tensor = outputs.t_u64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U32 => { - let tensor = outputs.t_u32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U16 => { - let tensor = outputs.t_u16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::U8 => { - let tensor = outputs.t_u8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I64 => { - let tensor = outputs.t_i64.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I32 => { - let tensor = outputs.t_i32.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I16 => { - let tensor = outputs.t_i16.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - ElemwisePrecision::I8 => { - let tensor = outputs.t_i8.index(pos); - let offset = match layout { - LayoutInfo::SameAsRef => ref_pos, - LayoutInfo::IsRef => ref_pos, - LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), - }; - Line::cast_from(tensor[offset]) - } - _ => comptime![panic!("Unsupported precision {precision:?}")], - }, + Arg::Input(pos, precision, layout) => { + read_input(inputs, outputs, pos, ref_pos, layout, precision, config) + } + Arg::Output(pos, precision, layout) => { + read_output(inputs, outputs, pos, ref_pos, layout, precision, config) + } Arg::Local(pos, precision) => match comptime![precision] { ElemwisePrecision::F32 => Line::cast_from(locals.l_f32.find(pos)), ElemwisePrecision::F16 => Line::cast_from(locals.l_f16.find(pos)), @@ -248,6 +50,234 @@ pub fn read( } } +#[cube] +pub fn read_input( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + #[comptime] pos: u32, + ref_pos: u32, + #[comptime] layout: LayoutInfo, + #[comptime] precision: ElemwisePrecision, + #[comptime] config: &ElemwiseConfig, +) -> Line { + match comptime![precision] { + ElemwisePrecision::F32 => { + let tensor = inputs.t_f32.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::F16 => { + let tensor = inputs.t_f16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::BF16 => { + let tensor = inputs.t_bf16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U64 => { + let tensor = inputs.t_u64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U32 => { + let tensor = inputs.t_u32.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U16 => { + let tensor = inputs.t_u16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U8 => { + let tensor = inputs.t_u8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I64 => { + let tensor = inputs.t_i64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I32 => { + let tensor = inputs.t_i32.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I16 => { + let tensor = inputs.t_i16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I8 => { + let tensor = inputs.t_i8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + _ => comptime![panic!("Unsupported precision {precision:?}")], + } +} + +#[cube] +pub fn read_output( + inputs: &GlobalArgs, + outputs: &GlobalArgs, + pos: u32, + ref_pos: u32, + #[comptime] layout: LayoutInfo, + #[comptime] precision: ElemwisePrecision, + #[comptime] config: &ElemwiseConfig, +) -> Line { + match comptime![precision] { + ElemwisePrecision::F32 => { + let tensor = outputs.t_f32.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::F16 => { + let tensor = outputs.t_f16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::BF16 => { + let tensor = outputs.t_bf16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U64 => { + let tensor = outputs.t_u64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U32 => { + let tensor = outputs.t_u32.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U16 => { + let tensor = outputs.t_u16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::U8 => { + let tensor = outputs.t_u8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I64 => { + let tensor = outputs.t_i64.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I32 => { + let tensor = outputs.t_i32.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I16 => { + let tensor = outputs.t_i16.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + ElemwisePrecision::I8 => { + let tensor = outputs.t_i8.index(pos); + let offset = match layout { + LayoutInfo::SameAsRef => ref_pos, + LayoutInfo::IsRef => ref_pos, + LayoutInfo::Unknown => get_offset(inputs, outputs, tensor, ref_pos, config), + }; + Line::cast_from(tensor[offset]) + } + _ => comptime![panic!("Unsupported precision {precision:?}")], + } +} + #[cube] /// Write the given value at the [arg](Arg) position. pub fn write( @@ -497,3 +527,168 @@ fn get_offset( _ => comptime![panic!("Invalid ref layout.")], } } +#[cube] +pub fn global_rank( + global: &GlobalArgs, + #[comptime] pos: u32, + #[comptime] precision: ElemwisePrecision, +) -> u32 { + match comptime![precision] { + ElemwisePrecision::F32 => { + let tensor = global.t_f32.index(pos); + tensor.rank() + } + ElemwisePrecision::F16 => { + let tensor = global.t_f16.index(pos); + tensor.rank() + } + ElemwisePrecision::BF16 => { + let tensor = global.t_bf16.index(pos); + tensor.rank() + } + ElemwisePrecision::U64 => { + let tensor = global.t_u64.index(pos); + tensor.rank() + } + ElemwisePrecision::U32 => { + let tensor = global.t_u32.index(pos); + tensor.rank() + } + ElemwisePrecision::U16 => { + let tensor = global.t_u16.index(pos); + tensor.rank() + } + ElemwisePrecision::U8 => { + let tensor = global.t_u8.index(pos); + tensor.rank() + } + ElemwisePrecision::I64 => { + let tensor = global.t_i64.index(pos); + tensor.rank() + } + ElemwisePrecision::I32 => { + let tensor = global.t_i32.index(pos); + tensor.rank() + } + ElemwisePrecision::I16 => { + let tensor = global.t_i16.index(pos); + tensor.rank() + } + ElemwisePrecision::I8 => { + let tensor = global.t_i8.index(pos); + tensor.rank() + } + _ => comptime![panic!("Unsupported precision {precision:?}")], + } +} +#[cube] +pub fn global_shape( + global: &GlobalArgs, + dim: u32, + #[comptime] pos: u32, + #[comptime] precision: ElemwisePrecision, +) -> u32 { + match comptime![precision] { + ElemwisePrecision::F32 => { + let tensor = global.t_f32.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::F16 => { + let tensor = global.t_f16.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::BF16 => { + let tensor = global.t_bf16.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::U64 => { + let tensor = global.t_u64.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::U32 => { + let tensor = global.t_u32.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::U16 => { + let tensor = global.t_u16.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::U8 => { + let tensor = global.t_u8.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::I64 => { + let tensor = global.t_i64.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::I32 => { + let tensor = global.t_i32.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::I16 => { + let tensor = global.t_i16.index(pos); + tensor.shape(dim) + } + ElemwisePrecision::I8 => { + let tensor = global.t_i8.index(pos); + tensor.shape(dim) + } + _ => comptime![panic!("Unsupported precision {precision:?}")], + } +} + +#[cube] +pub fn global_stride( + global: &GlobalArgs, + dim: u32, + #[comptime] pos: u32, + #[comptime] precision: ElemwisePrecision, +) -> u32 { + match comptime![precision] { + ElemwisePrecision::F32 => { + let tensor = global.t_f32.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::F16 => { + let tensor = global.t_f16.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::BF16 => { + let tensor = global.t_bf16.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::U64 => { + let tensor = global.t_u64.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::U32 => { + let tensor = global.t_u32.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::U16 => { + let tensor = global.t_u16.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::U8 => { + let tensor = global.t_u8.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::I64 => { + let tensor = global.t_i64.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::I32 => { + let tensor = global.t_i32.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::I16 => { + let tensor = global.t_i16.index(pos); + tensor.stride(dim) + } + ElemwisePrecision::I8 => { + let tensor = global.t_i8.index(pos); + tensor.stride(dim) + } + _ => comptime![panic!("Unsupported precision {precision:?}")], + } +} diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index eb99d38323..9e83ba1c37 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -4,9 +4,7 @@ use cubecl::prelude::*; use half::{bf16, f16}; use serde::{Deserialize, Serialize}; -#[derive( - CubeType, Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord, -)] +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, PartialOrd, Ord)] /// Argument to an [elemwise operation](ElemwiseOp). pub enum Arg { Input(u32, ElemwisePrecision, LayoutInfo), @@ -42,6 +40,22 @@ impl Arg { } } +impl CubeType for Arg { + type ExpandType = Self; +} + +impl Init for Arg { + fn init(self, _context: &mut CubeContext) -> Self { + self + } +} + +impl IntoRuntime for Arg { + fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { + self + } +} + #[derive(CubeType, Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)] /// Operations that can be executed and fused. pub enum ElemwiseOp { @@ -100,6 +114,86 @@ pub struct GlobalArgs { pub s_u8: Sequence, } +impl GlobalArgsLaunch<'_, R> { + /// Get the shape of the given [argument](Arg). + /// + /// # Panics + /// + /// If the argument doesn't have an handle. + pub fn shape(&self, arg: &Arg) -> &[usize] { + match self.resolve_arg(arg) { + TensorArg::Handle { handle, .. } => handle.shape, + TensorArg::Alias { .. } => panic!("Unsupported yet"), + } + } + + /// Get the strides of the given [argument](Arg). + /// + /// # Panics + /// + /// If the argument doesn't have an handle. + pub fn strides(&self, arg: &Arg) -> &[usize] { + match self.resolve_arg(arg) { + TensorArg::Handle { handle, .. } => handle.strides, + TensorArg::Alias { .. } => panic!("Unsupported yet"), + } + } + + /// Get the line size of the given [argument](Arg). + /// + /// # Panics + /// + /// If the argument doesn't have an handle. + pub fn line_size(&self, arg: &Arg) -> u8 { + match self.resolve_arg(arg) { + TensorArg::Handle { + vectorization_factor, + .. + } => *vectorization_factor, + TensorArg::Alias { .. } => panic!("Unsupported yet"), + } + } + + /// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg). + /// + /// # Panics + /// + /// If the argument isn't a global input or output tensor. + pub fn resolve_arg(&self, arg: &Arg) -> &TensorArg<'_, R> { + match arg { + Arg::Input(pos, precision, _) => match precision { + ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize], + ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize], + ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize], + ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize], + ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize], + ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize], + ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize], + ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize], + ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize], + ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize], + ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize], + ElemwisePrecision::Bool => panic!("Unsupported yet"), + }, + Arg::Output(pos, precision, _) => match precision { + ElemwisePrecision::F32 => &self.t_f32.values[*pos as usize], + ElemwisePrecision::F16 => &self.t_f16.values[*pos as usize], + ElemwisePrecision::BF16 => &self.t_bf16.values[*pos as usize], + ElemwisePrecision::I64 => &self.t_i64.values[*pos as usize], + ElemwisePrecision::I32 => &self.t_i32.values[*pos as usize], + ElemwisePrecision::I16 => &self.t_i16.values[*pos as usize], + ElemwisePrecision::I8 => &self.t_i8.values[*pos as usize], + ElemwisePrecision::U64 => &self.t_u64.values[*pos as usize], + ElemwisePrecision::U32 => &self.t_u32.values[*pos as usize], + ElemwisePrecision::U16 => &self.t_u16.values[*pos as usize], + ElemwisePrecision::U8 => &self.t_u8.values[*pos as usize], + ElemwisePrecision::Bool => panic!("Unsupported yet"), + }, + _ => panic!("Only input & output can have a shape"), + } + } +} + #[derive(CubeType, Clone)] /// Keep track of all local variables that are used as argument in fused /// [element wise operations](ElemwiseOp). diff --git a/crates/burn-jit/src/fusion/on_write/trace.rs b/crates/burn-jit/src/fusion/on_write/trace.rs index d9ec09aea8..2c29d05ce8 100644 --- a/crates/burn-jit/src/fusion/on_write/trace.rs +++ b/crates/burn-jit/src/fusion/on_write/trace.rs @@ -13,7 +13,7 @@ use cubecl::{ir::Elem, prelude::*}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; -#[derive(new, Clone, Serialize, Deserialize)] +#[derive(new, Clone, Serialize, Deserialize, Debug)] /// Trace containing all element wise operations as well as reads and writes. pub struct FuseOnWriteTrace { outputs: RegisteredTensors, @@ -22,33 +22,86 @@ pub struct FuseOnWriteTrace { ops: Vec, reads: BTreeMap, writes: BTreeMap, + inputs_unhandled: Vec, } /// A trace runner is responsible for determining the vectorization factor as well as launching /// a kernel based on global [inputs](GlobalArgsLaunch) and [outputs](GlobalArgsLaunch) /// with a provided [element wise config](ElemwiseConfig). pub trait TraceRunner { + /// The error that might happen while running the trace. + type Error; + /// Run the trace. fn run<'a>( - client: &ComputeClient, + &'a self, + client: &'a ComputeClient, inputs: GlobalArgsLaunch<'a, R>, outputs: GlobalArgsLaunch<'a, R>, - config: ElemwiseConfig, - ); + config: &'a ElemwiseConfig, + ) -> Result<(), Self::Error>; + /// The vectorization factor for all inputs and outputs. fn vectorization<'a>( handles_inputs: impl Iterator>, inputs: impl Iterator, outputs: impl Iterator, - ) -> u8; + ) -> u8 { + // The default version uses the last dimension as vectorization axis and assumes a + // perpendicular contiguous line. + + let vectorization_input = |handle: &JitFusionHandle, desc: &TensorDescription| { + let rank = handle.strides.len(); + + // Last dimension strides should be 1, otherwise vecX won't be contiguous. + if handle.strides[rank - 1] != 1 { + return 1; + } + + for s in R::line_size_elem(&desc.dtype.into()) { + // The last dimension should be a multiple of the vector size. + if desc.shape[rank - 1] % s as usize == 0 { + return s; + } + } + + 1 + }; + + let vectorization_output = |desc: &TensorDescription| { + let rank = desc.shape.len(); + + for s in R::line_size_elem(&desc.dtype.into()) { + // The last dimension should be a multiple of the vector size. + if desc.shape[rank - 1] % s as usize == 0 { + return s; + } + } + + 1 + }; + + let mut output = u8::MAX; + + for (handle, tensor) in handles_inputs.zip(inputs) { + output = Ord::min(vectorization_input(handle, tensor), output); + } + + for tensor in outputs { + output = Ord::min(vectorization_output(tensor), output); + } + + output + } } -struct LaunchAnalysis<'a, 'c, R: JitRuntime> { +#[derive(Debug)] +struct LaunchAnalysis<'a, R: JitRuntime> { potential_inplaces: Vec>, - global_inputs: Vec<&'c TensorDescription>, - global_outputs: Vec<&'c TensorDescription>, - handle_inputs: Vec>, - handle_outputs: Vec>, + global_inputs: Vec, + global_outputs: Vec, + handle_inputs: Vec>, + handle_outputs: Vec>, reference: Option, reads: BTreeMap, writes: BTreeMap, @@ -57,31 +110,36 @@ struct LaunchAnalysis<'a, 'c, R: JitRuntime> { } #[derive(Debug)] -enum HandleOutput<'c, R: JitRuntime> { +enum HandleOutput { Alias { input_pos: usize, precision: ElemwisePrecision, }, Owned { + global_id: TensorId, precision: ElemwisePrecision, handle: JitFusionHandle, - global_shape: &'c [usize], + global_shape: Vec, }, } -struct HandleInput<'c, R: JitRuntime> { +#[derive(Debug)] +struct HandleInput { relative_id: TensorId, + global_id: TensorId, precision: ElemwisePrecision, handle: JitFusionHandle, - global_shape: &'c [usize], + global_shape: Vec, } +#[derive(Debug)] struct Reference { layout: Arg, shape: Vec, strides: Vec, } +#[derive(Debug)] struct PotentialInplace<'a> { input_pos: usize, tensor_relative: &'a TensorDescription, @@ -95,7 +153,8 @@ impl FuseOnWriteTrace { client: &ComputeClient, device: &R::Device, context: &mut Context<'_, JitFusionHandle>, - ) { + runner: &Runner, + ) -> Result<(), Runner::Error> { let analysis = self.analyse::(client, device, context); let inputs = self.register_inputs(context, &analysis.handle_inputs, analysis.vectorization); @@ -124,15 +183,42 @@ impl FuseOnWriteTrace { ops, }; - Runner::run(client, inputs, outputs, config) + match Runner::run(runner, client, inputs, outputs, &config) { + Err(err) => { + self.rollback(context, analysis.handle_inputs, analysis.handle_outputs); + Err(err) + } + Ok(val) => Ok(val), + } + } + + fn rollback( + &self, + context: &mut Context<'_, JitFusionHandle>, + handle_inputs: Vec>, + handle_outputs: Vec>, + ) { + for input in handle_inputs { + context + .handles + .register_handle(input.global_id, input.handle); + } + for output in handle_outputs { + if let HandleOutput::Owned { + global_id, handle, .. + } = output + { + context.handles.register_handle(global_id, handle); + } + } } - fn analyse<'a, 'c, R: JitRuntime, BT: BoolElement, Runner: TraceRunner>( + fn analyse<'a, R: JitRuntime, BT: BoolElement, Runner: TraceRunner>( &'a self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'c, JitFusionHandle>, - ) -> LaunchAnalysis<'a, 'c, R> { + context: &mut Context<'_, JitFusionHandle>, + ) -> LaunchAnalysis<'a, R> { let mut analysis = LaunchAnalysis { potential_inplaces: Vec::new(), global_inputs: Vec::new(), @@ -151,27 +237,30 @@ impl FuseOnWriteTrace { analysis.vectorization = Runner::vectorization( analysis.handle_inputs.iter().map(|item| &item.handle), - analysis.global_inputs.iter().copied(), - analysis.global_outputs.iter().copied(), + analysis.global_inputs.iter(), + analysis.global_outputs.iter(), ); analysis } - fn analyse_inputs<'a, 'c, R: JitRuntime>( + fn analyse_inputs<'a, R: JitRuntime>( &'a self, - context: &mut Context<'c, JitFusionHandle>, - analysis: &mut LaunchAnalysis<'a, 'c, R>, + context: &mut Context<'_, JitFusionHandle>, + analysis: &mut LaunchAnalysis<'a, R>, ) { for (i, (precision, tensor_relative)) in self.inputs.iter().enumerate() { - let tensor_global = context.tensors.get(&tensor_relative.id).unwrap(); + let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); // Important to take the status of the relative graph and not // the global graph, since the status of the global graph // might be of a later operation on the same tensor id. let status = &tensor_relative.status; let handle = context.handles.get_handle(&tensor_global.id, status); - if status == &TensorStatus::ReadWrite && handle.handle.can_mut() { + if status == &TensorStatus::ReadWrite + && handle.handle.can_mut() + && !self.inputs_unhandled.contains(&tensor_relative.id) + { analysis.potential_inplaces.push(PotentialInplace { input_pos: i, tensor_relative, @@ -179,28 +268,28 @@ impl FuseOnWriteTrace { }); } - analysis.global_inputs.push(tensor_global); analysis.rank = usize::max(tensor_global.shape.len(), analysis.rank); analysis.handle_inputs.push(HandleInput { precision, handle, relative_id: tensor_relative.id, - global_shape: &tensor_global.shape, + global_id: tensor_global.id, + global_shape: tensor_global.shape.clone(), }); + analysis.global_inputs.push(tensor_global); } } - fn analyse_outputs<'a, 'c, R: JitRuntime, BT: BoolElement>( + fn analyse_outputs<'a, R: JitRuntime, BT: BoolElement>( &'a self, client: &ComputeClient, device: &R::Device, - context: &mut Context<'c, JitFusionHandle>, - analysis: &mut LaunchAnalysis<'a, 'c, R>, + context: &mut Context<'_, JitFusionHandle>, + analysis: &mut LaunchAnalysis<'a, R>, ) { for (precision, tensor_relative) in self.outputs.iter() { - let tensor_global = context.tensors.get(&tensor_relative.id).unwrap(); + let tensor_global = context.tensors.get(&tensor_relative.id).unwrap().clone(); let strides = strides_dyn_rank(&tensor_global.shape); - analysis.global_outputs.push(tensor_global); if let Some(index) = analysis .potential_inplaces @@ -231,14 +320,14 @@ impl FuseOnWriteTrace { strides: handle_input.handle.strides.clone(), }); - if let ElemwiseOp::Assign(op) = - analysis.reads.get_mut(&handle_input.relative_id).unwrap() + if let Some(ElemwiseOp::Assign(op)) = + analysis.reads.get_mut(&handle_input.relative_id) { op.input.add_layout_info(LayoutInfo::IsRef); }; - if let ElemwiseOp::Assign(op) = - analysis.writes.get_mut(&tensor_relative.id).unwrap() + if let Some(ElemwiseOp::Assign(op)) = + analysis.writes.get_mut(&tensor_relative.id) { op.out.add_layout_info(LayoutInfo::IsRef); }; @@ -251,6 +340,7 @@ impl FuseOnWriteTrace { input_pos: potential_inplace.input_pos, precision, }); + analysis.global_outputs.push(tensor_global); } else { if analysis.reference.is_none() { analysis.reference = Some(Reference { @@ -297,20 +387,21 @@ impl FuseOnWriteTrace { analysis.handle_outputs.push(HandleOutput::Owned { precision, handle, - global_shape: &tensor_global.shape, + global_shape: tensor_global.shape.clone(), + global_id: tensor_global.id, }); + analysis.global_outputs.push(tensor_global); } } Self::add_layout_info_inputs(analysis); } - fn add_layout_info_inputs(analysis: &mut LaunchAnalysis<'_, '_, R>) { + fn add_layout_info_inputs(analysis: &mut LaunchAnalysis<'_, R>) { for hi in analysis.handle_inputs.iter() { if let Some(reference) = analysis.reference.as_ref() { if reference.strides == hi.handle.strides && reference.shape == hi.global_shape { - if let ElemwiseOp::Assign(op) = analysis.reads.get_mut(&hi.relative_id).unwrap() - { + if let Some(ElemwiseOp::Assign(op)) = analysis.reads.get_mut(&hi.relative_id) { op.input.add_layout_info(LayoutInfo::SameAsRef); } } @@ -318,10 +409,10 @@ impl FuseOnWriteTrace { } } - fn register_inputs<'c, 'h, R: JitRuntime>( + fn register_inputs<'h, R: JitRuntime>( &self, - context: &mut Context<'c, JitFusionHandle>, - handle_inputs: &'h [HandleInput<'c, R>], + context: &mut Context<'_, JitFusionHandle>, + handle_inputs: &'h [HandleInput], vectorization: u8, ) -> GlobalArgsLaunch<'h, R> { let mut inputs = GlobalArgsLaunch::new( @@ -350,7 +441,7 @@ impl FuseOnWriteTrace { ); for hi in handle_inputs.iter() { - let arg = hi.handle.as_tensor_arg(hi.global_shape, vectorization); + let arg = hi.handle.as_tensor_arg(&hi.global_shape, vectorization); match hi.precision { ElemwisePrecision::F32 => inputs.t_f32.push(arg), ElemwisePrecision::F16 => inputs.t_f16.push(arg), @@ -409,7 +500,7 @@ impl FuseOnWriteTrace { fn register_outputs<'s, R: JitRuntime, BT: BoolElement>( &self, - handle_outputs: &'s [HandleOutput<'_, R>], + handle_outputs: &'s [HandleOutput], vectorization: u8, ) -> GlobalArgsLaunch<'s, R> { let mut outputs = GlobalArgsLaunch::new( @@ -459,6 +550,7 @@ impl FuseOnWriteTrace { precision, handle, global_shape, + .. } => { let arg = handle.as_tensor_arg(global_shape, vectorization); @@ -488,7 +580,7 @@ impl FuseOnWriteTrace { } } -#[derive(Default, Clone, Serialize, Deserialize)] +#[derive(Default, Clone, Serialize, Deserialize, Debug)] pub struct RegisteredTensors { tensors: BTreeMap>, } diff --git a/crates/burn-jit/src/fusion/on_write/trace_builder.rs b/crates/burn-jit/src/fusion/on_write/trace_builder.rs index 5cb427814d..c37237ae4e 100644 --- a/crates/burn-jit/src/fusion/on_write/trace_builder.rs +++ b/crates/burn-jit/src/fusion/on_write/trace_builder.rs @@ -17,6 +17,8 @@ pub struct FuseOnWriteTraceBuilder { ops: Vec, reads: BTreeMap, pub bool_precision: ElemwisePrecision, + outputs_unhandled: Vec, + inputs_unhandled: Vec, } impl FuseOnWriteTraceBuilder { @@ -29,6 +31,8 @@ impl FuseOnWriteTraceBuilder { ops: Vec::new(), reads: BTreeMap::new(), bool_precision, + outputs_unhandled: Vec::new(), + inputs_unhandled: Vec::new(), } } @@ -48,6 +52,27 @@ impl FuseOnWriteTraceBuilder { meta + inputs + outputs + scalar } + pub fn output_unhandled(&mut self, tensor: &TensorDescription) -> Arg { + let arg = self.output(tensor); + self.outputs_unhandled.push(arg); + arg + } + + pub fn input_unhandled(&mut self, tensor: &TensorDescription) -> Arg { + let precision = tensor.dtype.into(); + + // Bool tensors are encoded as bool_precision. + let precision_input = match precision { + ElemwisePrecision::Bool => self.bool_precision, + _ => precision, + }; + let new_input = self.inputs.insert(precision_input, tensor.clone()); + let arg = Arg::Input(new_input, precision_input, LayoutInfo::Unknown); + + self.inputs_unhandled.push(tensor.id); + arg + } + pub fn input(&mut self, tensor: &TensorDescription) -> Arg { let precision = tensor.dtype.into(); @@ -141,7 +166,15 @@ impl FuseOnWriteTraceBuilder { } // Current problem is that I need btreemap instead of sequences. - FuseOnWriteTrace::new(outputs, inputs, scalars, ops, reads, writes) + FuseOnWriteTrace::new( + outputs, + inputs, + scalars, + ops, + reads, + writes, + self.inputs_unhandled.clone(), + ) } fn output_tensors(&self) -> RegisteredTensors { @@ -309,6 +342,10 @@ impl FuseOnWriteTraceBuilder { mark_op(op); } + for arg in self.outputs_unhandled.iter() { + mark(arg, &mut local_tensor_ids_output); + } + // All output tensors that are never read by a following operation should be written to // since they are essentially the "logical" output of the shader. for entry in local_tensor_ids_output { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs index 7a210e7c70..c132576534 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs @@ -4,11 +4,7 @@ use cubecl::{ linalg::matmul::{ components::{ stage::{self, StageSize}, - tile::{ - self, - accelerated::{Accelerated16x16x16, CmmaValid}, - Matmul as _, - }, + tile::{self, accelerated::Accelerated16x16x16, Matmul as _}, MatmulKernel, }, kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, @@ -19,23 +15,17 @@ use cubecl::{ use super::{ base::{Convolution, ConvolutionKernel, ConvolutionLaunch, ConvolutionProblem}, homogeneous::base::ImplicitGemmConvolution, + spec::ConvSpec, }; /// Specifications for a convolution algorithm -pub trait Algorithm { - const PLANE_DIM: u32; - - type EG: Numeric; - type ES: Numeric; - type EA: Numeric; - - type TileMatmul: tile::Matmul + MatmulKernel; +pub trait Algorithm { + type TileMatmul: tile::Matmul + MatmulKernel; type StageSize: StageSize; - type StageMatmul: stage::Matmul + MatmulKernel; + type StageMatmul: stage::Matmul + MatmulKernel; - type GlobalConvolution: Convolution - + ConvolutionLaunch; + type GlobalConvolution: Convolution + ConvolutionLaunch; /// Cube dim for launch fn cube_dim() -> CubeDim; @@ -48,7 +38,7 @@ pub trait Algorithm { cube_dim: &CubeDim, cube_count: &CubeCount, advanced_config: &AdvancedConfig, - ) -> >::Config { + ) -> >::Config { Self::GlobalConvolution::make_config(problem, cube_dim, cube_count, advanced_config) } @@ -78,39 +68,21 @@ pub trait Algorithm { } /// Cmma convolution -pub struct Cmma { - pub _eg: PhantomData, - pub _es: PhantomData, - pub _ea: PhantomData, +pub struct Cmma { + pub _cp: PhantomData, pub _stage: PhantomData, } -impl Algorithm - for Cmma -where - (ES, EA): CmmaValid, -{ - const PLANE_DIM: u32 = 32; - type EG = EG; - type ES = ES; - type EA = EA; - - type TileMatmul = Accelerated16x16x16; - +impl Algorithm for Cmma { + type TileMatmul = Accelerated16x16x16; type StageSize = Stage; - type StageMatmul = stage::multi_buffer::Matmul< - Self::ES, - Self::EG, - Self::EA, - Self::TileMatmul, - Self::StageSize, - >; + type StageMatmul = + stage::multi_buffer::Matmul; - type GlobalConvolution = - ImplicitGemmConvolution; + type GlobalConvolution = ImplicitGemmConvolution; fn cube_dim() -> CubeDim { - CubeDim::new(Self::PLANE_DIM, Self::StageSize::NUM_M, 1) + CubeDim::new(CS::PLANE_DIM, Self::StageSize::NUM_M, 1) } fn cube_count(problem: &ConvolutionProblem) -> CubeCount { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs index bc242107f9..22977f293a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs @@ -1,24 +1,27 @@ use burn_tensor::ops::ConvOptions; -use cubecl::linalg::matmul::{ - components::{ - global::{AccumulatorLoader, Unloader}, - stage, MatmulProblem, MatrixLayout, +use cubecl::linalg::{ + matmul::{ + components::{ + global::{AccumulatorLoader, Unloader}, + stage, MatmulProblem, MatrixLayout, + }, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + tensor::{ReadWrite, VirtualTensor}, }; use cubecl::prelude::*; -use super::Config; +use super::{spec::ConvSpec, Config}; #[cube] -pub trait Convolution>: - 'static + Send + Sync + ConvolutionKernel +pub trait Convolution>: + 'static + Send + Sync + ConvolutionKernel { type LhsLoader: CubeType; type RhsLoader: CubeType; - type AccumulatorLoader: AccumulatorLoader; + type AccumulatorLoader: AccumulatorLoader; - type Out: Unloader; + type Out: Unloader; type Accumulator: CubeType; /// Performs the convolution over data loaded by the @@ -38,27 +41,31 @@ pub trait Convolution>, + lhs: VirtualTensor, x_offset: u32, y_offset: u32, #[comptime] config: Self::Config, ) -> Self::LhsLoader; fn init_rhs_loader( - rhs: &Tensor>, + rhs: VirtualTensor, x_offset: u32, y_offset: u32, #[comptime] config: Self::Config, ) -> Self::RhsLoader; fn init_bias_loader( - rhs: &Tensor>, + bias: VirtualTensor, n_offset: u32, #[comptime] config: Self::Config, #[comptime] has_bias: bool, ) -> Self::AccumulatorLoader; - fn init_unloader(out: &mut Tensor>, x_offset: u32, y_offset: u32) -> Self::Out; + fn init_unloader( + out: VirtualTensor, + x_offset: u32, + y_offset: u32, + ) -> Self::Out; fn init_accumulator(#[comptime] config: Self::Config) -> Self::Accumulator; } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index 3a69960972..a3ccc73296 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -1,20 +1,23 @@ use cubecl::{ - linalg::matmul::{ - components::{ - global::{ - self, - full_load::{self, CyclicLoading, RhsLoader}, - unloader::Unloader, - AccumulatorLoader, Config as _, Loader, + linalg::{ + matmul::{ + components::{ + global::{ + self, + full_load::{self, CyclicLoading, RhsLoader}, + unloader::Unloader, + AccumulatorLoader, Config as _, Loader, + }, + stage::{ + self, + multi_buffer::{LhsReader, RhsReader}, + TilingOrderConfig, + }, + Ident, MatrixLayout, StageDim, }, - stage::{ - self, - multi_buffer::{LhsReader, RhsReader}, - TilingOrderConfig, - }, - Ident, MatrixLayout, StageDim, + kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + tensor::{ReadWrite, VirtualTensor}, }, prelude::*, }; @@ -23,46 +26,36 @@ use std::marker::PhantomData; use crate::kernel::conv::{ conv2d::gemm::base::{Convolution, ConvolutionKernel, ConvolutionLaunch, ConvolutionProblem}, loader::im2col::SimpleIm2colLoader, + spec::ConvSpec, }; use crate::kernel::conv::{conv2d::gemm::Config as _, loader::bias::BiasLoader}; /// Performs matrix multiplication at the global level, with each plane sharing the same responsibilities /// - All planes load data to the stage /// - All planes are used in the stage matmul computation -pub struct ImplicitGemmConvolution< - EG: Numeric, - ES: Numeric, - Acc: Numeric, - SMM: stage::Matmul, -> { - _eg: PhantomData, - _es: PhantomData, - _acc: PhantomData, +pub struct ImplicitGemmConvolution> { + _cs: PhantomData, _stage_matmul: PhantomData, } #[cube] -impl Convolution - for ImplicitGemmConvolution +impl Convolution for ImplicitGemmConvolution where - EG: Numeric, - ES: Numeric, - Acc: Numeric, SMMConf: stage::Config, SMM: stage::Matmul< - ES, - EG, - Acc, - LhsReader = LhsReader, - RhsReader = RhsReader, + CS::ES, + CS::EG, + CS::EA, + LhsReader = LhsReader, + RhsReader = RhsReader, Config = SMMConf, >, { - type LhsLoader = SimpleIm2colLoader; - type RhsLoader = RhsLoader; - type AccumulatorLoader = BiasLoader; + type LhsLoader = SimpleIm2colLoader; + type RhsLoader = RhsLoader; + type AccumulatorLoader = BiasLoader; - type Out = Unloader; + type Out = Unloader; type Accumulator = SMM::Accumulator; fn execute( @@ -126,7 +119,7 @@ where } fn init_lhs_loader( - lhs: &Tensor>, + lhs: VirtualTensor, x_offset: u32, y_offset: u32, #[comptime] config: Self::Config, @@ -142,7 +135,7 @@ where } fn init_rhs_loader( - rhs: &Tensor>, + rhs: VirtualTensor, x_offset: u32, y_offset: u32, #[comptime] config: Self::Config, @@ -151,7 +144,7 @@ where } fn init_bias_loader( - bias: &Tensor>, + bias: VirtualTensor, n_offset: u32, #[comptime] config: Self::Config, #[comptime] has_bias: bool, @@ -159,7 +152,11 @@ where Self::AccumulatorLoader::new(bias, n_offset, config.to_smm_config(), has_bias) } - fn init_unloader(out: &mut Tensor>, x_offset: u32, y_offset: u32) -> Self::Out { + fn init_unloader( + out: VirtualTensor, + x_offset: u32, + y_offset: u32, + ) -> Self::Out { Self::Out::new(out, x_offset, y_offset, 0) } @@ -168,12 +165,9 @@ where } } -impl ConvolutionKernel for ImplicitGemmConvolution +impl ConvolutionKernel for ImplicitGemmConvolution where - EG: Numeric, - ES: Numeric, - Acc: Numeric, - SMM: stage::Matmul, + SMM: stage::Matmul, { type Config = config::Config>; @@ -221,11 +215,15 @@ where } impl< - EG: Numeric, - ES: Numeric, - Acc: Numeric, - SMM: stage::Matmul, RhsReader = RhsReader>, - > ConvolutionLaunch for ImplicitGemmConvolution + CS: ConvSpec, + SMM: stage::Matmul< + CS::ES, + CS::EG, + CS::EA, + LhsReader = LhsReader, + RhsReader = RhsReader, + >, + > ConvolutionLaunch for ImplicitGemmConvolution { unsafe fn launch_unchecked( client: &ComputeClient<::Server, ::Channel>, @@ -235,11 +233,11 @@ impl< weight: TensorArg<'_, R>, bias: TensorArg<'_, R>, out: TensorArg<'_, R>, - config: >::Config, + config: >::Config, ) { Self::check_config(config); - implicit_conv::launch_unchecked::( + implicit_conv::launch_unchecked::( client, cube_count, cube_dim, @@ -255,16 +253,14 @@ impl< #[cube(launch_unchecked)] pub(crate) fn implicit_conv< - EG: Numeric, - ES: Numeric, - Acc: Numeric, - GMM: Convolution, - SMM: stage::Matmul, + CS: ConvSpec, + GMM: Convolution, + SMM: stage::Matmul, >( - lhs: &Tensor>, - rhs: &Tensor>, - bias: &Tensor>, - out: &mut Tensor>, + lhs: &Tensor>, + rhs: &Tensor>, + bias: &Tensor>, + out: &mut Tensor>, #[comptime] config: GMM::Config, #[comptime] has_bias: bool, ) { @@ -272,6 +268,11 @@ pub(crate) fn implicit_conv< let y_offset = CUBE_POS_Y * config.stage_dim(Ident::Rhs).num_elements_y_dim(); let k_range = (0, rhs.shape(0)); + let lhs = VirtualTensor::::new::>>(lhs); + let rhs = VirtualTensor::::new::>>(rhs); + let bias = VirtualTensor::::new::>>(bias); + let out = VirtualTensor::::new::>>(out); + GMM::execute( GMM::init_lhs_loader(lhs, x_offset, k_range.0, config), GMM::init_rhs_loader(rhs, k_range.0, y_offset, config), diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index 0ecf1880a6..fc3afe6eda 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -15,6 +15,7 @@ use cubecl::{ }; use half::{bf16, f16}; +use super::spec::{ConvSpec, SingleConvSpec}; use crate::{ kernel::{ conv::{ @@ -33,24 +34,30 @@ use crate::{ /// Large m stage size for the usual case where `batch_size * out_h * out_w` is significantly larger /// than `out_channels` -pub type CmmaLargeMAlgorithm = Cmma; +pub type CmmaLargeMAlgorithm = Cmma; /// Balanced stage size for cases where `batch_size * out_h * out_w` is relatively small and `k` or /// `out_channels` is relatively large -pub type CmmaBalancedAlgorithm = Cmma; +pub type CmmaBalancedAlgorithm = Cmma; macro_rules! select_launch_algo { ($algo:tt, $float:ty, $input:expr) => { match (<$float>::as_elem(), has_tf32(&$input)) { (Elem::Float(FloatKind::F32), true) => { - conv2d_gemm_with_algo::> + type Spec = SingleConvSpec<32, F, tf32, f32>; + conv2d_gemm_with_algo::, $algo>> } (Elem::Float(FloatKind::F16), _) => { - conv2d_gemm_with_algo::> + type Spec = SingleConvSpec<32, $float, f16, f16>; + conv2d_gemm_with_algo::, $algo>> } (Elem::Float(FloatKind::BF16), _) => { - conv2d_gemm_with_algo::> + type Spec = SingleConvSpec<32, $float, bf16, f32>; + conv2d_gemm_with_algo::, $algo>> + } + _ => { + type Spec = SingleConvSpec<32, $float, f16, f32>; + conv2d_gemm_with_algo::, $algo>> } - _ => conv2d_gemm_with_algo::>, } }; } @@ -100,7 +107,7 @@ pub fn conv2d_gemm_cmma_balanced( /// * `options` - The options to use for the convolution /// /// -pub fn conv2d_gemm_with_algo>( +pub fn conv2d_gemm_with_algo>( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs index bb4c5bb017..3d5b5493d3 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs @@ -1,30 +1,31 @@ use std::marker::PhantomData; use cubecl::{ - linalg::matmul::components::{ - global::AccumulatorLoader, - stage::{self, Stage}, - tile::{self, Config as _}, - Ident, + linalg::{ + matmul::components::{ + global::AccumulatorLoader, + stage::{self, Stage}, + tile::{self, Config as _}, + Ident, + }, + tensor::VirtualTensor, }, prelude::*, }; -use crate::kernel::conv::reader::bias::BiasReader; +use crate::kernel::conv::{reader::bias::BiasReader, spec::ConvSpec}; /// Special loader to broadcast the 1D bias to the 2D accumulator matrix #[derive(CubeType)] -pub struct BiasLoader { - pub tensor_view: BiasReader, - pub stage: Stage, +pub struct BiasLoader { + pub tensor_view: BiasReader, + pub stage: Stage, pub has_bias: bool, _config: PhantomData, } #[cube] -impl AccumulatorLoader - for BiasLoader -{ +impl AccumulatorLoader for BiasLoader { fn fill_stage(this: &mut Self, #[comptime] config: G) { if this.has_bias { let stage_dim = config.stage_dim(Ident::Rhs); @@ -47,7 +48,7 @@ impl AccumulatorLoader } /// Load accumulator - fn load>( + fn load>( this: &mut Self, acc: &mut Tile::Accumulator, tile_n: u32, @@ -66,46 +67,28 @@ impl AccumulatorLoader } #[cube] -impl BiasLoader { +impl BiasLoader { pub fn new( - tensor: &Tensor>, + tensor: VirtualTensor, n_offset: u32, #[comptime] config: G, #[comptime] has_bias: bool, ) -> Self { if has_bias { - let stage = { - let line_size = config.line_size(Ident::Out); + let stage = init_stage::(config); + let shape_n = tensor.shape(0); + let tensor_view = BiasReader::::new(tensor, n_offset, shape_n); - let smem = SharedMemory::new_lined( - comptime!(config.stage_dim(Ident::Rhs).num_elements_y_dim() / line_size), - line_size, - ); - - Stage:: { smem } - }; - let tensor_view = BiasReader:: { - tensor, - n_offset, - shape_n: tensor.shape(0), - }; - - BiasLoader:: { + BiasLoader:: { tensor_view, stage, has_bias, _config: PhantomData::.runtime(), } } else { - let stage = Stage:: { - smem: SharedMemory::new(1), - }; - let tensor_view = BiasReader:: { - tensor, - n_offset: 0, - shape_n: 0, - }; - BiasLoader:: { + let stage = init_empty_stage::(); + let tensor_view = BiasReader::::new(tensor, 0, 0); + BiasLoader:: { stage, tensor_view, has_bias, @@ -114,3 +97,22 @@ impl BiasLoader { } } } + +#[cube] +fn init_stage(#[comptime] config: G) -> Stage { + let line_size = config.line_size(Ident::Out); + + let smem = SharedMemory::new_lined( + comptime!(config.stage_dim(Ident::Rhs).num_elements_y_dim() / line_size), + line_size, + ); + + Stage:: { smem } +} + +#[cube] +fn init_empty_stage() -> Stage { + Stage:: { + smem: SharedMemory::new(1), + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs index 11ee03d83e..a5ce5876bd 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs @@ -1,32 +1,35 @@ use cubecl::{ - linalg::matmul::components::{ - global::Loader, - stage::{ - multi_buffer::LhsReader, ColMajorTiling, RowMajorTiling, Stage, TilingOrder as _, - TilingOrderConfig, + linalg::{ + matmul::components::{ + global::Loader, + stage::{ + multi_buffer::LhsReader, ColMajorTiling, RowMajorTiling, Stage, TilingOrder as _, + TilingOrderConfig, + }, + Ident, }, - Ident, + tensor::VirtualTensor, }, prelude::*, }; use std::marker::PhantomData; -use crate::kernel::conv::{reader::im2col::Im2colReader, Config}; +use crate::kernel::conv::{reader::im2col::Im2colReader, spec::ConvSpec, Config}; /// Loader that translates matrix coordinates to input coordinates using the `im2col` algorithm #[derive(CubeType)] -pub struct SimpleIm2colLoader { - pub tensor_view: Im2colReader, - pub stage: Stage, +pub struct SimpleIm2colLoader { + pub tensor_view: Im2colReader, + pub stage: Stage, _config: PhantomData, } #[cube] -impl Loader for SimpleIm2colLoader { - type StageReader = LhsReader; +impl Loader for SimpleIm2colLoader { + type StageReader = LhsReader; fn fill_stage(this: &mut Self, #[comptime] config: G) { - SimpleIm2col::load_to_slice::( + SimpleIm2col::load_to_slice::( &this.tensor_view, &mut this.stage.as_slice_mut(), Ident::Lhs, @@ -44,9 +47,9 @@ impl Loader for SimpleIm2colLoad } #[cube] -impl SimpleIm2colLoader { +impl SimpleIm2colLoader { pub fn new( - tensor: &Tensor>, + tensor: VirtualTensor, shape_out_y: u32, shape_out_x: u32, x_offset: u32, @@ -60,25 +63,18 @@ impl SimpleIm2colLoader { let shape_m = shape_batch * shape_out_y * shape_out_x; let shape_k = shape_channel * config.kernel_size(0) * config.kernel_size(1); - let tensor_view = Im2colReader:: { + let tensor_view = Im2colReader::::new( tensor, - m_offset: x_offset, - k_offset: y_offset, - stride_batch: tensor.stride(0), - stride_y: tensor.stride(1), - stride_x: tensor.stride(2), - stride_channel: tensor.stride(3), - shape_y: tensor.shape(1), - shape_x: tensor.shape(2), - shape_channel, shape_out_y, shape_out_x, - - shape_m, + x_offset, + y_offset, shape_k, - }; + shape_channel, + shape_m, + ); - SimpleIm2colLoader:: { + SimpleIm2colLoader:: { tensor_view, stage, _config: PhantomData::.runtime(), @@ -93,9 +89,9 @@ pub struct SimpleIm2col; #[cube] impl SimpleIm2col { - pub fn load_to_slice( - read_view: &Im2colReader, - slice: &mut SliceMut>, + pub fn load_to_slice( + read_view: &Im2colReader, + slice: &mut SliceMut>, #[comptime] ident: Ident, #[comptime] config: G, ) { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs index 5fd4a309b9..aaad6c639f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/mod.rs @@ -1,10 +1,12 @@ +mod config; + pub mod algorithm; pub mod base; -mod config; pub mod homogeneous; pub mod launch; pub mod loader; pub mod reader; +pub mod spec; pub use config::*; pub use launch::*; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs index 67162a28a8..cc86503b7b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs @@ -1,5 +1,8 @@ use cubecl::{ - linalg::matmul::components::{stage, Ident}, + linalg::{ + matmul::components::{stage, Ident}, + tensor::VirtualTensor, + }, prelude::*, }; @@ -8,7 +11,7 @@ use cubecl::{ /// Ensures safe access by preventing out-of-bounds errors. /// Includes pre-fetched shapes and strides for optimized performance. pub struct BiasReader { - pub tensor: *const Tensor>, + pub tensor: VirtualTensor, pub n_offset: u32, pub shape_n: u32, } @@ -18,6 +21,15 @@ unsafe impl Send for BiasReader {} #[cube] impl BiasReader { + /// Load the 1D bias into shared memory + pub fn new(tensor: VirtualTensor, n_offset: u32, shape_n: u32) -> BiasReader { + BiasReader:: { + tensor, + n_offset, + shape_n, + } + } + /// Load the 1D bias into shared memory pub fn load_simple(&self, unit_id: u32, #[comptime] config: G) -> Line { let line_size = config.line_size(Ident::Out); @@ -33,6 +45,6 @@ impl BiasReader { } fn read(&self, position: u32) -> Line { - unsafe { *(*self.tensor).index_unchecked(position) } + self.tensor.read(position) } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs index b278bb051b..bcf7b2529f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs @@ -1,4 +1,7 @@ -use cubecl::{linalg::matmul::components::Ident, prelude::*}; +use cubecl::{ + linalg::{matmul::components::Ident, tensor::VirtualTensor}, + prelude::*, +}; use crate::kernel::conv::Config; @@ -7,7 +10,7 @@ use crate::kernel::conv::Config; /// Ensures safe access by preventing out-of-bounds errors. /// Includes pre-fetched shapes and strides for optimized performance. pub struct Im2colReader { - pub tensor: *const Tensor>, + pub tensor: VirtualTensor, pub m_offset: u32, pub k_offset: u32, @@ -27,11 +30,50 @@ pub struct Im2colReader { pub shape_k: u32, } +#[cube] +impl Im2colReader { + #[allow(clippy::too_many_arguments)] + pub fn new( + tensor: VirtualTensor, + shape_out_y: u32, + shape_out_x: u32, + x_offset: u32, + y_offset: u32, + shape_k: u32, + shape_channel: u32, + shape_m: u32, + ) -> Im2colReader { + let stride_batch = tensor.stride(0); + let stride_y = tensor.stride(1); + let stride_x = tensor.stride(2); + let stride_channel = tensor.stride(3); + let shape_y = tensor.shape(1); + let shape_x = tensor.shape(2); + + Im2colReader:: { + tensor, + m_offset: x_offset, + k_offset: y_offset, + stride_batch, + stride_y, + stride_x, + stride_channel, + shape_y, + shape_x, + shape_channel, + shape_out_y, + shape_out_x, + shape_m, + shape_k, + } + } +} + unsafe impl Sync for Im2colReader {} unsafe impl Send for Im2colReader {} #[cube] -impl Im2colReader { +impl Im2colReader { /// Advance the view along the k dimension by a specified offset, `k_offset`. pub fn update_view(&mut self, k_offset: u32) { self.k_offset += k_offset; @@ -54,7 +96,7 @@ impl Im2colReader { unit_id: u32, #[comptime] ident: Ident, #[comptime] config: G, - ) -> Line { + ) -> Line { let line_size = config.global_line_size(ident); let tile_size_x = config.stage_dim(ident).tile_size_x_dim(); let tile_size_y = config.stage_dim(ident).tile_size_y_dim(); @@ -98,7 +140,7 @@ impl Im2colReader { let read_pos = read_pos / line_size; - let mut res = Line::empty(line_size).fill(F::from_int(0)); + let mut res = Line::empty(line_size).fill(E::from_int(0)); if in_bounds { res = self.read(read_pos); } @@ -106,7 +148,7 @@ impl Im2colReader { res } - fn read(&self, position: u32) -> Line { - unsafe { *(*self.tensor).index_unchecked(position) } + fn read(&self, position: u32) -> Line { + self.tensor.read(position) } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs new file mode 100644 index 0000000000..ad2069c505 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/spec.rs @@ -0,0 +1,33 @@ +use cubecl::prelude::Numeric; +use std::marker::PhantomData; + +/// Implicit convolution spec definiting each element types used in the computation. +pub trait ConvSpec: Send + Sync + Clone + 'static { + /// The plane size used by this kernel. + const PLANE_DIM: u32; + + /// Element type of each input and output tensor of the kernel. + type EG: Numeric; + /// Element type of the intermediate representation of the inputs. + type ES: Numeric; + /// Element type of the intermediate representation of the output accumulator. + type EA: Numeric; +} + +/// Specification for a single conv using global tensor as inputs. +#[derive(Clone)] +pub struct SingleConvSpec { + _eg: PhantomData, + _es: PhantomData, + _ea: PhantomData, +} + +impl ConvSpec + for SingleConvSpec +{ + const PLANE_DIM: u32 = PLANE_DIM; + + type EG = EG; + type ES = ES; + type EA = EA; +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 56fc73965e..67858b38cd 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -9,13 +9,14 @@ use cubecl::{ }; use half::{bf16, f16}; +use super::Conv2dAutotuneKey; use crate::{ kernel::{ conv::{ algorithm::Algorithm, batches_per_run, can_do_implicit_gemm, conv2d_direct, conv2d_gemm_cmma_balanced, conv2d_gemm_cmma_large_m, conv2d_im2col, - conv2d_implicit_gemm, has_tf32, problem_from_key, CmmaBalancedAlgorithm, - CmmaLargeMAlgorithm, + conv2d_implicit_gemm, has_tf32, problem_from_key, spec::SingleConvSpec, + CmmaBalancedAlgorithm, CmmaLargeMAlgorithm, }, prng::random_uniform, }, @@ -23,8 +24,6 @@ use crate::{ FloatElement, JitAutotuneKey, JitRuntime, JitTuneId, }; -use super::Conv2dAutotuneKey; - /// Executes autotune on conv2d operations pub fn conv2d_autotune( input: JitTensor, @@ -86,15 +85,21 @@ macro_rules! check_algo { ($algo:tt, $float:ty, $input:expr, $problem:expr) => { match (<$float>::as_elem(), has_tf32(&$input)) { (Elem::Float(FloatKind::F32), true) => { - $algo::<$float, tf32, f32>::can_launch::(&$input.client, &$problem) + type Spec = SingleConvSpec<32, F, tf32, f32>; + $algo::>::can_launch::(&$input.client, &$problem) } (Elem::Float(FloatKind::F16), _) => { - $algo::<$float, f16, f16>::can_launch::(&$input.client, &$problem) + type Spec = SingleConvSpec<32, $float, f16, f16>; + $algo::>::can_launch::(&$input.client, &$problem) } (Elem::Float(FloatKind::BF16), _) => { - $algo::<$float, bf16, f32>::can_launch::(&$input.client, &$problem) + type Spec = SingleConvSpec<32, $float, bf16, f32>; + $algo::>::can_launch::(&$input.client, &$problem) + } + _ => { + type Spec = SingleConvSpec<32, $float, f16, f32>; + $algo::>::can_launch::(&$input.client, &$problem) } - _ => $algo::<$float, f16, f32>::can_launch::(&$input.client, &$problem), } }; } diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index 8b5adfbbf3..d3cb280e90 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -67,13 +67,7 @@ mod cube_wgpu { WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), WgpuDevice::Cpu => DeviceId::new(3, 0), WgpuDevice::BestAvailable | WgpuDevice::DefaultDevice => DeviceId::new(4, 0), - // For an existing device, use the 64 bit wgpu device ID as the burn DeviceID. - // We're only storing 32 bits, so wrap the the 64 bit value to 32 bits. This - // might collide - but a 1 in 4 billion chance seems ok given there's only a few - // devices in flight at any time. - WgpuDevice::Existing(id) => { - DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32) - } + WgpuDevice::Existing(id) => DeviceId::new(5, *id), } } }