danieldk HF staff commited on
Commit
6eaa88c
·
1 Parent(s): f235406

Sync with upstream

Browse files
build.toml CHANGED
@@ -7,24 +7,39 @@ src = [
7
  "core/scalar_type.hpp",
8
  "ext-torch/registration.h",
9
  "ext-torch/torch_binding.cpp",
10
- "ext-torch/torch_binding.h"
11
  ]
12
- include = [ "." ]
13
  pyroot = "ext-torch"
14
- pyext = [ "py", "json" ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  [kernel.moe]
17
- capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
18
  src = [
19
  "cuda_compat.h",
20
  "dispatch_utils.h",
21
  "moe/moe_align_sum_kernels.cu",
22
  "moe/topk_softmax_kernels.cu",
23
  ]
24
- depends = [ "torch" ]
25
 
26
  [kernel.moe-marlin]
27
- capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0" ]
28
  src = [
29
  "core/exception.hpp",
30
  "core/scalar_type.hpp",
@@ -37,14 +52,14 @@ src = [
37
  "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu",
38
  "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h",
39
  ]
40
- include = [ "." ]
41
- depends = [ "torch" ]
42
 
43
  [kernel.activation]
44
- capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
45
  src = [
46
  "activation/activation_kernels.cu",
47
  "activation/cuda_compat.h",
48
  "activation/dispatch_utils.h",
49
  ]
50
- depends = [ "torch" ]
 
7
  "core/scalar_type.hpp",
8
  "ext-torch/registration.h",
9
  "ext-torch/torch_binding.cpp",
10
+ "ext-torch/torch_binding.h",
11
  ]
12
+ include = ["."]
13
  pyroot = "ext-torch"
14
+ pyext = ["py", "json"]
15
+
16
+ [kernel.fp8]
17
+ capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
18
+ src = [
19
+ "cuda_compat.h",
20
+ "dispatch_utils.h",
21
+ "fp8/amd/hip_float8.h",
22
+ "fp8/amd/hip_float8_impl.h",
23
+ "fp8/common.cu",
24
+ "fp8/common.cuh",
25
+ "fp8/vectorization.cuh",
26
+ ]
27
+ include = ["."]
28
+ depends = ["torch"]
29
+
30
 
31
  [kernel.moe]
32
+ capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
33
  src = [
34
  "cuda_compat.h",
35
  "dispatch_utils.h",
36
  "moe/moe_align_sum_kernels.cu",
37
  "moe/topk_softmax_kernels.cu",
38
  ]
39
+ depends = ["torch"]
40
 
41
  [kernel.moe-marlin]
42
+ capabilities = ["8.0", "8.6", "8.7", "8.9", "9.0"]
43
  src = [
44
  "core/exception.hpp",
45
  "core/scalar_type.hpp",
 
52
  "marlin-moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu",
53
  "marlin-moe/marlin_kernels/marlin_moe_kernel_ku8b128.h",
54
  ]
55
+ include = ["."]
56
+ depends = ["torch"]
57
 
58
  [kernel.activation]
59
+ capabilities = ["7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0"]
60
  src = [
61
  "activation/activation_kernels.cu",
62
  "activation/cuda_compat.h",
63
  "activation/dispatch_utils.h",
64
  ]
65
+ depends = ["torch"]
ext-torch/moe/fp8.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import torch
 
 
2
 
3
- from typing import Tuple, Optional, Union
 
4
 
5
 
6
  def is_hip() -> bool:
@@ -49,15 +54,179 @@ def scaled_fp8_quant(
49
  if scale is None:
50
  if use_per_token_if_dynamic:
51
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
52
- torch.ops._C.dynamic_per_token_scaled_fp8_quant(
53
- output, input, scale, scale_ub
54
- )
55
  else:
56
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
57
- torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
58
  else:
59
  # num_token_padding not implemented for this case
60
  assert scale.numel() == 1 or num_token_padding is None
61
- torch.ops._C.static_scaled_fp8_quant(output, input, scale)
62
 
63
  return output, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, Union
2
+
3
  import torch
4
+ import triton
5
+ import triton.language as tl
6
 
7
+
8
+ from ._ops import ops
9
 
10
 
11
  def is_hip() -> bool:
 
54
  if scale is None:
55
  if use_per_token_if_dynamic:
56
  scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
57
+ ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub)
 
 
58
  else:
59
  scale = torch.zeros(1, device=input.device, dtype=torch.float32)
60
+ ops.dynamic_scaled_fp8_quant(output, input, scale)
61
  else:
62
  # num_token_padding not implemented for this case
63
  assert scale.numel() == 1 or num_token_padding is None
64
+ ops.static_scaled_fp8_quant(output, input, scale)
65
 
66
  return output, scale
67
+
68
+
69
+ @triton.jit
70
+ def _per_token_group_quant_fp8(
71
+ # Pointers to inputs and output
72
+ y_ptr,
73
+ y_q_ptr,
74
+ y_s_ptr,
75
+ group_size,
76
+ # Avoid to divide zero
77
+ eps,
78
+ # Information for float8
79
+ fp8_min,
80
+ fp8_max,
81
+ # Meta-parameters
82
+ BLOCK: tl.constexpr,
83
+ ):
84
+ """A Triton-accelerated function to perform per-token-group
85
+ quantization on a tensor.
86
+ This function converts the tensor values into float8 values.
87
+ """
88
+ # Map the program id to the row of X and Y it should compute.
89
+ g_id = tl.program_id(0)
90
+ y_ptr += g_id * group_size
91
+ y_q_ptr += g_id * group_size
92
+ y_s_ptr += g_id
93
+
94
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
95
+ mask = cols < group_size
96
+
97
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
98
+ # Quant
99
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
100
+ y_s = _absmax / fp8_max
101
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
+
103
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
104
+ tl.store(y_s_ptr, y_s)
105
+
106
+
107
+ @triton.jit
108
+ def _per_token_group_quant_fp8_colmajor(
109
+ # Pointers to inputs and output
110
+ y_ptr,
111
+ y_q_ptr,
112
+ y_s_ptr,
113
+ group_size,
114
+ # Num columns of y
115
+ y_num_columns,
116
+ # Stride from one column to the next of y_s
117
+ y_s_col_stride,
118
+ # Avoid to divide zero
119
+ eps,
120
+ # Information for float8
121
+ fp8_min,
122
+ fp8_max,
123
+ # Meta-parameters
124
+ BLOCK: tl.constexpr,
125
+ ):
126
+ """A Triton-accelerated function to perform per-token-group
127
+ quantization on a tensor.
128
+ This function converts the tensor values into float8 values.
129
+ """
130
+ # Map the program id to the row of X and Y it should compute.
131
+ g_id = tl.program_id(0)
132
+ y_ptr += g_id * group_size
133
+ y_q_ptr += g_id * group_size
134
+
135
+ # Convert g_id the flattened block coordinate to 2D so we can index
136
+ # into the output y_scales matrix
137
+ blocks_per_row = y_num_columns // group_size
138
+ scale_col = g_id % blocks_per_row
139
+ scale_row = g_id // blocks_per_row
140
+ y_s_ptr += scale_col * y_s_col_stride + scale_row
141
+
142
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
143
+ mask = cols < group_size
144
+
145
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
146
+ # Quant
147
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
148
+ y_s = _absmax / fp8_max
149
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
150
+
151
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
152
+ tl.store(y_s_ptr, y_s)
153
+
154
+
155
+ def per_token_group_quant_fp8(
156
+ x: torch.Tensor,
157
+ group_size: int,
158
+ eps: float = 1e-10,
159
+ dtype: Optional[torch.dtype] = None,
160
+ column_major_scales: bool = False,
161
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
162
+ """Function to perform per-token-group quantization on an input tensor `x`.
163
+ It converts the tensor values into signed float8 values and returns the
164
+ quantized tensor along with the scaling factor used for quantization.
165
+ Args:
166
+ x: The input tensor with ndim >= 2.
167
+ group_size: The group size used for quantization.
168
+ eps: The minimum to avoid dividing zero.
169
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
170
+ is supported for now.
171
+ Returns:
172
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
173
+ scaling factor for quantization.
174
+ """
175
+ if dtype is None:
176
+ dtype = (
177
+ torch.float8_e4m3fnuz if current_platform.is_rocm() else torch.float8_e4m3fn
178
+ )
179
+ assert x.shape[-1] % group_size == 0, (
180
+ f"the last dimension of `x` {x.shape[-1]} must be divisible "
181
+ f"by `group_size` {group_size}"
182
+ )
183
+ assert x.is_contiguous(), "`x` must be contiguous"
184
+
185
+ finfo = torch.finfo(dtype)
186
+ fp8_min = finfo.min
187
+ fp8_max = finfo.max
188
+
189
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
190
+ M = x.numel() // group_size
191
+ N = group_size
192
+ if column_major_scales:
193
+ shape = (x.shape[-1] // group_size,) + x.shape[:-1]
194
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
195
+ else:
196
+ shape = x.shape[:-1] + (x.shape[-1] // group_size,)
197
+ x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
198
+
199
+ BLOCK = triton.next_power_of_2(N)
200
+ # heuristics for number of warps
201
+ num_warps = min(max(BLOCK // 256, 1), 8)
202
+ num_stages = 1
203
+ if column_major_scales:
204
+ _per_token_group_quant_fp8_colmajor[(M,)](
205
+ x,
206
+ x_q,
207
+ x_s,
208
+ group_size,
209
+ x.shape[1],
210
+ x_s.stride(1),
211
+ eps,
212
+ fp8_min=fp8_min,
213
+ fp8_max=fp8_max,
214
+ BLOCK=BLOCK,
215
+ num_warps=num_warps,
216
+ num_stages=num_stages,
217
+ )
218
+ else:
219
+ _per_token_group_quant_fp8[(M,)](
220
+ x,
221
+ x_q,
222
+ x_s,
223
+ group_size,
224
+ eps,
225
+ fp8_min=fp8_min,
226
+ fp8_max=fp8_max,
227
+ BLOCK=BLOCK,
228
+ num_warps=num_warps,
229
+ num_stages=num_stages,
230
+ )
231
+
232
+ return x_q, x_s
ext-torch/moe/fused_marlin_moe.py CHANGED
@@ -40,7 +40,6 @@ def single_marlin_moe(
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
43
- override_config: Optional[Dict[str, Any]] = None,
44
  num_bits: int = 8,
45
  is_k_full: bool = True,
46
  ) -> torch.Tensor:
@@ -61,8 +60,6 @@ def single_marlin_moe(
61
  - topk (int): The number of top-k experts to select.
62
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
63
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
64
- - override_config (Optional[Dict[str, Any]]): Optional override
65
- for the kernel configuration.
66
  - num_bits (bool): The number of bits in expert weights quantization.
67
 
68
  Returns:
@@ -90,7 +87,6 @@ def single_marlin_moe(
90
  w.shape,
91
  topk_ids.shape[1],
92
  None,
93
- override_config=override_config,
94
  is_marlin=True,
95
  )
96
  config = get_config_func(M)
@@ -154,6 +150,25 @@ def single_marlin_moe(
154
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
155
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  def fused_marlin_moe(
158
  hidden_states: torch.Tensor,
159
  w1: torch.Tensor,
@@ -169,7 +184,6 @@ def fused_marlin_moe(
169
  sort_indices2: Optional[torch.Tensor] = None,
170
  w1_zeros: Optional[torch.Tensor] = None,
171
  w2_zeros: Optional[torch.Tensor] = None,
172
- override_config: Optional[Dict[str, Any]] = None,
173
  num_bits: int = 8,
174
  is_k_full: bool = True,
175
  ) -> torch.Tensor:
@@ -193,8 +207,6 @@ def fused_marlin_moe(
193
  permutation.
194
  - topk_weights (torch.Tensor): Top-k weights.
195
  - topk_ids (torch.Tensor): Indices of topk-k elements.
196
- - override_config (Optional[Dict[str, Any]]): Optional override
197
- for the kernel configuration.
198
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
199
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
200
  - num_bits (bool): The number of bits in expert weights quantization.
@@ -248,7 +260,6 @@ def fused_marlin_moe(
248
  w2.shape,
249
  topk_ids.shape[1],
250
  None,
251
- override_config=override_config,
252
  is_marlin=True,
253
  )
254
  config = get_config_func(M)
@@ -350,6 +361,30 @@ def fused_marlin_moe(
350
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  if hasattr(ops, "marlin_gemm_moe"):
354
 
355
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
 
40
  g_idx: Optional[torch.Tensor] = None,
41
  sort_indices: Optional[torch.Tensor] = None,
42
  w_zeros: Optional[torch.Tensor] = None,
 
43
  num_bits: int = 8,
44
  is_k_full: bool = True,
45
  ) -> torch.Tensor:
 
60
  - topk (int): The number of top-k experts to select.
61
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
62
  - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
 
 
63
  - num_bits (bool): The number of bits in expert weights quantization.
64
 
65
  Returns:
 
87
  w.shape,
88
  topk_ids.shape[1],
89
  None,
 
90
  is_marlin=True,
91
  )
92
  config = get_config_func(M)
 
150
  return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
151
 
152
 
153
+ if hasattr(ops, "single_marlin_gemm_moe"):
154
+
155
+ @register_fake(add_op_namespace_prefix("single_marlin_gemm_moe"))
156
+ def single_marlin_moe_fake(
157
+ hidden_states: torch.Tensor,
158
+ w: torch.Tensor,
159
+ scales: torch.Tensor,
160
+ gating_output: torch.Tensor,
161
+ topk: int,
162
+ renormalize: bool,
163
+ g_idx: Optional[torch.Tensor] = None,
164
+ sort_indices: Optional[torch.Tensor] = None,
165
+ w_zeros: Optional[torch.Tensor] = None,
166
+ num_bits: int = 8,
167
+ is_k_full: bool = True,
168
+ ) -> torch.Tensor:
169
+ return torch.empty_like(hidden_states)
170
+
171
+
172
  def fused_marlin_moe(
173
  hidden_states: torch.Tensor,
174
  w1: torch.Tensor,
 
184
  sort_indices2: Optional[torch.Tensor] = None,
185
  w1_zeros: Optional[torch.Tensor] = None,
186
  w2_zeros: Optional[torch.Tensor] = None,
 
187
  num_bits: int = 8,
188
  is_k_full: bool = True,
189
  ) -> torch.Tensor:
 
207
  permutation.
208
  - topk_weights (torch.Tensor): Top-k weights.
209
  - topk_ids (torch.Tensor): Indices of topk-k elements.
 
 
210
  - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
211
  - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
212
  - num_bits (bool): The number of bits in expert weights quantization.
 
260
  w2.shape,
261
  topk_ids.shape[1],
262
  None,
 
263
  is_marlin=True,
264
  )
265
  config = get_config_func(M)
 
361
  return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
362
 
363
 
364
+ if hasattr(ops, "fused_marlin_moe"):
365
+
366
+ @register_fake(add_op_namespace_prefix("fused_marlin_moe"))
367
+ def fused_marlin_moe_fake(
368
+ hidden_states: torch.Tensor,
369
+ w1: torch.Tensor,
370
+ w2: torch.Tensor,
371
+ w1_scale: torch.Tensor,
372
+ w2_scale: torch.Tensor,
373
+ gating_output: torch.Tensor,
374
+ topk_weights: torch.Tensor,
375
+ topk_ids: torch.Tensor,
376
+ g_idx1: Optional[torch.Tensor] = None,
377
+ g_idx2: Optional[torch.Tensor] = None,
378
+ sort_indices1: Optional[torch.Tensor] = None,
379
+ sort_indices2: Optional[torch.Tensor] = None,
380
+ w1_zeros: Optional[torch.Tensor] = None,
381
+ w2_zeros: Optional[torch.Tensor] = None,
382
+ num_bits: int = 8,
383
+ is_k_full: bool = True,
384
+ ) -> torch.Tensor:
385
+ return torch.empty_like(hidden_states)
386
+
387
+
388
  if hasattr(ops, "marlin_gemm_moe"):
389
 
390
  @register_fake(add_op_namespace_prefix("marlin_gemm_moe"))
ext-torch/moe/fused_moe.py CHANGED
@@ -1,21 +1,242 @@
 
1
  """Fused MoE kernel."""
2
 
3
  import functools
4
  import json
 
5
  import os
6
- from typing import Any, Callable, Dict, Optional, Tuple
7
 
8
  import torch
9
  import triton
10
  import triton.language as tl
11
 
 
12
  from ._ops import ops
13
- from .fp8 import scaled_fp8_quant
14
  from .platforms import current_platform
15
 
 
 
 
16
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  @triton.jit
20
  def fused_moe_kernel(
21
  # Pointers to matrices
@@ -44,8 +265,14 @@ def fused_moe_kernel(
44
  stride_bn,
45
  stride_cm,
46
  stride_cn,
 
 
47
  stride_bse,
 
48
  stride_bsn,
 
 
 
49
  # Meta-parameters
50
  BLOCK_SIZE_M: tl.constexpr,
51
  BLOCK_SIZE_N: tl.constexpr,
@@ -105,17 +332,17 @@ def fused_moe_kernel(
105
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
106
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
107
  return
108
- offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
110
  token_mask = offs_token < num_valid_tokens
111
 
112
- offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
113
  offs_k = tl.arange(0, BLOCK_SIZE_K)
114
  a_ptrs = a_ptr + (
115
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
116
  )
117
 
118
- off_experts = tl.load(expert_ids_ptr + pid_m)
119
  b_ptrs = (
120
  b_ptr
121
  + off_experts * stride_be
@@ -128,8 +355,15 @@ def fused_moe_kernel(
128
  b_scale = tl.load(b_scale_ptrs)
129
 
130
  if use_fp8_w8a8:
131
- a_scale = tl.load(a_scale_ptr)
132
- b_scale = tl.load(b_scale_ptr + off_experts)
 
 
 
 
 
 
 
133
 
134
  # -----------------------------------------------------------
135
  # Iterate to compute a block of the C matrix.
@@ -151,7 +385,17 @@ def fused_moe_kernel(
151
  if use_int8_w8a16:
152
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
153
  elif use_fp8_w8a8:
154
- accumulator = tl.dot(a, b, acc=accumulator)
 
 
 
 
 
 
 
 
 
 
155
  else:
156
  accumulator += tl.dot(a, b)
157
  # Advance the ptrs to the next K block.
@@ -164,7 +408,10 @@ def fused_moe_kernel(
164
  if use_int8_w8a16:
165
  accumulator = (accumulator * b_scale).to(compute_type)
166
  elif use_fp8_w8a8:
167
- accumulator = (accumulator * a_scale * b_scale).to(compute_type)
 
 
 
168
  else:
169
  accumulator = accumulator.to(compute_type)
170
  # -----------------------------------------------------------
@@ -175,6 +422,141 @@ def fused_moe_kernel(
175
  tl.store(c_ptrs, accumulator, mask=c_mask)
176
 
177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def moe_align_block_size(
179
  topk_ids: torch.Tensor, block_size: int, num_experts: int
180
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -225,9 +607,34 @@ def moe_align_block_size(
225
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
226
  )
227
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
228
- ops.moe_align_block_size(
229
- topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
230
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  return sorted_ids, expert_ids, num_tokens_post_pad
232
 
233
 
@@ -237,6 +644,7 @@ def invoke_fused_moe_kernel(
237
  C: torch.Tensor,
238
  A_scale: Optional[torch.Tensor],
239
  B_scale: Optional[torch.Tensor],
 
240
  topk_weights: torch.Tensor,
241
  topk_ids: torch.Tensor,
242
  sorted_token_ids: torch.Tensor,
@@ -248,64 +656,147 @@ def invoke_fused_moe_kernel(
248
  compute_type: tl.dtype,
249
  use_fp8_w8a8: bool,
250
  use_int8_w8a16: bool,
 
 
251
  ) -> None:
252
  assert topk_weights.stride(1) == 1
253
  assert sorted_token_ids.stride(0) == 1
254
 
255
  if use_fp8_w8a8:
256
- A, A_scale = scaled_fp8_quant(A, A_scale)
257
  assert B_scale is not None
258
- elif use_int8_w8a16:
 
 
 
 
 
 
 
 
 
259
  assert B_scale is not None
 
260
  else:
261
  assert A_scale is None
262
  assert B_scale is None
263
 
 
 
 
 
 
 
 
264
  grid = lambda META: (
265
- triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
266
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
267
  )
268
 
269
- fused_moe_kernel[grid](
270
- A,
271
- B,
272
- C,
273
- A_scale,
274
- B_scale,
275
- topk_weights,
276
- sorted_token_ids,
277
- expert_ids,
278
- num_tokens_post_padded,
279
- B.shape[1],
280
- B.shape[2],
281
- sorted_token_ids.shape[0],
282
- topk_ids.numel(),
283
- A.stride(0),
284
- A.stride(1),
285
- B.stride(0),
286
- B.stride(2),
287
- B.stride(1),
288
- C.stride(1),
289
- C.stride(2),
290
- B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
291
- B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
292
- MUL_ROUTED_WEIGHT=mul_routed_weight,
293
- top_k=top_k,
294
- compute_type=compute_type,
295
- use_fp8_w8a8=use_fp8_w8a8,
296
- use_int8_w8a16=use_int8_w8a16,
297
- **config,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
- def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
 
 
 
302
  device_name = current_platform.get_device_name().replace(" ", "_")
303
  dtype_selector = "" if not dtype else f",dtype={dtype}"
304
- return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
 
 
 
305
 
306
 
 
307
  @functools.lru_cache
308
- def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
 
 
 
 
 
 
309
  """
310
  Return optimized configurations for the fused MoE kernel.
311
 
@@ -317,18 +808,27 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int,
317
 
318
  # First look up if an optimized configuration is available in the configs
319
  # directory
320
- json_file_name = get_config_file_name(E, N, dtype)
 
321
 
322
  config_file_path = os.path.join(
323
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
324
  )
325
  if os.path.exists(config_file_path):
326
  with open(config_file_path) as f:
 
327
  # If a configuration has been found, return it
328
  return {int(key): val for key, val in json.load(f).items()}
329
 
330
  # If no optimized configuration is available, we will use the default
331
  # configuration
 
 
 
 
 
 
 
332
  return None
333
 
334
 
@@ -340,21 +840,34 @@ def get_default_config(
340
  topk: int,
341
  dtype: Optional[str],
342
  is_marlin: bool,
 
343
  ) -> Dict[str, int]:
344
- config = {
345
- "BLOCK_SIZE_M": 64,
346
- "BLOCK_SIZE_N": 64,
347
- "BLOCK_SIZE_K": 32,
348
- "GROUP_SIZE_M": 8,
349
- }
350
- # A heuristic: fused marlin works faster with this config for small M
351
- if M <= E or (is_marlin and M <= 32):
352
  config = {
353
- "BLOCK_SIZE_M": 16,
354
- "BLOCK_SIZE_N": 32,
355
- "BLOCK_SIZE_K": 64,
356
- "GROUP_SIZE_M": 1,
 
 
357
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return config
359
 
360
 
@@ -364,15 +877,21 @@ def try_get_optimal_moe_config(
364
  top_k: int,
365
  dtype: Optional[str],
366
  M: int,
367
- override_config: Optional[Dict[str, Any]] = None,
368
  is_marlin: bool = False,
 
369
  ):
 
 
 
 
370
  if override_config:
371
  config = override_config
372
  else:
373
  # First try to load optimal config from the file
374
  E, _, N = w2_shape
375
- configs = get_moe_configs(E, N, dtype)
 
 
376
 
377
  if configs:
378
  # If an optimal configuration map has been found, look up the
@@ -380,7 +899,9 @@ def try_get_optimal_moe_config(
380
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
  else:
382
  # Else use the default config
383
- config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
 
 
384
  return config
385
 
386
 
@@ -416,7 +937,8 @@ def fused_topk(
416
  return topk_weights, topk_ids
417
 
418
 
419
- # This is used by the Deepseek-V2 model
 
420
  def grouped_topk(
421
  hidden_states: torch.Tensor,
422
  gating_output: torch.Tensor,
@@ -424,11 +946,25 @@ def grouped_topk(
424
  renormalize: bool,
425
  num_expert_group: int = 0,
426
  topk_group: int = 0,
 
 
427
  ):
428
 
429
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
430
 
431
- scores = torch.softmax(gating_output, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
432
  num_token = scores.shape[0]
433
  group_scores = (
434
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
@@ -444,7 +980,13 @@ def grouped_topk(
444
  .reshape(num_token, -1)
445
  ) # [n, e]
446
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
447
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
 
 
 
 
 
 
448
 
449
  if renormalize:
450
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -454,6 +996,7 @@ def grouped_topk(
454
 
455
  def get_config_dtype_str(
456
  dtype: torch.dtype,
 
457
  use_int8_w8a16: Optional[bool] = False,
458
  use_fp8_w8a8: Optional[bool] = False,
459
  ):
@@ -461,6 +1004,8 @@ def get_config_dtype_str(
461
  return "fp8_w8a8"
462
  elif use_int8_w8a16:
463
  return "int8_w8a16"
 
 
464
  elif dtype == torch.float:
465
  # avoiding cases where kernel fails when float32 MoE
466
  # use fp16/bfloat16 configs
@@ -468,6 +1013,80 @@ def get_config_dtype_str(
468
  return None
469
 
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  def fused_experts(
472
  hidden_states: torch.Tensor,
473
  w1: torch.Tensor,
@@ -475,16 +1094,80 @@ def fused_experts(
475
  topk_weights: torch.Tensor,
476
  topk_ids: torch.Tensor,
477
  inplace: bool = False,
478
- override_config: Optional[Dict[str, Any]] = None,
479
  use_fp8_w8a8: bool = False,
480
  use_int8_w8a16: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  w1_scale: Optional[torch.Tensor] = None,
482
  w2_scale: Optional[torch.Tensor] = None,
 
 
483
  a1_scale: Optional[torch.Tensor] = None,
484
  a2_scale: Optional[torch.Tensor] = None,
 
485
  ):
486
  # Check constraints.
487
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
 
 
 
 
488
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
489
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
490
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -500,6 +1183,7 @@ def fused_experts(
500
  config_dtype = get_config_dtype_str(
501
  use_fp8_w8a8=use_fp8_w8a8,
502
  use_int8_w8a16=use_int8_w8a16,
 
503
  dtype=hidden_states.dtype,
504
  )
505
 
@@ -509,7 +1193,7 @@ def fused_experts(
509
  w2.shape,
510
  topk_ids.shape[1],
511
  config_dtype,
512
- override_config=override_config,
513
  )
514
 
515
  config = get_config_func(M)
@@ -530,7 +1214,14 @@ def fused_experts(
530
  dtype=hidden_states.dtype,
531
  )
532
 
533
- compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
 
 
 
 
 
 
 
534
 
535
  if inplace:
536
  out_hidden_states = hidden_states
@@ -571,6 +1262,7 @@ def fused_experts(
571
  intermediate_cache1,
572
  a1_scale,
573
  w1_scale,
 
574
  curr_topk_weights,
575
  curr_topk_ids,
576
  sorted_token_ids,
@@ -582,6 +1274,8 @@ def fused_experts(
582
  compute_type=compute_type,
583
  use_fp8_w8a8=use_fp8_w8a8,
584
  use_int8_w8a16=use_int8_w8a16,
 
 
585
  )
586
 
587
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
@@ -592,6 +1286,7 @@ def fused_experts(
592
  intermediate_cache3,
593
  a2_scale,
594
  w2_scale,
 
595
  curr_topk_weights,
596
  curr_topk_ids,
597
  sorted_token_ids,
@@ -603,6 +1298,8 @@ def fused_experts(
603
  compute_type=compute_type,
604
  use_fp8_w8a8=use_fp8_w8a8,
605
  use_int8_w8a16=use_int8_w8a16,
 
 
606
  )
607
 
608
  ops.moe_sum(
@@ -620,17 +1317,20 @@ def fused_moe(
620
  topk: int,
621
  renormalize: bool,
622
  inplace: bool = False,
623
- override_config: Optional[Dict[str, Any]] = None,
624
  use_grouped_topk: bool = False,
625
  num_expert_group: Optional[int] = None,
626
  topk_group: Optional[int] = None,
627
  custom_routing_function: Optional[Callable] = None,
628
  use_fp8_w8a8: bool = False,
629
  use_int8_w8a16: bool = False,
 
630
  w1_scale: Optional[torch.Tensor] = None,
631
  w2_scale: Optional[torch.Tensor] = None,
 
 
632
  a1_scale: Optional[torch.Tensor] = None,
633
  a2_scale: Optional[torch.Tensor] = None,
 
634
  ) -> torch.Tensor:
635
  """
636
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -646,20 +1346,28 @@ def fused_moe(
646
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
647
  - inplace (bool): If True, perform the operation in-place.
648
  Defaults to False.
649
- - override_config (Optional[Dict[str, Any]]): Optional override
650
- for the kernel configuration.
651
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
652
  - topk_group: Optional[int]: additional parameter for grouped_topk
653
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
654
  note: Deepseekv2 model uses grouped_topk
655
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
656
  products for w1 and w2. Defaults to False.
657
- - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
658
- products for w1 and w2. Defaults to False.
 
 
 
 
659
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
660
  w1.
661
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
662
  w2.
 
 
 
 
 
 
663
 
664
  Returns:
665
  - torch.Tensor: The output tensor after applying the MoE layer.
@@ -693,11 +1401,14 @@ def fused_moe(
693
  topk_weights,
694
  topk_ids,
695
  inplace=inplace,
696
- override_config=override_config,
697
  use_fp8_w8a8=use_fp8_w8a8,
698
  use_int8_w8a16=use_int8_w8a16,
 
699
  w1_scale=w1_scale,
700
  w2_scale=w2_scale,
 
 
701
  a1_scale=a1_scale,
702
  a2_scale=a2_scale,
 
703
  )
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
  """Fused MoE kernel."""
3
 
4
  import functools
5
  import json
6
+ import logging
7
  import os
8
+ from typing import Any, Callable, Dict, List, Optional, Tuple
9
 
10
  import torch
11
  import triton
12
  import triton.language as tl
13
 
14
+
15
  from ._ops import ops
16
+ from .fp8 import per_token_group_quant_fp8, scaled_fp8_quant
17
  from .platforms import current_platform
18
 
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
  VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768"))
23
 
24
 
25
+ @triton.jit
26
+ def fused_moe_kernel_gptq_awq(
27
+ # Pointers to matrices
28
+ a_ptr,
29
+ b_ptr,
30
+ c_ptr,
31
+ b_scale_ptr,
32
+ b_zp_ptr,
33
+ topk_weights_ptr,
34
+ sorted_token_ids_ptr,
35
+ expert_ids_ptr,
36
+ num_tokens_post_padded_ptr,
37
+ # Matrix dimensions
38
+ N: tl.constexpr,
39
+ K: tl.constexpr,
40
+ EM,
41
+ num_valid_tokens,
42
+ # The stride variables represent how much to increase the ptr by when
43
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
44
+ # how much to increase `a_ptr` by to get the element one row down
45
+ # (A has M rows).
46
+ stride_am,
47
+ stride_ak,
48
+ stride_be,
49
+ stride_bk,
50
+ stride_bn,
51
+ stride_cm,
52
+ stride_cn,
53
+ stride_bse,
54
+ stride_bsk,
55
+ stride_bsn,
56
+ stride_bze,
57
+ stride_bzk,
58
+ stride_bzn,
59
+ block_k_diviable: tl.constexpr,
60
+ group_size: tl.constexpr,
61
+ # Meta-parameters
62
+ BLOCK_SIZE_M: tl.constexpr,
63
+ BLOCK_SIZE_N: tl.constexpr,
64
+ BLOCK_SIZE_K: tl.constexpr,
65
+ GROUP_SIZE_M: tl.constexpr,
66
+ MUL_ROUTED_WEIGHT: tl.constexpr,
67
+ top_k: tl.constexpr,
68
+ compute_type: tl.constexpr,
69
+ has_zp: tl.constexpr,
70
+ use_int4_w4a16: tl.constexpr,
71
+ use_int8_w8a16: tl.constexpr,
72
+ ):
73
+ """
74
+ Implements the fused computation for a Mixture of Experts (MOE) using
75
+ token and expert matrices.
76
+
77
+ Key Parameters:
78
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
79
+ be any shape representing batches and K is the feature dimension of
80
+ each token.
81
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
82
+ the number of experts, K is the input feature dimension, and N is
83
+ the output feature dimension.
84
+ - C: The output cache tensor with shape (M, topk, N), where M is the
85
+ total number of tokens post padding, topk is the number of times
86
+ each token is repeated, and N is the output feature dimension.
87
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
88
+ repeated topk times and arranged by the expert index they are
89
+ assigned to.
90
+ - expert_ids: A tensor containing the indices of the expert for each
91
+ block. It determines which expert matrix from B should be used for
92
+ each block in A.
93
+ This kernel performs the multiplication of a token by its corresponding
94
+ expert matrix as determined by `expert_ids`. The sorting of
95
+ `sorted_token_ids` by expert index and padding ensures divisibility by
96
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
97
+ multiplication across different blocks processed by the same expert.
98
+ """
99
+ # -----------------------------------------------------------
100
+ # Map program ids `pid` to the block of C it should compute.
101
+ # This is done in a grouped ordering to promote L2 data reuse.
102
+ pid = tl.program_id(axis=0)
103
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
104
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
105
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
106
+ group_id = pid // num_pid_in_group
107
+ first_pid_m = group_id * GROUP_SIZE_M
108
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
109
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
110
+ pid_n = (pid % num_pid_in_group) // group_size_m
111
+
112
+ # ----------------------------------------------------------
113
+ # Create pointers for the first blocks of A and B.
114
+ # We will advance this pointer as we move in the K direction
115
+ # and accumulate
116
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
117
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
118
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
119
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
120
+ return
121
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
122
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
123
+ token_mask = offs_token < num_valid_tokens
124
+
125
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
126
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
127
+ a_ptrs = a_ptr + (
128
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
129
+ )
130
+
131
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
132
+
133
+ if use_int4_w4a16:
134
+ b_ptrs = (
135
+ b_ptr
136
+ + off_experts * stride_be
137
+ + (offs_k[:, None] // 2) * stride_bk
138
+ + offs_bn[None, :] * stride_bn
139
+ )
140
+ b_shifter = (offs_k[:, None] % 2) * 4
141
+ elif use_int8_w8a16:
142
+ b_ptrs = (
143
+ b_ptr
144
+ + off_experts * stride_be
145
+ + offs_k[:, None] * stride_bk
146
+ + offs_bn[None, :] * stride_bn
147
+ )
148
+
149
+ if not has_zp and use_int4_w4a16:
150
+ b_zp_num = 8
151
+ if not has_zp and use_int8_w8a16:
152
+ b_zp_num = 128
153
+ elif has_zp and use_int4_w4a16:
154
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
155
+
156
+ # -----------------------------------------------------------
157
+ # Iterate to compute a block of the C matrix.
158
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
159
+ # of fp32 values for higher accuracy.
160
+ # `accumulator` will be converted back to fp16 after the loop.
161
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
162
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
163
+ # Load the next block of A and B, generate a mask by checking the
164
+ # K dimension.
165
+
166
+ if not block_k_diviable:
167
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
168
+ k_other = 0.0
169
+ else:
170
+ k_mask = None
171
+ k_other = None
172
+
173
+ a = tl.load(
174
+ a_ptrs,
175
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
176
+ other=0.0,
177
+ )
178
+ b = tl.load(b_ptrs)
179
+ if use_int4_w4a16:
180
+ b = (b >> b_shifter) & 0xF
181
+
182
+ b_scale_ptrs = (
183
+ b_scale_ptr
184
+ + off_experts * stride_bse
185
+ + offs_bn[None, :] * stride_bsn
186
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
187
+ )
188
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
189
+ b_scale = b_scale.to(tl.float32)
190
+
191
+ if has_zp and use_int4_w4a16:
192
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
193
+ b_zp_ptrs = (
194
+ b_zp_ptr
195
+ + off_experts * stride_bze
196
+ + (offs_bn[None, :] // 2) * stride_bzn
197
+ + offs_k_true * stride_bzk
198
+ )
199
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
200
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
201
+ b_zp = b_zp.to(tl.float32)
202
+ elif has_zp and use_int8_w8a16:
203
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
204
+ b_zp_ptrs = (
205
+ b_zp_ptr
206
+ + off_experts * stride_bze
207
+ + offs_bn[None, :] * stride_bzn
208
+ + offs_k_true * stride_bzk
209
+ )
210
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
211
+ b_zp = b_zp.to(tl.float32)
212
+
213
+ # We accumulate along the K dimension.
214
+ if has_zp:
215
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
216
+ else:
217
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
218
+ accumulator = tl.dot(a, b, acc=accumulator)
219
+
220
+ # Advance the ptrs to the next K block.
221
+ a_ptrs += BLOCK_SIZE_K * stride_ak
222
+ if use_int4_w4a16:
223
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
224
+ else:
225
+ b_ptrs += BLOCK_SIZE_K * stride_bk
226
+
227
+ if MUL_ROUTED_WEIGHT:
228
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
229
+ accumulator = accumulator * moe_weight[:, None]
230
+
231
+ accumulator = accumulator.to(compute_type)
232
+ # -----------------------------------------------------------
233
+ # Write back the block of the output
234
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
235
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
236
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
237
+ tl.store(c_ptrs, accumulator, mask=c_mask)
238
+
239
+
240
  @triton.jit
241
  def fused_moe_kernel(
242
  # Pointers to matrices
 
265
  stride_bn,
266
  stride_cm,
267
  stride_cn,
268
+ stride_asm,
269
+ stride_ask,
270
  stride_bse,
271
+ stride_bsk,
272
  stride_bsn,
273
+ # Block size for block-wise quantization
274
+ group_n: tl.constexpr,
275
+ group_k: tl.constexpr,
276
  # Meta-parameters
277
  BLOCK_SIZE_M: tl.constexpr,
278
  BLOCK_SIZE_N: tl.constexpr,
 
332
  num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
333
  if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
334
  return
335
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
336
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
337
  token_mask = offs_token < num_valid_tokens
338
 
339
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
340
  offs_k = tl.arange(0, BLOCK_SIZE_K)
341
  a_ptrs = a_ptr + (
342
  offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
343
  )
344
 
345
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
346
  b_ptrs = (
347
  b_ptr
348
  + off_experts * stride_be
 
355
  b_scale = tl.load(b_scale_ptrs)
356
 
357
  if use_fp8_w8a8:
358
+ if group_k > 0 and group_n > 0:
359
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
360
+ offs_bsn = offs_bn // group_n
361
+ b_scale_ptrs = (
362
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
363
+ )
364
+ else:
365
+ a_scale = tl.load(a_scale_ptr)
366
+ b_scale = tl.load(b_scale_ptr + off_experts)
367
 
368
  # -----------------------------------------------------------
369
  # Iterate to compute a block of the C matrix.
 
385
  if use_int8_w8a16:
386
  accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
387
  elif use_fp8_w8a8:
388
+ if group_k > 0 and group_n > 0:
389
+ k_start = k * BLOCK_SIZE_K
390
+ offs_ks = k_start // group_k
391
+ a_scale = tl.load(
392
+ a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
393
+ )
394
+ b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
395
+
396
+ accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
397
+ else:
398
+ accumulator = tl.dot(a, b, acc=accumulator)
399
  else:
400
  accumulator += tl.dot(a, b)
401
  # Advance the ptrs to the next K block.
 
408
  if use_int8_w8a16:
409
  accumulator = (accumulator * b_scale).to(compute_type)
410
  elif use_fp8_w8a8:
411
+ if group_k > 0 and group_n > 0:
412
+ accumulator = accumulator.to(compute_type)
413
+ else:
414
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
415
  else:
416
  accumulator = accumulator.to(compute_type)
417
  # -----------------------------------------------------------
 
422
  tl.store(c_ptrs, accumulator, mask=c_mask)
423
 
424
 
425
+ def ceil_div(a, b):
426
+ return (a + b - 1) // b
427
+
428
+
429
+ @triton.jit
430
+ def moe_align_block_size_stage1(
431
+ topk_ids_ptr,
432
+ tokens_cnts_ptr,
433
+ num_experts: tl.constexpr,
434
+ numel: tl.constexpr,
435
+ tokens_per_thread: tl.constexpr,
436
+ ):
437
+ pid = tl.program_id(0)
438
+
439
+ start_idx = pid * tokens_per_thread
440
+
441
+ off_c = (pid + 1) * num_experts
442
+
443
+ for i in range(tokens_per_thread):
444
+ if start_idx + i < numel:
445
+ idx = tl.load(topk_ids_ptr + start_idx + i)
446
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
447
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
448
+
449
+
450
+ @triton.jit
451
+ def moe_align_block_size_stage2(
452
+ tokens_cnts_ptr,
453
+ num_experts: tl.constexpr,
454
+ ):
455
+ pid = tl.program_id(0)
456
+
457
+ last_cnt = 0
458
+ for i in range(1, num_experts + 1):
459
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
460
+ last_cnt = last_cnt + token_cnt
461
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
462
+
463
+
464
+ @triton.jit
465
+ def moe_align_block_size_stage3(
466
+ total_tokens_post_pad_ptr,
467
+ tokens_cnts_ptr,
468
+ cumsum_ptr,
469
+ num_experts: tl.constexpr,
470
+ block_size: tl.constexpr,
471
+ ):
472
+ last_cumsum = 0
473
+ off_cnt = num_experts * num_experts
474
+ for i in range(1, num_experts + 1):
475
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
476
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
477
+ tl.store(cumsum_ptr + i, last_cumsum)
478
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
479
+
480
+
481
+ @triton.jit
482
+ def moe_align_block_size_stage4(
483
+ topk_ids_ptr,
484
+ sorted_token_ids_ptr,
485
+ expert_ids_ptr,
486
+ tokens_cnts_ptr,
487
+ cumsum_ptr,
488
+ num_experts: tl.constexpr,
489
+ block_size: tl.constexpr,
490
+ numel: tl.constexpr,
491
+ tokens_per_thread: tl.constexpr,
492
+ ):
493
+ pid = tl.program_id(0)
494
+ start_idx = tl.load(cumsum_ptr + pid)
495
+ end_idx = tl.load(cumsum_ptr + pid + 1)
496
+
497
+ for i in range(start_idx, end_idx, block_size):
498
+ tl.store(expert_ids_ptr + i // block_size, pid)
499
+
500
+ start_idx = pid * tokens_per_thread
501
+ off_t = pid * num_experts
502
+
503
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
504
+ expert_id = tl.load(topk_ids_ptr + i)
505
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
506
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
507
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
508
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
509
+
510
+
511
+ # Triton implementation based on:
512
+ # https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
513
+ def moe_align_block_size_triton(
514
+ topk_ids: torch.Tensor,
515
+ num_experts: int,
516
+ block_size: int,
517
+ sorted_token_ids: torch.Tensor,
518
+ expert_ids: torch.Tensor,
519
+ num_tokens_post_pad: torch.Tensor,
520
+ ) -> None:
521
+ numel = topk_ids.numel()
522
+ grid = (num_experts,)
523
+ tokens_cnts = torch.zeros(
524
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
525
+ )
526
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
527
+ tokens_per_thread = ceil_div(numel, num_experts)
528
+
529
+ moe_align_block_size_stage1[grid](
530
+ topk_ids,
531
+ tokens_cnts,
532
+ num_experts,
533
+ numel,
534
+ tokens_per_thread,
535
+ )
536
+ moe_align_block_size_stage2[grid](
537
+ tokens_cnts,
538
+ num_experts,
539
+ )
540
+ moe_align_block_size_stage3[(1,)](
541
+ num_tokens_post_pad,
542
+ tokens_cnts,
543
+ cumsum,
544
+ num_experts,
545
+ block_size,
546
+ )
547
+ moe_align_block_size_stage4[grid](
548
+ topk_ids,
549
+ sorted_token_ids,
550
+ expert_ids,
551
+ tokens_cnts,
552
+ cumsum,
553
+ num_experts,
554
+ block_size,
555
+ numel,
556
+ tokens_per_thread,
557
+ )
558
+
559
+
560
  def moe_align_block_size(
561
  topk_ids: torch.Tensor, block_size: int, num_experts: int
562
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 
607
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
608
  )
609
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
610
+ if num_experts >= 224:
611
+ if VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
612
+ moe_align_block_size_triton(
613
+ topk_ids,
614
+ num_experts,
615
+ block_size,
616
+ sorted_ids,
617
+ expert_ids,
618
+ num_tokens_post_pad,
619
+ )
620
+ else:
621
+ ops.sgl_moe_align_block_size(
622
+ topk_ids,
623
+ num_experts,
624
+ block_size,
625
+ sorted_ids,
626
+ expert_ids,
627
+ num_tokens_post_pad,
628
+ )
629
+ else:
630
+ ops.moe_align_block_size(
631
+ topk_ids,
632
+ num_experts,
633
+ block_size,
634
+ sorted_ids,
635
+ expert_ids,
636
+ num_tokens_post_pad,
637
+ )
638
  return sorted_ids, expert_ids, num_tokens_post_pad
639
 
640
 
 
644
  C: torch.Tensor,
645
  A_scale: Optional[torch.Tensor],
646
  B_scale: Optional[torch.Tensor],
647
+ B_zp: Optional[torch.Tensor],
648
  topk_weights: torch.Tensor,
649
  topk_ids: torch.Tensor,
650
  sorted_token_ids: torch.Tensor,
 
656
  compute_type: tl.dtype,
657
  use_fp8_w8a8: bool,
658
  use_int8_w8a16: bool,
659
+ use_int4_w4a16: bool,
660
+ block_shape: Optional[List[int]] = None,
661
  ) -> None:
662
  assert topk_weights.stride(1) == 1
663
  assert sorted_token_ids.stride(0) == 1
664
 
665
  if use_fp8_w8a8:
 
666
  assert B_scale is not None
667
+ if block_shape is None:
668
+ A, A_scale = scaled_fp8_quant(A, A_scale)
669
+ else:
670
+ assert len(block_shape) == 2
671
+ block_n, block_k = block_shape[0], block_shape[1]
672
+ A, A_scale = per_token_group_quant_fp8(A, block_k)
673
+ assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
674
+ assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
675
+ assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
676
+ elif use_int8_w8a16 or use_int4_w4a16:
677
  assert B_scale is not None
678
+ assert block_shape is None or block_shape[0] == 0
679
  else:
680
  assert A_scale is None
681
  assert B_scale is None
682
 
683
+ EM = sorted_token_ids.shape[0]
684
+ if A.shape[0] < config["BLOCK_SIZE_M"]:
685
+ # optimize for small batch_size.
686
+ # We assume that top_ids of each token is unique, so
687
+ # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
688
+ # and we can skip some invalid blocks.
689
+ EM = min(sorted_token_ids.shape[0], A.shape[0] * top_k * config["BLOCK_SIZE_M"])
690
  grid = lambda META: (
691
+ triton.cdiv(EM, META["BLOCK_SIZE_M"])
692
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
693
  )
694
 
695
+ if (
696
+ (use_int8_w8a16 or use_int4_w4a16)
697
+ and block_shape is not None
698
+ and block_shape[1] > 0
699
+ ):
700
+ assert B_scale is not None and B_scale.ndim == 3
701
+ assert B_zp is None or B_zp.ndim == 3
702
+
703
+ fused_moe_kernel_gptq_awq[grid](
704
+ A,
705
+ B,
706
+ C,
707
+ B_scale,
708
+ B_zp,
709
+ topk_weights,
710
+ sorted_token_ids,
711
+ expert_ids,
712
+ num_tokens_post_padded,
713
+ B.shape[1],
714
+ A.shape[1],
715
+ EM,
716
+ topk_ids.numel(),
717
+ A.stride(0),
718
+ A.stride(1),
719
+ B.stride(0),
720
+ B.stride(2),
721
+ B.stride(1),
722
+ C.stride(1),
723
+ C.stride(2),
724
+ B_scale.stride(0),
725
+ B_scale.stride(2),
726
+ B_scale.stride(1),
727
+ B_zp.stride(0) if B_zp is not None else 0,
728
+ B_zp.stride(2) if B_zp is not None else 0,
729
+ B_zp.stride(1) if B_zp is not None else 0,
730
+ block_k_diviable=A.shape[1] % config["BLOCK_SIZE_K"] == 0,
731
+ group_size=block_shape[1],
732
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
733
+ top_k=top_k,
734
+ compute_type=compute_type,
735
+ has_zp=B_zp is not None,
736
+ use_int4_w4a16=use_int4_w4a16,
737
+ use_int8_w8a16=use_int8_w8a16,
738
+ **config,
739
+ )
740
+
741
+ else:
742
+ fused_moe_kernel[grid](
743
+ A,
744
+ B,
745
+ C,
746
+ A_scale,
747
+ B_scale,
748
+ topk_weights,
749
+ sorted_token_ids,
750
+ expert_ids,
751
+ num_tokens_post_padded,
752
+ B.shape[1],
753
+ A.shape[1],
754
+ EM,
755
+ topk_ids.numel(),
756
+ A.stride(0),
757
+ A.stride(1),
758
+ B.stride(0),
759
+ B.stride(2),
760
+ B.stride(1),
761
+ C.stride(1),
762
+ C.stride(2),
763
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
764
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
765
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
766
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
767
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
768
+ 0 if block_shape is None else block_shape[0],
769
+ 0 if block_shape is None else block_shape[1],
770
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
771
+ top_k=top_k,
772
+ compute_type=compute_type,
773
+ use_fp8_w8a8=use_fp8_w8a8,
774
+ use_int8_w8a16=use_int8_w8a16,
775
+ **config,
776
+ )
777
 
778
 
779
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
780
+ def get_config_file_name(
781
+ E: int, N: int, dtype: Optional[str], block_shape: Optional[List[int]] = None
782
+ ) -> str:
783
  device_name = current_platform.get_device_name().replace(" ", "_")
784
  dtype_selector = "" if not dtype else f",dtype={dtype}"
785
+ block_shape_selector = (
786
+ "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
787
+ )
788
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
789
 
790
 
791
+ # Adapted from: https://github.com/sgl-project/sglang/pull/2628
792
  @functools.lru_cache
793
+ def get_moe_configs(
794
+ E: int,
795
+ N: int,
796
+ dtype: Optional[str],
797
+ block_n: Optional[int] = None,
798
+ block_k: Optional[int] = None,
799
+ ) -> Optional[Dict[int, Any]]:
800
  """
801
  Return optimized configurations for the fused MoE kernel.
802
 
 
808
 
809
  # First look up if an optimized configuration is available in the configs
810
  # directory
811
+ block_shape = [block_n, block_k] if block_n and block_k else None
812
+ json_file_name = get_config_file_name(E, N, dtype, block_shape)
813
 
814
  config_file_path = os.path.join(
815
  os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
816
  )
817
  if os.path.exists(config_file_path):
818
  with open(config_file_path) as f:
819
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
820
  # If a configuration has been found, return it
821
  return {int(key): val for key, val in json.load(f).items()}
822
 
823
  # If no optimized configuration is available, we will use the default
824
  # configuration
825
+ logger.warning(
826
+ (
827
+ "Using default MoE config. Performance might be sub-optimal! "
828
+ "Config file not found at %s"
829
+ ),
830
+ config_file_path,
831
+ )
832
  return None
833
 
834
 
 
840
  topk: int,
841
  dtype: Optional[str],
842
  is_marlin: bool,
843
+ block_shape: Optional[List[int]] = None,
844
  ) -> Dict[str, int]:
845
+ if dtype == "fp8_w8a8" and block_shape is not None:
846
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
847
+ # BLOCK_SIZE_K must be divisible by block_shape[1]
 
 
 
 
 
848
  config = {
849
+ "BLOCK_SIZE_M": 64,
850
+ "BLOCK_SIZE_N": block_shape[0],
851
+ "BLOCK_SIZE_K": block_shape[1],
852
+ "GROUP_SIZE_M": 32,
853
+ "num_warps": 4,
854
+ "num_stages": 3,
855
  }
856
+ else:
857
+ config = {
858
+ "BLOCK_SIZE_M": 64,
859
+ "BLOCK_SIZE_N": 64,
860
+ "BLOCK_SIZE_K": 32,
861
+ "GROUP_SIZE_M": 8,
862
+ }
863
+ # A heuristic: fused marlin works faster with this config for small M
864
+ if M <= E or (is_marlin and M <= 32):
865
+ config = {
866
+ "BLOCK_SIZE_M": 16,
867
+ "BLOCK_SIZE_N": 32,
868
+ "BLOCK_SIZE_K": 64,
869
+ "GROUP_SIZE_M": 1,
870
+ }
871
  return config
872
 
873
 
 
877
  top_k: int,
878
  dtype: Optional[str],
879
  M: int,
 
880
  is_marlin: bool = False,
881
+ block_shape: Optional[List[int]] = None,
882
  ):
883
+ # from vllm.model_executor.layers.fused_moe import get_config
884
+ # TODO: removed when syncing to vLLM, do we need this?
885
+ # override_config = get_config()
886
+ override_config = None
887
  if override_config:
888
  config = override_config
889
  else:
890
  # First try to load optimal config from the file
891
  E, _, N = w2_shape
892
+ block_n = block_shape[0] if block_shape else 0
893
+ block_k = block_shape[1] if block_shape else 0
894
+ configs = get_moe_configs(E, N, dtype, block_n, block_k)
895
 
896
  if configs:
897
  # If an optimal configuration map has been found, look up the
 
899
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
900
  else:
901
  # Else use the default config
902
+ config = get_default_config(
903
+ M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
904
+ )
905
  return config
906
 
907
 
 
937
  return topk_weights, topk_ids
938
 
939
 
940
+ # This is used by the Deepseek-V2 and Deepseek-V3 model
941
+ @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
942
  def grouped_topk(
943
  hidden_states: torch.Tensor,
944
  gating_output: torch.Tensor,
 
946
  renormalize: bool,
947
  num_expert_group: int = 0,
948
  topk_group: int = 0,
949
+ scoring_func: str = "softmax",
950
+ e_score_correction_bias: Optional[torch.Tensor] = None,
951
  ):
952
 
953
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
954
 
955
+ if scoring_func == "softmax":
956
+ scores = torch.softmax(gating_output, dim=-1)
957
+ elif scoring_func == "sigmoid":
958
+ scores = gating_output.sigmoid()
959
+ else:
960
+ raise ValueError(f"Unsupported scoring function: {scoring_func}")
961
+
962
+ if e_score_correction_bias is not None:
963
+ # Store original scores before applying correction bias. We use biased
964
+ # scores for expert selection but original scores for routing weights
965
+ original_scores = scores
966
+ scores = scores + e_score_correction_bias.unsqueeze(0)
967
+
968
  num_token = scores.shape[0]
969
  group_scores = (
970
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
 
980
  .reshape(num_token, -1)
981
  ) # [n, e]
982
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
983
+
984
+ if e_score_correction_bias is not None:
985
+ topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
986
+ # Use original unbiased scores for the routing weights
987
+ topk_weights = original_scores.gather(1, topk_ids)
988
+ else:
989
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
990
 
991
  if renormalize:
992
  topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
996
 
997
  def get_config_dtype_str(
998
  dtype: torch.dtype,
999
+ use_int4_w4a16: Optional[bool] = False,
1000
  use_int8_w8a16: Optional[bool] = False,
1001
  use_fp8_w8a8: Optional[bool] = False,
1002
  ):
 
1004
  return "fp8_w8a8"
1005
  elif use_int8_w8a16:
1006
  return "int8_w8a16"
1007
+ elif use_int4_w4a16:
1008
+ return "int4_w8a16"
1009
  elif dtype == torch.float:
1010
  # avoiding cases where kernel fails when float32 MoE
1011
  # use fp16/bfloat16 configs
 
1013
  return None
1014
 
1015
 
1016
+ def inplace_fused_experts(
1017
+ hidden_states: torch.Tensor,
1018
+ w1: torch.Tensor,
1019
+ w2: torch.Tensor,
1020
+ topk_weights: torch.Tensor,
1021
+ topk_ids: torch.Tensor,
1022
+ use_fp8_w8a8: bool = False,
1023
+ use_int8_w8a16: bool = False,
1024
+ use_int4_w4a16: bool = False,
1025
+ w1_scale: Optional[torch.Tensor] = None,
1026
+ w2_scale: Optional[torch.Tensor] = None,
1027
+ w1_zp: Optional[torch.Tensor] = None,
1028
+ w2_zp: Optional[torch.Tensor] = None,
1029
+ a1_scale: Optional[torch.Tensor] = None,
1030
+ a2_scale: Optional[torch.Tensor] = None,
1031
+ block_shape: Optional[List[int]] = None,
1032
+ ) -> None:
1033
+ fused_experts_impl(
1034
+ hidden_states,
1035
+ w1,
1036
+ w2,
1037
+ topk_weights,
1038
+ topk_ids,
1039
+ True,
1040
+ use_fp8_w8a8,
1041
+ use_int8_w8a16,
1042
+ use_int4_w4a16,
1043
+ w1_scale,
1044
+ w2_scale,
1045
+ w1_zp,
1046
+ w2_zp,
1047
+ a1_scale,
1048
+ a2_scale,
1049
+ block_shape,
1050
+ )
1051
+
1052
+
1053
+ def outplace_fused_experts(
1054
+ hidden_states: torch.Tensor,
1055
+ w1: torch.Tensor,
1056
+ w2: torch.Tensor,
1057
+ topk_weights: torch.Tensor,
1058
+ topk_ids: torch.Tensor,
1059
+ use_fp8_w8a8: bool = False,
1060
+ use_int8_w8a16: bool = False,
1061
+ use_int4_w4a16: bool = False,
1062
+ w1_scale: Optional[torch.Tensor] = None,
1063
+ w2_scale: Optional[torch.Tensor] = None,
1064
+ w1_zp: Optional[torch.Tensor] = None,
1065
+ w2_zp: Optional[torch.Tensor] = None,
1066
+ a1_scale: Optional[torch.Tensor] = None,
1067
+ a2_scale: Optional[torch.Tensor] = None,
1068
+ block_shape: Optional[List[int]] = None,
1069
+ ) -> torch.Tensor:
1070
+ return fused_experts_impl(
1071
+ hidden_states,
1072
+ w1,
1073
+ w2,
1074
+ topk_weights,
1075
+ topk_ids,
1076
+ False,
1077
+ use_fp8_w8a8,
1078
+ use_int8_w8a16,
1079
+ use_int4_w4a16,
1080
+ w1_scale,
1081
+ w2_scale,
1082
+ w1_zp,
1083
+ w2_zp,
1084
+ a1_scale,
1085
+ a2_scale,
1086
+ block_shape,
1087
+ )
1088
+
1089
+
1090
  def fused_experts(
1091
  hidden_states: torch.Tensor,
1092
  w1: torch.Tensor,
 
1094
  topk_weights: torch.Tensor,
1095
  topk_ids: torch.Tensor,
1096
  inplace: bool = False,
 
1097
  use_fp8_w8a8: bool = False,
1098
  use_int8_w8a16: bool = False,
1099
+ use_int4_w4a16: bool = False,
1100
+ w1_scale: Optional[torch.Tensor] = None,
1101
+ w2_scale: Optional[torch.Tensor] = None,
1102
+ w1_zp: Optional[torch.Tensor] = None,
1103
+ w2_zp: Optional[torch.Tensor] = None,
1104
+ a1_scale: Optional[torch.Tensor] = None,
1105
+ a2_scale: Optional[torch.Tensor] = None,
1106
+ block_shape: Optional[List[int]] = None,
1107
+ ):
1108
+ if inplace:
1109
+ inplace_fused_experts(
1110
+ hidden_states,
1111
+ w1,
1112
+ w2,
1113
+ topk_weights,
1114
+ topk_ids,
1115
+ use_fp8_w8a8,
1116
+ use_int8_w8a16,
1117
+ use_int4_w4a16,
1118
+ w1_scale,
1119
+ w2_scale,
1120
+ w1_zp,
1121
+ w2_zp,
1122
+ a1_scale,
1123
+ a2_scale,
1124
+ block_shape,
1125
+ )
1126
+ return hidden_states
1127
+ else:
1128
+ return outplace_fused_experts(
1129
+ hidden_states,
1130
+ w1,
1131
+ w2,
1132
+ topk_weights,
1133
+ topk_ids,
1134
+ use_fp8_w8a8,
1135
+ use_int8_w8a16,
1136
+ use_int4_w4a16,
1137
+ w1_scale,
1138
+ w2_scale,
1139
+ w1_zp,
1140
+ w2_zp,
1141
+ a1_scale,
1142
+ a2_scale,
1143
+ block_shape,
1144
+ )
1145
+
1146
+
1147
+ def fused_experts_impl(
1148
+ hidden_states: torch.Tensor,
1149
+ w1: torch.Tensor,
1150
+ w2: torch.Tensor,
1151
+ topk_weights: torch.Tensor,
1152
+ topk_ids: torch.Tensor,
1153
+ inplace: bool = False,
1154
+ use_fp8_w8a8: bool = False,
1155
+ use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
1157
  w1_scale: Optional[torch.Tensor] = None,
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
1161
  a1_scale: Optional[torch.Tensor] = None,
1162
  a2_scale: Optional[torch.Tensor] = None,
1163
+ block_shape: Optional[List[int]] = None,
1164
  ):
1165
  # Check constraints.
1166
+ if use_int4_w4a16:
1167
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1168
+ else:
1169
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
1170
+
1171
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
1172
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
1173
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
 
1183
  config_dtype = get_config_dtype_str(
1184
  use_fp8_w8a8=use_fp8_w8a8,
1185
  use_int8_w8a16=use_int8_w8a16,
1186
+ use_int4_w4a16=use_int4_w4a16,
1187
  dtype=hidden_states.dtype,
1188
  )
1189
 
 
1193
  w2.shape,
1194
  topk_ids.shape[1],
1195
  config_dtype,
1196
+ block_shape=block_shape,
1197
  )
1198
 
1199
  config = get_config_func(M)
 
1214
  dtype=hidden_states.dtype,
1215
  )
1216
 
1217
+ if hidden_states.dtype == torch.bfloat16:
1218
+ compute_type = tl.bfloat16
1219
+ elif hidden_states.dtype == torch.float16:
1220
+ compute_type = tl.float16
1221
+ elif hidden_states.dtype == torch.float32:
1222
+ compute_type = tl.float32
1223
+ else:
1224
+ raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
1225
 
1226
  if inplace:
1227
  out_hidden_states = hidden_states
 
1262
  intermediate_cache1,
1263
  a1_scale,
1264
  w1_scale,
1265
+ w1_zp,
1266
  curr_topk_weights,
1267
  curr_topk_ids,
1268
  sorted_token_ids,
 
1274
  compute_type=compute_type,
1275
  use_fp8_w8a8=use_fp8_w8a8,
1276
  use_int8_w8a16=use_int8_w8a16,
1277
+ use_int4_w4a16=use_int4_w4a16,
1278
+ block_shape=block_shape,
1279
  )
1280
 
1281
  ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
 
1286
  intermediate_cache3,
1287
  a2_scale,
1288
  w2_scale,
1289
+ w2_zp,
1290
  curr_topk_weights,
1291
  curr_topk_ids,
1292
  sorted_token_ids,
 
1298
  compute_type=compute_type,
1299
  use_fp8_w8a8=use_fp8_w8a8,
1300
  use_int8_w8a16=use_int8_w8a16,
1301
+ use_int4_w4a16=use_int4_w4a16,
1302
+ block_shape=block_shape,
1303
  )
1304
 
1305
  ops.moe_sum(
 
1317
  topk: int,
1318
  renormalize: bool,
1319
  inplace: bool = False,
 
1320
  use_grouped_topk: bool = False,
1321
  num_expert_group: Optional[int] = None,
1322
  topk_group: Optional[int] = None,
1323
  custom_routing_function: Optional[Callable] = None,
1324
  use_fp8_w8a8: bool = False,
1325
  use_int8_w8a16: bool = False,
1326
+ use_int4_w4a16: bool = False,
1327
  w1_scale: Optional[torch.Tensor] = None,
1328
  w2_scale: Optional[torch.Tensor] = None,
1329
+ w1_zp: Optional[torch.Tensor] = None,
1330
+ w2_zp: Optional[torch.Tensor] = None,
1331
  a1_scale: Optional[torch.Tensor] = None,
1332
  a2_scale: Optional[torch.Tensor] = None,
1333
+ block_shape: Optional[List[int]] = None,
1334
  ) -> torch.Tensor:
1335
  """
1336
  This function computes a Mixture of Experts (MoE) layer using two sets of
 
1346
  - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
1347
  - inplace (bool): If True, perform the operation in-place.
1348
  Defaults to False.
 
 
1349
  - num_expert_group: Optional[int]: additional parameter for grouped_topk
1350
  - topk_group: Optional[int]: additional parameter for grouped_topk
1351
  - use_grouped_topk: If True, use grouped_topk instead of fused_topk
1352
  note: Deepseekv2 model uses grouped_topk
1353
  - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
1354
  products for w1 and w2. Defaults to False.
1355
+ - use_int8_w8a16 (bool): If True, use matmul of int8 weight and bf16/fp16
1356
+ activation to compute the inner products for w1 and w2.
1357
+ Defaults to False.
1358
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1359
+ activation to compute the inner products for w1 and w2.
1360
+ Defaults to False.
1361
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1362
  w1.
1363
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
1364
  w2.
1365
+ - a1_scale (Optional[torch.Tensor]): Optional scale to be used for
1366
+ a1.
1367
+ - a2_scale (Optional[torch.Tensor]): Optional scale to be used for
1368
+ a2.
1369
+ - block_shape: (Optional[List[int]]): Optional block size for block-wise
1370
+ quantization.
1371
 
1372
  Returns:
1373
  - torch.Tensor: The output tensor after applying the MoE layer.
 
1401
  topk_weights,
1402
  topk_ids,
1403
  inplace=inplace,
 
1404
  use_fp8_w8a8=use_fp8_w8a8,
1405
  use_int8_w8a16=use_int8_w8a16,
1406
+ use_int4_w4a16=use_int4_w4a16,
1407
  w1_scale=w1_scale,
1408
  w2_scale=w2_scale,
1409
+ w1_zp=w1_zp,
1410
+ w2_zp=w2_zp,
1411
  a1_scale=a1_scale,
1412
  a2_scale=a2_scale,
1413
+ block_shape=block_shape,
1414
  )
ext-torch/moe/platforms.py CHANGED
@@ -1,22 +1,32 @@
1
- from typing import Callable, ParamSpec, TypeVar
2
- import os
3
- from functools import lru_cache, wraps
4
 
5
  import torch
6
 
7
  IS_ROCM = torch.version.hip is not None
8
 
9
- class CudaPlatform:
 
 
 
 
 
10
  @classmethod
11
  @lru_cache(maxsize=8)
12
  def get_device_name(cls, device_id: int = 0) -> str:
13
  return torch.cuda.get_device_name(0)
14
 
15
- class RocmPlatform:
 
 
 
 
16
  @classmethod
17
  @lru_cache(maxsize=8)
18
  def get_device_name(cls, device_id: int = 0) -> str:
19
  return torch.cuda.get_device_name(device_id)
20
 
 
 
 
21
 
22
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
 
1
+ from functools import lru_cache
 
 
2
 
3
  import torch
4
 
5
  IS_ROCM = torch.version.hip is not None
6
 
7
+
8
+ class Platform:
9
+ simple_compile_backend: str = "inductor"
10
+
11
+
12
+ class CudaPlatform(Platform):
13
  @classmethod
14
  @lru_cache(maxsize=8)
15
  def get_device_name(cls, device_id: int = 0) -> str:
16
  return torch.cuda.get_device_name(0)
17
 
18
+ def is_rocm(self):
19
+ return False
20
+
21
+
22
+ class RocmPlatform(Platform):
23
  @classmethod
24
  @lru_cache(maxsize=8)
25
  def get_device_name(cls, device_id: int = 0) -> str:
26
  return torch.cuda.get_device_name(device_id)
27
 
28
+ def is_rocm(self):
29
+ return True
30
+
31
 
32
  current_platform = RocmPlatform() if IS_ROCM else CudaPlatform()
ext-torch/torch_binding.cpp CHANGED
@@ -26,6 +26,34 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
26
  " Tensor! num_tokens_post_pad) -> ()");
27
  ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  #ifndef USE_ROCM
30
  ops.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
31
  "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
 
26
  " Tensor! num_tokens_post_pad) -> ()");
27
  ops.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
28
 
29
+ // temporarily adapted from
30
+ // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
31
+ ops.def("sgl_moe_align_block_size(Tensor topk_ids, int num_experts,"
32
+ " int block_size, Tensor! sorted_token_ids,"
33
+ " Tensor! experts_ids,"
34
+ " Tensor! num_tokens_post_pad) -> ()");
35
+ ops.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size);
36
+
37
+ // Compute FP8 quantized tensor for given scaling factor.
38
+ ops.def(
39
+ "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
40
+ "()");
41
+ ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
42
+
43
+ // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
44
+ ops.def(
45
+ "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
46
+ "-> "
47
+ "()");
48
+ ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
49
+
50
+ // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
51
+ ops.def("dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
52
+ "Tensor! scale, Tensor? scale_ub) -> "
53
+ "()");
54
+ ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
55
+ &dynamic_per_token_scaled_fp8_quant);
56
+
57
  #ifndef USE_ROCM
58
  ops.def("marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
59
  "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
ext-torch/torch_binding.h CHANGED
@@ -17,6 +17,22 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
17
  torch::Tensor experts_ids,
18
  torch::Tensor num_tokens_post_pad);
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  #ifndef USE_ROCM
21
  torch::Tensor marlin_gemm_moe(
22
  const torch::Tensor &a, const torch::Tensor &b_q_weights,
 
17
  torch::Tensor experts_ids,
18
  torch::Tensor num_tokens_post_pad);
19
 
20
+ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
21
+ int64_t block_size,
22
+ torch::Tensor sorted_token_ids,
23
+ torch::Tensor experts_ids,
24
+ torch::Tensor num_tokens_post_pad);
25
+
26
+ void static_scaled_fp8_quant(torch::Tensor &out, torch::Tensor const &input,
27
+ torch::Tensor const &scale);
28
+
29
+ void dynamic_scaled_fp8_quant(torch::Tensor &out, torch::Tensor const &input,
30
+ torch::Tensor &scale);
31
+
32
+ void dynamic_per_token_scaled_fp8_quant(
33
+ torch::Tensor &out, torch::Tensor const &input, torch::Tensor &scale,
34
+ std::optional<torch::Tensor> const &scale_ub);
35
+
36
  #ifndef USE_ROCM
37
  torch::Tensor marlin_gemm_moe(
38
  const torch::Tensor &a, const torch::Tensor &b_q_weights,
fp8/amd/hip_float8.h ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef __HIPCC__
4
+ #include <hip/hip_runtime.h>
5
+ #else
6
+ #include <type_traits>
7
+ #include <stdint.h>
8
+ #include <math.h>
9
+ #include <iostream>
10
+ #endif
11
+
12
+ #include "hip_float8_impl.h"
13
+
14
+ struct alignas(1) hip_fp8 {
15
+ struct from_bits_t {};
16
+ HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() {
17
+ return from_bits_t();
18
+ }
19
+ uint8_t data;
20
+
21
+ hip_fp8() = default;
22
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default;
23
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete;
24
+ explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t)
25
+ : data(v) {}
26
+
27
+ #ifdef __HIP__MI300__
28
+ // NOTE: ON-DEVICE... always optimal bias
29
+ explicit HIP_FP8_DEVICE hip_fp8(float v)
30
+ : data(hip_fp8_impl::to_fp8_from_fp32(v)) {}
31
+
32
+ explicit HIP_FP8_DEVICE hip_fp8(_Float16 v)
33
+ : hip_fp8(static_cast<float>(v)) {}
34
+
35
+ // Host only implementation using s/w simulation
36
+ explicit HIP_FP8_HOST
37
+ #else // __HIP__MI300__
38
+ // both Host and DEVICE for non-MI300 using s/w simulation
39
+ explicit HIP_FP8_HOST_DEVICE
40
+ #endif // __HIP__MI300__
41
+ hip_fp8(float v) {
42
+ data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/,
43
+ true /*clip*/>(v);
44
+ }
45
+
46
+ explicit HIP_FP8_HOST_DEVICE hip_fp8(double v)
47
+ : hip_fp8(static_cast<float>(v)) {}
48
+
49
+ #ifdef __HIP__MI300__
50
+ // upcast using device specific intrinsic
51
+ explicit inline HIP_FP8_DEVICE operator float() const {
52
+ float fval;
53
+ uint32_t i32val = static_cast<uint32_t>(data);
54
+
55
+ // upcast
56
+ asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
57
+ : "=v"(fval)
58
+ : "v"(i32val));
59
+
60
+ return fval;
61
+ }
62
+
63
+ explicit inline HIP_FP8_HOST operator float() const
64
+ #else // __HIP__MI300__
65
+ explicit inline HIP_FP8_HOST_DEVICE operator float() const
66
+ #endif // __HIP__MI300__
67
+ {
68
+ return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>(
69
+ data);
70
+ }
71
+ };
72
+
73
+ namespace std {
74
+ inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); }
75
+ inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); }
76
+ HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; }
77
+ } // namespace std
78
+
79
+ // Special operator overloading
80
+ inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) {
81
+ return os << float(f8);
82
+ }
83
+
84
+ // all + operator overloading with mixed types
85
+ // mixed types, always converts to f32, does computation in f32, and returns
86
+ // float
87
+ inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) {
88
+ return (fa + float(b));
89
+ }
90
+
91
+ inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) {
92
+ return (float(a) + fb);
93
+ }
94
+
95
+ inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) {
96
+ return hip_fp8(float(a) + float(b));
97
+ }
98
+
99
+ inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) {
100
+ return a = hip_fp8(float(a) + float(b));
101
+ }
102
+
103
+ // overloading multiplication, always returns float,
104
+ inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) {
105
+ return float(a) * float(b);
106
+ }
107
+
108
+ inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) {
109
+ return (a * float(b));
110
+ }
111
+
112
+ inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) {
113
+ return (float(a) * b);
114
+ }
115
+
116
+ inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) {
117
+ return ((float)a * float(b));
118
+ }
119
+
120
+ inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) {
121
+ return ((float)a * float(b));
122
+ }
123
+
124
+ // overloading for compare
125
+ inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) {
126
+ return (a.data == b.data);
127
+ }
128
+ inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) {
129
+ return (a.data != b.data);
130
+ }
131
+
132
+ inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) {
133
+ return static_cast<float>(a) >= static_cast<float>(b);
134
+ }
135
+ inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) {
136
+ return static_cast<float>(a) > static_cast<float>(b);
137
+ }
fp8/amd/hip_float8_impl.h ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #if defined(__HIPCC__) && \
4
+ (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
5
+ #define __HIP__MI300__
6
+ #endif
7
+
8
+ #ifdef __HIPCC__
9
+ #define HIP_FP8_HOST_DEVICE __host__ __device__
10
+ #define HIP_FP8_HOST __host__
11
+ #define HIP_FP8_DEVICE __device__
12
+ #else
13
+ #define HIP_FP8_HOST_DEVICE
14
+ #define HIP_FP8_HOST
15
+ #define HIP_FP8_DEVICE
16
+ #endif
17
+
18
+ namespace hip_fp8_impl {
19
+
20
+ #ifdef __HIP__MI300__
21
+ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) {
22
+ uint8_t i8data;
23
+ union {
24
+ float fval;
25
+ uint32_t i32val;
26
+ uint8_t i8val[4]; // NOTE: not endian independent
27
+ } val;
28
+
29
+ uint32_t ival = 0;
30
+ val.fval = v;
31
+
32
+ if ((val.i32val & 0x7F800000) !=
33
+ 0x7F800000) { /// propagate NAN/INF, no clipping
34
+ val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
35
+ }
36
+
37
+ ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
38
+ false); // false -> WORD0
39
+ val.i32val = ival;
40
+ i8data = val.i8val[0];
41
+
42
+ return i8data;
43
+ }
44
+ #endif // __HIP__MI300__
45
+
46
+ HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); }
47
+ #if defined(__HIPCC__) || defined(__CUDA_ARCH__)
48
+ HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); }
49
+ #endif
50
+
51
+ template <int we, int wm, typename T, bool negative_zero_nan, bool clip>
52
+ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false,
53
+ uint32_t rng = 0) {
54
+ #ifdef __HIPCC__
55
+ constexpr bool is_half = std::is_same<T, _Float16>::value;
56
+ #else
57
+ constexpr bool is_half = false;
58
+ #endif
59
+ constexpr bool is_float = std::is_same<T, float>::value;
60
+ static_assert(wm + we == 7, "wm+we==7");
61
+ static_assert(is_half || is_float, "Only half and float can be cast to f8");
62
+
63
+ const int mfmt = (sizeof(T) == 4) ? 23 : 10;
64
+ uint32_t x;
65
+ if (sizeof(T) == 4) {
66
+ x = reinterpret_cast<uint32_t&>(_x);
67
+ } else {
68
+ x = reinterpret_cast<uint16_t&>(_x);
69
+ }
70
+
71
+ uint32_t head, mantissa;
72
+ int exponent, bias;
73
+ uint32_t sign;
74
+
75
+ if (sizeof(T) == 4) {
76
+ head = x & 0xFF800000;
77
+ mantissa = x & 0x7FFFFF;
78
+ exponent = (head >> 23) & 0xFF;
79
+ sign = head >> 31;
80
+ bias = 127;
81
+ } else {
82
+ head = x & 0xFC00;
83
+ mantissa = x & 0x3FF;
84
+ exponent = (head >> 10) & 0x1F;
85
+ sign = head >> 15;
86
+ bias = 15;
87
+ }
88
+
89
+ uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
90
+
91
+ // Deal with inf and NaNs
92
+ if (negative_zero_nan) {
93
+ if (sizeof(T) == 4) {
94
+ if ((x & 0x7F800000) == 0x7F800000) {
95
+ return 0x80;
96
+ }
97
+ } else {
98
+ // if(__hisinf(x) || __hisnan(x))
99
+ if ((x & 0x7C00) == 0x7C00) {
100
+ return 0x80;
101
+ }
102
+ }
103
+ } else {
104
+ if (sizeof(T) == 4) {
105
+ if ((x & 0x7F800000) == 0x7F800000) {
106
+ return signed_inf + (mantissa != 0 ? 1 : 0);
107
+ }
108
+ } else {
109
+ if ((x & 0x7C00) == 0x7C00) {
110
+ return signed_inf + (mantissa != 0 ? 1 : 0);
111
+ }
112
+ }
113
+ }
114
+ if (x == 0) {
115
+ return 0;
116
+ }
117
+
118
+ // First need to check if it is normal or denorm as there is a difference of
119
+ // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
120
+ // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
121
+ // to mantissa and truncate. And for RNE, no need to add rng. Then probably
122
+ // need to check whether there is carry and adjust exponent and mantissa again
123
+
124
+ // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
125
+ // bits
126
+ const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
127
+ const int f8_denormal_act_exponent =
128
+ 1 - f8_bias; // actual exponent of f8 denormal
129
+ // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
130
+ // f8_exponent is the converted f8 exponent with bias encoding
131
+ // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
132
+ // the difference needs to be adjusted and mantissa shifted
133
+ int act_exponent, f8_exponent, exponent_diff;
134
+
135
+ if (exponent == 0) { // fp32/fp16 is in denormal.
136
+ /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
137
+ mostly concern fp16 here. In this case, f8 is usually in denormal. But there
138
+ could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
139
+ exponent bias 16. It means that there are some numbers in fp16 denormal but they
140
+ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
141
+ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
142
+ (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
143
+ act_exponent = exponent - bias + 1;
144
+ exponent_diff =
145
+ f8_denormal_act_exponent -
146
+ act_exponent; // actual exponent is exponent-bias+1 as it is denormal
147
+ } else { // fp32/fp16 is normal with implicit 1
148
+ act_exponent = exponent - bias;
149
+ if (act_exponent <= f8_denormal_act_exponent) {
150
+ /* This is the case where fp32/fp16 is normal but it is in f8 denormal
151
+ range. For example fp8 nanoo mode, denormal exponent is -7, but if the
152
+ fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
153
+ Therefore it needs to be adjust to -6 and mantissa shift right by 1.
154
+ So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
155
+ exponent_diff = f8_denormal_act_exponent - act_exponent;
156
+ } else { // both fp32/fp16 and f8 are in normal range
157
+ exponent_diff = 0; // exponent_diff=0 does not mean there is no
158
+ // difference for this case, act_exponent could be
159
+ // larger. Just that it does not need shift mantissa
160
+ }
161
+ mantissa += (1 << mfmt); // Add the implicit 1 into mantissa
162
+ }
163
+
164
+ bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) ==
165
+ static_cast<uint32_t>(1 << (mfmt - wm + exponent_diff - 1));
166
+ /* This part is a bit tricky. The judgment of whether it is a tie needs to be
167
+ done before we shift right as shift right could rip off some residual part
168
+ and make something not midpoint look like midpoint. For example, the fp16
169
+ number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
170
+ shift right by 4 bits, it would look like midpoint.
171
+ */
172
+
173
+ if (exponent_diff > 0) {
174
+ mantissa >>= exponent_diff;
175
+ } else if (exponent_diff == -1) {
176
+ mantissa <<= -exponent_diff;
177
+ }
178
+ bool implicit_one = mantissa & (1 << mfmt);
179
+ // if there is no implicit 1, it means the f8 is denormal and need to adjust
180
+ // to denorm exponent
181
+ f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ +
182
+ f8_bias - (implicit_one ? 0 : 1);
183
+
184
+ // Now we have the exponent and mantissa adjusted
185
+ uint32_t drop_mask = (1 << (mfmt - wm)) - 1;
186
+ bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit
187
+ // that is not truncated is 1
188
+ mantissa +=
189
+ (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) &
190
+ drop_mask;
191
+
192
+ // Now we deal with overflow
193
+ if (f8_exponent == 0) {
194
+ if ((1 << mfmt) & mantissa) {
195
+ f8_exponent = 1; // denormal overflow to become normal, promote exponent
196
+ }
197
+ } else {
198
+ if ((1 << (mfmt + 1)) & mantissa) {
199
+ mantissa >>= 1;
200
+ f8_exponent++;
201
+ }
202
+ }
203
+
204
+ mantissa >>= (mfmt - wm);
205
+
206
+ // above range: quantize to maximum possible float of the same sign
207
+ const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
208
+ if (f8_exponent > max_exp) {
209
+ if (clip) {
210
+ mantissa = (1 << wm) - 1;
211
+ f8_exponent = max_exp;
212
+ } else {
213
+ return signed_inf;
214
+ }
215
+ }
216
+
217
+ if (f8_exponent == 0 && mantissa == 0) {
218
+ return negative_zero_nan ? 0 : (sign << 7);
219
+ }
220
+ mantissa &= (1 << wm) - 1;
221
+ return (sign << 7) | (f8_exponent << wm) | mantissa;
222
+ }
223
+
224
+ template <int we, int wm, typename T = float, bool negative_zero_nan = true>
225
+ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) {
226
+ #ifdef __HIPCC__
227
+ constexpr bool is_half = std::is_same<T, _Float16>::value;
228
+ #else
229
+ constexpr bool is_half = false;
230
+ #endif
231
+ constexpr bool is_float = std::is_same<T, float>::value;
232
+ static_assert(is_half || is_float, "only half and float are supported");
233
+
234
+ constexpr int weo = is_half ? 5 : 8;
235
+ constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7);
236
+
237
+ T fInf, fNegInf, fNaN, fNeg0;
238
+
239
+ #ifdef __HIPCC__
240
+ if (is_half) {
241
+ const uint16_t ihInf = 0x7C00;
242
+ const uint16_t ihNegInf = 0xFC00;
243
+ const uint16_t ihNaN = 0x7C01;
244
+ const uint16_t ihNeg0 = 0x8000;
245
+ fInf = reinterpret_cast<const _Float16&>(ihInf);
246
+ fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
247
+ fNaN = reinterpret_cast<const _Float16&>(ihNaN);
248
+ fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
249
+ } else
250
+ #endif
251
+ if (is_float) {
252
+ const uint32_t ifInf = 0x7F800000;
253
+ const uint32_t ifNegInf = 0xFF800000;
254
+ const uint32_t ifNaN = 0x7F800001;
255
+ const uint32_t ifNeg0 = 0x80000000;
256
+ fInf = reinterpret_cast<const float&>(ifInf);
257
+ fNegInf = reinterpret_cast<const float&>(ifNegInf);
258
+ fNaN = reinterpret_cast<const float&>(ifNaN);
259
+ fNeg0 = reinterpret_cast<const float&>(ifNeg0);
260
+ }
261
+
262
+ if (x == 0) {
263
+ return 0;
264
+ }
265
+
266
+ uint32_t sign = x >> 7;
267
+ uint32_t mantissa = x & ((1 << wm) - 1);
268
+ int exponent = (x & 0x7F) >> wm;
269
+ if (negative_zero_nan) {
270
+ if (x == 0x80) {
271
+ return fNaN;
272
+ }
273
+ } else {
274
+ if (x == 0x80) {
275
+ return fNeg0;
276
+ }
277
+ if (exponent == ((1 << we) - 1)) {
278
+ return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
279
+ }
280
+ }
281
+ typename std::conditional<sizeof(T) == 2, uint16_t, uint32_t>::type retval;
282
+ if (we == 5 && is_half && !negative_zero_nan) {
283
+ retval = x << 8;
284
+ return reinterpret_cast<const T&>(retval);
285
+ }
286
+
287
+ const int exp_low_cutoff =
288
+ (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
289
+
290
+ // subnormal input
291
+ if (exponent == 0) {
292
+ // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
293
+ int sh = 1 + clz(mantissa) - (32 - wm);
294
+ mantissa <<= sh;
295
+ exponent += 1 - sh;
296
+ mantissa &= ((1 << wm) - 1);
297
+ }
298
+ exponent += exp_low_cutoff - 1;
299
+ mantissa <<= wmo - wm;
300
+
301
+ // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
302
+ if (exponent <= 0) {
303
+ mantissa |= 1 << wmo;
304
+ mantissa >>= 1 - exponent;
305
+ exponent = 0;
306
+ }
307
+
308
+ if (sizeof(T) == 2) {
309
+ retval = (sign << 15) | (exponent << 10) | mantissa;
310
+ } else {
311
+ retval = (sign << 31) | (exponent << 23) | mantissa;
312
+ }
313
+ return reinterpret_cast<const T&>(retval);
314
+ }
315
+
316
+ } // namespace hip_fp8_impl
fp8/common.cu ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "common.cuh"
2
+ #include "dispatch_utils.h"
3
+
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ #ifndef USE_ROCM
7
+ #include <cub/cub.cuh>
8
+ #else
9
+ #include <hipcub/hipcub.hpp>
10
+ #endif
11
+
12
+ namespace vllm {
13
+
14
+ template <typename scalar_t>
15
+ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
16
+ const scalar_t* __restrict__ input,
17
+ const float* __restrict__ scale,
18
+ int64_t num_elems) {
19
+ int tid = blockDim.x * blockIdx.x + threadIdx.x;
20
+
21
+ // Invert the scale so that we can use multiplications to avoid expensive
22
+ // division.
23
+ const float inverted_scale = 1.0f / (*scale);
24
+ scaled_fp8_conversion_vec<scalar_t, true>(
25
+ out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
26
+ }
27
+
28
+ template <typename scalar_t>
29
+ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
30
+ FP8_TYPE* __restrict__ out, float* __restrict__ scale,
31
+ scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
32
+ const int hidden_size) {
33
+ float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
34
+
35
+ int const tid = threadIdx.x;
36
+ int const token_idx = blockIdx.x;
37
+
38
+ // Use int64 to avoid overflowing an int32 when calculating this offset
39
+ int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
40
+ scalar_t const* __restrict__ token_input = &input[offset];
41
+ FP8_TYPE* __restrict__ token_output = &out[offset];
42
+
43
+ // For vectorization, token_input and token_output pointers need to be
44
+ // aligned at 8-byte and 4-byte addresses respectively.
45
+ bool const can_vectorize = hidden_size % 4 == 0;
46
+
47
+ float absmax_val = 0.0f;
48
+ if (can_vectorize) {
49
+ absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
50
+ } else {
51
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
52
+ float const x = static_cast<float>(token_input[i]);
53
+ absmax_val = max(absmax_val, fabs(x));
54
+ }
55
+ }
56
+
57
+ using BlockReduce = cub::BlockReduce<float, 1024>;
58
+ __shared__ typename BlockReduce::TempStorage reduceStorage;
59
+ float const block_absmax_val_maybe =
60
+ BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
61
+ __shared__ float token_scale;
62
+ if (tid == 0) {
63
+ if (scale_ub) {
64
+ token_scale = min(block_absmax_val_maybe, *scale_ub);
65
+ } else {
66
+ token_scale = block_absmax_val_maybe;
67
+ }
68
+ // token scale computation
69
+ token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
70
+ scale[token_idx] = token_scale;
71
+ }
72
+ __syncthreads();
73
+
74
+ // Note that we don't use inverted scales so we can match FBGemm impl.
75
+ if (can_vectorize) {
76
+ scaled_fp8_conversion_vec<scalar_t, false>(
77
+ token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
78
+ } else {
79
+ for (int i = tid; i < hidden_size; i += blockDim.x) {
80
+ token_output[i] = scaled_fp8_conversion<false>(
81
+ static_cast<float>(token_input[i]), token_scale);
82
+ }
83
+ }
84
+ }
85
+
86
+ } // namespace vllm
87
+
88
+ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
89
+ torch::Tensor const& input, // [..., d]
90
+ torch::Tensor const& scale) // [1]
91
+ {
92
+ int64_t num_tokens = input.numel() / input.size(-1);
93
+ int64_t num_elems = input.numel();
94
+ dim3 grid(num_tokens);
95
+ dim3 block(1024);
96
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
97
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
98
+ VLLM_DISPATCH_FLOATING_TYPES(
99
+ input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
100
+ vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
101
+ out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
102
+ scale.data_ptr<float>(), num_elems);
103
+ });
104
+ }
105
+
106
+ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
107
+ torch::Tensor const& input, // [..., d]
108
+ torch::Tensor& scale) // [1]
109
+ {
110
+ int64_t num_tokens = input.numel() / input.size(-1);
111
+ int64_t num_elems = input.numel();
112
+ dim3 grid(num_tokens);
113
+ dim3 block(1024);
114
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
115
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
116
+ VLLM_DISPATCH_FLOATING_TYPES(
117
+ input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
118
+ vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
119
+ scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
120
+ vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
121
+ out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
122
+ scale.data_ptr<float>(), num_elems);
123
+ });
124
+ }
125
+
126
+ void dynamic_per_token_scaled_fp8_quant(
127
+ torch::Tensor& out, // [..., d]
128
+ torch::Tensor const& input, // [..., d]
129
+ torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
130
+ TORCH_CHECK(input.is_contiguous());
131
+ TORCH_CHECK(out.is_contiguous());
132
+
133
+ int const hidden_size = input.size(-1);
134
+ int const num_tokens = input.numel() / hidden_size;
135
+ dim3 const grid(num_tokens);
136
+ dim3 const block(std::min(hidden_size, 1024));
137
+
138
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
139
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
140
+ VLLM_DISPATCH_FLOATING_TYPES(
141
+ input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
142
+ vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
143
+ <<<grid, block, 0, stream>>>(
144
+ out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
145
+ input.data_ptr<scalar_t>(),
146
+ scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
147
+ hidden_size);
148
+ });
149
+ }
fp8/common.cuh ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "vectorization.cuh"
4
+
5
+ #include <c10/core/ScalarType.h>
6
+ #include <cmath>
7
+
8
+ #ifndef USE_ROCM
9
+ #include <c10/util/Float8_e4m3fn.h>
10
+ using FP8_TYPE = c10::Float8_e4m3fn;
11
+ C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
12
+ std::numeric_limits<FP8_TYPE>::max();
13
+ #else
14
+ #include "amd/hip_float8.h"
15
+ #include <c10/util/Float8_e4m3fnuz.h>
16
+ using FP8_TYPE = c10::Float8_e4m3fnuz;
17
+ // Using the default max value from pytorch (240.0) will cause accuracy
18
+ // issue when running dynamic quantization. Here use 224.0f for rocm.
19
+ constexpr auto FP8_E4M3_MAX = 224.0f;
20
+ #endif
21
+ constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
22
+
23
+ namespace vllm {
24
+
25
+ __device__ __forceinline__ float atomicMaxFloat(float *addr, float value) {
26
+ float old;
27
+ old = (value >= 0)
28
+ ? __int_as_float(atomicMax((int *)addr, __float_as_int(value)))
29
+ : __uint_as_float(
30
+ atomicMin((unsigned int *)addr, __float_as_uint(value)));
31
+
32
+ return old;
33
+ }
34
+
35
+ template <bool is_scale_inverted>
36
+ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
37
+ float const scale) {
38
+ float x = 0.0f;
39
+ if constexpr (is_scale_inverted) {
40
+ x = val * scale;
41
+ } else {
42
+ x = val / scale;
43
+ }
44
+
45
+ float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
46
+ #ifndef USE_ROCM
47
+ return static_cast<c10::Float8_e4m3fn>(r);
48
+ #else
49
+ // Use hardware cvt instruction for fp8 on rocm
50
+ return c10::Float8_e4m3fnuz(hip_fp8(r).data,
51
+ c10::Float8_e4m3fnuz::from_bits());
52
+ #endif
53
+ }
54
+
55
+ // Compute the absolute maximum m of the input tensor and store
56
+ // m / float8_e4m3::max() in *scale. Each thread block performs a
57
+ // reduction tree and the memory in scale is atomically updated.
58
+ // So to get the right answer, *scale needs to be initialized to
59
+ // a value <= 0.0 and we need to wait for all thread blocks to
60
+ // finish before consuming *scale.
61
+ template <typename scalar_t>
62
+ __global__ void segmented_max_reduction(float *__restrict__ scale,
63
+ const scalar_t *__restrict__ input,
64
+ int64_t num_elems) {
65
+ __shared__ float cache[1024];
66
+ int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
67
+
68
+ // First store maximum for all values processes by
69
+ // the current thread in cache[threadIdx.x]
70
+ scalar_t tmp = 0.0;
71
+ while (i < num_elems) {
72
+ float x = static_cast<float>(input[i]);
73
+ tmp = max(tmp, fabs(x));
74
+ i += blockDim.x * gridDim.x;
75
+ }
76
+ cache[threadIdx.x] = tmp;
77
+
78
+ __syncthreads();
79
+
80
+ // Now perform parallel reduction within the thread block
81
+ int ib = blockDim.x / 2;
82
+ while (ib != 0) {
83
+ if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
84
+ cache[threadIdx.x] = cache[threadIdx.x + ib];
85
+ }
86
+ __syncthreads();
87
+ ib /= 2;
88
+ }
89
+ // Finally, since cache[0] contains the maximum for this thread block,
90
+ // atomically write the max to the target location
91
+ if (threadIdx.x == 0) {
92
+ atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
93
+ }
94
+ }
95
+
96
+ template <typename scalar_t>
97
+ __device__ float thread_max_vec(scalar_t const *__restrict__ input,
98
+ int64_t const num_elems, int const tid,
99
+ int const step) {
100
+ // Vectorized input/output to better utilize memory bandwidth.
101
+ vec4_t<scalar_t> const *vectorized_in =
102
+ reinterpret_cast<vec4_t<scalar_t> const *>(input);
103
+
104
+ int64_t const num_vec_elems = num_elems >> 2;
105
+ float absmax_val = 0.0f;
106
+
107
+ #pragma unroll 4
108
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
109
+ vec4_t<scalar_t> in_vec = vectorized_in[i];
110
+ absmax_val = max(absmax_val, fabs(in_vec.x));
111
+ absmax_val = max(absmax_val, fabs(in_vec.y));
112
+ absmax_val = max(absmax_val, fabs(in_vec.z));
113
+ absmax_val = max(absmax_val, fabs(in_vec.w));
114
+ }
115
+
116
+ // Handle the remaining elements if num_elems is not divisible by 4
117
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
118
+ absmax_val = max(absmax_val, fabs(input[i]));
119
+ }
120
+
121
+ return absmax_val;
122
+ }
123
+
124
+ template <typename scalar_t, bool is_scale_inverted>
125
+ __device__ void scaled_fp8_conversion_vec(FP8_TYPE *__restrict__ out,
126
+ scalar_t const *__restrict__ input,
127
+ float const scale,
128
+ int64_t const num_elems,
129
+ int const tid, int const step) {
130
+ using float8x4_t = q8x4_t<FP8_TYPE>;
131
+ // Vectorized input/output to better utilize memory bandwidth.
132
+ auto const *vectorized_in = reinterpret_cast<vec4_t<scalar_t> const *>(input);
133
+ auto *vectorized_out = reinterpret_cast<float8x4_t *>(out);
134
+
135
+ int64_t const num_vec_elems = num_elems >> 2;
136
+
137
+ #pragma unroll 4
138
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
139
+ vec4_t<scalar_t> in_vec = vectorized_in[i];
140
+ float8x4_t out_vec;
141
+
142
+ out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
143
+ static_cast<float>(in_vec.x), scale);
144
+ out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
145
+ static_cast<float>(in_vec.y), scale);
146
+ out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
147
+ static_cast<float>(in_vec.z), scale);
148
+ out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
149
+ static_cast<float>(in_vec.w), scale);
150
+ vectorized_out[i] = out_vec;
151
+ }
152
+
153
+ // Handle the remaining elements if num_elems is not divisible by 4
154
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
155
+ out[i] = scaled_fp8_conversion<is_scale_inverted>(
156
+ static_cast<float>(input[i]), scale);
157
+ }
158
+ }
159
+
160
+ } // namespace vllm
fp8/vectorization.cuh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ /**
3
+ * __device__ datatypes vectorized by 4
4
+ */
5
+
6
+ // Include both AMD and NVIDIA fp8 types to avoid circular import
7
+ // TODO(luka/varun) use FP8_TYPE instead after refactoring
8
+ #include <c10/util/Float8_e4m3fnuz.h>
9
+ #include <c10/util/Float8_e4m3fn.h>
10
+
11
+ namespace vllm {
12
+
13
+ // Vectorization containers
14
+ template <typename scalar_t>
15
+ struct __align__(8) vec4_t {
16
+ scalar_t x;
17
+ scalar_t y;
18
+ scalar_t z;
19
+ scalar_t w;
20
+ };
21
+
22
+ template <typename quant_type_t>
23
+ struct __align__(4) q8x4_t {
24
+ static_assert(std::is_same_v<quant_type_t, int8_t> ||
25
+ std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
26
+ std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
27
+ quant_type_t x;
28
+ quant_type_t y;
29
+ quant_type_t z;
30
+ quant_type_t w;
31
+ };
32
+
33
+ } // namespace vllm
marlin-moe/marlin_kernels/marlin_moe_kernel.h CHANGED
@@ -138,8 +138,8 @@ __device__ inline FragB dequant<vllm::kU4B8.id()>(int q) {
138
  const int HI = 0x00f000f0;
139
  const int EX = 0x64006400;
140
  // Guarantee that the `(a & b) | c` operations are LOP3s.
141
- int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
142
- int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
143
  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
144
  // directly into `SUB` and `ADD`.
145
  const int SUB = 0x64086408;
@@ -182,8 +182,8 @@ __device__ inline FragB dequant<vllm::kU4.id()>(int q) {
182
  const int HI = 0x00f000f0;
183
  const int EX = 0x64006400;
184
  // Guarantee that the `(a & b) | c` operations are LOP3s.
185
- int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
186
- int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
187
 
188
  const int SUB = 0x64006400;
189
  const int MUL = 0x2c002c00;
 
138
  const int HI = 0x00f000f0;
139
  const int EX = 0x64006400;
140
  // Guarantee that the `(a & b) | c` operations are LOP3s.
141
+ int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
142
+ int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
143
  // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
144
  // directly into `SUB` and `ADD`.
145
  const int SUB = 0x64086408;
 
182
  const int HI = 0x00f000f0;
183
  const int EX = 0x64006400;
184
  // Guarantee that the `(a & b) | c` operations are LOP3s.
185
+ int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX);
186
+ int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX);
187
 
188
  const int SUB = 0x64006400;
189
  const int MUL = 0x2c002c00;
moe/moe_align_sum_kernels.cu CHANGED
@@ -21,7 +21,7 @@ __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
21
  }
