Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic sweep op #9

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = [email protected]:NVIDIA/cutlass.git
49 changes: 30 additions & 19 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
cmake_minimum_required(VERSION 3.26 FATAL_ERROR)

project(
${SKBUILD_PROJECT_NAME}
VERSION ${SKBUILD_PROJECT_VERSION}
LANGUAGES CXX CUDA)
project(${SKBUILD_PROJECT_NAME}
VERSION ${SKBUILD_PROJECT_VERSION}
LANGUAGES CXX CUDA)

# Set the C++ standard for all targets
set(CMAKE_CXX_STANDARD 20) # This might be unsafe since pytorch use std17
Expand All @@ -14,20 +12,21 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

find_package(Python REQUIRED COMPONENTS Interpreter Development)
execute_process(
COMMAND "${Python3_EXECUTABLE}" "-c" "import torch;print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE PT_CMAKE_PREFIX
COMMAND_ECHO STDOUT
OUTPUT_STRIP_TRAILING_WHITESPACE
COMMAND_ERROR_IS_FATAL ANY
COMMAND "${Python3_EXECUTABLE}" "-c" "import torch;print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE PT_CMAKE_PREFIX
COMMAND_ECHO STDOUT
OUTPUT_STRIP_TRAILING_WHITESPACE
COMMAND_ERROR_IS_FATAL ANY
)

# cache CUDA_ARCHITECTURES, which seems to be reset by Torch
set(TMP_STORE_CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}")
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX})
# Set CUDA architecture to SM90a
set(CMAKE_CUDA_ARCHITECTURES 90a)
set(TORCH_CUDA_ARCH_LIST "9.0a")

set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX})
find_package(Torch REQUIRED CONFIG)

# simple_cuda source files
# driss_torch source files
file(GLOB_RECURSE CU_SOURCES src/*.cu)
file(GLOB_RECURSE CPP_SOURCES src/*.cpp)
MESSAGE(STATUS "CU_SOURCES: ${CU_SOURCES}")
Expand All @@ -38,17 +37,29 @@ add_library(driss_torch SHARED
${CPP_SOURCES}
)

# Set the library output directory, I think this makes ninja builds work
# Set the library output directory
set_target_properties(driss_torch PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/driss_torch/lib"
LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/driss_torch/lib"
)

# Check for CUTLASS
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include/cutlass/cutlass.h")
message(FATAL_ERROR "The Cutlass submodule was not downloaded! Please update submodules and try again.")
endif()

# Include CUTLASS headers without building the entire library
target_include_directories(driss_torch PUBLIC
src/include
src/scaled_mm_kernels/
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/tools/util/include
${CMAKE_CURRENT_SOURCE_DIR}/third_party/cutlass/tools/library/include
)
# Add include directories to the library
target_include_directories(driss_torch PUBLIC src/include)

# Link the library to the Torch library
target_link_libraries(driss_torch PRIVATE ${TORCH_LIBRARIES} Python::Python)

# Install the library to the wheel distribution
install(TARGETS driss_torch
LIBRARY DESTINATION driss_torch/lib
LIBRARY DESTINATION driss_torch/lib
)
32 changes: 32 additions & 0 deletions driss_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Optional

import torch

Expand All @@ -9,6 +10,7 @@


ops = torch.ops.DrissTorch
Tensor = torch.Tensor


def list_ops():
Expand Down Expand Up @@ -40,3 +42,33 @@ def saturated_cast(
def amax(x: torch.Tensor) -> float:
"""This op takes in a tensor and returns the max absolute value of it."""
return ops.amax(x)


def sweep_mm(
x: torch.Tensor,
w: torch.Tensor,
x_scale: torch.Tensor,
w_scale: torch.Tensor,
bias: Optional[torch.Tensor],
out_dtype: torch.dtype,
use_fast_accum: bool,
cluster_shape_x: int,
cluster_shape_y: int,
cluster_shape_z: int,
transposed: bool,
swizzle: int,
) -> torch.Tensor:
return ops.sweep_mm(
x,
w,
x_scale,
w_scale,
bias,
out_dtype,
use_fast_accum,
cluster_shape_x,
cluster_shape_y,
cluster_shape_z,
transposed,
swizzle,
)
21 changes: 21 additions & 0 deletions src/include/sweep_mm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <torch/torch.h>

namespace driss_torch {

at::Tensor sweep_mm(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
std::optional<at::Tensor> bias,
at::ScalarType out_dtype,
bool use_fast_accum,
const int64_t cluster_shape_x,
const int64_t cluster_shape_y,
const int64_t cluster_shape_z,
bool transposed,
const int64_t swizzle);

} // namespace driss_torch
4 changes: 4 additions & 0 deletions src/register_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// Custom up headers
#include "saturated_cast.h"
#include "amax.h"
#include "sweep_mm.h"

TORCH_LIBRARY(DrissTorch, m) {
m.impl_abstract_pystub("driss_torch.abstract_impls");
Expand All @@ -13,4 +14,7 @@ TORCH_LIBRARY(DrissTorch, m) {
// Amax func
m.def("amax(Tensor input) -> Tensor");
m.impl("amax", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::amax));
// sweep_mm
m.def("sweep_mm(Tensor x, Tensor w, Tensor x_scale, Tensor w_scale, Tensor? bias , ScalarType out_dtype, bool use_fast_accum, int cluster_shape_x, int cluster_shape_y, int cluster_shape_z, bool transposed, int swizzle) -> Tensor");
m.impl("sweep_mm", c10::DispatchKey::CUDA, TORCH_FN(driss_torch::sweep_mm));
}
Loading