danieldk HF staff commited on
Commit
0da5bf5
·
1 Parent(s): 0bc428e

Sync with vLLM

Browse files

This fixes (among other things) a race condition in GPTQ-Marlin.

build.toml CHANGED
@@ -15,6 +15,7 @@ pyroot = "ext-torch"
15
  [kernel.cutlass_w8a8]
16
  capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
17
  src = [
 
18
  "cutlass_w8a8/common.hpp",
19
  "cutlass_w8a8/scaled_mm_c2x.cu",
20
  "cutlass_w8a8/scaled_mm_c2x.cuh",
@@ -27,25 +28,32 @@ src = [
27
  "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
28
  ]
29
  include = [ "." ]
30
- depends = [ "cutlass", "torch" ]
31
 
32
  [kernel.cutlass_w8a8_hopper]
33
  capabilities = [ "9.0", "9.0a" ]
34
  src = [
 
35
  "cutlass_w8a8/common.hpp",
36
  "cutlass_w8a8/scaled_mm_c3x.cu",
 
 
 
 
 
37
  "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
38
  "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
39
  ]
40
  include = [ "." ]
41
- depends = [ "cutlass", "torch" ]
42
 
43
  [kernel.fp8_common]
44
  capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
45
  src = [
46
  "fp8/common.cu",
47
  "fp8/common.cuh",
48
- "dispatch_utils.h"
 
49
  ]
50
  include = [ "." ]
51
  depends = [ "torch" ]
 
15
  [kernel.cutlass_w8a8]
16
  capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
17
  src = [
18
+ "core/math.hpp",
19
  "cutlass_w8a8/common.hpp",
20
  "cutlass_w8a8/scaled_mm_c2x.cu",
21
  "cutlass_w8a8/scaled_mm_c2x.cuh",
 
28
  "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
29
  ]
30
  include = [ "." ]
31
+ depends = [ "cutlass_3_6", "torch" ]
32
 
33
  [kernel.cutlass_w8a8_hopper]
34
  capabilities = [ "9.0", "9.0a" ]
35
  src = [
36
+ "core/math.hpp",
37
  "cutlass_w8a8/common.hpp",
38
  "cutlass_w8a8/scaled_mm_c3x.cu",
39
+ "cutlass_w8a8/scaled_mm_c3x.cuh",
40
+ "cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh",
41
+ "cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh",
42
+ "cutlass_extensions/common.cpp",
43
+ "cutlass_extensions/common.hpp",
44
  "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
45
  "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
46
  ]
47
  include = [ "." ]
48
+ depends = [ "cutlass_3_6", "torch" ]
49
 
50
  [kernel.fp8_common]
51
  capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
52
  src = [
53
  "fp8/common.cu",
54
  "fp8/common.cuh",
55
+ "dispatch_utils.h",
56
+ "vectorization.cuh"
57
  ]
58
  include = [ "." ]
59
  depends = [ "torch" ]
compressed_tensors/int8_quant_kernels.cu CHANGED
@@ -226,7 +226,7 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
226
  void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
227
  torch::Tensor const& input, // [..., hidden_size]
228
  torch::Tensor const& scale,
229
- c10::optional<torch::Tensor> const& azp) {
230
  TORCH_CHECK(input.is_contiguous());
231
  TORCH_CHECK(out.is_contiguous());
232
  TORCH_CHECK(scale.numel() == 1);
@@ -257,7 +257,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
257
  void dynamic_scaled_int8_quant(
258
  torch::Tensor& out, // [..., hidden_size]
259
  torch::Tensor const& input, // [..., hidden_size]
260
- torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
261
  TORCH_CHECK(input.is_contiguous());
262
  TORCH_CHECK(out.is_contiguous());
263
  TORCH_CHECK(scales.is_contiguous());
 
226
  void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
227
  torch::Tensor const& input, // [..., hidden_size]
228
  torch::Tensor const& scale,
229
+ std::optional<torch::Tensor> const& azp) {
230
  TORCH_CHECK(input.is_contiguous());
231
  TORCH_CHECK(out.is_contiguous());
232
  TORCH_CHECK(scale.numel() == 1);
 
257
  void dynamic_scaled_int8_quant(
258
  torch::Tensor& out, // [..., hidden_size]
259
  torch::Tensor const& input, // [..., hidden_size]
260
+ torch::Tensor& scales, std::optional<torch::Tensor> const& azp) {
261
  TORCH_CHECK(input.is_contiguous());
262
  TORCH_CHECK(out.is_contiguous());
263
  TORCH_CHECK(scales.is_contiguous());
core/math.hpp ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #include <climits>
2
+ #include <iostream>
3
+
4
+ inline uint32_t next_pow_2(uint32_t const num) {
5
+ if (num <= 1) return num;
6
+ return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
7
+ }
cutlass_extensions/common.cpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cutlass_extensions/common.hpp"
2
+
3
+ int32_t get_sm_version_num() {
4
+ int32_t major_capability, minor_capability;
5
+ cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
6
+ 0);
7
+ cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
8
+ 0);
9
+ int32_t version_num = major_capability * 10 + minor_capability;
10
+ return version_num;
11
+ }
cutlass_extensions/common.hpp ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/cutlass.h"
4
+ #include <climits>
5
+ #include "cuda_runtime.h"
6
+ #include <iostream>
7
+
8
+ /**
9
+ * Helper function for checking CUTLASS errors
10
+ */
11
+ #define CUTLASS_CHECK(status) \
12
+ { \
13
+ cutlass::Status error = status; \
14
+ TORCH_CHECK(error == cutlass::Status::kSuccess, \
15
+ cutlassGetStatusString(error)); \
16
+ }
17
+
18
+ /**
19
+ * Panic wrapper for unwinding CUDA runtime errors
20
+ */
21
+ #define CUDA_CHECK(status) \
22
+ { \
23
+ cudaError_t error = status; \
24
+ TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
25
+ }
26
+
27
+ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
28
+ int max_shared_mem_per_block_opt_in = 0;
29
+ cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
30
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
31
+ device);
32
+ return max_shared_mem_per_block_opt_in;
33
+ }
34
+
35
+ int32_t get_sm_version_num();
cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp CHANGED
@@ -1,3 +1,5 @@
 
 
1
  #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
2
 