22
  } // namespace
23
 
24
- template <typename scalar_t>
25
  __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
26
  int32_t* sorted_token_ids,
27
  int32_t* expert_ids,
@@ -32,12 +32,96 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
32
  const size_t start_idx = threadIdx.x * tokens_per_thread;
33
 
34
  extern __shared__ int32_t shared_mem[];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- int32_t* tokens_cnts =
37
- shared_mem; // 2d tensor with shape (blockDim.x + 1, num_experts)
38
- int32_t* cumsum =
39
- shared_mem +
40
- (blockDim.x + 1) * num_experts; // 1d tensor with shape (num_experts + 1)
 
 
 
 
 
 
41
 
42
  for (int i = 0; i < num_experts; ++i) {
43
  tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
@@ -113,6 +197,72 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
113
  }
114
  }
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  template <typename scalar_t, int TOPK>
117
  __global__ void moe_sum_kernel(
118
  scalar_t* __restrict__ out, // [..., d]
@@ -137,24 +287,113 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
137
  torch::Tensor experts_ids,
138
  torch::Tensor num_tokens_post_pad) {
139
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  VLLM_DISPATCH_INTEGRAL_TYPES(
141
- topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
142
  // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
143
  // tensors
144
- const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
145
- const int32_t shared_mem =
146
- ((num_thread + 1) * num_experts + (num_experts + 1)) *
147
- sizeof(int32_t);
148
-
149
- // set dynamic shared mem
150
- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
151
- AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
152
- (void*)kernel, shared_mem));
153
- kernel<<<1, num_thread, shared_mem, stream>>>(
154
  topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
155
  experts_ids.data_ptr<int32_t>(),
156
  num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
157
- topk_ids.numel());
158
  });
