Skip to content

Commit

Permalink
Custom fusion (#2486)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBrussee authored Nov 13, 2024
1 parent a4567db commit c7233bf
Show file tree
Hide file tree
Showing 35 changed files with 354 additions and 166 deletions.
25 changes: 12 additions & 13 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.2", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3882ed25b47506d49562c501a179b7468e61702e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3882ed25b47506d49562c501a179b7468e61702e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a1471a7ffa089ee2878bb8c140d09f66a2b2b664" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
20 changes: 16 additions & 4 deletions crates/burn-fusion/src/client/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,43 @@ where
&self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by an int tensor.
fn read_tensor_int<B>(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a bool tensor.
fn read_tensor_bool<B>(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> impl Future<Output = TensorData> + Send
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Read the values contained by a quantized tensor.
fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
) -> impl Future<Output = TensorData> + Send
) -> impl Future<Output = TensorData> + Send + 'static
where
B: FusionBackend<FusionRuntime = R>;
/// Resolve the given float tensor to a primitive tensor.
fn resolve_tensor_float<B>(&self, tensor: FusionTensor<R>) -> B::FloatTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>;
/// Resolve the given int tensor to a primitive tensor.
fn resolve_tensor_int<B>(&self, tensor: FusionTensor<R>) -> B::IntTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>;
/// Resolve the given bool tensor to a primitive tensor.
fn resolve_tensor_bool<B>(&self, tensor: FusionTensor<R>) -> B::BoolTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>;
/// Change the client of the given float tensor.
Expand Down
57 changes: 41 additions & 16 deletions crates/burn-fusion/src/client/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use burn_tensor::{
DType,
};
use spin::Mutex;
use std::sync::Arc;
use std::{future::Future, sync::Arc};

/// Use a mutex to communicate with the fusion server.
pub struct MutexFusionClient<R: FusionRuntime> {
Expand Down Expand Up @@ -79,51 +79,49 @@ where
FusionTensor::new(id, shape, dtype, self.clone(), stream)
}

async fn read_tensor_float<B>(
fn read_tensor_float<B>(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_float::<B>(tensor, stream).await
let mut server = self.server.lock();
server.read_float::<B>(tensor, stream)
}

async fn read_tensor_int<B>(
fn read_tensor_int<B>(
&self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_int::<B>(tensor, id).await
self.server.lock().read_int::<B>(tensor, id)
}

async fn read_tensor_bool<B>(
fn read_tensor_bool<B>(
&self,
tensor: TensorDescription,
stream: StreamId,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
self.server.lock().read_bool::<B>(tensor, stream).await
self.server.lock().read_bool::<B>(tensor, stream)
}

async fn read_tensor_quantized<B>(
fn read_tensor_quantized<B>(
&self,
tensor: QuantizedTensorDescription,
streams: Vec<StreamId>,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
self.server
.lock()
.read_quantized::<B>(tensor, streams)
.await
self.server.lock().read_quantized::<B>(tensor, streams)
}

fn change_client_float<B>(
Expand Down Expand Up @@ -246,4 +244,31 @@ where
fn register_orphan(&self, id: &TensorId) {
self.server.lock().drop_tensor_handle(*id);
}

fn resolve_tensor_float<B>(&self, tensor: FusionTensor<R>) -> B::FloatTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server = self.server.lock();
server.drain_stream(tensor.stream);
server.resolve_server_float::<B>(&tensor.into_description())
}

fn resolve_tensor_int<B>(&self, tensor: FusionTensor<R>) -> B::IntTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server = self.server.lock();
server.drain_stream(tensor.stream);
server.resolve_server_int::<B>(&tensor.into_description())
}

fn resolve_tensor_bool<B>(&self, tensor: FusionTensor<R>) -> B::BoolTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
let mut server = self.server.lock();
server.drain_stream(tensor.stream);
server.resolve_server_bool::<B>(&tensor.into_description())
}
}
47 changes: 34 additions & 13 deletions crates/burn-fusion/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use burn_tensor::repr::{
HandleContainer, OperationDescription, QuantizedKind, QuantizedTensorDescription,
TensorDescription, TensorId,
};
use std::sync::Arc;
use std::{future::Future, sync::Arc};

pub struct FusionServer<R: FusionRuntime> {
streams: MultiStream<R>,
Expand Down Expand Up @@ -42,11 +42,11 @@ where
self.handles.create_tensor_uninit()
}

pub async fn read_float<B>(
pub fn read_float<B>(
&mut self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
Expand All @@ -55,14 +55,14 @@ where
self.drain_stream(id);

let tensor = self.handles.get_float_tensor::<B>(&tensor);
B::float_into_data(tensor).await
B::float_into_data(tensor)
}

pub async fn read_int<B>(
pub fn read_int<B>(
&mut self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
Expand All @@ -71,14 +71,14 @@ where
self.drain_stream(id);

let tensor = self.handles.get_int_tensor::<B>(&tensor);
B::int_into_data(tensor).await
B::int_into_data(tensor)
}

pub async fn read_bool<B>(
pub fn read_bool<B>(
&mut self,
tensor: TensorDescription,
id: StreamId,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
Expand All @@ -87,14 +87,14 @@ where
self.drain_stream(id);

let tensor = self.handles.get_bool_tensor::<B>(&tensor);
B::bool_into_data(tensor).await
B::bool_into_data(tensor)
}

pub async fn read_quantized<B>(
pub fn read_quantized<B>(
&mut self,
tensor: QuantizedTensorDescription,
ids: Vec<StreamId>,
) -> burn_tensor::TensorData
) -> impl Future<Output = burn_tensor::TensorData> + 'static
where
B: FusionBackend<FusionRuntime = R>,
{
Expand All @@ -105,7 +105,7 @@ where
}

let tensor = self.handles.get_quantized_tensor::<B>(&tensor);
B::q_into_data(tensor).await
B::q_into_data(tensor)
}

pub fn change_server_float<B>(
Expand All @@ -128,6 +128,27 @@ where
id
}

pub fn resolve_server_float<B>(&mut self, tensor: &TensorDescription) -> B::FloatTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
self.handles.get_float_tensor::<B>(tensor)
}

pub fn resolve_server_int<B>(&mut self, tensor: &TensorDescription) -> B::IntTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
self.handles.get_int_tensor::<B>(tensor)
}

pub fn resolve_server_bool<B>(&mut self, tensor: &TensorDescription) -> B::BoolTensorPrimitive
where
B: FusionBackend<FusionRuntime = R>,
{
self.handles.get_bool_tensor::<B>(tensor)
}

pub fn change_server_int<B>(
&mut self,
tensor: &TensorDescription,
Expand Down
Loading

0 comments on commit c7233bf

Please sign in to comment.