Sync with vLLM
Browse filesThis fixes (among other things) a race condition in GPTQ-Marlin.
- build.toml +11 -3
- compressed_tensors/int8_quant_kernels.cu +2 -2
- core/math.hpp +7 -0
- cutlass_extensions/common.cpp +11 -0
- cutlass_extensions/common.hpp +35 -0
- cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +5 -3
- cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +7 -5
- cutlass_w8a8/scaled_mm_c2x.cu +9 -9
- cutlass_w8a8/scaled_mm_c2x.cuh +5 -4
- cutlass_w8a8/scaled_mm_c3x.cu +6 -372
- cutlass_w8a8/scaled_mm_c3x.cuh +160 -0
- cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh +96 -0
- cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh +140 -0
- cutlass_w8a8/scaled_mm_entry.cu +17 -25
- fp8/common.cuh +8 -20
- gptq_marlin/gptq_marlin.cu +21 -19
- vectorization.cuh +33 -0
build.toml
CHANGED
@@ -15,6 +15,7 @@ pyroot = "ext-torch"
|
|
15 |
[kernel.cutlass_w8a8]
|
16 |
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
17 |
src = [
|
|
|
18 |
"cutlass_w8a8/common.hpp",
|
19 |
"cutlass_w8a8/scaled_mm_c2x.cu",
|
20 |
"cutlass_w8a8/scaled_mm_c2x.cuh",
|
@@ -27,25 +28,32 @@ src = [
|
|
27 |
"cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
|
28 |
]
|
29 |
include = [ "." ]
|
30 |
-
depends = [ "
|
31 |
|
32 |
[kernel.cutlass_w8a8_hopper]
|
33 |
capabilities = [ "9.0", "9.0a" ]
|
34 |
src = [
|
|
|
35 |
"cutlass_w8a8/common.hpp",
|
36 |
"cutlass_w8a8/scaled_mm_c3x.cu",
|
|
|
|
|
|
|
|
|
|
|
37 |
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
|
38 |
"cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
|
39 |
]
|
40 |
include = [ "." ]
|
41 |
-
depends = [ "
|
42 |
|
43 |
[kernel.fp8_common]
|
44 |
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
45 |
src = [
|
46 |
"fp8/common.cu",
|
47 |
"fp8/common.cuh",
|
48 |
-
"dispatch_utils.h"
|
|
|
49 |
]
|
50 |
include = [ "." ]
|
51 |
depends = [ "torch" ]
|
|
|
15 |
[kernel.cutlass_w8a8]
|
16 |
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
17 |
src = [
|
18 |
+
"core/math.hpp",
|
19 |
"cutlass_w8a8/common.hpp",
|
20 |
"cutlass_w8a8/scaled_mm_c2x.cu",
|
21 |
"cutlass_w8a8/scaled_mm_c2x.cuh",
|
|
|
28 |
"cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
|
29 |
]
|
30 |
include = [ "." ]
|
31 |
+
depends = [ "cutlass_3_6", "torch" ]
|
32 |
|
33 |
[kernel.cutlass_w8a8_hopper]
|
34 |
capabilities = [ "9.0", "9.0a" ]
|
35 |
src = [
|
36 |
+
"core/math.hpp",
|
37 |
"cutlass_w8a8/common.hpp",
|
38 |
"cutlass_w8a8/scaled_mm_c3x.cu",
|
39 |
+
"cutlass_w8a8/scaled_mm_c3x.cuh",
|
40 |
+
"cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh",
|
41 |
+
"cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh",
|
42 |
+
"cutlass_extensions/common.cpp",
|
43 |
+
"cutlass_extensions/common.hpp",
|
44 |
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
|
45 |
"cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
|
46 |
]
|
47 |
include = [ "." ]
|
48 |
+
depends = [ "cutlass_3_6", "torch" ]
|
49 |
|
50 |
[kernel.fp8_common]
|
51 |
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
52 |
src = [
|
53 |
"fp8/common.cu",
|
54 |
"fp8/common.cuh",
|
55 |
+
"dispatch_utils.h",
|
56 |
+
"vectorization.cuh"
|
57 |
]
|
58 |
include = [ "." ]
|
59 |
depends = [ "torch" ]
|
compressed_tensors/int8_quant_kernels.cu
CHANGED
@@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
|
226 |
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
227 |
torch::Tensor const& input, // [..., hidden_size]
|
228 |
torch::Tensor const& scale,
|
229 |
-
|
230 |
TORCH_CHECK(input.is_contiguous());
|
231 |
TORCH_CHECK(out.is_contiguous());
|
232 |
TORCH_CHECK(scale.numel() == 1);
|
@@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|
257 |
void dynamic_scaled_int8_quant(
|
258 |
torch::Tensor& out, // [..., hidden_size]
|
259 |
torch::Tensor const& input, // [..., hidden_size]
|
260 |
-
torch::Tensor& scales,
|
261 |
TORCH_CHECK(input.is_contiguous());
|
262 |
TORCH_CHECK(out.is_contiguous());
|
263 |
TORCH_CHECK(scales.is_contiguous());
|
|
|
226 |
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
227 |
torch::Tensor const& input, // [..., hidden_size]
|
228 |
torch::Tensor const& scale,
|
229 |
+
std::optional<torch::Tensor> const& azp) {
|
230 |
TORCH_CHECK(input.is_contiguous());
|
231 |
TORCH_CHECK(out.is_contiguous());
|
232 |
TORCH_CHECK(scale.numel() == 1);
|
|
|
257 |
void dynamic_scaled_int8_quant(
|
258 |
torch::Tensor& out, // [..., hidden_size]
|
259 |
torch::Tensor const& input, // [..., hidden_size]
|
260 |
+
torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
|
261 |
TORCH_CHECK(input.is_contiguous());
|
262 |
TORCH_CHECK(out.is_contiguous());
|
263 |
TORCH_CHECK(scales.is_contiguous());
|
core/math.hpp
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <climits>
|
2 |
+
#include <iostream>
|
3 |
+
|
4 |
+
inline uint32_t next_pow_2(uint32_t const num) {
|
5 |
+
if (num <= 1) return num;
|
6 |
+
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
7 |
+
}
|
cutlass_extensions/common.cpp
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cutlass_extensions/common.hpp"
|
2 |
+
|
3 |
+
int32_t get_sm_version_num() {
|
4 |
+
int32_t major_capability, minor_capability;
|
5 |
+
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
6 |
+
0);
|
7 |
+
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
8 |
+
0);
|
9 |
+
int32_t version_num = major_capability * 10 + minor_capability;
|
10 |
+
return version_num;
|
11 |
+
}
|
cutlass_extensions/common.hpp
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/cutlass.h"
|
4 |
+
#include <climits>
|
5 |
+
#include "cuda_runtime.h"
|
6 |
+
#include <iostream>
|
7 |
+
|
8 |
+
/**
|
9 |
+
* Helper function for checking CUTLASS errors
|
10 |
+
*/
|
11 |
+
#define CUTLASS_CHECK(status) \
|
12 |
+
{ \
|
13 |
+
cutlass::Status error = status; \
|
14 |
+
TORCH_CHECK(error == cutlass::Status::kSuccess, \
|
15 |
+
cutlassGetStatusString(error)); \
|
16 |
+
}
|
17 |
+
|
18 |
+
/**
|
19 |
+
* Panic wrapper for unwinding CUDA runtime errors
|
20 |
+
*/
|
21 |
+
#define CUDA_CHECK(status) \
|
22 |
+
{ \
|
23 |
+
cudaError_t error = status; \
|
24 |
+
TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
|
25 |
+
}
|
26 |
+
|
27 |
+
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
28 |
+
int max_shared_mem_per_block_opt_in = 0;
|
29 |
+
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
30 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
31 |
+
device);
|
32 |
+
return max_shared_mem_per_block_opt_in;
|
33 |
+
}
|
34 |
+
|
35 |
+
int32_t get_sm_version_num();
|
cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
2 |
|
3 |
/*
|
@@ -66,7 +68,7 @@ struct ScaledEpilogueBase {
|
|
66 |
// This overload handles the case where there might not be a tensor, in which
|
67 |
// case a nullptr is passed and a constant (0) is used.
|
68 |
template <typename Descriptor, typename T>
|
69 |
-
static auto args_from_tensor(
|
70 |
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
71 |
using Arguments = typename Descriptor::Arguments;
|
72 |
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
@@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
|
221 |
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
222 |
torch::Tensor const& b_scales,
|
223 |
torch::Tensor const& azp_adj,
|
224 |
-
|
225 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
226 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
227 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
@@ -299,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
|
|
299 |
torch::Tensor const& b_scales,
|
300 |
torch::Tensor const& azp_adj,
|
301 |
torch::Tensor const& azp,
|
302 |
-
|
303 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
304 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
305 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
4 |
|
5 |
/*
|
|
|
68 |
// This overload handles the case where there might not be a tensor, in which
|
69 |
// case a nullptr is passed and a constant (0) is used.
|
70 |
template <typename Descriptor, typename T>
|
71 |
+
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
72 |
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
73 |
using Arguments = typename Descriptor::Arguments;
|
74 |
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
|
|
223 |
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
224 |
torch::Tensor const& b_scales,
|
225 |
torch::Tensor const& azp_adj,
|
226 |
+
std::optional<torch::Tensor> const& bias) {
|
227 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
228 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
229 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
301 |
torch::Tensor const& b_scales,
|
302 |
torch::Tensor const& azp_adj,
|
303 |
torch::Tensor const& azp,
|
304 |
+
std::optional<torch::Tensor> const& bias) {
|
305 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
306 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
307 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
2 |
|
3 |
/*
|
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
|
|
36 |
// Don't want to support nullptr by default
|
37 |
template <typename T, bool EnableNullPtr = false>
|
38 |
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
39 |
-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
40 |
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
41 |
|
42 |
// Don't want to support nullptr by default
|
43 |
template <typename T, bool EnableNullPtr = false>
|
44 |
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
45 |
-
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
46 |
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
47 |
|
48 |
// This utility function constructs the arguments for the load descriptors
|
@@ -65,7 +67,7 @@ struct ScaledEpilogueBase {
|
|
65 |
// This overload handles the case where there might not be a tensor, in which
|
66 |
// case a nullptr is passed and a constant (0) is used.
|
67 |
template <typename Descriptor, typename T>
|
68 |
-
static auto args_from_tensor(
|
69 |
using Arguments = typename Descriptor::Arguments;
|
70 |
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
71 |
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
@@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp
|
|
221 |
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
222 |
torch::Tensor const& b_scales,
|
223 |
torch::Tensor const& azp_adj,
|
224 |
-
|
225 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
226 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
227 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
@@ -297,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
|
|
297 |
torch::Tensor const& b_scales,
|
298 |
torch::Tensor const& azp_adj,
|
299 |
torch::Tensor const& azp,
|
300 |
-
|
301 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
302 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
303 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
4 |
|
5 |
/*
|
|
|
38 |
// Don't want to support nullptr by default
|
39 |
template <typename T, bool EnableNullPtr = false>
|
40 |
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
41 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
42 |
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
43 |
|
44 |
// Don't want to support nullptr by default
|
45 |
template <typename T, bool EnableNullPtr = false>
|
46 |
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
47 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
|
48 |
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
49 |
|
50 |
// This utility function constructs the arguments for the load descriptors
|
|
|
67 |
// This overload handles the case where there might not be a tensor, in which
|
68 |
// case a nullptr is passed and a constant (0) is used.
|
69 |
template <typename Descriptor, typename T>
|
70 |
+
static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
|
71 |
using Arguments = typename Descriptor::Arguments;
|
72 |
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
73 |
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
|
|
223 |
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
224 |
torch::Tensor const& b_scales,
|
225 |
torch::Tensor const& azp_adj,
|
226 |
+
std::optional<torch::Tensor> const& bias) {
|
227 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
228 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
229 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
|
|
299 |
torch::Tensor const& b_scales,
|
300 |
torch::Tensor const& azp_adj,
|
301 |
torch::Tensor const& azp,
|
302 |
+
std::optional<torch::Tensor> const& bias) {
|
303 |
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
304 |
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
305 |
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
cutlass_w8a8/scaled_mm_c2x.cu
CHANGED
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|
39 |
torch::Tensor const& b,
|
40 |
torch::Tensor const& a_scales,
|
41 |
torch::Tensor const& b_scales,
|
42 |
-
|
43 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
44 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
45 |
if (bias) {
|
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
|
58 |
torch::Tensor const& a_scales,
|
59 |
torch::Tensor const& b_scales,
|
60 |
torch::Tensor const& azp_adj,
|
61 |
-
|
62 |
-
|
63 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
64 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
65 |
|
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|
94 |
torch::Tensor const& b,
|
95 |
torch::Tensor const& a_scales,
|
96 |
torch::Tensor const& b_scales,
|
97 |
-
|
98 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
99 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
100 |
if (bias) {
|
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
|
113 |
torch::Tensor const& a_scales,
|
114 |
torch::Tensor const& b_scales,
|
115 |
torch::Tensor const& azp_adj,
|
116 |
-
|
117 |
-
|
118 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
119 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
120 |
|
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|
165 |
torch::Tensor const& b,
|
166 |
torch::Tensor const& a_scales,
|
167 |
torch::Tensor const& b_scales,
|
168 |
-
|
169 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
170 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
171 |
if (bias) {
|
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
|
184 |
torch::Tensor const& a_scales,
|
185 |
torch::Tensor const& b_scales,
|
186 |
torch::Tensor const& azp_adj,
|
187 |
-
|
188 |
-
|
189 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
190 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
191 |
|
|
|
39 |
torch::Tensor const& b,
|
40 |
torch::Tensor const& a_scales,
|
41 |
torch::Tensor const& b_scales,
|
42 |
+
std::optional<torch::Tensor> const& bias) {
|
43 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
44 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
45 |
if (bias) {
|
|
|
58 |
torch::Tensor const& a_scales,
|
59 |
torch::Tensor const& b_scales,
|
60 |
torch::Tensor const& azp_adj,
|
61 |
+
std::optional<torch::Tensor> const& azp,
|
62 |
+
std::optional<torch::Tensor> const& bias) {
|
63 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
64 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
65 |
|
|
|
94 |
torch::Tensor const& b,
|
95 |
torch::Tensor const& a_scales,
|
96 |
torch::Tensor const& b_scales,
|
97 |
+
std::optional<torch::Tensor> const& bias) {
|
98 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
99 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
100 |
if (bias) {
|
|
|
113 |
torch::Tensor const& a_scales,
|
114 |
torch::Tensor const& b_scales,
|
115 |
torch::Tensor const& azp_adj,
|
116 |
+
std::optional<torch::Tensor> const& azp,
|
117 |
+
std::optional<torch::Tensor> const& bias) {
|
118 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
119 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
120 |
|
|
|
165 |
torch::Tensor const& b,
|
166 |
torch::Tensor const& a_scales,
|
167 |
torch::Tensor const& b_scales,
|
168 |
+
std::optional<torch::Tensor> const& bias) {
|
169 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
170 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
171 |
if (bias) {
|
|
|
184 |
torch::Tensor const& a_scales,
|
185 |
torch::Tensor const& b_scales,
|
186 |
torch::Tensor const& azp_adj,
|
187 |
+
std::optional<torch::Tensor> const& azp,
|
188 |
+
std::optional<torch::Tensor> const& bias) {
|
189 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
190 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
191 |
|
cutlass_w8a8/scaled_mm_c2x.cuh
CHANGED
@@ -21,15 +21,16 @@
|
|
21 |
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
22 |
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
23 |
|
24 |
-
#include "
|
|
|
25 |
// clang-format on
|
26 |
|
27 |
using namespace cute;
|
28 |
|
29 |
/*
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
as well as a static prepare_args function that constructs an
|
34 |
EVTCompute::Arguments struct.
|
35 |
*/
|
|
|
21 |
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
22 |
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
23 |
|
24 |
+
#include "core/math.hpp"
|
25 |
+
#include "cutlass_extensions/common.hpp"
|
26 |
// clang-format on
|
27 |
|
28 |
using namespace cute;
|
29 |
|
30 |
/*
|
31 |
+
Epilogues defined in,
|
32 |
+
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
|
33 |
+
must contain a public type named EVTCompute of type Sm80EVT,
|
34 |
as well as a static prepare_args function that constructs an
|
35 |
EVTCompute::Arguments struct.
|
36 |
*/
|
cutlass_w8a8/scaled_mm_c3x.cu
CHANGED
@@ -1,384 +1,18 @@
|
|
1 |
-
// clang-format will break include orders
|
2 |
-
// clang-format off
|
3 |
#include <cudaTypedefs.h>
|
4 |
|
5 |
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
6 |
|
7 |
-
#include
|
|
|
8 |
|
9 |
-
#include
|
10 |
-
|
11 |
-
#include <iostream>
|
12 |
-
#include <sstream>
|
13 |
-
#include <vector>
|
14 |
-
|
15 |
-
#include "cutlass/cutlass.h"
|
16 |
-
|
17 |
-
#include "cute/tensor.hpp"
|
18 |
-
#include "cute/atom/mma_atom.hpp"
|
19 |
-
#include "cutlass/numeric_types.h"
|
20 |
-
|
21 |
-
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
22 |
-
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
23 |
-
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
24 |
-
#include "cutlass/gemm/collective/collective_builder.hpp"
|
25 |
-
|
26 |
-
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
27 |
-
#include "common.hpp"
|
28 |
-
// clang-format on
|
29 |
-
|
30 |
-
using namespace cute;
|
31 |
using namespace vllm;
|
32 |
|
33 |
/*
|
34 |
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
35 |
NVIDIA GPUs with sm90a (Hopper) or later.
|
36 |
-
|
37 |
-
Epilogue functions can be defined to post-process the output before it is
|
38 |
-
written to GPU memory.
|
39 |
-
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
40 |
-
as well as a static prepare_args function that constructs an
|
41 |
-
EVTCompute::Arguments struct.
|
42 |
*/
|
43 |
|
44 |
-
namespace {
|
45 |
-
|
46 |
-
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
47 |
-
// architectures that will never use the kernel. The purpose of this is to
|
48 |
-
// reduce the size of the compiled binary.
|
49 |
-
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
50 |
-
// into code that will be executed on the device where it is defined.
|
51 |
-
template <typename Kernel>
|
52 |
-
struct enable_sm90_or_later : Kernel {
|
53 |
-
template <typename... Args>
|
54 |
-
CUTLASS_DEVICE void operator()(Args&&... args) {
|
55 |
-
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
56 |
-
Kernel::operator()(std::forward<Args>(args)...);
|
57 |
-
#endif
|
58 |
-
}
|
59 |
-
};
|
60 |
-
template <typename ElementAB_, typename ElementD_,
|
61 |
-
template <typename, typename, typename> typename Epilogue_,
|
62 |
-
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
63 |
-
typename EpilogueSchedule>
|
64 |
-
struct cutlass_3x_gemm {
|
65 |
-
using ElementAB = ElementAB_;
|
66 |
-
using ElementD = ElementD_;
|
67 |
-
using ElementAcc =
|
68 |
-
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
69 |
-
float>::type;
|
70 |
-
|
71 |
-
using EpilogueDescriptor =
|
72 |
-
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
73 |
-
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
74 |
-
ElementD, EpilogueSchedule>;
|
75 |
-
|
76 |
-
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
77 |
-
|
78 |
-
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
79 |
-
using ElementC = void;
|
80 |
-
using StrideC = StrideD;
|
81 |
-
|
82 |
-
using EVTCompute = typename Epilogue::EVTCompute;
|
83 |
-
|
84 |
-
using CollectiveEpilogue =
|
85 |
-
typename cutlass::epilogue::collective::CollectiveBuilder<
|
86 |
-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
87 |
-
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
88 |
-
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
89 |
-
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
90 |
-
|
91 |
-
static constexpr size_t CEStorageSize =
|
92 |
-
sizeof(typename CollectiveEpilogue::SharedStorage);
|
93 |
-
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
94 |
-
static_cast<int>(CEStorageSize)>;
|
95 |
-
|
96 |
-
// clang-format off
|
97 |
-
using CollectiveMainloop =
|
98 |
-
typename cutlass::gemm::collective::CollectiveBuilder<
|
99 |
-
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
100 |
-
ElementAB, cutlass::layout::RowMajor, 16,
|
101 |
-
ElementAB, cutlass::layout::ColumnMajor, 16,
|
102 |
-
ElementAcc, TileShape, ClusterShape,
|
103 |
-
Stages,
|
104 |
-
KernelSchedule>::CollectiveOp;
|
105 |
-
// clang-format on
|
106 |
-
|
107 |
-
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
108 |
-
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
109 |
-
cutlass::gemm::PersistentScheduler>>;
|
110 |
-
|
111 |
-
struct GemmKernel : public KernelType {};
|
112 |
-
};
|
113 |
-
|
114 |
-
template <typename Gemm, typename... EpilogueArgs>
|
115 |
-
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
116 |
-
torch::Tensor const& b,
|
117 |
-
EpilogueArgs&&... epilogue_params) {
|
118 |
-
using ElementAB = typename Gemm::ElementAB;
|
119 |
-
using ElementD = typename Gemm::ElementD;
|
120 |
-
|
121 |
-
int32_t m = a.size(0);
|
122 |
-
int32_t n = b.size(1);
|
123 |
-
int32_t k = a.size(1);
|
124 |
-
|
125 |
-
int64_t lda = a.stride(0);
|
126 |
-
int64_t ldb = b.stride(1);
|
127 |
-
int64_t ldc = out.stride(0);
|
128 |
-
|
129 |
-
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
130 |
-
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
131 |
-
using StrideC = typename Gemm::StrideC;
|
132 |
-
|
133 |
-
StrideA a_stride{lda, Int<1>{}, 0};
|
134 |
-
StrideB b_stride{ldb, Int<1>{}, 0};
|
135 |
-
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
136 |
-
|
137 |
-
using GemmKernel = typename Gemm::GemmKernel;
|
138 |
-
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
139 |
-
|
140 |
-
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
141 |
-
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
142 |
-
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
143 |
-
b_stride};
|
144 |
-
|
145 |
-
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
146 |
-
typename GemmKernel::EpilogueArguments epilogue_args{
|
147 |
-
Gemm::Epilogue::prepare_args(
|
148 |
-
std::forward<EpilogueArgs>(epilogue_params)...),
|
149 |
-
c_ptr, c_stride, c_ptr, c_stride};
|
150 |
-
|
151 |
-
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
152 |
-
prob_shape, mainloop_args, epilogue_args};
|
153 |
-
|
154 |
-
// Launch the CUTLASS GEMM kernel.
|
155 |
-
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
156 |
-
GemmOp gemm_op;
|
157 |
-
CUTLASS_CHECK(gemm_op.can_implement(args));
|
158 |
-
|
159 |
-
size_t workspace_size = gemm_op.get_workspace_size(args);
|
160 |
-
auto const workspace_options =
|
161 |
-
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
162 |
-
auto workspace = torch::empty(workspace_size, workspace_options);
|
163 |
-
|
164 |
-
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
165 |
-
|
166 |
-
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
167 |
-
CUTLASS_CHECK(status);
|
168 |
-
}
|
169 |
-
|
170 |
-
template <typename InType, typename OutType,
|
171 |
-
template <typename, typename, typename> typename Epilogue>
|
172 |
-
struct sm90_fp8_config_default {
|
173 |
-
// M in (128, inf)
|
174 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
175 |
-
using KernelSchedule =
|
176 |
-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
177 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
178 |
-
using TileShape = Shape<_128, _128, _128>;
|
179 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
180 |
-
using Cutlass3xGemm =
|
181 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
182 |
-
KernelSchedule, EpilogueSchedule>;
|
183 |
-
};
|
184 |
-
|
185 |
-
template <typename InType, typename OutType,
|
186 |
-
template <typename, typename, typename> typename Epilogue>
|
187 |
-
struct sm90_fp8_config_M128 {
|
188 |
-
// M in (64, 128]
|
189 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
190 |
-
using KernelSchedule =
|
191 |
-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
192 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
193 |
-
using TileShape = Shape<_64, _128, _128>;
|
194 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
195 |
-
using Cutlass3xGemm =
|
196 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
197 |
-
KernelSchedule, EpilogueSchedule>;
|
198 |
-
};
|
199 |
-
|
200 |
-
template <typename InType, typename OutType,
|
201 |
-
template <typename, typename, typename> typename Epilogue>
|
202 |
-
struct sm90_fp8_config_M64 {
|
203 |
-
// M in [1, 64]
|
204 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
205 |
-
using KernelSchedule =
|
206 |
-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
207 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
208 |
-
using TileShape = Shape<_64, _64, _128>;
|
209 |
-
using ClusterShape = Shape<_1, _8, _1>;
|
210 |
-
|
211 |
-
using Cutlass3xGemm =
|
212 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
213 |
-
KernelSchedule, EpilogueSchedule>;
|
214 |
-
};
|
215 |
-
|
216 |
-
template <typename InType, typename OutType,
|
217 |
-
template <typename, typename, typename> typename Epilogue>
|
218 |
-
struct sm90_int8_config_default {
|
219 |
-
// For M > 128 and any N
|
220 |
-
static_assert(std::is_same<InType, int8_t>());
|
221 |
-
using KernelSchedule =
|
222 |
-
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
223 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
224 |
-
using TileShape = Shape<_128, _128, _128>;
|
225 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
226 |
-
using Cutlass3xGemm =
|
227 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
228 |
-
KernelSchedule, EpilogueSchedule>;
|
229 |
-
};
|
230 |
-
|
231 |
-
template <typename InType, typename OutType,
|
232 |
-
template <typename, typename, typename> typename Epilogue>
|
233 |
-
struct sm90_int8_config_M128 {
|
234 |
-
// For M in (64, 128] and any N
|
235 |
-
static_assert(std::is_same<InType, int8_t>());
|
236 |
-
using KernelSchedule =
|
237 |
-
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
238 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
239 |
-
using TileShape = Shape<_64, _128, _128>;
|
240 |
-
using ClusterShape = Shape<_2, _1, _1>;
|
241 |
-
using Cutlass3xGemm =
|
242 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
243 |
-
KernelSchedule, EpilogueSchedule>;
|
244 |
-
};
|
245 |
-
|
246 |
-
template <typename InType, typename OutType,
|
247 |
-
template <typename, typename, typename> typename Epilogue>
|
248 |
-
struct sm90_int8_config_M64 {
|
249 |
-
// For M in (32, 64] and any N
|
250 |
-
static_assert(std::is_same<InType, int8_t>());
|
251 |
-
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
252 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
253 |
-
using TileShape = Shape<_64, _64, _256>;
|
254 |
-
using ClusterShape = Shape<_1, _1, _1>;
|
255 |
-
using Cutlass3xGemm =
|
256 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
257 |
-
KernelSchedule, EpilogueSchedule>;
|
258 |
-
};
|
259 |
-
|
260 |
-
template <typename InType, typename OutType,
|
261 |
-
template <typename, typename, typename> typename Epilogue>
|
262 |
-
struct sm90_int8_config_M32_NBig {
|
263 |
-
// For M in [1, 32] and N >= 8192
|
264 |
-
static_assert(std::is_same<InType, int8_t>());
|
265 |
-
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
266 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
267 |
-
using TileShape = Shape<_64, _128, _256>;
|
268 |
-
using ClusterShape = Shape<_1, _4, _1>;
|
269 |
-
using Cutlass3xGemm =
|
270 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
271 |
-
KernelSchedule, EpilogueSchedule>;
|
272 |
-
};
|
273 |
-
|
274 |
-
template <typename InType, typename OutType,
|
275 |
-
template <typename, typename, typename> typename Epilogue>
|
276 |
-
struct sm90_int8_config_M32_NSmall {
|
277 |
-
// For M in [1, 32] and N < 8192
|
278 |
-
static_assert(std::is_same<InType, int8_t>());
|
279 |
-
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
280 |
-
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
281 |
-
using TileShape = Shape<_64, _64, _256>;
|
282 |
-
using ClusterShape = Shape<_1, _8, _1>;
|
283 |
-
using Cutlass3xGemm =
|
284 |
-
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
285 |
-
KernelSchedule, EpilogueSchedule>;
|
286 |
-
};
|
287 |
-
|
288 |
-
} // namespace
|
289 |
-
|
290 |
-
template <typename InType, typename OutType,
|
291 |
-
template <typename, typename, typename> typename Epilogue,
|
292 |
-
typename... EpilogueArgs>
|
293 |
-
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
294 |
-
torch::Tensor const& b,
|
295 |
-
EpilogueArgs&&... args) {
|
296 |
-
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
297 |
-
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
298 |
-
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
299 |
-
|
300 |
-
using Cutlass3xGemmDefault =
|
301 |
-
typename sm90_fp8_config_default<InType, OutType,
|
302 |
-
Epilogue>::Cutlass3xGemm;
|
303 |
-
using Cutlass3xGemmM64 =
|
304 |
-
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
305 |
-
using Cutlass3xGemmM128 =
|
306 |
-
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
307 |
-
|
308 |
-
uint32_t const m = a.size(0);
|
309 |
-
uint32_t const mp2 =
|
310 |
-
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
311 |
-
|
312 |
-
if (mp2 <= 64) {
|
313 |
-
// m in [1, 64]
|
314 |
-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
315 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
316 |
-
} else if (mp2 <= 128) {
|
317 |
-
// m in (64, 128]
|
318 |
-
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
319 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
320 |
-
} else {
|
321 |
-
// m in (128, inf)
|
322 |
-
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
323 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
324 |
-
}
|
325 |
-
}
|
326 |
-
|
327 |
-
template <typename InType, typename OutType,
|
328 |
-
template <typename, typename, typename> typename Epilogue,
|
329 |
-
typename... EpilogueArgs>
|
330 |
-
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
331 |
-
torch::Tensor const& b,
|
332 |
-
EpilogueArgs&&... args) {
|
333 |
-
static_assert(std::is_same<InType, int8_t>());
|
334 |
-
TORCH_CHECK(a.dtype() == torch::kInt8);
|
335 |
-
TORCH_CHECK(b.dtype() == torch::kInt8);
|
336 |
-
|
337 |
-
using Cutlass3xGemmDefault =
|
338 |
-
typename sm90_int8_config_default<InType, OutType,
|
339 |
-
Epilogue>::Cutlass3xGemm;
|
340 |
-
using Cutlass3xGemmM128 =
|
341 |
-
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
342 |
-
using Cutlass3xGemmM64 =
|
343 |
-
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
344 |
-
using Cutlass3xGemmM32NBig =
|
345 |
-
typename sm90_int8_config_M32_NBig<InType, OutType,
|
346 |
-
Epilogue>::Cutlass3xGemm;
|
347 |
-
using Cutlass3xGemmM32NSmall =
|
348 |
-
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
349 |
-
Epilogue>::Cutlass3xGemm;
|
350 |
-
|
351 |
-
uint32_t const n = out.size(1);
|
352 |
-
bool const is_small_n = n < 8192;
|
353 |
-
|
354 |
-
uint32_t const m = a.size(0);
|
355 |
-
uint32_t const mp2 =
|
356 |
-
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
357 |
-
|
358 |
-
if (mp2 <= 32) {
|
359 |
-
// m in [1, 32]
|
360 |
-
if (is_small_n) {
|
361 |
-
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
362 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
363 |
-
} else {
|
364 |
-
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
365 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
366 |
-
}
|
367 |
-
} else if (mp2 <= 64) {
|
368 |
-
// m in (32, 64]
|
369 |
-
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
370 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
371 |
-
} else if (mp2 <= 128) {
|
372 |
-
// m in (64, 128]
|
373 |
-
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
374 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
375 |
-
} else {
|
376 |
-
// m in (128, inf)
|
377 |
-
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
378 |
-
out, a, b, std::forward<EpilogueArgs>(args)...);
|
379 |
-
}
|
380 |
-
}
|
381 |
-
|
382 |
template <template <typename, typename, typename> typename Epilogue,
|
383 |
typename... EpilogueArgs>
|
384 |
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
@@ -417,7 +51,7 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
417 |
torch::Tensor const& b,
|
418 |
torch::Tensor const& a_scales,
|
419 |
torch::Tensor const& b_scales,
|
420 |
-
|
421 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
422 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
423 |
if (bias) {
|
@@ -436,8 +70,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
|
436 |
torch::Tensor const& a_scales,
|
437 |
torch::Tensor const& b_scales,
|
438 |
torch::Tensor const& azp_adj,
|
439 |
-
|
440 |
-
|
441 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
442 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
443 |
|
|
|
|
|
|
|
1 |
#include <cudaTypedefs.h>
|
2 |
|
3 |
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
4 |
|
5 |
+
#include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
|
6 |
+
#include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
|
7 |
|
8 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
using namespace vllm;
|
10 |
|
11 |
/*
|
12 |
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
13 |
NVIDIA GPUs with sm90a (Hopper) or later.
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
*/
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
template <template <typename, typename, typename> typename Epilogue,
|
17 |
typename... EpilogueArgs>
|
18 |
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
|
|
51 |
torch::Tensor const& b,
|
52 |
torch::Tensor const& a_scales,
|
53 |
torch::Tensor const& b_scales,
|
54 |
+
std::optional<torch::Tensor> const& bias) {
|
55 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
56 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
57 |
if (bias) {
|
|
|
70 |
torch::Tensor const& a_scales,
|
71 |
torch::Tensor const& b_scales,
|
72 |
torch::Tensor const& azp_adj,
|
73 |
+
std::optional<torch::Tensor> const& azp,
|
74 |
+
std::optional<torch::Tensor> const& bias) {
|
75 |
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
76 |
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
77 |
|
cutlass_w8a8/scaled_mm_c3x.cuh
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// clang-format will break include orders
|
4 |
+
// clang-format off
|
5 |
+
#include <torch/all.h>
|
6 |
+
|
7 |
+
#include <ATen/cuda/CUDAContext.h>
|
8 |
+
|
9 |
+
#include "cutlass/cutlass.h"
|
10 |
+
|
11 |
+
#include "cute/tensor.hpp"
|
12 |
+
#include "cute/atom/mma_atom.hpp"
|
13 |
+
#include "cutlass/numeric_types.h"
|
14 |
+
|
15 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
16 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
17 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
18 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
19 |
+
|
20 |
+
#include "core/math.hpp"
|
21 |
+
#include "cutlass_extensions/common.hpp"
|
22 |
+
// clang-format on
|
23 |
+
|
24 |
+
/*
|
25 |
+
Epilogues defined in,
|
26 |
+
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
|
27 |
+
must contain a public type named EVTCompute of type Sm90EVT, as well as a
|
28 |
+
static prepare_args function that constructs an EVTCompute::Arguments struct.
|
29 |
+
*/
|
30 |
+
|
31 |
+
using namespace cute;
|
32 |
+
|
33 |
+
namespace vllm {
|
34 |
+
|
35 |
+
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
36 |
+
// architectures that will never use the kernel. The purpose of this is to
|
37 |
+
// reduce the size of the compiled binary.
|
38 |
+
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
39 |
+
// into code that will be executed on the device where it is defined.
|
40 |
+
template <typename Kernel>
|
41 |
+
struct enable_sm90_or_later : Kernel {
|
42 |
+
template <typename... Args>
|
43 |
+
CUTLASS_DEVICE void operator()(Args&&... args) {
|
44 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
45 |
+
Kernel::operator()(std::forward<Args>(args)...);
|
46 |
+
#endif
|
47 |
+
}
|
48 |
+
};
|
49 |
+
|
50 |
+
template <typename ElementAB_, typename ElementD_,
|
51 |
+
template <typename, typename, typename> typename Epilogue_,
|
52 |
+
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
53 |
+
typename EpilogueSchedule>
|
54 |
+
struct cutlass_3x_gemm {
|
55 |
+
using ElementAB = ElementAB_;
|
56 |
+
using ElementD = ElementD_;
|
57 |
+
using ElementAcc =
|
58 |
+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
59 |
+
float>::type;
|
60 |
+
|
61 |
+
using EpilogueDescriptor =
|
62 |
+
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
63 |
+
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
64 |
+
ElementD, EpilogueSchedule>;
|
65 |
+
|
66 |
+
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
67 |
+
|
68 |
+
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
69 |
+
using ElementC = void;
|
70 |
+
using StrideC = StrideD;
|
71 |
+
|
72 |
+
using EVTCompute = typename Epilogue::EVTCompute;
|
73 |
+
|
74 |
+
using CollectiveEpilogue =
|
75 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
76 |
+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
77 |
+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
78 |
+
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
79 |
+
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
80 |
+
|
81 |
+
static constexpr size_t CEStorageSize =
|
82 |
+
sizeof(typename CollectiveEpilogue::SharedStorage);
|
83 |
+
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
84 |
+
static_cast<int>(CEStorageSize)>;
|
85 |
+
|
86 |
+
// clang-format off
|
87 |
+
using CollectiveMainloop =
|
88 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
89 |
+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
90 |
+
ElementAB, cutlass::layout::RowMajor, 16,
|
91 |
+
ElementAB, cutlass::layout::ColumnMajor, 16,
|
92 |
+
ElementAcc, TileShape, ClusterShape,
|
93 |
+
Stages,
|
94 |
+
KernelSchedule>::CollectiveOp;
|
95 |
+
// clang-format on
|
96 |
+
|
97 |
+
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
98 |
+
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
99 |
+
cutlass::gemm::PersistentScheduler>>;
|
100 |
+
|
101 |
+
struct GemmKernel : public KernelType {};
|
102 |
+
};
|
103 |
+
|
104 |
+
template <typename Gemm, typename... EpilogueArgs>
|
105 |
+
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
106 |
+
torch::Tensor const& b,
|
107 |
+
EpilogueArgs&&... epilogue_params) {
|
108 |
+
using ElementAB = typename Gemm::ElementAB;
|
109 |
+
using ElementD = typename Gemm::ElementD;
|
110 |
+
|
111 |
+
int32_t m = a.size(0);
|
112 |
+
int32_t n = b.size(1);
|
113 |
+
int32_t k = a.size(1);
|
114 |
+
|
115 |
+
int64_t lda = a.stride(0);
|
116 |
+
int64_t ldb = b.stride(1);
|
117 |
+
int64_t ldc = out.stride(0);
|
118 |
+
|
119 |
+
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
120 |
+
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
121 |
+
using StrideC = typename Gemm::StrideC;
|
122 |
+
|
123 |
+
StrideA a_stride{lda, Int<1>{}, 0};
|
124 |
+
StrideB b_stride{ldb, Int<1>{}, 0};
|
125 |
+
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
126 |
+
|
127 |
+
using GemmKernel = typename Gemm::GemmKernel;
|
128 |
+
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
129 |
+
|
130 |
+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
131 |
+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
132 |
+
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
133 |
+
b_stride};
|
134 |
+
|
135 |
+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
136 |
+
typename GemmKernel::EpilogueArguments epilogue_args{
|
137 |
+
Gemm::Epilogue::prepare_args(
|
138 |
+
std::forward<EpilogueArgs>(epilogue_params)...),
|
139 |
+
c_ptr, c_stride, c_ptr, c_stride};
|
140 |
+
|
141 |
+
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
142 |
+
prob_shape, mainloop_args, epilogue_args};
|
143 |
+
|
144 |
+
// Launch the CUTLASS GEMM kernel.
|
145 |
+
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
146 |
+
GemmOp gemm_op;
|
147 |
+
CUTLASS_CHECK(gemm_op.can_implement(args));
|
148 |
+
|
149 |
+
size_t workspace_size = gemm_op.get_workspace_size(args);
|
150 |
+
auto const workspace_options =
|
151 |
+
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
152 |
+
auto workspace = torch::empty(workspace_size, workspace_options);
|
153 |
+
|
154 |
+
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
155 |
+
|
156 |
+
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
157 |
+
CUTLASS_CHECK(status);
|
158 |
+
}
|
159 |
+
|
160 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm_c3x.cuh"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
|
7 |
+
* shape.
|
8 |
+
*/
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
template <typename InType, typename OutType,
|
13 |
+
template <typename, typename, typename> typename Epilogue>
|
14 |
+
struct sm90_fp8_config_default {
|
15 |
+
// M in (128, inf)
|
16 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
17 |
+
using KernelSchedule =
|
18 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
19 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
20 |
+
using TileShape = Shape<_128, _128, _128>;
|
21 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
22 |
+
using Cutlass3xGemm =
|
23 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
24 |
+
KernelSchedule, EpilogueSchedule>;
|
25 |
+
};
|
26 |
+
|
27 |
+
template <typename InType, typename OutType,
|
28 |
+
template <typename, typename, typename> typename Epilogue>
|
29 |
+
struct sm90_fp8_config_M128 {
|
30 |
+
// M in (64, 128]
|
31 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
32 |
+
using KernelSchedule =
|
33 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
34 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
35 |
+
using TileShape = Shape<_64, _128, _128>;
|
36 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
37 |
+
using Cutlass3xGemm =
|
38 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
39 |
+
KernelSchedule, EpilogueSchedule>;
|
40 |
+
};
|
41 |
+
|
42 |
+
template <typename InType, typename OutType,
|
43 |
+
template <typename, typename, typename> typename Epilogue>
|
44 |
+
struct sm90_fp8_config_M64 {
|
45 |
+
// M in [1, 64]
|
46 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
47 |
+
using KernelSchedule =
|
48 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
49 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
50 |
+
using TileShape = Shape<_64, _64, _128>;
|
51 |
+
using ClusterShape = Shape<_1, _8, _1>;
|
52 |
+
|
53 |
+
using Cutlass3xGemm =
|
54 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
55 |
+
KernelSchedule, EpilogueSchedule>;
|
56 |
+
};
|
57 |
+
|
58 |
+
template <typename InType, typename OutType,
|
59 |
+
template <typename, typename, typename> typename Epilogue,
|
60 |
+
typename... EpilogueArgs>
|
61 |
+
inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
|
62 |
+
torch::Tensor const& a,
|
63 |
+
torch::Tensor const& b,
|
64 |
+
EpilogueArgs&&... args) {
|
65 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
66 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
67 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
68 |
+
|
69 |
+
using Cutlass3xGemmDefault =
|
70 |
+
typename sm90_fp8_config_default<InType, OutType,
|
71 |
+
Epilogue>::Cutlass3xGemm;
|
72 |
+
using Cutlass3xGemmM64 =
|
73 |
+
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
74 |
+
using Cutlass3xGemmM128 =
|
75 |
+
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
76 |
+
|
77 |
+
uint32_t const m = a.size(0);
|
78 |
+
uint32_t const mp2 =
|
79 |
+
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
80 |
+
|
81 |
+
if (mp2 <= 64) {
|
82 |
+
// m in [1, 64]
|
83 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
84 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
85 |
+
} else if (mp2 <= 128) {
|
86 |
+
// m in (64, 128]
|
87 |
+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
88 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
89 |
+
} else {
|
90 |
+
// m in (128, inf)
|
91 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
92 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
93 |
+
}
|
94 |
+
}
|
95 |
+
|
96 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm_c3x.cuh"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* This file defines Gemm kernel configurations for SM90 (int8) based on the
|
7 |
+
* Gemm shape.
|
8 |
+
*/
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
template <typename InType, typename OutType,
|
13 |
+
template <typename, typename, typename> typename Epilogue>
|
14 |
+
struct sm90_int8_config_default {
|
15 |
+
// For M > 128 and any N
|
16 |
+
static_assert(std::is_same<InType, int8_t>());
|
17 |
+
using KernelSchedule =
|
18 |
+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
19 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
20 |
+
using TileShape = Shape<_128, _128, _128>;
|
21 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
22 |
+
using Cutlass3xGemm =
|
23 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
24 |
+
KernelSchedule, EpilogueSchedule>;
|
25 |
+
};
|
26 |
+
|
27 |
+
template <typename InType, typename OutType,
|
28 |
+
template <typename, typename, typename> typename Epilogue>
|
29 |
+
struct sm90_int8_config_M128 {
|
30 |
+
// For M in (64, 128] and any N
|
31 |
+
static_assert(std::is_same<InType, int8_t>());
|
32 |
+
using KernelSchedule =
|
33 |
+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
34 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
35 |
+
using TileShape = Shape<_64, _128, _128>;
|
36 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
37 |
+
using Cutlass3xGemm =
|
38 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
39 |
+
KernelSchedule, EpilogueSchedule>;
|
40 |
+
};
|
41 |
+
|
42 |
+
template <typename InType, typename OutType,
|
43 |
+
template <typename, typename, typename> typename Epilogue>
|
44 |
+
struct sm90_int8_config_M64 {
|
45 |
+
// For M in (32, 64] and any N
|
46 |
+
static_assert(std::is_same<InType, int8_t>());
|
47 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
48 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
49 |
+
using TileShape = Shape<_64, _64, _256>;
|
50 |
+
using ClusterShape = Shape<_1, _1, _1>;
|
51 |
+
using Cutlass3xGemm =
|
52 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
53 |
+
KernelSchedule, EpilogueSchedule>;
|
54 |
+
};
|
55 |
+
|
56 |
+
template <typename InType, typename OutType,
|
57 |
+
template <typename, typename, typename> typename Epilogue>
|
58 |
+
struct sm90_int8_config_M32_NBig {
|
59 |
+
// For M in [1, 32] and N >= 8192
|
60 |
+
static_assert(std::is_same<InType, int8_t>());
|
61 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
62 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
63 |
+
using TileShape = Shape<_64, _128, _256>;
|
64 |
+
using ClusterShape = Shape<_1, _4, _1>;
|
65 |
+
using Cutlass3xGemm =
|
66 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
67 |
+
KernelSchedule, EpilogueSchedule>;
|
68 |
+
};
|
69 |
+
|
70 |
+
template <typename InType, typename OutType,
|
71 |
+
template <typename, typename, typename> typename Epilogue>
|
72 |
+
struct sm90_int8_config_M32_NSmall {
|
73 |
+
// For M in [1, 32] and N < 8192
|
74 |
+
static_assert(std::is_same<InType, int8_t>());
|
75 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
76 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
77 |
+
using TileShape = Shape<_64, _64, _256>;
|
78 |
+
using ClusterShape = Shape<_1, _8, _1>;
|
79 |
+
using Cutlass3xGemm =
|
80 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
81 |
+
KernelSchedule, EpilogueSchedule>;
|
82 |
+
};
|
83 |
+
|
84 |
+
template <typename InType, typename OutType,
|
85 |
+
template <typename, typename, typename> typename Epilogue,
|
86 |
+
typename... EpilogueArgs>
|
87 |
+
inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
|
88 |
+
torch::Tensor const& a,
|
89 |
+
torch::Tensor const& b,
|
90 |
+
EpilogueArgs&&... args) {
|
91 |
+
static_assert(std::is_same<InType, int8_t>());
|
92 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
93 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
94 |
+
|
95 |
+
using Cutlass3xGemmDefault =
|
96 |
+
typename sm90_int8_config_default<InType, OutType,
|
97 |
+
Epilogue>::Cutlass3xGemm;
|
98 |
+
using Cutlass3xGemmM128 =
|
99 |
+
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
100 |
+
using Cutlass3xGemmM64 =
|
101 |
+
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
102 |
+
using Cutlass3xGemmM32NBig =
|
103 |
+
typename sm90_int8_config_M32_NBig<InType, OutType,
|
104 |
+
Epilogue>::Cutlass3xGemm;
|
105 |
+
using Cutlass3xGemmM32NSmall =
|
106 |
+
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
107 |
+
Epilogue>::Cutlass3xGemm;
|
108 |
+
|
109 |
+
uint32_t const n = out.size(1);
|
110 |
+
bool const is_small_n = n < 8192;
|
111 |
+
|
112 |
+
uint32_t const m = a.size(0);
|
113 |
+
uint32_t const mp2 =
|
114 |
+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
115 |
+
|
116 |
+
if (mp2 <= 32) {
|
117 |
+
// m in [1, 32]
|
118 |
+
if (is_small_n) {
|
119 |
+
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
120 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
121 |
+
} else {
|
122 |
+
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
123 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
124 |
+
}
|
125 |
+
} else if (mp2 <= 64) {
|
126 |
+
// m in (32, 64]
|
127 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
128 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
129 |
+
} else if (mp2 <= 128) {
|
130 |
+
// m in (64, 128]
|
131 |
+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
132 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
133 |
+
} else {
|
134 |
+
// m in (128, inf)
|
135 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
136 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
137 |
+
}
|
138 |
+
}
|
139 |
+
|
140 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_entry.cu
CHANGED
@@ -3,30 +3,32 @@
|
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
4 |
#include <torch/all.h>
|
5 |
|
|
|
|
|
6 |
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
7 |
torch::Tensor const& b,
|
8 |
torch::Tensor const& a_scales,
|
9 |
torch::Tensor const& b_scales,
|
10 |
-
|
11 |
|
12 |
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
13 |
torch::Tensor const& b,
|
14 |
torch::Tensor const& a_scales,
|
15 |
torch::Tensor const& b_scales,
|
16 |
-
|
17 |
|
18 |
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
19 |
torch::Tensor const& b,
|
20 |
torch::Tensor const& a_scales,
|
21 |
torch::Tensor const& b_scales,
|
22 |
-
|
23 |
|
24 |
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
25 |
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
26 |
torch::Tensor const& b,
|
27 |
torch::Tensor const& a_scales,
|
28 |
torch::Tensor const& b_scales,
|
29 |
-
|
30 |
#endif
|
31 |
|
32 |
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
@@ -34,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
34 |
torch::Tensor const& a_scales,
|
35 |
torch::Tensor const& b_scales,
|
36 |
torch::Tensor const& azp_adj,
|
37 |
-
|
38 |
-
|
39 |
|
40 |
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
41 |
torch::Tensor const& b,
|
42 |
torch::Tensor const& a_scales,
|
43 |
torch::Tensor const& b_scales,
|
44 |
torch::Tensor const& azp_adj,
|
45 |
-
|
46 |
-
|
47 |
|
48 |
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
49 |
torch::Tensor const& b,
|
50 |
torch::Tensor const& a_scales,
|
51 |
torch::Tensor const& b_scales,
|
52 |
torch::Tensor const& azp_adj,
|
53 |
-
|
54 |
-
|
55 |
|
56 |
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
57 |
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
@@ -59,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
59 |
torch::Tensor const& a_scales,
|
60 |
torch::Tensor const& b_scales,
|
61 |
torch::Tensor const& azp_adj,
|
62 |
-
|
63 |
-
|
64 |
#endif
|
65 |
|
66 |
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
@@ -79,20 +81,10 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|
79 |
return false;
|
80 |
}
|
81 |
|
82 |
-
int32_t get_sm_version_num() {
|
83 |
-
int32_t major_capability, minor_capability;
|
84 |
-
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
85 |
-
0);
|
86 |
-
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
87 |
-
0);
|
88 |
-
int32_t version_num = major_capability * 10 + minor_capability;
|
89 |
-
return version_num;
|
90 |
-
}
|
91 |
-
|
92 |
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
93 |
torch::Tensor const& b, torch::Tensor const& a_scales,
|
94 |
torch::Tensor const& b_scales,
|
95 |
-
|
96 |
// Checks for conformality
|
97 |
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
98 |
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
@@ -154,8 +146,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
|
154 |
torch::Tensor const& a_scales,
|
155 |
torch::Tensor const& b_scales,
|
156 |
torch::Tensor const& azp_adj,
|
157 |
-
|
158 |
-
|
159 |
// Checks for conformality
|
160 |
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
161 |
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
|
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
4 |
#include <torch/all.h>
|
5 |
|
6 |
+
#include "cutlass_extensions/common.hpp"
|
7 |
+
|
8 |
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
9 |
torch::Tensor const& b,
|
10 |
torch::Tensor const& a_scales,
|
11 |
torch::Tensor const& b_scales,
|
12 |
+
std::optional<torch::Tensor> const& bias);
|
13 |
|
14 |
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
15 |
torch::Tensor const& b,
|
16 |
torch::Tensor const& a_scales,
|
17 |
torch::Tensor const& b_scales,
|
18 |
+
std::optional<torch::Tensor> const& bias);
|
19 |
|
20 |
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
21 |
torch::Tensor const& b,
|
22 |
torch::Tensor const& a_scales,
|
23 |
torch::Tensor const& b_scales,
|
24 |
+
std::optional<torch::Tensor> const& bias);
|
25 |
|
26 |
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
27 |
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
28 |
torch::Tensor const& b,
|
29 |
torch::Tensor const& a_scales,
|
30 |
torch::Tensor const& b_scales,
|
31 |
+
std::optional<torch::Tensor> const& bias);
|
32 |
#endif
|
33 |
|
34 |
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
|
|
36 |
torch::Tensor const& a_scales,
|
37 |
torch::Tensor const& b_scales,
|
38 |
torch::Tensor const& azp_adj,
|
39 |
+
std::optional<torch::Tensor> const& azp,
|
40 |
+
std::optional<torch::Tensor> const& bias);
|
41 |
|
42 |
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
43 |
torch::Tensor const& b,
|
44 |
torch::Tensor const& a_scales,
|
45 |
torch::Tensor const& b_scales,
|
46 |
torch::Tensor const& azp_adj,
|
47 |
+
std::optional<torch::Tensor> const& azp,
|
48 |
+
std::optional<torch::Tensor> const& bias);
|
49 |
|
50 |
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
51 |
torch::Tensor const& b,
|
52 |
torch::Tensor const& a_scales,
|
53 |
torch::Tensor const& b_scales,
|
54 |
torch::Tensor const& azp_adj,
|
55 |
+
std::optional<torch::Tensor> const& azp,
|
56 |
+
std::optional<torch::Tensor> const& bias);
|
57 |
|
58 |
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
59 |
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
|
|
61 |
torch::Tensor const& a_scales,
|
62 |
torch::Tensor const& b_scales,
|
63 |
torch::Tensor const& azp_adj,
|
64 |
+
std::optional<torch::Tensor> const& azp,
|
65 |
+
std::optional<torch::Tensor> const& bias);
|
66 |
#endif
|
67 |
|
68 |
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
|
|
81 |
return false;
|
82 |
}
|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
85 |
torch::Tensor const& b, torch::Tensor const& a_scales,
|
86 |
torch::Tensor const& b_scales,
|
87 |
+
std::optional<torch::Tensor> const& bias) {
|
88 |
// Checks for conformality
|
89 |
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
90 |
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
|
|
146 |
torch::Tensor const& a_scales,
|
147 |
torch::Tensor const& b_scales,
|
148 |
torch::Tensor const& azp_adj,
|
149 |
+
std::optional<torch::Tensor> const& azp,
|
150 |
+
std::optional<torch::Tensor> const& bias) {
|
151 |
// Checks for conformality
|
152 |
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
153 |
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
fp8/common.cuh
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
#pragma once
|
2 |
|
|
|
|
|
3 |
#include <cmath>
|
|
|
4 |
|
5 |
#ifndef USE_ROCM
|
6 |
#include <c10/util/Float8_e4m3fn.h>
|
@@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
|
|
15 |
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
16 |
constexpr auto FP8_E4M3_MAX = 224.0f;
|
17 |
#endif
|
|
|
18 |
|
19 |
namespace vllm {
|
20 |
|
@@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|
89 |
}
|
90 |
}
|
91 |
|
92 |
-
template <typename scalar_t>
|
93 |
-
struct __align__(8) vec4_t {
|
94 |
-
scalar_t x;
|
95 |
-
scalar_t y;
|
96 |
-
scalar_t z;
|
97 |
-
scalar_t w;
|
98 |
-
};
|
99 |
-
|
100 |
-
typedef struct __align__(4) {
|
101 |
-
FP8_TYPE x;
|
102 |
-
FP8_TYPE y;
|
103 |
-
FP8_TYPE z;
|
104 |
-
FP8_TYPE w;
|
105 |
-
}
|
106 |
-
float8x4_t;
|
107 |
-
|
108 |
template <typename scalar_t>
|
109 |
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
110 |
int64_t const num_elems, int const tid,
|
@@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
|
139 |
float const scale,
|
140 |
int64_t const num_elems,
|
141 |
int const tid, int const step) {
|
|
|
142 |
// Vectorized input/output to better utilize memory bandwidth.
|
143 |
-
|
144 |
-
|
145 |
-
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
146 |
|
147 |
int64_t const num_vec_elems = num_elems >> 2;
|
148 |
|
@@ -169,4 +157,4 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
|
169 |
}
|
170 |
}
|
171 |
|
172 |
-
} // namespace vllm
|
|
|
1 |
#pragma once
|
2 |
|
3 |
+
#include "vectorization.cuh"
|
4 |
+
|
5 |
#include <cmath>
|
6 |
+
#include <c10/core/ScalarType.h>
|
7 |
|
8 |
#ifndef USE_ROCM
|
9 |
#include <c10/util/Float8_e4m3fn.h>
|
|
|
18 |
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
19 |
constexpr auto FP8_E4M3_MAX = 224.0f;
|
20 |
#endif
|
21 |
+
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
|
22 |
|
23 |
namespace vllm {
|
24 |
|
|
|
93 |
}
|
94 |
}
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
template <typename scalar_t>
|
97 |
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
98 |
int64_t const num_elems, int const tid,
|
|
|
127 |
float const scale,
|
128 |
int64_t const num_elems,
|
129 |
int const tid, int const step) {
|
130 |
+
using float8x4_t = q8x4_t<FP8_TYPE>;
|
131 |
// Vectorized input/output to better utilize memory bandwidth.
|
132 |
+
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
133 |
+
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
|
|
134 |
|
135 |
int64_t const num_vec_elems = num_elems >> 2;
|
136 |
|
|
|
157 |
}
|
158 |
}
|
159 |
|
160 |
+
} // namespace vllm
|
gptq_marlin/gptq_marlin.cu
CHANGED
@@ -832,6 +832,7 @@ __global__ void Marlin(
|
|
832 |
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
833 |
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
834 |
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
|
|
835 |
|
836 |
// Register storage for double buffer of shared memory reads.
|
837 |
FragA frag_a[2][thread_m_blocks];
|
@@ -930,11 +931,11 @@ __global__ void Marlin(
|
|
930 |
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
931 |
|
932 |
if constexpr (group_blocks >= thread_k_blocks) {
|
|
|
|
|
|
|
933 |
// Only fetch scales if this tile starts a new group
|
934 |
-
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
935 |
-
if (s_sh_wr_pred) {
|
936 |
-
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
937 |
-
}
|
938 |
s_gl_rd += s_gl_rd_delta;
|
939 |
}
|
940 |
} else {
|
@@ -1036,9 +1037,7 @@ __global__ void Marlin(
|
|
1036 |
// No act-order case
|
1037 |
if constexpr (group_blocks != -1) {
|
1038 |
if constexpr (group_blocks >= thread_k_blocks) {
|
1039 |
-
int4* sh_s_stage =
|
1040 |
-
sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
|
1041 |
-
(pipe / (group_blocks / thread_k_blocks)));
|
1042 |
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
1043 |
} else {
|
1044 |
int warp_id = threadIdx.x / 32;
|
@@ -1337,15 +1336,15 @@ __global__ void Marlin(
|
|
1337 |
int red_sh_wr =
|
1338 |
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
1339 |
if (i < red_off) {
|
1340 |
-
float* c_rd =
|
1341 |
-
|
1342 |
-
float* c_wr = reinterpret_cast<float*>(&
|
1343 |
#pragma unroll
|
1344 |
for (int k = 0; k < 4; k++)
|
1345 |
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
1346 |
c_rd[k] + c_wr[k];
|
1347 |
}
|
1348 |
-
|
1349 |
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
1350 |
}
|
1351 |
}
|
@@ -1355,7 +1354,7 @@ __global__ void Marlin(
|
|
1355 |
#pragma unroll
|
1356 |
for (int i = 0; i < 4 * 2; i++) {
|
1357 |
float* c_rd =
|
1358 |
-
reinterpret_cast<float*>(&
|
1359 |
#pragma unroll
|
1360 |
for (int j = 0; j < 4; j++)
|
1361 |
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
@@ -1395,7 +1394,7 @@ __global__ void Marlin(
|
|
1395 |
#pragma unroll
|
1396 |
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1397 |
cp_async4_pred(
|
1398 |
-
&
|
1399 |
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
1400 |
c_gl_wr_delta_i * (i % 2)],
|
1401 |
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
@@ -1408,7 +1407,7 @@ __global__ void Marlin(
|
|
1408 |
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1409 |
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
1410 |
if (!first) {
|
1411 |
-
int4 c_red =
|
1412 |
#pragma unroll
|
1413 |
for (int j = 0; j < 2 * 4; j++) {
|
1414 |
reinterpret_cast<float*>(
|
@@ -1459,10 +1458,10 @@ __global__ void Marlin(
|
|
1459 |
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
1460 |
#pragma unroll
|
1461 |
for (int k = 0; k < th_size; k++) {
|
1462 |
-
|
1463 |
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
1464 |
|
1465 |
-
float* sh_c_ptr = reinterpret_cast<float*>(&
|
1466 |
#pragma unroll
|
1467 |
for (int f = 0; f < 4; f++) {
|
1468 |
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
@@ -1513,7 +1512,7 @@ __global__ void Marlin(
|
|
1513 |
res = __hmul2(res, s[0]);
|
1514 |
}
|
1515 |
|
1516 |
-
((scalar_t2*)
|
1517 |
};
|
1518 |
|
1519 |
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
@@ -1541,7 +1540,7 @@ __global__ void Marlin(
|
|
1541 |
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
1542 |
i++) {
|
1543 |
if (c_gl_wr < c_gl_wr_end) {
|
1544 |
-
C[c_gl_wr] =
|
1545 |
c_gl_wr += c_gl_wr_delta;
|
1546 |
c_sh_rd += c_sh_rd_delta;
|
1547 |
}
|
@@ -1863,9 +1862,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
|
|
1863 |
|
1864 |
float pipe_size = (a_size + b_size) * pipe_stages;
|
1865 |
|
|
|
|
|
|
|
1866 |
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
1867 |
|
1868 |
-
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
|
1869 |
}
|
1870 |
|
1871 |
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
|
|
832 |
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
833 |
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
834 |
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
835 |
+
int4* sh_red = sh_s + (stages * s_sh_stage);
|
836 |
|
837 |
// Register storage for double buffer of shared memory reads.
|
838 |
FragA frag_a[2][thread_m_blocks];
|
|
|
931 |
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
932 |
|
933 |
if constexpr (group_blocks >= thread_k_blocks) {
|
934 |
+
if (s_sh_wr_pred) {
|
935 |
+
cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
|
936 |
+
}
|
937 |
// Only fetch scales if this tile starts a new group
|
938 |
+
if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
|
|
|
|
|
|
|
939 |
s_gl_rd += s_gl_rd_delta;
|
940 |
}
|
941 |
} else {
|
|
|
1037 |
// No act-order case
|
1038 |
if constexpr (group_blocks != -1) {
|
1039 |
if constexpr (group_blocks >= thread_k_blocks) {
|
1040 |
+
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
|
|
|
|
|
1041 |
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
|
1042 |
} else {
|
1043 |
int warp_id = threadIdx.x / 32;
|
|
|
1336 |
int red_sh_wr =
|
1337 |
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
|
1338 |
if (i < red_off) {
|
1339 |
+
float* c_rd = reinterpret_cast<float*>(
|
1340 |
+
&sh_red[red_sh_delta * j + red_sh_rd]);
|
1341 |
+
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
|
1342 |
#pragma unroll
|
1343 |
for (int k = 0; k < 4; k++)
|
1344 |
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
|
1345 |
c_rd[k] + c_wr[k];
|
1346 |
}
|
1347 |
+
sh_red[red_sh_wr] =
|
1348 |
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
|
1349 |
}
|
1350 |
}
|
|
|
1354 |
#pragma unroll
|
1355 |
for (int i = 0; i < 4 * 2; i++) {
|
1356 |
float* c_rd =
|
1357 |
+
reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
|
1358 |
#pragma unroll
|
1359 |
for (int j = 0; j < 4; j++)
|
1360 |
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
|
|
|
1394 |
#pragma unroll
|
1395 |
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1396 |
cp_async4_pred(
|
1397 |
+
&sh_red[c_sh_wr + c_sh_wr_delta * i],
|
1398 |
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
|
1399 |
c_gl_wr_delta_i * (i % 2)],
|
1400 |
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
|
|
|
1407 |
for (int i = 0; i < thread_m_blocks * 4; i++) {
|
1408 |
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
|
1409 |
if (!first) {
|
1410 |
+
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
|
1411 |
#pragma unroll
|
1412 |
for (int j = 0; j < 2 * 4; j++) {
|
1413 |
reinterpret_cast<float*>(
|
|
|
1458 |
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
1459 |
#pragma unroll
|
1460 |
for (int k = 0; k < th_size; k++) {
|
1461 |
+
sh_red[threadIdx.x] =
|
1462 |
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
1463 |
|
1464 |
+
float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
|
1465 |
#pragma unroll
|
1466 |
for (int f = 0; f < 4; f++) {
|
1467 |
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
|
|
1512 |
res = __hmul2(res, s[0]);
|
1513 |
}
|
1514 |
|
1515 |
+
((scalar_t2*)sh_red)[idx] = res;
|
1516 |
};
|
1517 |
|
1518 |
if (threadIdx.x / 32 < thread_n_blocks / 4) {
|
|
|
1540 |
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
|
1541 |
i++) {
|
1542 |
if (c_gl_wr < c_gl_wr_end) {
|
1543 |
+
C[c_gl_wr] = sh_red[c_sh_rd];
|
1544 |
c_gl_wr += c_gl_wr_delta;
|
1545 |
c_sh_rd += c_sh_rd_delta;
|
1546 |
}
|
|
|
1862 |
|
1863 |
float pipe_size = (a_size + b_size) * pipe_stages;
|
1864 |
|
1865 |
+
float reduce_size = max(th_config.num_threads * 32 * 4,
|
1866 |
+
(tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
|
1867 |
+
|
1868 |
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
|
1869 |
|
1870 |
+
return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
|
1871 |
}
|
1872 |
|
1873 |
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
vectorization.cuh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
/**
|
3 |
+
* __device__ datatypes vectorized by 4
|
4 |
+
*/
|
5 |
+
|
6 |
+
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
7 |
+
// TODO(luka/varun) use FP8_TYPE instead after refactoring
|
8 |
+
#include <c10/util/Float8_e4m3fnuz.h>
|
9 |
+
#include <c10/util/Float8_e4m3fn.h>
|
10 |
+
|
11 |
+
namespace vllm {
|
12 |
+
|
13 |
+
// Vectorization containers
|
14 |
+
template <typename scalar_t>
|
15 |
+
struct __align__(8) vec4_t {
|
16 |
+
scalar_t x;
|
17 |
+
scalar_t y;
|
18 |
+
scalar_t z;
|
19 |
+
scalar_t w;
|
20 |
+
};
|
21 |
+
|
22 |
+
template <typename quant_type_t>
|
23 |
+
struct __align__(4) q8x4_t {
|
24 |
+
static_assert(std::is_same_v<quant_type_t, int8_t> ||
|
25 |
+
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
26 |
+
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
|
27 |
+
quant_type_t x;
|
28 |
+
quant_type_t y;
|
29 |
+
quant_type_t z;
|
30 |
+
quant_type_t w;
|
31 |
+
};
|
32 |
+
|
33 |
+
} // namespace vllm
|