159
  }
160
 
 
21
  }
22
  } // namespace
23
 
24
+ template <typename scalar_t, typename token_cnts_t>
25
  __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
26
  int32_t* sorted_token_ids,
27
  int32_t* expert_ids,
 
32
  const size_t start_idx = threadIdx.x * tokens_per_thread;
33
 
34
  extern __shared__ int32_t shared_mem[];
35
+ int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1)
36
+ token_cnts_t* tokens_cnts =
37
+ (token_cnts_t*)(shared_mem + num_experts +
38
+ 1); // 2d tensor with shape (blockDim.x + 1, num_experts)
39
+
40
+ for (int i = 0; i < num_experts; ++i) {
41
+ tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
42
+ }
43
+
44
+ /**
45
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
46
+ * which counts how many tokens in the token shard of thread_index are
47
+ * assigned to expert expert_index.
48
+ */
49
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
50
+ ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
51
+ }
52
+
53
+ __syncthreads();
54
+
55
+ // For each expert we accumulate the token counts from the different threads.
56
+ if (threadIdx.x < num_experts) {
57
+ tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
58
+ for (int i = 1; i <= blockDim.x; ++i) {
59
+ tokens_cnts[index(num_experts, i, threadIdx.x)] +=
60
+ tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
61
+ }
62
+ }
63
+
64
+ __syncthreads();
65
+
66
+ // We accumulate the token counts of all experts in thread 0.
67
+ if (threadIdx.x == 0) {
68
+ cumsum[0] = 0;
69
+ for (int i = 1; i <= num_experts; ++i) {
70
+ cumsum[i] = cumsum[i - 1] +
71
+ CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
72
+ block_size) *
73
+ block_size;
74
+ }
75
+ *total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
76
+ }
77
+
78
+ __syncthreads();
79
+
80
+ /**
81
+ * For each expert, each thread processes the tokens of the corresponding
82
+ * blocks and stores the corresponding expert_id for each block.
83
+ */
84
+ if (threadIdx.x < num_experts) {
85
+ for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
86
+ i += block_size) {
87
+ expert_ids[i / block_size] = threadIdx.x;
88
+ }
89
+ }
90
+
91
+ /**
92
+ * Each thread processes a token shard, calculating the index of each token
93
+ * after sorting by expert number. Given the example topk_ids =
94
+ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
95
+ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
96
+ * padding value(preset in python).
97
+ */
98
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
99
+ int32_t expert_id = topk_ids[i];
100
+ /** The cumsum[expert_id] stores the starting index of the tokens that the
101
+ * expert with expert_id needs to process, and
102
+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
103
+ * processed by the expert with expert_id within the current thread's token
104
+ * shard.
105
+ */
106
+ int32_t rank_post_pad =
107
+ tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
108
+ cumsum[expert_id];
109
+ sorted_token_ids[rank_post_pad] = i;
110
+ ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
111
+ }
112
+ }
113
 