3
  /*
@@ -66,7 +68,7 @@ struct ScaledEpilogueBase {
66
  // This overload handles the case where there might not be a tensor, in which
67
  // case a nullptr is passed and a constant (0) is used.
68
  template <typename Descriptor, typename T>
69
- static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
70
  static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
71
  using Arguments = typename Descriptor::Arguments;
72
  auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
@@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp
221
  static ArgumentType prepare_args(torch::Tensor const& a_scales,
222
  torch::Tensor const& b_scales,
223
  torch::Tensor const& azp_adj,
224
- c10::optional<torch::Tensor> const& bias) {
225
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
226
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
227
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -299,7 +301,7 @@ struct ScaledEpilogueBiasAzpToken
299
  torch::Tensor const& b_scales,
300
  torch::Tensor const& azp_adj,
301
  torch::Tensor const& azp,
302
- c10::optional<torch::Tensor> const& bias) {
303
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
304
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
305
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
 
1
+ #pragma once
2
+
3
  #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
4
 
5
  /*
 
68
  // This overload handles the case where there might not be a tensor, in which
69
  // case a nullptr is passed and a constant (0) is used.
70
  template <typename Descriptor, typename T>
71
+ static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
72
  static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
73
  using Arguments = typename Descriptor::Arguments;
74
  auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
 
223
  static ArgumentType prepare_args(torch::Tensor const& a_scales,
224
  torch::Tensor const& b_scales,
225
  torch::Tensor const& azp_adj,
226
+ std::optional<torch::Tensor> const& bias) {
227
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
228
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
229
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
 
301
  torch::Tensor const& b_scales,
302
  torch::Tensor const& azp_adj,
303
  torch::Tensor const& azp,
304
+ std::optional<torch::Tensor> const& bias) {
305
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
306
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
307
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp CHANGED
@@ -1,3 +1,5 @@
 
 
1
  #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
2
 
3
  /*
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
36
  // Don't want to support nullptr by default
37
  template <typename T, bool EnableNullPtr = false>
38
  using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
39
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
40
  Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
41
 
42
  // Don't want to support nullptr by default
43
  template <typename T, bool EnableNullPtr = false>
44
  using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
45
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
46
  Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
47
 
48
  // This utility function constructs the arguments for the load descriptors
@@ -65,7 +67,7 @@ struct ScaledEpilogueBase {
65
  // This overload handles the case where there might not be a tensor, in which
66
  // case a nullptr is passed and a constant (0) is used.
67
  template <typename Descriptor, typename T>
68
- static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
69
  using Arguments = typename Descriptor::Arguments;
70
  auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
71
  static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
@@ -221,7 +223,7 @@ struct ScaledEpilogueBiasAzp
221
  static ArgumentType prepare_args(torch::Tensor const& a_scales,
222
  torch::Tensor const& b_scales,
223
  torch::Tensor const& azp_adj,
224
- c10::optional<torch::Tensor> const& bias) {
225
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
226
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
227
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
@@ -297,7 +299,7 @@ struct ScaledEpilogueBiasAzpToken
297
  torch::Tensor const& b_scales,
298
  torch::Tensor const& azp_adj,
299
  torch::Tensor const& azp,
300
- c10::optional<torch::Tensor> const& bias) {
301
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
302
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
303
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
 
1
+ #pragma once
2
+
3
  #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
4
 
5
  /*
 
38
  // Don't want to support nullptr by default
39
  template <typename T, bool EnableNullPtr = false>
40
  using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
41
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
42
  Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
43
 
44
  // Don't want to support nullptr by default
45
  template <typename T, bool EnableNullPtr = false>
46
  using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
47
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
48
  Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
49
 
50
  // This utility function constructs the arguments for the load descriptors
 
67
  // This overload handles the case where there might not be a tensor, in which
68
  // case a nullptr is passed and a constant (0) is used.
69
  template <typename Descriptor, typename T>
70
+ static auto args_from_tensor(std::optional<torch::Tensor> const& tensor) {
71
  using Arguments = typename Descriptor::Arguments;
72
  auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
73
  static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
 
223
  static ArgumentType prepare_args(torch::Tensor const& a_scales,
224
  torch::Tensor const& b_scales,
225
  torch::Tensor const& azp_adj,
226
+ std::optional<torch::Tensor> const& bias) {
227
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
228
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
229
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
 
299
  torch::Tensor const& b_scales,
300
  torch::Tensor const& azp_adj,
301
  torch::Tensor const& azp,
302
+ std::optional<torch::Tensor> const& bias) {
303
  auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
304
  auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
305
  auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
cutlass_w8a8/scaled_mm_c2x.cu CHANGED
@@ -39,7 +39,7 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
39
  torch::Tensor const& b,
40
  torch::Tensor const& a_scales,
41
  torch::Tensor const& b_scales,
42
- c10::optional<torch::Tensor> const& bias) {
43
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
44
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
45
  if (bias) {
@@ -58,8 +58,8 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
58
  torch::Tensor const& a_scales,
59
  torch::Tensor const& b_scales,
60
  torch::Tensor const& azp_adj,
61
- c10::optional<torch::Tensor> const& azp,
62
- c10::optional<torch::Tensor> const& bias) {
63
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
64
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
65
 
@@ -94,7 +94,7 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
94
  torch::Tensor const& b,
95
  torch::Tensor const& a_scales,
96
  torch::Tensor const& b_scales,
97
- c10::optional<torch::Tensor> const& bias) {
98
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
99
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
100
  if (bias) {
@@ -113,8 +113,8 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
113
  torch::Tensor const& a_scales,
114
  torch::Tensor const& b_scales,
115
  torch::Tensor const& azp_adj,
116
- c10::optional<torch::Tensor> const& azp,
117
- c10::optional<torch::Tensor> const& bias) {
118
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
119
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
120
 
@@ -165,7 +165,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
165
  torch::Tensor const& b,
166
  torch::Tensor const& a_scales,
167
  torch::Tensor const& b_scales,
168
- c10::optional<torch::Tensor> const& bias) {
169
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
170
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
171
  if (bias) {
@@ -184,8 +184,8 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
184
  torch::Tensor const& a_scales,
185
  torch::Tensor const& b_scales,
186
  torch::Tensor const& azp_adj,
187
- c10::optional<torch::Tensor> const& azp,
188
- c10::optional<torch::Tensor> const& bias) {
189
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
190
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
191
 
 
39
  torch::Tensor const& b,
40
  torch::Tensor const& a_scales,
41
  torch::Tensor const& b_scales,
42
+ std::optional<torch::Tensor> const& bias) {
43
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
44
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
45
  if (bias) {
 
58
  torch::Tensor const& a_scales,
59
  torch::Tensor const& b_scales,
60
  torch::Tensor const& azp_adj,
61
+ std::optional<torch::Tensor> const& azp,
62
+ std::optional<torch::Tensor> const& bias) {
63
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
64
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
65
 
 
94
  torch::Tensor const& b,
95
  torch::Tensor const& a_scales,
96
  torch::Tensor const& b_scales,
97
+ std::optional<torch::Tensor> const& bias) {
98
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
99
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
100
  if (bias) {
 
113
  torch::Tensor const& a_scales,
114
  torch::Tensor const& b_scales,
115
  torch::Tensor const& azp_adj,
116
+ std::optional<torch::Tensor> const& azp,
117
+ std::optional<torch::Tensor> const& bias) {
118
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
119
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
120
 
 
165
  torch::Tensor const& b,
166
  torch::Tensor const& a_scales,
167
  torch::Tensor const& b_scales,
168
+ std::optional<torch::Tensor> const& bias) {
169
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
170
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
171
  if (bias) {
 
184
  torch::Tensor const& a_scales,
185
  torch::Tensor const& b_scales,
186
  torch::Tensor const& azp_adj,
187
+ std::optional<torch::Tensor> const& azp,
188
+ std::optional<torch::Tensor> const& bias) {
189
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
190
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
191
 
cutlass_w8a8/scaled_mm_c2x.cuh CHANGED
@@ -21,15 +21,16 @@
21
  #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
22
  #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
23
 
24
- #include "common.hpp"
 
25
  // clang-format on
26
 
27
  using namespace cute;
28
 
29
  /*
30
- Epilogue functions can be defined to post-process the output before it is
31
- written to GPU memory.
32
- Epilogues must contain a public type named EVTCompute of type Sm80EVT,
33
  as well as a static prepare_args function that constructs an
34
  EVTCompute::Arguments struct.
35
  */
 
21
  #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
22
  #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
23
 
24
+ #include "core/math.hpp"
25
+ #include "cutlass_extensions/common.hpp"
26
  // clang-format on
27
 
28
  using namespace cute;
29
 
30
  /*
31
+ Epilogues defined in,
32
+ csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
33
+ must contain a public type named EVTCompute of type Sm80EVT,
34
  as well as a static prepare_args function that constructs an
35
  EVTCompute::Arguments struct.
36
  */
cutlass_w8a8/scaled_mm_c3x.cu CHANGED
@@ -1,384 +1,18 @@
1
- // clang-format will break include orders
2
- // clang-format off
3
  #include <cudaTypedefs.h>
4
 
5
  #if defined CUDA_VERSION && CUDA_VERSION >= 12000
6
 
7
- #include <torch/all.h>
 
8
 
