Skip to content

Commit

Permalink
generic sweep op
Browse files Browse the repository at this point in the history
stack-info: PR: #9, branch: drisspg/stack/1
  • Loading branch information
drisspg committed Oct 15, 2024
1 parent 1f4c4dc commit b652881
Show file tree
Hide file tree
Showing 20 changed files with 1,390 additions and 19 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = [email protected]:NVIDIA/cutlass.git
49 changes: 30 additions & 19 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)

project(
${SKBUILD_PROJECT_NAME}
VERSION ${SKBUILD_PROJECT_VERSION}
LANGUAGES CXX CUDA)
project(${SKBUILD_PROJECT_NAME}
VERSION ${SKBUILD_PROJECT_VERSION}
LANGUAGES CXX CUDA)

# Set the C++ standard for all targets
set(CMAKE_CXX_STANDARD 20) # This might be unsafe since pytorch use std17
Expand All @@ -14,20 +12,21 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

find_package(Python REQUIRED COMPONENTS Interpreter Development)
execute_process(
COMMAND "${Python3_EXECUTABLE}" "-c" "import torch;print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE PT_CMAKE_PREFIX
COMMAND_ECHO STDOUT
OUTPUT_STRIP_TRAILING_WHITESPACE
COMMAND_ERROR_IS_FATAL ANY
COMMAND "${Python3_EXECUTABLE}" "-c" "import torch;print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE PT_CMAKE_PREFIX
COMMAND_ECHO STDOUT
OUTPUT_STRIP_TRAILING_WHITESPACE
COMMAND_ERROR_IS_FATAL ANY
)

# cache CUDA_ARCHITECTURES, which seems to be reset by Torch
set(TMP_STORE_CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}")
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX})
# Set CUDA architecture to SM90a
set(CMAKE_CUDA_ARCHITECTURES 90a)
set(TORCH_CUDA_ARCH_LIST "9.0a")

set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX})
find_package(Torch REQUIRED CONFIG)

# simple_cuda source files
# driss_torch source files
file(GLOB_RECURSE CU_SOURCES src/*.cu)
file(GLOB_RECURSE CPP_SOURCES src/*.cpp)
MESSAGE(STATUS "CU_SOURCES: ${CU_SOURCES}")
Expand All @@ -38,17 +37,29 @@ add_library(driss_torch SHARED
${CPP_SOURCES}
)

# Set the library output directory, I think this makes ninja builds work
# Set the library output directory
set_target_properties(driss_torch PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/driss_torch/lib"
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/driss_torch/lib"
)

# Check for CUTLASS
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include/cutlass/cutlass.h")
message(FATAL_ERROR "The Cutlass submodule was not downloaded! Please update submodules and try again.")
endif()

# Include CUTLASS headers without building the entire library
target_include_directories(driss_torch PUBLIC
src/include
src/scaled_mm_kernels/
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/tools/util/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/tools/library/include
)
# Add include directories to the library
target_include_directories(driss_torch PUBLIC src/include)

# Link the library to the Torch library
target_link_libraries(driss_torch PRIVATE ${TORCH_LIBRARIES} Python::Python)

# Install the library to the wheel distribution
install(TARGETS driss_torch
LIBRARY DESTINATION driss_torch/lib
LIBRARY DESTINATION driss_torch/lib
)
32 changes: 32 additions & 0 deletions driss_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Optional

import torch

Expand All @@ -9,6 +10,7 @@


ops = torch.ops.DrissTorch
Tensor = torch.Tensor


def list_ops():
Expand Down Expand Up @@ -40,3 +42,33 @@ def saturated_cast(
def amax(x: torch.Tensor) -> float:
"""This op takes in a tensor and returns the max absolute value of it."""
return ops.amax(x)


def sweep_mm(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: torch.dtype,
use_fast_accum: bool,
cluster_shape_x: int,
cluster_shape_y: int,
cluster_shape_z: int,
transposed: bool,
swizzle: int,
) -> torch.Tensor:
return ops.sweep_mm(
x,
w,
x_scale,
w_scale,
bias,
out_dtype,
use_fast_accum,
cluster_shape_x,
cluster_shape_y,
cluster_shape_z,
transposed,
swizzle,
)
21 changes: 21 additions & 0 deletions src/include/sweep_mm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <torch/torch.h>

namespace driss_torch {

at::Tensor sweep_mm(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
at::ScalarType out_dtype,
bool use_fast_accum,
const int64_t cluster_shape_x,
const int64_t cluster_shape_y,
const int64_t cluster_shape_z,
bool transposed,
const int64_t swizzle);

} // namespace driss_torch
4 changes: 4 additions & 0 deletions src/register_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Custom up headers
#include "saturated_cast.h"
#include "amax.h"
#include "sweep_mm.h"

TORCH_LIBRARY(DrissTorch, m) {
m.impl_abstract_pystub("driss_torch.abstract_impls");
Expand All @@ -13,4 +14,7 @@ TORCH_LIBRARY(DrissTorch, m) {
// Amax func
m.def("amax(Tensor input) -> Tensor");
m.impl("amax", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::amax));
// sweep_mm
m.def("sweep_mm(Tensor x, Tensor w, Tensor x_scale, Tensor w_scale, Tensor? bias , ScalarType out_dtype, bool use_fast_accum, int cluster_shape_x, int cluster_shape_y, int cluster_shape_z, bool transposed, int swizzle) -> Tensor");
m.impl("sweep_mm", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::sweep_mm));
}
Loading

0 comments on commit b652881

Please sign in to comment.