114
+ // TODO(simon): this is temporarily adapted from
115
+ // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
116
+ // we did this to unblock Deepseek V3 but there should be a better
117
+ // implementation to manage shared memory.
118
+ template <typename scalar_t>
119
+ __global__ void moe_align_block_size_global_mem_kernel(
120
+ scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
121
+ int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
122
+ int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
123
+ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
124
+ const size_t start_idx = threadIdx.x * tokens_per_thread;
125
 
126
  for (int i = 0; i < num_experts; ++i) {
127
  tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
 
197
  }
198
  }
199
 
200
+ // taken from
201
+ // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a
202
+ template <typename scalar_t>
203
+ __global__ void sgl_moe_align_block_size_kernel(
204
+ scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
205
+ int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
206
+ int32_t block_size, size_t numel, int32_t* cumsum) {
207
+ __shared__ int32_t shared_counts[32][8];
208
+ __shared__ int32_t local_offsets[256];
209
+
210
+ const int warp_id = threadIdx.x / WARP_SIZE;
211
+ const int lane_id = threadIdx.x % WARP_SIZE;
212
+ const int experts_per_warp = 8;
213
+ const int my_expert_start = warp_id * experts_per_warp;
214
+
215
+ for (int i = 0; i < experts_per_warp; ++i) {
216
+ if (my_expert_start + i < num_experts) {
217
+ shared_counts[warp_id][i] = 0;
218
+ }
219
+ }
220
+
221
+ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
222
+ const size_t start_idx = threadIdx.x * tokens_per_thread;
223
+
224
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
225
+ int expert_id = topk_ids[i];
226
+ int warp_idx = expert_id / experts_per_warp;
227
+ int expert_offset = expert_id % experts_per_warp;
228
+ atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
229
+ }
230
+
231
+ __syncthreads();
232
+
233
+ if (threadIdx.x == 0) {
234
+ cumsum[0] = 0;
235
+ for (int i = 1; i <= num_experts; ++i) {
236
+ int expert_count = 0;
237
+ int warp_idx = (i - 1) / experts_per_warp;
238
+ int expert_offset = (i - 1) % experts_per_warp;
239
+ expert_count = shared_counts[warp_idx][expert_offset];
240
+
241
+ cumsum[i] =
242
+ cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
243
+ }
244
+ *total_tokens_post_pad = cumsum[num_experts];
245
+ }
246
+
247
+ __syncthreads();
248
+
249
+ if (threadIdx.x < num_experts) {
250
+ for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
251
+ i += block_size) {
252
+ expert_ids[i / block_size] = threadIdx.x;
253
+ }
254
+ local_offsets[threadIdx.x] = cumsum[threadIdx.x];
255
+ }
256
+
257
+ __syncthreads();
258
+
259
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
260
+ int32_t expert_id = topk_ids[i];
261
+ int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1);
262
+ sorted_token_ids[rank_post_pad] = i;
263
+ }
264
+ }
265
+
266
  template <typename scalar_t, int TOPK>