9
- #include <ATen/cuda/CUDAContext.h>
10
-
11
- #include <iostream>
12
- #include <sstream>
13
- #include <vector>
14
-
15
- #include "cutlass/cutlass.h"
16
-
17
- #include "cute/tensor.hpp"
18
- #include "cute/atom/mma_atom.hpp"
19
- #include "cutlass/numeric_types.h"
20
-
21
- #include "cutlass/gemm/device/gemm_universal_adapter.h"
22
- #include "cutlass/gemm/kernel/gemm_universal.hpp"
23
- #include "cutlass/epilogue/collective/collective_builder.hpp"
24
- #include "cutlass/gemm/collective/collective_builder.hpp"
25
-
26
- #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
27
- #include "common.hpp"
28
- // clang-format on
29
-
30
- using namespace cute;
31
  using namespace vllm;
32
 
33
  /*
34
  This file defines quantized GEMM operations using the CUTLASS 3.x API, for
35
  NVIDIA GPUs with sm90a (Hopper) or later.
36
-
37
- Epilogue functions can be defined to post-process the output before it is
38
- written to GPU memory.
39
- Epilogues must contain a public type named EVTCompute of type Sm90EVT,
40
- as well as a static prepare_args function that constructs an
41
- EVTCompute::Arguments struct.
42
  */
43
 
44
- namespace {
45
-
46
- // A wrapper for the GEMM kernel that is used to guard against compilation on
47
- // architectures that will never use the kernel. The purpose of this is to
48
- // reduce the size of the compiled binary.
49
- // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
50
- // into code that will be executed on the device where it is defined.
51
- template <typename Kernel>
52
- struct enable_sm90_or_later : Kernel {
53
- template <typename... Args>
54
- CUTLASS_DEVICE void operator()(Args&&... args) {
55
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
56
- Kernel::operator()(std::forward<Args>(args)...);
57
- #endif
58
- }
59
- };
60
- template <typename ElementAB_, typename ElementD_,
61
- template <typename, typename, typename> typename Epilogue_,
62
- typename TileShape, typename ClusterShape, typename KernelSchedule,
63
- typename EpilogueSchedule>
64
- struct cutlass_3x_gemm {
65
- using ElementAB = ElementAB_;
66
- using ElementD = ElementD_;
67
- using ElementAcc =
68
- typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
69
- float>::type;
70
-
71
- using EpilogueDescriptor =
72
- cutlass::epilogue::collective::detail::EpilogueDescriptor<
73
- TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
74
- ElementD, EpilogueSchedule>;
75
-
76
- using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
77
-
78
- using StrideD = Stride<int64_t, Int<1>, Int<0>>;
79
- using ElementC = void;
80
- using StrideC = StrideD;
81
-
82
- using EVTCompute = typename Epilogue::EVTCompute;
83
-
84
- using CollectiveEpilogue =
85
- typename cutlass::epilogue::collective::CollectiveBuilder<
86
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
87
- ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
88
- ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
89
- EpilogueSchedule, EVTCompute>::CollectiveOp;
90
-
91
- static constexpr size_t CEStorageSize =
92
- sizeof(typename CollectiveEpilogue::SharedStorage);
93
- using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
94
- static_cast<int>(CEStorageSize)>;
95
-
96
- // clang-format off
97
- using CollectiveMainloop =
98
- typename cutlass::gemm::collective::CollectiveBuilder<
99
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
100
- ElementAB, cutlass::layout::RowMajor, 16,
101
- ElementAB, cutlass::layout::ColumnMajor, 16,
102
- ElementAcc, TileShape, ClusterShape,
103
- Stages,
104
- KernelSchedule>::CollectiveOp;
105
- // clang-format on
106
-
107
- using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
108
- cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
109
- cutlass::gemm::PersistentScheduler>>;
110
-
111
- struct GemmKernel : public KernelType {};
112
- };
113
-
114
- template <typename Gemm, typename... EpilogueArgs>
115
- void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
116
- torch::Tensor const& b,
117
- EpilogueArgs&&... epilogue_params) {
118
- using ElementAB = typename Gemm::ElementAB;
119
- using ElementD = typename Gemm::ElementD;
120
-
121
- int32_t m = a.size(0);
122
- int32_t n = b.size(1);
123
- int32_t k = a.size(1);
124
-
125
- int64_t lda = a.stride(0);
126
- int64_t ldb = b.stride(1);
127
- int64_t ldc = out.stride(0);
128
-
129
- using StrideA = Stride<int64_t, Int<1>, int64_t>;
130
- using StrideB = Stride<int64_t, Int<1>, int64_t>;
131
- using StrideC = typename Gemm::StrideC;
132
-
133
- StrideA a_stride{lda, Int<1>{}, 0};
134
- StrideB b_stride{ldb, Int<1>{}, 0};
135
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
136
-
137
- using GemmKernel = typename Gemm::GemmKernel;
138
- typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
139
-
140
- auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
141
- auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
142
- typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
143
- b_stride};
144
-
145
- auto c_ptr = static_cast<ElementD*>(out.data_ptr());
146
- typename GemmKernel::EpilogueArguments epilogue_args{
147
- Gemm::Epilogue::prepare_args(
148
- std::forward<EpilogueArgs>(epilogue_params)...),
149
- c_ptr, c_stride, c_ptr, c_stride};
150
-
151
- typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
152
- prob_shape, mainloop_args, epilogue_args};
153
-
154
- // Launch the CUTLASS GEMM kernel.
155
- using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
156
- GemmOp gemm_op;
157
- CUTLASS_CHECK(gemm_op.can_implement(args));
158
-
159
- size_t workspace_size = gemm_op.get_workspace_size(args);
160
- auto const workspace_options =
161
- torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
162
- auto workspace = torch::empty(workspace_size, workspace_options);
163
-
164
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
165
-
166
- cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
167
- CUTLASS_CHECK(status);
168
- }
169
-
170
- template <typename InType, typename OutType,
171
- template <typename, typename, typename> typename Epilogue>
172
- struct sm90_fp8_config_default {
173
- // M in (128, inf)
174
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
175
- using KernelSchedule =
176
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
177
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
178
- using TileShape = Shape<_128, _128, _128>;
179
- using ClusterShape = Shape<_2, _1, _1>;
180
- using Cutlass3xGemm =
181
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
182
- KernelSchedule, EpilogueSchedule>;
183
- };
184
-
185
- template <typename InType, typename OutType,
186
- template <typename, typename, typename> typename Epilogue>
187
- struct sm90_fp8_config_M128 {
188
- // M in (64, 128]
189
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
190
- using KernelSchedule =
191
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
192
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
193
- using TileShape = Shape<_64, _128, _128>;
194
- using ClusterShape = Shape<_2, _1, _1>;
195
- using Cutlass3xGemm =
196
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
197
- KernelSchedule, EpilogueSchedule>;
198
- };
199
-
200
- template <typename InType, typename OutType,
201
- template <typename, typename, typename> typename Epilogue>
202
- struct sm90_fp8_config_M64 {
203
- // M in [1, 64]
204
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
205
- using KernelSchedule =
206
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
207
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
208
- using TileShape = Shape<_64, _64, _128>;
209
- using ClusterShape = Shape<_1, _8, _1>;
210
-
211
- using Cutlass3xGemm =
212
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
213
- KernelSchedule, EpilogueSchedule>;
214
- };
215
-
216
- template <typename InType, typename OutType,
217
- template <typename, typename, typename> typename Epilogue>
218
- struct sm90_int8_config_default {
219
- // For M > 128 and any N
220
- static_assert(std::is_same<InType, int8_t>());
221
- using KernelSchedule =
222
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
223
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
224
- using TileShape = Shape<_128, _128, _128>;
225
- using ClusterShape = Shape<_2, _1, _1>;
226
- using Cutlass3xGemm =
227
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
228
- KernelSchedule, EpilogueSchedule>;
229
- };
230
-
231
- template <typename InType, typename OutType,
232
- template <typename, typename, typename> typename Epilogue>
233
- struct sm90_int8_config_M128 {
234
- // For M in (64, 128] and any N
235
- static_assert(std::is_same<InType, int8_t>());
236
- using KernelSchedule =
237
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
238
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
239
- using TileShape = Shape<_64, _128, _128>;
240
- using ClusterShape = Shape<_2, _1, _1>;
241
- using Cutlass3xGemm =
242
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
243
- KernelSchedule, EpilogueSchedule>;
244
- };
245
-
246
- template <typename InType, typename OutType,
247
- template <typename, typename, typename> typename Epilogue>
248
- struct sm90_int8_config_M64 {
249
- // For M in (32, 64] and any N
250
- static_assert(std::is_same<InType, int8_t>());
251
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
252
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
253
- using TileShape = Shape<_64, _64, _256>;
254
- using ClusterShape = Shape<_1, _1, _1>;
255
- using Cutlass3xGemm =
256
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
257
- KernelSchedule, EpilogueSchedule>;
258
- };
259
-
260
- template <typename InType, typename OutType,
261
- template <typename, typename, typename> typename Epilogue>
262
- struct sm90_int8_config_M32_NBig {
263
- // For M in [1, 32] and N >= 8192
264
- static_assert(std::is_same<InType, int8_t>());
265
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
266
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
267
- using TileShape = Shape<_64, _128, _256>;
268
- using ClusterShape = Shape<_1, _4, _1>;
269
- using Cutlass3xGemm =
270
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
271
- KernelSchedule, EpilogueSchedule>;
272
- };
273
-
274
- template <typename InType, typename OutType,
275
- template <typename, typename, typename> typename Epilogue>
276
- struct sm90_int8_config_M32_NSmall {
277
- // For M in [1, 32] and N < 8192
278
- static_assert(std::is_same<InType, int8_t>());
279
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
280
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
281
- using TileShape = Shape<_64, _64, _256>;
282
- using ClusterShape = Shape<_1, _8, _1>;
283
- using Cutlass3xGemm =
284
- cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
285
- KernelSchedule, EpilogueSchedule>;
286
- };
287
-
288
- } // namespace
289
-
290
- template <typename InType, typename OutType,
291
- template <typename, typename, typename> typename Epilogue,
292
- typename... EpilogueArgs>
293
- void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
294
- torch::Tensor const& b,
295
- EpilogueArgs&&... args) {
296
- static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
297
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
298
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
299
-
300
- using Cutlass3xGemmDefault =
301
- typename sm90_fp8_config_default<InType, OutType,
302
- Epilogue>::Cutlass3xGemm;
303
- using Cutlass3xGemmM64 =
304
- typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
305
- using Cutlass3xGemmM128 =
306
- typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
307
-
308
- uint32_t const m = a.size(0);
309
- uint32_t const mp2 =
310
- std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
311
-
312
- if (mp2 <= 64) {
313
- // m in [1, 64]
314
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
315
- out, a, b, std::forward<EpilogueArgs>(args)...);
316
- } else if (mp2 <= 128) {
317
- // m in (64, 128]
318
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
319
- out, a, b, std::forward<EpilogueArgs>(args)...);
320
- } else {
321
- // m in (128, inf)
322
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
323
- out, a, b, std::forward<EpilogueArgs>(args)...);
324
- }
325
- }
326
-
327
- template <typename InType, typename OutType,
328
- template <typename, typename, typename> typename Epilogue,
329
- typename... EpilogueArgs>
330
- void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
331
- torch::Tensor const& b,
332
- EpilogueArgs&&... args) {
333
- static_assert(std::is_same<InType, int8_t>());
334
- TORCH_CHECK(a.dtype() == torch::kInt8);
335
- TORCH_CHECK(b.dtype() == torch::kInt8);
336
-
337
- using Cutlass3xGemmDefault =
338
- typename sm90_int8_config_default<InType, OutType,
339
- Epilogue>::Cutlass3xGemm;
340
- using Cutlass3xGemmM128 =
341
- typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
342
- using Cutlass3xGemmM64 =
343
- typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
344
- using Cutlass3xGemmM32NBig =
345
- typename sm90_int8_config_M32_NBig<InType, OutType,
346
- Epilogue>::Cutlass3xGemm;
347
- using Cutlass3xGemmM32NSmall =
348
- typename sm90_int8_config_M32_NSmall<InType, OutType,
349
- Epilogue>::Cutlass3xGemm;
350
-
351
- uint32_t const n = out.size(1);
352
- bool const is_small_n = n < 8192;
353
-
354
- uint32_t const m = a.size(0);
355
- uint32_t const mp2 =
356
- std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
357
-
358
- if (mp2 <= 32) {
359
- // m in [1, 32]
360
- if (is_small_n) {
361
- return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
362
- out, a, b, std::forward<EpilogueArgs>(args)...);
363
- } else {
364
- return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
365
- out, a, b, std::forward<EpilogueArgs>(args)...);
366
- }
367
- } else if (mp2 <= 64) {
368
- // m in (32, 64]
369
- return cutlass_gemm_caller<Cutlass3xGemmM64>(
370
- out, a, b, std::forward<EpilogueArgs>(args)...);
371
- } else if (mp2 <= 128) {
372
- // m in (64, 128]
373
- return cutlass_gemm_caller<Cutlass3xGemmM128>(
374
- out, a, b, std::forward<EpilogueArgs>(args)...);
375
- } else {
376
- // m in (128, inf)
377
- return cutlass_gemm_caller<Cutlass3xGemmDefault>(
378
- out, a, b, std::forward<EpilogueArgs>(args)...);
379
- }
380
- }
381
-
382
  template <template <typename, typename, typename> typename Epilogue,
