Skip to content

Commit

Permalink
Refactor the tuner to be used standalone (#1884)
Browse files Browse the repository at this point in the history
* Refactor the tuner to be used standalone

* Add a name for the autotune cache

* Fix tests

* Fix typo
  • Loading branch information
nathanielsimard authored Jun 13, 2024
1 parent 5de1517 commit 5e58ae1
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 27 deletions.
4 changes: 2 additions & 2 deletions crates/burn-compute/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use burn_common::{reader::Reader, sync_type::SyncType};
#[derive(Debug)]
pub struct ComputeClient<Server: ComputeServer, Channel> {
channel: Channel,
tuner: Arc<RwLock<Tuner<Server, Channel>>>,
tuner: Arc<RwLock<Tuner<Server::AutotuneKey>>>,
}

impl<S, C> Clone for ComputeClient<S, C>
Expand All @@ -36,7 +36,7 @@ where
Channel: ComputeChannel<Server>,
{
/// Create a new client.
pub fn new(channel: Channel, tuner: Arc<RwLock<Tuner<Server, Channel>>>) -> Self {
pub fn new(channel: Channel, tuner: Arc<RwLock<Tuner<Server::AutotuneKey>>>) -> Self {
Self { channel, tuner }
}

Expand Down
6 changes: 5 additions & 1 deletion crates/burn-compute/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ pub(crate) struct TuneCache<K> {
persistent_cache: HashMap<K, PersistentCacheEntry>,
#[cfg(feature = "autotune-persistent-cache")]
device_id: String,
#[cfg(feature = "autotune-persistent-cache")]
name: String,
}

/// Result of the cache try
Expand All @@ -64,6 +66,7 @@ pub enum TuneCacheResult<K> {

impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn new(
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] name: &str,
#[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))]
device_id: &str,
) -> Self {
Expand All @@ -73,6 +76,7 @@ impl<K: AutotuneKey> TuneCache<K> {
in_memory_cache: HashMap::new(),
persistent_cache: HashMap::new(),
device_id: device_id.to_string(),
name: name.to_string(),
};
if let Err(e) = cache.load() {
log::warn!(
Expand Down Expand Up @@ -239,6 +243,6 @@ impl<K: AutotuneKey> TuneCache<K> {
/// Return the file path for the persistent cache on disk
#[cfg(feature = "autotune-persistent-cache")]
pub fn get_persistent_cache_file_path(&self) -> PathBuf {
get_persistent_cache_file_path(&self.device_id)
get_persistent_cache_file_path(&format!("{}-{}", self.name, self.device_id))
}
}
46 changes: 29 additions & 17 deletions crates/burn-compute/src/tune/tuner.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use core::marker::PhantomData;
#[cfg(target_family = "wasm")]
use web_time::Duration;

Expand All @@ -15,32 +14,37 @@ use crate::client::ComputeClient;
use crate::server::ComputeServer;
use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache};

use super::AutotuneKey;

#[derive(Debug)]
/// Executes autotune benchmarking and caching
pub struct Tuner<S: ComputeServer, C> {
tune_cache: TuneCache<S::AutotuneKey>,
_channel: PhantomData<C>,
pub struct Tuner<K: AutotuneKey> {
tune_cache: TuneCache<K>,
}

#[allow(clippy::new_without_default)]
impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
impl<K: AutotuneKey> Tuner<K> {
/// Returns a tuner with cache initialized from persistent cache
pub fn new(device_id: &str) -> Self {
pub fn new(name: &str, device_id: &str) -> Self {
Self {
tune_cache: TuneCache::new(device_id),
_channel: PhantomData,
tune_cache: TuneCache::new(name, device_id),
}
}

pub(crate) fn autotune_fastest(&self, key: &S::AutotuneKey) -> Option<usize> {
/// Fetch the fastest autotune operation index for an autotune key.
pub fn autotune_fastest(&self, key: &K) -> Option<usize> {
self.tune_cache.find_fastest(key)
}

pub(crate) fn execute_autotune(
/// Execute the fastest autotune operation if known, otherwise perform some benchmarks before.
pub fn execute_autotune<S, C>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
client: &ComputeClient<S, C>,
) {
) where
S: ComputeServer,
C: ComputeChannel<S>,
{
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
super::TuneCacheResult::Hit(ops) => ops,
super::TuneCacheResult::Miss(set) => self.autotuning(set, client),
Expand All @@ -49,11 +53,15 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
AutotuneOperation::execute(operation);
}

fn autotuning(
fn autotuning<S, C>(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
client: &ComputeClient<S, C>,
) -> Box<dyn AutotuneOperation> {
) -> Box<dyn AutotuneOperation>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
let key = autotune_operation_set.key();
let autotunables = autotune_operation_set.autotunables();
let mut names = Vec::with_capacity(autotunables.len());
Expand Down Expand Up @@ -86,11 +94,15 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
}
}

fn run_benchmark(
fn run_benchmark<S, C>(
&mut self,
operation: Box<dyn AutotuneOperation>,
client: &ComputeClient<S, C>,
) -> BenchmarkDurations {
) -> BenchmarkDurations
where
S: ComputeServer,
C: ComputeChannel<S>,
{
TuneBenchmark::new(operation, client.clone()).run()
}

Expand Down
3 changes: 2 additions & 1 deletion crates/burn-compute/tests/dummy/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ pub type DummyClient = ComputeClient<DummyServer, DummyChannel>;

static RUNTIME: ComputeRuntime<DummyDevice, DummyServer, DummyChannel> = ComputeRuntime::new();
pub static TUNER_DEVICE_ID: &str = "tests/dummy-device";
pub static TUNER_PREFIX: &str = "dummy-tests/dummy-device";

pub fn init_client() -> ComputeClient<DummyServer, MutexComputeChannel<DummyServer>> {
let storage = BytesStorage::default();
let memory_management =
SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never);
let server = DummyServer::new(memory_management);
let channel = MutexComputeChannel::new(server);
let tuner = Arc::new(RwLock::new(Tuner::new(TUNER_DEVICE_ID)));
let tuner = Arc::new(RwLock::new(Tuner::new("dummy", TUNER_DEVICE_ID)));
ComputeClient::new(channel, tuner)
}

Expand Down
6 changes: 2 additions & 4 deletions crates/burn-compute/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,7 @@ fn autotune_cache_same_key_return_a_cache_hit() {
#[cfg(feature = "std")]
fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
// delete the cache file
let file_path =
burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_DEVICE_ID);
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
let _ = std::fs::remove_file(file_path);

type Runtime = ComputeRuntime<DummyDevice, dummy::DummyServer, dummy::DummyChannel>;
Expand Down Expand Up @@ -176,8 +175,7 @@ fn autotune_cache_no_cache_on_disk_return_a_cache_miss() {
#[cfg(feature = "std")]
fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() {
// delete the cache file
let file_path =
burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_DEVICE_ID);
let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX);
let parent_dir = file_path
.parent()
.expect("Cache file should have a parent directory");
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Runtime for CudaRuntime {
let tuner_device_id = tuner_device_id();
ComputeClient::new(
MutexComputeChannel::new(server),
Arc::new(RwLock::new(Tuner::new(&tuner_device_id))),
Arc::new(RwLock::new(Tuner::new("cuda", &tuner_device_id))),
)
})
}
Expand Down
5 changes: 4 additions & 1 deletion crates/burn-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ fn create_client(
let channel = MutexComputeChannel::new(server);
let tuner_device_id = tuner_device_id(adapter.get_info());

ComputeClient::new(channel, Arc::new(RwLock::new(Tuner::new(&tuner_device_id))))
ComputeClient::new(
channel,
Arc::new(RwLock::new(Tuner::new("wgpu", &tuner_device_id))),
)
}

/// Select the wgpu device and queue based on the provided [device](WgpuDevice).
Expand Down

0 comments on commit 5e58ae1

Please sign in to comment.