Skip to content

Commit

Permalink
[Breaking] add runtime options in wgpu init methods (#1505)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Mar 28, 2024
1 parent 32a8d80 commit efc3b2d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 18 deletions.
59 changes: 42 additions & 17 deletions crates/burn-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Runtime for WgpuRuntime<G,

fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
RUNTIME.client(device, move || {
pollster::block_on(create_client::<G>(device))
pollster::block_on(create_client::<G>(device, RuntimeOptions::default()))
})
}

Expand All @@ -52,16 +52,52 @@ impl<G: GraphicsApi, F: FloatElement, I: IntElement> Runtime for WgpuRuntime<G,
}
}

/// The values that control how a WGPU Runtime will perform its calculations.
pub struct RuntimeOptions {
/// How the buffers are deallocated.
pub dealloc_strategy: DeallocStrategy,
/// Control the slicing strategy.
pub slice_strategy: SliceStrategy,
/// Control the amount of compute tasks to be aggregated into a single GPU command.
pub max_tasks: usize,
}

impl Default for RuntimeOptions {
fn default() -> Self {
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 64, // 64 tasks by default
};

Self {
dealloc_strategy: DeallocStrategy::new_period_tick(max_tasks * 2),
slice_strategy: SliceStrategy::Ratio(0.8),
max_tasks,
}
}
}

/// Init the client sync, useful to configure the runtime options.
pub fn init_sync<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
let device = Arc::new(device);
let client = pollster::block_on(create_client::<G>(&device, options));

RUNTIME.register(&device, client)
}

/// Init the client async, necessary for wasm.
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice) {
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
let device = Arc::new(device);
let client = create_client::<G>(&device).await;
let client = create_client::<G>(&device, options).await;

RUNTIME.register(&device, client)
}

async fn create_client<G: GraphicsApi>(
device: &WgpuDevice,
options: RuntimeOptions,
) -> ComputeClient<
WgpuServer<SimpleMemoryManagement<WgpuStorage>>,
MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>,
Expand All @@ -74,22 +110,11 @@ async fn create_client<G: GraphicsApi>(
info
);

// TODO: Support a way to modify max_tasks without std.
let max_tasks = match std::env::var("BURN_WGPU_MAX_TASKS") {
Ok(value) => value
.parse::<usize>()
.expect("BURN_WGPU_MAX_TASKS should be a positive integer."),
Err(_) => 64, // 64 tasks by default
};

let device = Arc::new(device_wgpu);
let storage = WgpuStorage::new(device.clone());
let memory_management = SimpleMemoryManagement::new(
storage,
DeallocStrategy::new_period_tick(max_tasks * 2),
SliceStrategy::Ratio(0.8),
);
let server = WgpuServer::new(memory_management, device, queue, max_tasks);
let memory_management =
SimpleMemoryManagement::new(storage, options.dealloc_strategy, options.slice_strategy);
let server = WgpuServer::new(memory_management, device, queue, options.max_tasks);
let channel = MutexComputeChannel::new(server);

let tuner_device_id = tuner_device_id(info);
Expand Down
2 changes: 1 addition & 1 deletion examples/image-classification-web/src/web.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl ImageClassifier {
log::info!("Loading the model to the Wgpu backend");
let start = Instant::now();
let device = WgpuDevice::default();
init_async::<AutoGraphicsApi>(&device).await;
init_async::<AutoGraphicsApi>(&device, Default::default()).await;
self.model = ModelType::WithWgpuBackend(Model::new(&device));
let duration = start.elapsed();
log::debug!("Model is loaded to the Wgpu backend in {:?}", duration);
Expand Down

0 comments on commit efc3b2d

Please sign in to comment.