383
  typename... EpilogueArgs>
384
  void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
@@ -417,7 +51,7 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
417
  torch::Tensor const& b,
418
  torch::Tensor const& a_scales,
419
  torch::Tensor const& b_scales,
420
- c10::optional<torch::Tensor> const& bias) {
421
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
422
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
423
  if (bias) {
@@ -436,8 +70,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
436
  torch::Tensor const& a_scales,
437
  torch::Tensor const& b_scales,
438
  torch::Tensor const& azp_adj,
439
- c10::optional<torch::Tensor> const& azp,
440
- c10::optional<torch::Tensor> const& bias) {
441
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
442
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
443
 
 
 
 
1
  #include <cudaTypedefs.h>
2
 
3
  #if defined CUDA_VERSION && CUDA_VERSION >= 12000
4
 
5
+ #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
6
+ #include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
7
 
8
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  using namespace vllm;
10
 
11
  /*
12
  This file defines quantized GEMM operations using the CUTLASS 3.x API, for
13
  NVIDIA GPUs with sm90a (Hopper) or later.
 
 
 
 
 
 
14
  */
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  template <template <typename, typename, typename> typename Epilogue,
17
  typename... EpilogueArgs>
18
  void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
 
51
  torch::Tensor const& b,
52
  torch::Tensor const& a_scales,
53
  torch::Tensor const& b_scales,
54
+ std::optional<torch::Tensor> const& bias) {
55
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
56
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
57
  if (bias) {
 
70
  torch::Tensor const& a_scales,
71
  torch::Tensor const& b_scales,
72
  torch::Tensor const& azp_adj,
73
+ std::optional<torch::Tensor> const& azp,
74
+ std::optional<torch::Tensor> const& bias) {
75
  TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
76
  TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
77
 
cutlass_w8a8/scaled_mm_c3x.cuh ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // clang-format will break include orders
4
+ // clang-format off
5
+ #include <torch/all.h>
6
+
7
+ #include <ATen/cuda/CUDAContext.h>
8
+
9
+ #include "cutlass/cutlass.h"
10
+
11
+ #include "cute/tensor.hpp"
12
+ #include "cute/atom/mma_atom.hpp"
13
+ #include "cutlass/numeric_types.h"
14
+
15
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
16
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
17
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
18
+ #include "cutlass/gemm/collective/collective_builder.hpp"
19
+
20
+ #include "core/math.hpp"
21
+ #include "cutlass_extensions/common.hpp"
22
+ // clang-format on
23
+
24
+ /*
25
+ Epilogues defined in,
26
+ csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
27
+ must contain a public type named EVTCompute of type Sm90EVT, as well as a
28
+ static prepare_args function that constructs an EVTCompute::Arguments struct.
29
+ */
30
+
31
+ using namespace cute;
32
+
33
+ namespace vllm {
34
+
35
+ // A wrapper for the GEMM kernel that is used to guard against compilation on
36
+ // architectures that will never use the kernel. The purpose of this is to
37
+ // reduce the size of the compiled binary.
38
+ // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
39
+ // into code that will be executed on the device where it is defined.
40
+ template <typename Kernel>
41
+ struct enable_sm90_or_later : Kernel {
42
+ template <typename... Args>
43
+ CUTLASS_DEVICE void operator()(Args&&... args) {
44
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
45
+ Kernel::operator()(std::forward<Args>(args)...);
46
+ #endif
47
+ }
48
+ };
49
+
50
+ template <typename ElementAB_, typename ElementD_,
51
+ template <typename, typename, typename> typename Epilogue_,
52
+ typename TileShape, typename ClusterShape, typename KernelSchedule,
53
+ typename EpilogueSchedule>
54
+ struct cutlass_3x_gemm {
55
+ using ElementAB = ElementAB_;
56
+ using ElementD = ElementD_;
57
+ using ElementAcc =
58
+ typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
59
+ float>::type;
60
+
61
+ using EpilogueDescriptor =
62
+ cutlass::epilogue::collective::detail::EpilogueDescriptor<
63
+ TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
64
+ ElementD, EpilogueSchedule>;
65
+
66
+ using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
67
+
68
+ using StrideD = Stride<int64_t, Int<1>, Int<0>>;
69
+ using ElementC = void;
70
+ using StrideC = StrideD;
71
+
72
+ using EVTCompute = typename Epilogue::EVTCompute;
73
+
74
+ using CollectiveEpilogue =
75
+ typename cutlass::epilogue::collective::CollectiveBuilder<
76
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
77
+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
78
+ ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
79
+ EpilogueSchedule, EVTCompute>::CollectiveOp;
80
+
81
+ static constexpr size_t CEStorageSize =
82
+ sizeof(typename CollectiveEpilogue::SharedStorage);
83
+ using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
84
+ static_cast<int>(CEStorageSize)>;
85
+
86
+ // clang-format off
87
+ using CollectiveMainloop =
88
+ typename cutlass::gemm::collective::CollectiveBuilder<
89
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
90
+ ElementAB, cutlass::layout::RowMajor, 16,
91
+ ElementAB, cutlass::layout::ColumnMajor, 16,
92
+ ElementAcc, TileShape, ClusterShape,
93
+ Stages,
94
+ KernelSchedule>::CollectiveOp;
95
+ // clang-format on
96
+
97
+ using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
98
+ cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
99
+ cutlass::gemm::PersistentScheduler>>;
100
+
101
+ struct GemmKernel : public KernelType {};
102
+ };
103
+
104
+ template <typename Gemm, typename... EpilogueArgs>
105
+ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
106
+ torch::Tensor const& b,
107
+ EpilogueArgs&&... epilogue_params) {
108
+ using ElementAB = typename Gemm::ElementAB;
109
+ using ElementD = typename Gemm::ElementD;
110
+
111
+ int32_t m = a.size(0);
112
+ int32_t n = b.size(1);
113
+ int32_t k = a.size(1);
114
+
115
+ int64_t lda = a.stride(0);
116
+ int64_t ldb = b.stride(1);
117
+ int64_t ldc = out.stride(0);
118
+
119
+ using StrideA = Stride<int64_t, Int<1>, int64_t>;
120
+ using StrideB = Stride<int64_t, Int<1>, int64_t>;
121
+ using StrideC = typename Gemm::StrideC;
122
+
123
+ StrideA a_stride{lda, Int<1>{}, 0};
124
+ StrideB b_stride{ldb, Int<1>{}, 0};
125
+ StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
126
+
127
+ using GemmKernel = typename Gemm::GemmKernel;
128
+ typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
129
+
130
+ auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
131
+ auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
132
+ typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
133
+ b_stride};
134
+
135
+ auto c_ptr = static_cast<ElementD*>(out.data_ptr());
136
+ typename GemmKernel::EpilogueArguments epilogue_args{
137
+ Gemm::Epilogue::prepare_args(
138
+ std::forward<EpilogueArgs>(epilogue_params)...),
139
+ c_ptr, c_stride, c_ptr, c_stride};
140
+
141
+ typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
142
+ prob_shape, mainloop_args, epilogue_args};
143
+
144
+ // Launch the CUTLASS GEMM kernel.
145
+ using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
146
+ GemmOp gemm_op;
147
+ CUTLASS_CHECK(gemm_op.can_implement(args));
148
+
149
+ size_t workspace_size = gemm_op.get_workspace_size(args);
150
+ auto const workspace_options =
151
+ torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
152
+ auto workspace = torch::empty(workspace_size, workspace_options);
153
+
154
+ auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
155
+
156
+ cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
157
+ CUTLASS_CHECK(status);
158
+ }
159
+
160
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm_c3x.cuh"
4
+
5
+ /**
6
+ * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
7
+ * shape.
8
+ */
9
+
10
+ namespace vllm {
11
+
12
+ template <typename InType, typename OutType,
13
+ template <typename, typename, typename> typename Epilogue>
14
+ struct sm90_fp8_config_default {
15
+ // M in (128, inf)
16
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
17
+ using KernelSchedule =
18
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
19
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
20
+ using TileShape = Shape<_128, _128, _128>;
21
+ using ClusterShape = Shape<_2, _1, _1>;
22
+ using Cutlass3xGemm =
23
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
24
+ KernelSchedule, EpilogueSchedule>;
25
+ };
26
+
27
+ template <typename InType, typename OutType,
28
+ template <typename, typename, typename> typename Epilogue>
29
+ struct sm90_fp8_config_M128 {
30
+ // M in (64, 128]
31
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
32
+ using KernelSchedule =
33
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
34
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
35
+ using TileShape = Shape<_64, _128, _128>;
36
+ using ClusterShape = Shape<_2, _1, _1>;
37
+ using Cutlass3xGemm =
38
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
39
+ KernelSchedule, EpilogueSchedule>;
40
+ };
41
+
42
+ template <typename InType, typename OutType,
43
+ template <typename, typename, typename> typename Epilogue>
44
+ struct sm90_fp8_config_M64 {
45
+ // M in [1, 64]
46
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
47
+ using KernelSchedule =
48
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
49
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
50
+ using TileShape = Shape<_64, _64, _128>;
51
+ using ClusterShape = Shape<_1, _8, _1>;
52
+
53
+ using Cutlass3xGemm =
54
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
55
+ KernelSchedule, EpilogueSchedule>;
56
+ };
57
+
58
+ template <typename InType, typename OutType,
59
+ template <typename, typename, typename> typename Epilogue,
60
+ typename... EpilogueArgs>
61
+ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
62
+ torch::Tensor const& a,
63
+ torch::Tensor const& b,
64
+ EpilogueArgs&&... args) {
65
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
66
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
67
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
68
+
69
+ using Cutlass3xGemmDefault =
70
+ typename sm90_fp8_config_default<InType, OutType,
71
+ Epilogue>::Cutlass3xGemm;
72
+ using Cutlass3xGemmM64 =
73
+ typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
74
+ using Cutlass3xGemmM128 =
75
+ typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
76
+
77
+ uint32_t const m = a.size(0);
78
+ uint32_t const mp2 =
79
+ std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
80
+
81
+ if (mp2 <= 64) {
82
+ // m in [1, 64]
83
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
84
+ out, a, b, std::forward<EpilogueArgs>(args)...);
85
+ } else if (mp2 <= 128) {
86
+ // m in (64, 128]
87
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
88
+ out, a, b, std::forward<EpilogueArgs>(args)...);
89
+ } else {
90
+ // m in (128, inf)
91
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
92
+ out, a, b, std::forward<EpilogueArgs>(args)...);
93
+ }
94
+ }
95
+
96
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm_c3x.cuh"
4
+
5
+ /**
6
+ * This file defines Gemm kernel configurations for SM90 (int8) based on the
7
+ * Gemm shape.
8
+ */
9
+
10
+ namespace vllm {
11
+
12
+ template <typename InType, typename OutType,
13
+ template <typename, typename, typename> typename Epilogue>
14
+ struct sm90_int8_config_default {
15
+ // For M > 128 and any N
16
+ static_assert(std::is_same<InType, int8_t>());
17
+ using KernelSchedule =
18
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
19
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
20
+ using TileShape = Shape<_128, _128, _128>;
21
+ using ClusterShape = Shape<_2, _1, _1>;
22
+ using Cutlass3xGemm =
23
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
24
+ KernelSchedule, EpilogueSchedule>;
25
+ };
26
+
27
+ template <typename InType, typename OutType,
28
+ template <typename, typename, typename> typename Epilogue>
29
+ struct sm90_int8_config_M128 {
30
+ // For M in (64, 128] and any N
31
+ static_assert(std::is_same<InType, int8_t>());
32
+ using KernelSchedule =
33
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
34
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
35
+ using TileShape = Shape<_64, _128, _128>;
36
+ using ClusterShape = Shape<_2, _1, _1>;
37
+ using Cutlass3xGemm =
38
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
39
+ KernelSchedule, EpilogueSchedule>;
40
+ };
41
+
42
+ template <typename InType, typename OutType,
43
+ template <typename, typename, typename> typename Epilogue>
44
+ struct sm90_int8_config_M64 {
45
+ // For M in (32, 64] and any N
46
+ static_assert(std::is_same<InType, int8_t>());
47
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
48
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
49
+ using TileShape = Shape<_64, _64, _256>;
50
+ using ClusterShape = Shape<_1, _1, _1>;
51
+ using Cutlass3xGemm =
52
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
53
+ KernelSchedule, EpilogueSchedule>;
54
+ };
55
+
56
+ template <typename InType, typename OutType,
57
+ template <typename, typename, typename> typename Epilogue>
58
+ struct sm90_int8_config_M32_NBig {
59
+ // For M in [1, 32] and N >= 8192
60
+ static_assert(std::is_same<InType, int8_t>());
61
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
62
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
63
+ using TileShape = Shape<_64, _128, _256>;
64
+ using ClusterShape = Shape<_1, _4, _1>;
65
+ using Cutlass3xGemm =
66
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
67
+ KernelSchedule, EpilogueSchedule>;
68
+ };
69
+
70
+ template <typename InType, typename OutType,
71
+ template <typename, typename, typename> typename Epilogue>
72
+ struct sm90_int8_config_M32_NSmall {
73
+ // For M in [1, 32] and N < 8192
74
+ static_assert(std::is_same<InType, int8_t>());
75
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
76
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
77
+ using TileShape = Shape<_64, _64, _256>;
78
+ using ClusterShape = Shape<_1, _8, _1>;
79
+ using Cutlass3xGemm =
80
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
81
+ KernelSchedule, EpilogueSchedule>;
82
+ };
83
+
84
+ template <typename InType, typename OutType,
85
+ template <typename, typename, typename> typename Epilogue,
86
+ typename... EpilogueArgs>
87
+ inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
88
+ torch::Tensor const& a,
89
+ torch::Tensor const& b,
90
+ EpilogueArgs&&... args) {
91
+ static_assert(std::is_same<InType, int8_t>());
92
+ TORCH_CHECK(a.dtype() == torch::kInt8);
93
+ TORCH_CHECK(b.dtype() == torch::kInt8);
94
+
95
+ using Cutlass3xGemmDefault =
96
+ typename sm90_int8_config_default<InType, OutType,
97
+ Epilogue>::Cutlass3xGemm;
98
+ using Cutlass3xGemmM128 =
99
+ typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
100
+ using Cutlass3xGemmM64 =
101
+ typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
102
+ using Cutlass3xGemmM32NBig =
103
+ typename sm90_int8_config_M32_NBig<InType, OutType,
104
+ Epilogue>::Cutlass3xGemm;
105
+ using Cutlass3xGemmM32NSmall =
106
+ typename sm90_int8_config_M32_NSmall<InType, OutType,
107
+ Epilogue>::Cutlass3xGemm;
108
+
109
+ uint32_t const n = out.size(1);
110
+ bool const is_small_n = n < 8192;
111
+
112
+ uint32_t const m = a.size(0);
113
+ uint32_t const mp2 =
114
+ std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
115
+
116
+ if (mp2 <= 32) {
117
+ // m in [1, 32]
118
+ if (is_small_n) {
119
+ return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
120
+ out, a, b, std::forward<EpilogueArgs>(args)...);
121
+ } else {
122
+ return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
123
+ out, a, b, std::forward<EpilogueArgs>(args)...);
124
+ }
125
+ } else if (mp2 <= 64) {
126
+ // m in (32, 64]
127
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
128
+ out, a, b, std::forward<EpilogueArgs>(args)...);
129
+ } else if (mp2 <= 128) {
130
+ // m in (64, 128]
131
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
132
+ out, a, b, std::forward<EpilogueArgs>(args)...);
133
+ } else {
134
+ // m in (128, inf)
135
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
136
+ out, a, b, std::forward<EpilogueArgs>(args)...);
137
+ }
138
+ }
139
+
140
+ } // namespace vllm
cutlass_w8a8/scaled_mm_entry.cu CHANGED
@@ -3,30 +3,32 @@
3
  #include <c10/cuda/CUDAGuard.h>
