-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
stack-info: PR: #9, branch: drisspg/stack/1
- Loading branch information
Showing
20 changed files
with
1,390 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "third_party/cutlass"] | ||
path = third_party/cutlass | ||
url = [email protected]:NVIDIA/cutlass.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
|
||
#include <torch/torch.h> | ||
|
||
namespace driss_torch { | ||
|
||
at::Tensor sweep_mm( | ||
at::Tensor XQ, | ||
at::Tensor WQ, | ||
at::Tensor x_scale, | ||
at::Tensor w_scale, | ||
std::optional<at::Tensor> bias, | ||
at::ScalarType out_dtype, | ||
bool use_fast_accum, | ||
const int64_t cluster_shape_x, | ||
const int64_t cluster_shape_y, | ||
const int64_t cluster_shape_z, | ||
bool transposed, | ||
const int64_t swizzle); | ||
|
||
} // namespace driss_torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.