267
  __global__ void moe_sum_kernel(
268
  scalar_t* __restrict__ out, // [..., d]
 
287
  torch::Tensor experts_ids,
288
  torch::Tensor num_tokens_post_pad) {
289
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
290
+
291
+ int device_max_shared_mem;
292
+ auto dev = topk_ids.get_device();
293
+ cudaDeviceGetAttribute(&device_max_shared_mem,
294
+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
295
+
296
+ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
297
+ const int32_t shared_mem_i32 =
298
+ ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
299
+ const int32_t shared_mem_i16 =
300
+ ((num_thread + 1) * num_experts) * sizeof(uint16_t) +
301
+ (num_experts + 1) * sizeof(int32_t);
302
+
303
+ bool use_global_memory = false;
304
+ bool use_i16 = false; // Use uint16_t for shared memory token counts
305
+ if (shared_mem_i32 < device_max_shared_mem) {
306
+ // Do nothing in this case. We're all set to use int32_t token counts
307
+ } else if (shared_mem_i16 < device_max_shared_mem &&
308
+ topk_ids.numel() <= 65535) {
309
+ // when nelements of topk_ids is smaller than 65535 (max value of uint16),
310
+ // element value of token_cnts would also smaller than 65535,
311
+ // so we can use uint16 as dtype of token_cnts
312
+ use_i16 = true;
313
+ } else {
314
+ use_global_memory = true;
315
+ }
316
+
317
+ if (use_global_memory) {
318
+ VLLM_DISPATCH_INTEGRAL_TYPES(
319
+ topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
320
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
321
+ // tensors
322
+ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
323
+
324
+ auto options_int = torch::TensorOptions()
325
+ .dtype(torch::kInt)
326
+ .device(topk_ids.device());
327
+ torch::Tensor token_cnts_buffer =
328
+ torch::empty({(num_experts + 1) * num_experts}, options_int);
329
+ torch::Tensor cumsum_buffer =
330
+ torch::empty({num_experts + 1}, options_int);
331
+
332
+ auto kernel =
333
+ vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t>;
334
+ kernel<<<1, num_thread, 0, stream>>>(
335
+ topk_ids.data_ptr<scalar_t>(),
336
+ sorted_token_ids.data_ptr<int32_t>(),
337
+ experts_ids.data_ptr<int32_t>(),
338
+ num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
339
+ topk_ids.numel(), token_cnts_buffer.data_ptr<int32_t>(),
340
+ cumsum_buffer.data_ptr<int32_t>());
341
+ });
342
+ } else if (use_i16) {
343
+ VLLM_DISPATCH_INTEGRAL_TYPES(
344
+ topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
345
+ // set dynamic shared mem
346
+ auto kernel =
347
+ vllm::moe::moe_align_block_size_kernel<scalar_t, uint16_t>;
348
+ AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
349
+ (void*)kernel, shared_mem_i16));
350
+ kernel<<<1, num_thread, shared_mem_i16, stream>>>(
351
+ topk_ids.data_ptr<scalar_t>(),
352
+ sorted_token_ids.data_ptr<int32_t>(),
353
+ experts_ids.data_ptr<int32_t>(),
354
+ num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
355
+ topk_ids.numel());
356
+ });
357
+ } else {
358
+ VLLM_DISPATCH_INTEGRAL_TYPES(
359
+ topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
360
+ auto kernel =
361
+ vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
362
+ AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
363
+ (void*)kernel, shared_mem_i32));
364
+ kernel<<<1, num_thread, shared_mem_i32, stream>>>(
365
+ topk_ids.data_ptr<scalar_t>(),
366
+ sorted_token_ids.data_ptr<int32_t>(),
367
+ experts_ids.data_ptr<int32_t>(),
368
+ num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
369
+ topk_ids.numel());
370
+ });
371
+ }
372
+ }
373
+
374
+ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
375
+ int64_t block_size,
376
+ torch::Tensor sorted_token_ids,
377
+ torch::Tensor experts_ids,
378
+ torch::Tensor num_tokens_post_pad) {
379
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
380
  VLLM_DISPATCH_INTEGRAL_TYPES(
381
+ topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
382
  // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
383
  // tensors
384
+ auto options_int =
385
+ torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
386
+ // torch::Tensor token_cnts_buffer =
387
+ // torch::empty({(num_experts + 1) * num_experts}, options_int);
388
+ torch::Tensor cumsum_buffer =
389
+ torch::empty({num_experts + 1}, options_int);
390
+
391
+ auto kernel = vllm::moe::sgl_moe_align_block_size_kernel<scalar_t>;
392
+ kernel<<<1, 1024, 0, stream>>>(
 
393
  topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
394
  experts_ids.data_ptr<int32_t>(),
395
  num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
396
+ topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
397
  });
398
  }