4
  #include <torch/all.h>
5
 
 
 
6
  void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
7
  torch::Tensor const& b,
8
  torch::Tensor const& a_scales,
9
  torch::Tensor const& b_scales,
10
- c10::optional<torch::Tensor> const& bias);
11
 
12
  void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
13
  torch::Tensor const& b,
14
  torch::Tensor const& a_scales,
15
  torch::Tensor const& b_scales,
16
- c10::optional<torch::Tensor> const& bias);
17
 
18
  void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
19
  torch::Tensor const& b,
20
  torch::Tensor const& a_scales,
21
  torch::Tensor const& b_scales,
22
- c10::optional<torch::Tensor> const& bias);
23
 
24
  #if defined CUDA_VERSION && CUDA_VERSION >= 12000
25
  void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
26
  torch::Tensor const& b,
27
  torch::Tensor const& a_scales,
28
  torch::Tensor const& b_scales,
29
- c10::optional<torch::Tensor> const& bias);
30
  #endif
31
 
32
  void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
@@ -34,24 +36,24 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
34
  torch::Tensor const& a_scales,
35
  torch::Tensor const& b_scales,
36
  torch::Tensor const& azp_adj,
37
- c10::optional<torch::Tensor> const& azp,
38
- c10::optional<torch::Tensor> const& bias);
39
 
