diff --git a/crates/burn-compute/src/client.rs b/crates/burn-compute/src/client.rs index cb1ec1a8ba..c946682b9e 100644 --- a/crates/burn-compute/src/client.rs +++ b/crates/burn-compute/src/client.rs @@ -14,7 +14,7 @@ use burn_common::{reader::Reader, sync_type::SyncType}; #[derive(Debug)] pub struct ComputeClient { channel: Channel, - tuner: Arc>>, + tuner: Arc>>, } impl Clone for ComputeClient @@ -36,7 +36,7 @@ where Channel: ComputeChannel, { /// Create a new client. - pub fn new(channel: Channel, tuner: Arc>>) -> Self { + pub fn new(channel: Channel, tuner: Arc>>) -> Self { Self { channel, tuner } } diff --git a/crates/burn-compute/src/tune/tune_cache.rs b/crates/burn-compute/src/tune/tune_cache.rs index 7200942525..0d41ac9657 100644 --- a/crates/burn-compute/src/tune/tune_cache.rs +++ b/crates/burn-compute/src/tune/tune_cache.rs @@ -52,6 +52,8 @@ pub(crate) struct TuneCache { persistent_cache: HashMap, #[cfg(feature = "autotune-persistent-cache")] device_id: String, + #[cfg(feature = "autotune-persistent-cache")] + name: String, } /// Result of the cache try @@ -64,6 +66,7 @@ pub enum TuneCacheResult { impl TuneCache { pub(crate) fn new( + #[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] name: &str, #[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] device_id: &str, ) -> Self { @@ -73,6 +76,7 @@ impl TuneCache { in_memory_cache: HashMap::new(), persistent_cache: HashMap::new(), device_id: device_id.to_string(), + name: name.to_string(), }; if let Err(e) = cache.load() { log::warn!( @@ -239,6 +243,6 @@ impl TuneCache { /// Return the file path for the persistent cache on disk #[cfg(feature = "autotune-persistent-cache")] pub fn get_persistent_cache_file_path(&self) -> PathBuf { - get_persistent_cache_file_path(&self.device_id) + get_persistent_cache_file_path(&format!("{}-{}", self.name, self.device_id)) } } diff --git a/crates/burn-compute/src/tune/tuner.rs b/crates/burn-compute/src/tune/tuner.rs index 4fafa5adf0..bac874c4e1 100644 --- a/crates/burn-compute/src/tune/tuner.rs +++ b/crates/burn-compute/src/tune/tuner.rs @@ -1,4 +1,3 @@ -use core::marker::PhantomData; #[cfg(target_family = "wasm")] use web_time::Duration; @@ -15,32 +14,37 @@ use crate::client::ComputeClient; use crate::server::ComputeServer; use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache}; +use super::AutotuneKey; + #[derive(Debug)] /// Executes autotune benchmarking and caching -pub struct Tuner { - tune_cache: TuneCache, - _channel: PhantomData, +pub struct Tuner { + tune_cache: TuneCache, } #[allow(clippy::new_without_default)] -impl> Tuner { +impl Tuner { /// Returns a tuner with cache initialized from persistent cache - pub fn new(device_id: &str) -> Self { + pub fn new(name: &str, device_id: &str) -> Self { Self { - tune_cache: TuneCache::new(device_id), - _channel: PhantomData, + tune_cache: TuneCache::new(name, device_id), } } - pub(crate) fn autotune_fastest(&self, key: &S::AutotuneKey) -> Option { + /// Fetch the fastest autotune operation index for an autotune key. + pub fn autotune_fastest(&self, key: &K) -> Option { self.tune_cache.find_fastest(key) } - pub(crate) fn execute_autotune( + /// Execute the fastest autotune operation if known, otherwise perform some benchmarks before. + pub fn execute_autotune( &mut self, - autotune_operation_set: Box>, + autotune_operation_set: Box>, client: &ComputeClient, - ) { + ) where + S: ComputeServer, + C: ComputeChannel, + { let operation = match self.tune_cache.try_cache(autotune_operation_set) { super::TuneCacheResult::Hit(ops) => ops, super::TuneCacheResult::Miss(set) => self.autotuning(set, client), @@ -49,11 +53,15 @@ impl> Tuner { AutotuneOperation::execute(operation); } - fn autotuning( + fn autotuning( &mut self, - autotune_operation_set: Box>, + autotune_operation_set: Box>, client: &ComputeClient, - ) -> Box { + ) -> Box + where + S: ComputeServer, + C: ComputeChannel, + { let key = autotune_operation_set.key(); let autotunables = autotune_operation_set.autotunables(); let mut names = Vec::with_capacity(autotunables.len()); @@ -86,11 +94,15 @@ impl> Tuner { } } - fn run_benchmark( + fn run_benchmark( &mut self, operation: Box, client: &ComputeClient, - ) -> BenchmarkDurations { + ) -> BenchmarkDurations + where + S: ComputeServer, + C: ComputeChannel, + { TuneBenchmark::new(operation, client.clone()).run() } diff --git a/crates/burn-compute/tests/dummy/compute.rs b/crates/burn-compute/tests/dummy/compute.rs index 875f79705a..1699b170ec 100644 --- a/crates/burn-compute/tests/dummy/compute.rs +++ b/crates/burn-compute/tests/dummy/compute.rs @@ -20,6 +20,7 @@ pub type DummyClient = ComputeClient; static RUNTIME: ComputeRuntime = ComputeRuntime::new(); pub static TUNER_DEVICE_ID: &str = "tests/dummy-device"; +pub static TUNER_PREFIX: &str = "dummy-tests/dummy-device"; pub fn init_client() -> ComputeClient> { let storage = BytesStorage::default(); @@ -27,7 +28,7 @@ pub fn init_client() -> ComputeClient; @@ -176,8 +175,7 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() { #[cfg(feature = "std")] fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() { // delete the cache file - let file_path = - burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_DEVICE_ID); + let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX); let parent_dir = file_path .parent() .expect("Cache file should have a parent directory"); diff --git a/crates/burn-cuda/src/runtime.rs b/crates/burn-cuda/src/runtime.rs index 1cb0180639..aab907c93f 100644 --- a/crates/burn-cuda/src/runtime.rs +++ b/crates/burn-cuda/src/runtime.rs @@ -65,7 +65,7 @@ impl Runtime for CudaRuntime { let tuner_device_id = tuner_device_id(); ComputeClient::new( MutexComputeChannel::new(server), - Arc::new(RwLock::new(Tuner::new(&tuner_device_id))), + Arc::new(RwLock::new(Tuner::new("cuda", &tuner_device_id))), ) }) } diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index 782dc723da..8b1ee1451b 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -168,7 +168,10 @@ fn create_client( let channel = MutexComputeChannel::new(server); let tuner_device_id = tuner_device_id(adapter.get_info()); - ComputeClient::new(channel, Arc::new(RwLock::new(Tuner::new(&tuner_device_id)))) + ComputeClient::new( + channel, + Arc::new(RwLock::new(Tuner::new("wgpu", &tuner_device_id))), + ) } /// Select the wgpu device and queue based on the provided [device](WgpuDevice).