399
 
test/kernels/test_moe.py CHANGED
@@ -11,10 +11,11 @@ import torch
11
  from moe._ops import ops
12
  from moe.fused_moe import fused_moe, fused_topk, moe_align_block_size
13
  from moe.fused_marlin_moe import fused_marlin_moe
 
14
  from moe.scalar_type import scalar_types
15
- from moe.utils.marlin_utils_test import marlin_quantize
16
 
17
- from .utils import compute_max_diff, opcheck
18
 
19
 
20
  def stack_and_dev(tensors: List[torch.Tensor]):
@@ -26,6 +27,136 @@ NUM_EXPERTS = [8, 64]
26
  TOP_KS = [2, 6]
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  @pytest.mark.parametrize("m", [1, 33, 64, 222])
30
  @pytest.mark.parametrize("n", [128, 2048])
31
  @pytest.mark.parametrize("k", [128, 1024])
@@ -35,7 +166,7 @@ TOP_KS = [2, 6]
35
  @pytest.mark.parametrize("act_order", [True, False])
36
  @pytest.mark.parametrize("num_bits", [4, 8])
37
  @pytest.mark.parametrize("is_k_full", [True, False])
38
- # @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
39
  def test_fused_marlin_moe(
40
  m: int,
41
  n: int,
 
11
  from moe._ops import ops
12
  from moe.fused_moe import fused_moe, fused_topk, moe_align_block_size
13
  from moe.fused_marlin_moe import fused_marlin_moe
14
+ from moe.platforms import current_platform
15
  from moe.scalar_type import scalar_types
16
+ from moe.utils.marlin_utils_test import marlin_quantize, quantize_weights
17
 
18
+ from .utils import compute_max_diff, opcheck, torch_moe
19
 
20
 
21
  def stack_and_dev(tensors: List[torch.Tensor]):
 
27
  TOP_KS = [2, 6]
28
 
29
 
30
+ @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
31
+ @pytest.mark.parametrize("n", [128, 1024, 2048])
32
+ @pytest.mark.parametrize("k", [128, 511, 1024])
33
+ @pytest.mark.parametrize("e", NUM_EXPERTS)
34
+ @pytest.mark.parametrize("topk", TOP_KS)
35
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
36
+ def test_fused_moe(
37
+ m: int,
38
+ n: int,
39
+ k: int,
40
+ e: int,
41
+ topk: int,
42
+ dtype: torch.dtype,
43
+ ):
44
+ a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
45
+ w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
46
+ w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
47
+
48
+ score = torch.randn((m, e), device="cuda", dtype=dtype)
49
+ triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
50
+ torch_output = torch_moe(a, w1, w2, score, topk)
51
+ torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
52
+ # iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False)
53
+ # torch.testing.assert_close(iterative_output, torch_output, atol=2e-2, rtol=0)
54
+
55
+
56
+ @pytest.mark.parametrize("m", [1, 32, 222])
57
+ @pytest.mark.parametrize("n", [128, 1024, 2048])
58
+ @pytest.mark.parametrize("k", [128, 1024])
59
+ @pytest.mark.parametrize("e", NUM_EXPERTS)
60
+ @pytest.mark.parametrize("topk", TOP_KS)
61
+ @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
62
+ @pytest.mark.parametrize("group_size", [64, 128])
63
+ @pytest.mark.parametrize("has_zp", [True, False])
64
+ @pytest.mark.parametrize("weight_bits", [4, 8])
65
+ def test_fused_moe_wn16(
66
+ m: int,
67
+ n: int,
68
+ k: int,
69
+ e: int,
70
+ topk: int,
71
+ dtype: torch.dtype,
72
+ group_size: int,
73
+ has_zp: bool,
74
+ weight_bits: int,
75
+ ):
76
+ print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
77
+ a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
78
+ w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
79
+ w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
80
+ score = torch.randn((m, e), device="cuda", dtype=dtype)
81
+
82
+ if weight_bits == 4:
83
+ pack_factor = 2
84
+ quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
85
+ elif weight_bits == 8:
86
+ pack_factor = 1
87
+ quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
88
+
89
+ w1_ref = w1.clone()
90
+ w2_ref = w2.clone()
91
+ w1_qweight = torch.empty(
92
+ (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
93
+ )
94
+ w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
95
+ w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
96
+ w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
97
+ w1_qzeros = torch.empty(
98
+ (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
99
+ )
100
+ w2_qzeros = torch.empty(
101
+ (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
102
+ )
103
+
104
+ for i in range(e * 2):
105
+ expert_id = i % e
106
+ if i // e == 0:
107
+ w, w_ref, w_qweight, w_scales, w_qzeros = (
108
+ w1,
109
+ w1_ref,
110
+ w1_qweight,
111
+ w1_scales,
112
+ w1_qzeros,
113
+ )
114
+ else:
115
+ w, w_ref, w_qweight, w_scales, w_qzeros = (
116
+ w2,
117
+ w2_ref,
118
+ w2_qweight,
119
+ w2_scales,
120
+ w2_qzeros,
121
+ )
122
+ weight, qweight, scales, qzeros = quantize_weights(
123
+ w[expert_id].T, quant_type, group_size, has_zp, False
124
+ )
125
+ weight = weight.T
126
+ qweight = qweight.T.contiguous().to(torch.uint8)
127
+ scales = scales.T
128
+ if has_zp:
129
+ qzeros = qzeros.T.contiguous().to(torch.uint8)
130
+ if weight_bits == 4:
131
+ qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
132
+ if has_zp:
133
+ qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
134
+
135
+ w_ref[expert_id] = weight
136
+ w_qweight[expert_id] = qweight
137
+ w_scales[expert_id] = scales
138
+ if has_zp:
139
+ w_qzeros[expert_id] = qzeros
140
+
141
+ triton_output = fused_moe(
142
+ a,
143
+ w1_qweight,
144
+ w2_qweight,
145
+ score,
146
+ topk,
147
+ renormalize=False,
148
+ use_int4_w4a16=weight_bits == 4,
149
+ use_int8_w8a16=weight_bits == 8,
150
+ w1_scale=w1_scales,
151
+ w2_scale=w2_scales,
152
+ w1_zp=w1_qzeros if has_zp else None,
153
+ w2_zp=w2_qzeros if has_zp else None,
154
+ block_shape=[0, group_size],
155
+ )
156
+ torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
157
+ torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
158
+
159
+
160
  @pytest.mark.parametrize("m", [1, 33, 64, 222])
161
  @pytest.mark.parametrize("n", [128, 2048])
162
  @pytest.mark.parametrize("k", [128, 1024])
 
166
  @pytest.mark.parametrize("act_order", [True, False])
167
  @pytest.mark.parametrize("num_bits", [4, 8])
168
  @pytest.mark.parametrize("is_k_full", [True, False])
169
+ @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
170
  def test_fused_marlin_moe(
171
  m: int,
172
  n: int,
test/kernels/utils.py CHANGED
@@ -8,6 +8,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
8
 
9
  import pytest
10
  import torch
 
 
11
  from torch._prims_common import TensorLikeType
12
 
13
  # For now, disable "test_aot_dispatch_dynamic" since there are some
@@ -26,6 +28,35 @@ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
26
  )
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  # Copied/modified from torch._refs.__init__.py
30
  def fp8_allclose(
31
  a: TensorLikeType,
@@ -50,7 +81,8 @@ def fp8_allclose(
50
 
51
  def compute_max_diff(output, output_ref):
52
  return torch.mean(torch.abs(output - output_ref)) / torch.mean(
53
- torch.abs(output_ref))
 
54
 
55
 
56
  # A special version of op check that has a restricted default set of test_utils
 
8
 
9
  import pytest
10
  import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
  from torch._prims_common import TensorLikeType
14
 
15
  # For now, disable "test_aot_dispatch_dynamic" since there are some
 
28
  )
29
 
30
 
31
+ class SiluAndMul(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ """PyTorch-native implementation equivalent to forward()."""
37
+ d = x.shape[-1] // 2
38
+ return F.silu(x[..., :d]) * x[..., d:]
39
+
40
+
41
+ def torch_moe(a, w1, w2, score, topk):
42
+ B, D = a.shape
43
+ a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
44
+ out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
45
+ score = torch.softmax(score, dim=-1, dtype=torch.float32)
46
+ topk_weight, topk_ids = torch.topk(score, topk)
47
+ topk_weight = topk_weight.view(-1)
48
+ topk_ids = topk_ids.view(-1)
49
+ for i in range(w1.shape[0]):
50
+ mask = topk_ids == i
51
+ if mask.sum():
52
+ out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(
53
+ 0, 1
54
+ )
55
+ return (
56
+ out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
57
+ ).sum(dim=1)
58
+
59
+
60
  # Copied/modified from torch._refs.__init__.py
61
  def fp8_allclose(
62
  a: TensorLikeType,
 
81
 
82
  def compute_max_diff(output, output_ref):
83
  return torch.mean(torch.abs(output - output_ref)) / torch.mean(
84
+ torch.abs(output_ref)
85
+ )
86
 
87
 
88
  # A special version of op check that has a restricted default set of test_utils