40
  void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
41
  torch::Tensor const& b,
42
  torch::Tensor const& a_scales,
43
  torch::Tensor const& b_scales,
44
  torch::Tensor const& azp_adj,
45
- c10::optional<torch::Tensor> const& azp,
46
- c10::optional<torch::Tensor> const& bias);
47
 
48
  void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
49
  torch::Tensor const& b,
50
  torch::Tensor const& a_scales,
51
  torch::Tensor const& b_scales,
52
  torch::Tensor const& azp_adj,
53
- c10::optional<torch::Tensor> const& azp,
54
- c10::optional<torch::Tensor> const& bias);
55
 
56
  #if defined CUDA_VERSION && CUDA_VERSION >= 12000
57
  void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
@@ -59,8 +61,8 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
59
  torch::Tensor const& a_scales,
60
  torch::Tensor const& b_scales,
61
  torch::Tensor const& azp_adj,
62
- c10::optional<torch::Tensor> const& azp,
63
- c10::optional<torch::Tensor> const& bias);
64
  #endif
65
 
66
  bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
@@ -79,20 +81,10 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
79
  return false;
80
  }
81
 
82
- int32_t get_sm_version_num() {
83
- int32_t major_capability, minor_capability;
84
- cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
85
- 0);
86
- cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
87
- 0);
88
- int32_t version_num = major_capability * 10 + minor_capability;
89
- return version_num;
90
- }
91
-
92
  void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
93
  torch::Tensor const& b, torch::Tensor const& a_scales,
94
  torch::Tensor const& b_scales,
95
- c10::optional<torch::Tensor> const& bias) {
96
  // Checks for conformality
97
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
98
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
@@ -154,8 +146,8 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
154
  torch::Tensor const& a_scales,
155
  torch::Tensor const& b_scales,
156
  torch::Tensor const& azp_adj,
157
- c10::optional<torch::Tensor> const& azp,
158
- c10::optional<torch::Tensor> const& bias) {
159
  // Checks for conformality
160
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
161
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
 
3
  #include <c10/cuda/CUDAGuard.h>
4
  #include <torch/all.h>
5
 
6
+ #include "cutlass_extensions/common.hpp"
7
+
8
  void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
9
  torch::Tensor const& b,
10
  torch::Tensor const& a_scales,
11
  torch::Tensor const& b_scales,
12
+ std::optional<torch::Tensor> const& bias);
13
 
14
  void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
15
  torch::Tensor const& b,
16
  torch::Tensor const& a_scales,
17
  torch::Tensor const& b_scales,
18
+ std::optional<torch::Tensor> const& bias);
19
 
20
  void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
21
  torch::Tensor const& b,
22
  torch::Tensor const& a_scales,
23
  torch::Tensor const& b_scales,
24
+ std::optional<torch::Tensor> const& bias);
25
 
26
  #if defined CUDA_VERSION && CUDA_VERSION >= 12000
27
  void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
28
  torch::Tensor const& b,
29
  torch::Tensor const& a_scales,
30
  torch::Tensor const& b_scales,
31
+ std::optional<torch::Tensor> const& bias);
32
  #endif
33
 
34
  void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
 
36
  torch::Tensor const& a_scales,
37
  torch::Tensor const& b_scales,
38
  torch::Tensor const& azp_adj,
39
+ std::optional<torch::Tensor> const& azp,
40
+ std::optional<torch::Tensor> const& bias);
41
 
42
  void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
43
  torch::Tensor const& b,
44
  torch::Tensor const& a_scales,
45
  torch::Tensor const& b_scales,
46
  torch::Tensor const& azp_adj,
47
+ std::optional<torch::Tensor> const& azp,
48
+ std::optional<torch::Tensor> const& bias);
49
 
50
  void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
51
  torch::Tensor const& b,
52
  torch::Tensor const& a_scales,
53
  torch::Tensor const& b_scales,
54
  torch::Tensor const& azp_adj,
55
+ std::optional<torch::Tensor> const& azp,
56
+ std::optional<torch::Tensor> const& bias);
57
 
58
  #if defined CUDA_VERSION && CUDA_VERSION >= 12000
59
  void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
 
61
  torch::Tensor const& a_scales,
62
  torch::Tensor const& b_scales,
63
  torch::Tensor const& azp_adj,
64
+ std::optional<torch::Tensor> const& azp,
65
+ std::optional<torch::Tensor> const& bias);
66
  #endif
67
 
68
  bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
 
81
  return false;
82
  }
83
 
 
 
 
 
 
 
 
 
 
 
84
  void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
85
  torch::Tensor const& b, torch::Tensor const& a_scales,
86
  torch::Tensor const& b_scales,
87
+ std::optional<torch::Tensor> const& bias) {
88
  // Checks for conformality
89
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
90
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
 
146
  torch::Tensor const& a_scales,
147
  torch::Tensor const& b_scales,
148
  torch::Tensor const& azp_adj,
149
+ std::optional<torch::Tensor> const& azp,
150
+ std::optional<torch::Tensor> const& bias) {
151
  // Checks for conformality
152
  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
153
  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
fp8/common.cuh CHANGED
@@ -1,6 +1,9 @@
1
  #pragma once
2
 
 
 
3
  #include <cmath>
 
4
 
5
  #ifndef USE_ROCM
6
  #include <c10/util/Float8_e4m3fn.h>
@@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
15
  // issue when running dynamic quantization. Here use 224.0f for rocm.
16
  constexpr auto FP8_E4M3_MAX = 224.0f;
17
  #endif
 
18
 
19
  namespace vllm {
20
 
@@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
89
  }
90
  }
91
 
92
- template <typename scalar_t>
93
- struct __align__(8) vec4_t {
94
- scalar_t x;
95
- scalar_t y;
96
- scalar_t z;
97
- scalar_t w;
98
- };
99
-
100
- typedef struct __align__(4) {
101
- FP8_TYPE x;
102
- FP8_TYPE y;
103
- FP8_TYPE z;
104
- FP8_TYPE w;
105
- }
106
- float8x4_t;
107
-
108
  template <typename scalar_t>
109
  __device__ float thread_max_vec(scalar_t const* __restrict__ input,
110
  int64_t const num_elems, int const tid,
@@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
139
  float const scale,
140
  int64_t const num_elems,
141
  int const tid, int const step) {
 
142
  // Vectorized input/output to better utilize memory bandwidth.
143
- vec4_t<scalar_t> const* vectorized_in =
144
- reinterpret_cast<vec4_t<scalar_t> const*>(input);
145
- float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
146
 
147
  int64_t const num_vec_elems = num_elems >> 2;
148
 
@@ -169,4 +157,4 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
169
  }
170
  }
171
 
172
- } // namespace vllm
 
1
  #pragma once
2
 
3
+ #include "vectorization.cuh"
4
+
5
  #include <cmath>
6
+ #include <c10/core/ScalarType.h>
7
 
8
  #ifndef USE_ROCM
9
  #include <c10/util/Float8_e4m3fn.h>
 
18
  // issue when running dynamic quantization. Here use 224.0f for rocm.
19
  constexpr auto FP8_E4M3_MAX = 224.0f;
20
  #endif
21
+ constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
22
 
23
  namespace vllm {
24
 
 
93
  }
94
  }
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  template <typename scalar_t>
97
  __device__ float thread_max_vec(scalar_t const* __restrict__ input,
98
  int64_t const num_elems, int const tid,
 
127
  float const scale,
128
  int64_t const num_elems,
129
  int const tid, int const step) {
130
+ using float8x4_t = q8x4_t<FP8_TYPE>;
131
  // Vectorized input/output to better utilize memory bandwidth.
132
+ auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
133
+ auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
 
134
 
135
  int64_t const num_vec_elems = num_elems >> 2;
136
 
 
157
  }
158
  }
159
 
160
+ } // namespace vllm
gptq_marlin/gptq_marlin.cu CHANGED
@@ -832,6 +832,7 @@ __global__ void Marlin(
832
  int4* sh_g_idx = sh_b + (stages * b_sh_stage);
833
  int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
834
  int4* sh_s = sh_zp + (stages * zp_sh_stage);
 
835
 
836
  // Register storage for double buffer of shared memory reads.
837
  FragA frag_a[2][thread_m_blocks];
@@ -930,11 +931,11 @@ __global__ void Marlin(
930
  int4* sh_s_stage = sh_s + s_sh_stage * pipe;
931
 
932
  if constexpr (group_blocks >= thread_k_blocks) {
 
 
 
933
  // Only fetch scales if this tile starts a new group
934
- if (pipe % (group_blocks / thread_k_blocks) == 0) {
935
- if (s_sh_wr_pred) {
936
- cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
937
- }
938
  s_gl_rd += s_gl_rd_delta;
939
  }
940
  } else {
@@ -1036,9 +1037,7 @@ __global__ void Marlin(
1036
  // No act-order case
1037
  if constexpr (group_blocks != -1) {
1038
  if constexpr (group_blocks >= thread_k_blocks) {
1039
- int4* sh_s_stage =
1040
- sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) *
1041
- (pipe / (group_blocks / thread_k_blocks)));
1042
  reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
1043
  } else {
1044
  int warp_id = threadIdx.x / 32;
@@ -1337,15 +1336,15 @@ __global__ void Marlin(
1337
  int red_sh_wr =
1338
  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
1339
  if (i < red_off) {
1340
- float* c_rd =
1341
- reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
1342
- float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
1343
  #pragma unroll
1344
  for (int k = 0; k < 4; k++)
1345
  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
1346
  c_rd[k] + c_wr[k];
1347
  }
1348
- sh[red_sh_wr] =
1349
  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
1350
  }
1351
  }
@@ -1355,7 +1354,7 @@ __global__ void Marlin(
1355
  #pragma unroll
1356
  for (int i = 0; i < 4 * 2; i++) {
1357
  float* c_rd =
1358
- reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
1359
  #pragma unroll
1360
  for (int j = 0; j < 4; j++)
1361
  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
@@ -1395,7 +1394,7 @@ __global__ void Marlin(
1395
  #pragma unroll
1396
  for (int i = 0; i < thread_m_blocks * 4; i++) {
1397
  cp_async4_pred(
1398
- &sh[c_sh_wr + c_sh_wr_delta * i],
1399
  &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
1400
  c_gl_wr_delta_i * (i % 2)],
1401
  i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
@@ -1408,7 +1407,7 @@ __global__ void Marlin(
1408
  for (int i = 0; i < thread_m_blocks * 4; i++) {
1409
  if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
1410
  if (!first) {
1411
- int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
1412
  #pragma unroll
1413
  for (int j = 0; j < 2 * 4; j++) {
1414
  reinterpret_cast<float*>(
@@ -1459,10 +1458,10 @@ __global__ void Marlin(
1459
  float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
1460
  #pragma unroll
1461
  for (int k = 0; k < th_size; k++) {
1462
- sh[threadIdx.x] =
1463
  C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
1464
 
1465
- float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
1466
  #pragma unroll
1467
  for (int f = 0; f < 4; f++) {
1468
  frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
@@ -1513,7 +1512,7 @@ __global__ void Marlin(
1513
  res = __hmul2(res, s[0]);
1514
  }
1515
 
1516
- ((scalar_t2*)sh)[idx] = res;
1517
  };
1518
 
1519
  if (threadIdx.x / 32 < thread_n_blocks / 4) {
@@ -1541,7 +1540,7 @@ __global__ void Marlin(
1541
  i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
1542
  i++) {
1543
  if (c_gl_wr < c_gl_wr_end) {
1544
- C[c_gl_wr] = sh[c_sh_rd];
1545
  c_gl_wr += c_gl_wr_delta;
1546
  c_sh_rd += c_sh_rd_delta;
1547
  }
@@ -1863,9 +1862,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
1863
 
1864
  float pipe_size = (a_size + b_size) * pipe_stages;
1865
 
 
 
 
1866
  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
1867
 
1868
- return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
1869
  }
1870
 
1871
  bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
 
832
  int4* sh_g_idx = sh_b + (stages * b_sh_stage);
833
  int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
834
  int4* sh_s = sh_zp + (stages * zp_sh_stage);
835
+ int4* sh_red = sh_s + (stages * s_sh_stage);
836
 
837
  // Register storage for double buffer of shared memory reads.
838
  FragA frag_a[2][thread_m_blocks];
 
931
  int4* sh_s_stage = sh_s + s_sh_stage * pipe;
932
 
933
  if constexpr (group_blocks >= thread_k_blocks) {
934
+ if (s_sh_wr_pred) {
935
+ cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]);
936
+ }
937
  // Only fetch scales if this tile starts a new group
938
+ if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) {
 
 
 
939
  s_gl_rd += s_gl_rd_delta;
940
  }
941
  } else {
 
1037
  // No act-order case
1038
  if constexpr (group_blocks != -1) {
1039
  if constexpr (group_blocks >= thread_k_blocks) {
1040
+ int4* sh_s_stage = sh_s + s_sh_stage * pipe;
 
 
1041
  reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
1042
  } else {
1043
  int warp_id = threadIdx.x / 32;
 
1336
  int red_sh_wr =
1337
  red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
1338
  if (i < red_off) {
1339
+ float* c_rd = reinterpret_cast<float*>(
1340
+ &sh_red[red_sh_delta * j + red_sh_rd]);
1341
+ float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
1342
  #pragma unroll
1343
  for (int k = 0; k < 4; k++)
1344
  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
1345
  c_rd[k] + c_wr[k];
1346
  }
1347
+ sh_red[red_sh_wr] =
1348
  reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
1349
  }
1350
  }
 
1354
  #pragma unroll
1355
  for (int i = 0; i < 4 * 2; i++) {
1356
  float* c_rd =
1357
+ reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
1358
  #pragma unroll
1359
  for (int j = 0; j < 4; j++)
1360
  reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
 
1394
  #pragma unroll
1395
  for (int i = 0; i < thread_m_blocks * 4; i++) {
1396
  cp_async4_pred(
1397
+ &sh_red[c_sh_wr + c_sh_wr_delta * i],
1398
  &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
1399
  c_gl_wr_delta_i * (i % 2)],
1400
  i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
 
1407
  for (int i = 0; i < thread_m_blocks * 4; i++) {
1408
  if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
1409
  if (!first) {
1410
+ int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
1411
  #pragma unroll
1412
  for (int j = 0; j < 2 * 4; j++) {
1413
  reinterpret_cast<float*>(
 
1458
  float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
1459
  #pragma unroll
1460
  for (int k = 0; k < th_size; k++) {
1461
+ sh_red[threadIdx.x] =
1462
  C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
1463
 
1464
+ float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]);
1465
  #pragma unroll
1466
  for (int f = 0; f < 4; f++) {
1467
  frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
 
1512
  res = __hmul2(res, s[0]);
1513
  }
1514
 
1515
+ ((scalar_t2*)sh_red)[idx] = res;
1516
  };
1517
 
1518
  if (threadIdx.x / 32 < thread_n_blocks / 4) {
 
1540
  i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
1541
  i++) {
1542
  if (c_gl_wr < c_gl_wr_end) {
1543
+ C[c_gl_wr] = sh_red[c_sh_rd];
1544
  c_gl_wr += c_gl_wr_delta;
1545
  c_sh_rd += c_sh_rd_delta;
1546
  }
 
1862
 
1863
  float pipe_size = (a_size + b_size) * pipe_stages;
1864
 
1865
+ float reduce_size = max(th_config.num_threads * 32 * 4,
1866
+ (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2);
1867
+
1868
  TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
1869
 
1870
+ return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size);
1871
  }
1872
 
1873
  bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
vectorization.cuh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ /**
3
+ * __device__ datatypes vectorized by 4
4
+ */
5
+
6
+ // Include both AMD and NVIDIA fp8 types to avoid circular import
7
+ // TODO(luka/varun) use FP8_TYPE instead after refactoring
8
+ #include <c10/util/Float8_e4m3fnuz.h>
9
+ #include <c10/util/Float8_e4m3fn.h>
10
+
11
+ namespace vllm {
12
+
13
+ // Vectorization containers
14
+ template <typename scalar_t>
15
+ struct __align__(8) vec4_t {
16
+ scalar_t x;
17
+ scalar_t y;
18
+ scalar_t z;
19
+ scalar_t w;
20
+ };
21
+
22
+ template <typename quant_type_t>
23
+ struct __align__(4) q8x4_t {
24
+ static_assert(std::is_same_v<quant_type_t, int8_t> ||
25
+ std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
26
+ std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
27
+ quant_type_t x;
28
+ quant_type_t y;
29
+ quant_type_t z;
30
+ quant_type_t w;
31
+ };
32
+
33
+ } // namespace vllm