marcsun13 HF Staff commited on
Commit
e6bb45f
·
verified ·
1 Parent(s): de70d68

Upload folder using huggingface_hub

Browse files
Files changed (41) hide show
  1. build/torch-universal/triton_kernels/__init__.py +0 -0
  2. build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc +0 -0
  3. build/torch-universal/triton_kernels/_ops.py +8 -0
  4. build/torch-universal/triton_kernels/compaction.py +69 -0
  5. build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py +20 -0
  6. build/torch-universal/triton_kernels/matmul_ogs.py +662 -0
  7. build/torch-universal/triton_kernels/matmul_ogs_details/_common.py +165 -0
  8. build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py +377 -0
  9. build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py +464 -0
  10. build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +505 -0
  11. build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py +298 -0
  12. build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +33 -0
  13. build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +111 -0
  14. build/torch-universal/triton_kernels/numerics.py +42 -0
  15. build/torch-universal/triton_kernels/numerics_details/__init__.py +0 -0
  16. build/torch-universal/triton_kernels/numerics_details/flexpoint.py +195 -0
  17. build/torch-universal/triton_kernels/numerics_details/mxfp.py +303 -0
  18. build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py +158 -0
  19. build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py +122 -0
  20. build/torch-universal/triton_kernels/proton_opts.py +17 -0
  21. build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py +111 -0
  22. build/torch-universal/triton_kernels/routing.py +386 -0
  23. build/torch-universal/triton_kernels/routing_details/_expt_data.py +64 -0
  24. build/torch-universal/triton_kernels/routing_details/_routing_compute.py +148 -0
  25. build/torch-universal/triton_kernels/specialize.py +132 -0
  26. build/torch-universal/triton_kernels/swiglu.py +100 -0
  27. build/torch-universal/triton_kernels/swiglu_details/_swiglu.py +102 -0
  28. build/torch-universal/triton_kernels/target_info.py +77 -0
  29. build/torch-universal/triton_kernels/tensor.py +211 -0
  30. build/torch-universal/triton_kernels/tensor_details/layout.py +32 -0
  31. build/torch-universal/triton_kernels/tensor_details/layout_details/base.py +19 -0
  32. build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py +58 -0
  33. build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py +80 -0
  34. build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py +323 -0
  35. build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py +17 -0
  36. build/torch-universal/triton_kernels/testing.py +192 -0
  37. build/torch-universal/triton_kernels/topk.py +92 -0
  38. build/torch-universal/triton_kernels/topk_details/__init__.py +0 -0
  39. build/torch-universal/triton_kernels/topk_details/_topk_backward.py +51 -0
  40. build/torch-universal/triton_kernels/topk_details/_topk_forward.py +146 -0
  41. flake.lock +168 -0
build/torch-universal/triton_kernels/__init__.py ADDED
File without changes
build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (220 Bytes). View file
 
build/torch-universal/triton_kernels/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._triton_kernels_de70d68_dirty
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_triton_kernels_de70d68_dirty::{op_name}"
build/torch-universal/triton_kernels/compaction.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .compaction_details._masked_compaction import _masked_compaction
3
+ from .tensor import Bitmatrix
4
+
5
+
6
+ def compaction(yv, yi, bitmask, sentinel=-1):
7
+ """
8
+ Return compacted copies of *yv* and *yi* based on a per-row bitmask.
9
+
10
+ Only the elements whose index appears among the active bits of *bitmask*
11
+ are kept; the rest are replaced by *sentinel*. Kept elements preserve
12
+ their original left-to-right order.
13
+
14
+ Parameters
15
+ ----------
16
+ yv : torch.Tensor, shape (B, K)
17
+ Values tensor.
18
+ yi : torch.Tensor, shape (B, K), dtype torch.long
19
+ Integer indices (0 ≤ index < 32) associated with *yv*.
20
+ bitmask : torch.Tensor, shape (B,) **or** (B, 32)
21
+ Per-row mask of active indices. See the in-place version for details.
22
+ sentinel : int, default -1
23
+ Value written into dropped positions of the returned tensors.
24
+
25
+ Returns
26
+ -------
27
+ (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
28
+ New tensors with the same dtype/device as the inputs.
29
+
30
+ """
31
+
32
+ n_rows, n_cols = yi.shape
33
+ ret_yv = torch.empty_like(yv)
34
+ ret_yi = torch.empty_like(yi)
35
+ if isinstance(bitmask, Bitmatrix):
36
+ bitmask = bitmask.storage.data
37
+
38
+ _masked_compaction[(n_rows, )](
39
+ yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs
40
+ ret_yv, ret_yi, # outputs
41
+ sentinel, # sentinel
42
+ K=n_cols # constants
43
+ )
44
+ return ret_yv, ret_yi
45
+
46
+
47
+ def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
48
+ """
49
+ reference implementation of `masked_compact`
50
+ """
51
+ B, K = yi.shape
52
+ device = yi.device
53
+ # Expand bitmask to a boolean matrix of active bits (B, 32)
54
+ w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype))
55
+ bits = (bitmask.unsqueeze(-1) & w) != 0
56
+ mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
57
+ # For every yi element decide whether it should be kept
58
+ keep = mask.gather(1, yi.long())
59
+ # Build a stable permutation that brings all "keep" items forward
60
+ # False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
61
+ order = (~keep).to(torch.int).argsort(dim=1, stable=True)
62
+ # Re‑order tensors according to above permutation
63
+ yi_sorted = yi.gather(1, order)
64
+ yv_sorted = yv.gather(1, order)
65
+ # fill relevant positions with sentinel
66
+ keep_sorted = keep.gather(1, order)
67
+ yi_sorted[~keep_sorted] = sentinel
68
+ yv_sorted[~keep_sorted] = sentinel
69
+ return yv_sorted, yi_sorted
build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr):
7
+ pid_m = tl.program_id(0)
8
+ yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
9
+ yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
10
+ div = yi // 32
11
+ rem = yi % 32
12
+ active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
13
+ exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
14
+ active_flags = active_bits.to(tl.int1)
15
+ rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
16
+ write_indx = exc_cumsum + rev_arange
17
+ yv = tl.where(active_flags, yv, sentinel)
18
+ yi = tl.where(active_flags, yi, sentinel)
19
+ tl.store(RetYv + pid_m * K + write_indx, yv)
20
+ tl.store(RetYi + pid_m * K + write_indx, yi)
build/torch-universal/triton_kernels/matmul_ogs.py ADDED
@@ -0,0 +1,662 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ # fmt: off
3
+ from dataclasses import dataclass
4
+ import itertools
5
+ import sys
6
+ import torch
7
+ import triton
8
+ from enum import Enum, auto
9
+ # utilities
10
+ from triton_kernels import target_info
11
+ from triton_kernels.numerics import InFlexData, OutFlexData
12
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
13
+ from triton_kernels.target_info import is_cuda
14
+ # details
15
+ from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
16
+ from .matmul_ogs_details._matmul_ogs import _matmul_ogs
17
+ from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
18
+ from .matmul_ogs_details._finalize_matmul import _finalize_matmul
19
+ from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints
20
+ from .numerics_details.mxfp import MXFP_BLOCK_SIZE
21
+ from .specialize import specialize
22
+ from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class FnSpecs:
27
+ name: str
28
+ fn: "triton.runtime.jit.JITFunction"
29
+ fn_arg_names: tuple[str]
30
+ fn_arg_do_not_specialize: tuple[str] = tuple()
31
+
32
+ @staticmethod
33
+ def default():
34
+ return FnSpecs("dflt", None, tuple())
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class FusedActivation:
39
+ specs: FnSpecs = FnSpecs.default()
40
+ fn_args: tuple[object] = tuple()
41
+ reduction_n: int = 1
42
+
43
+
44
+ @dataclass(frozen=True)
45
+ class Epilogue:
46
+ specs: FnSpecs = FnSpecs.default()
47
+ fn_arg_values_matmul: tuple[object] = tuple()
48
+ fn_arg_values_finalize: tuple[object] = tuple()
49
+ effective_itemsize: float = None
50
+
51
+ class FnName(Enum):
52
+ DEQUANTIZE_MXFP8 = auto()
53
+
54
+
55
+ EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
56
+
57
+ _kernels = dict()
58
+
59
+
60
+ def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
61
+ global _kernels
62
+ key = (fused_activation.name, epilogue.name)
63
+ if key in _kernels:
64
+ return _kernels[key]
65
+ spec_constants = {
66
+ "ACTIVATION_FN": fused_activation.fn,
67
+ "EPILOGUE_FN": epilogue.fn,
68
+ }
69
+ spec_tuples = {
70
+ "activation_fn_args": fused_activation.fn_arg_names,
71
+ "epilogue_fn_args": epilogue.fn_arg_names,
72
+ }
73
+ do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
74
+ import types
75
+
76
+ module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
77
+ sys.modules[module.__name__] = module
78
+ module._finalize_matmul = specialize(_finalize_matmul, module, spec_constants, spec_tuples,
79
+ do_not_specialize=do_not_specialize)
80
+ module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
81
+ do_not_specialize=do_not_specialize)
82
+ module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
83
+ do_not_specialize=do_not_specialize)
84
+ _kernels[key] = module
85
+ return module
86
+
87
+
88
+ # -----------------------------------------------------------------------------
89
+ # Matrix Multiplication + Outer Gather/Scatter
90
+ # -----------------------------------------------------------------------------
91
+
92
+
93
+ def can_overflow_int32(tensor: torch.Tensor):
94
+ max_int32 = (1 << 31) - 1
95
+ offset = 0
96
+ for i in range(tensor.ndim):
97
+ offset += (tensor.shape[i] - 1) * tensor.stride(i)
98
+ return offset > max_int32
99
+
100
+
101
+ def should_upcast_indices(*args):
102
+ return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
103
+
104
+
105
+ # ---------------------
106
+ # Numerics
107
+ # ---------------------
108
+
109
+ # fmt: off
110
+
111
+ @dataclass(frozen=True)
112
+ class FlexCtx:
113
+ lhs_data: InFlexData = InFlexData()
114
+ rhs_data: InFlexData = InFlexData()
115
+ out_data: OutFlexData = OutFlexData()
116
+
117
+ @dataclass
118
+ class PrecisionConfig:
119
+ max_num_imprecise_acc: int = None
120
+ allow_tf32: bool = True
121
+ flex_ctx: FlexCtx = FlexCtx()
122
+ acc_scale: int = 1.0
123
+ flexpoint_saturate_inf: bool = False
124
+ report_quantization_err_fn: callable = None
125
+ act_scale: Tensor | None = None
126
+ weight_scale: Tensor| None = None
127
+ out_scale: Tensor | None = None
128
+ out_dtype: torch.dtype = None
129
+ enforce_bitwise_invariance: bool = False
130
+
131
+ # ---------------------
132
+ # Preprocessing
133
+ # ---------------------
134
+
135
+ @dataclass(frozen=True)
136
+ class PreprocessingFeatures:
137
+ swap_xw: bool
138
+
139
+
140
+ def init_preprocessing_features(w, precision_config, opt_flags):
141
+ swap_xw = False # Whether or not to swap X and W operands to the tl.dot
142
+ if target_info.cuda_capability_geq(10, 0):
143
+ swap_xw = precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
144
+ return PreprocessingFeatures(swap_xw)
145
+
146
+ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features):
147
+ has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1
148
+ if has_fused_scatter_scratchpad:
149
+ M = scatter_indx.src_indx.shape[0]
150
+ writeback_idxs = torch.zeros((M,), dtype=torch.int32, device=x.device)
151
+ writeback_size = writeback_idxs.shape[0]
152
+ finalize_scatter_idxs = torch.zeros((M // routing_data.n_expts_act + M + 1,), dtype=torch.int32, device=x.device)
153
+ BLOCK_M=256
154
+ _compute_writeback_idx[(triton.cdiv(M, BLOCK_M),)](
155
+ writeback_idxs,
156
+ finalize_scatter_idxs,
157
+ scatter_indx.dst_indx,
158
+ scatter_indx.src_indx,
159
+ M // routing_data.n_expts_act,
160
+ M,
161
+ BLOCK_M=BLOCK_M,
162
+ N_EXPTS_ACT=routing_data.n_expts_act,
163
+ )
164
+ elif scatter_indx is not None and routing_data.n_expts_act == 1:
165
+ writeback_idxs = scatter_indx.dst_indx
166
+ writeback_size = scatter_indx.dst_indx.shape[0]
167
+ finalize_scatter_idxs = None
168
+ else:
169
+ writeback_idxs, writeback_size, finalize_scatter_idxs = None, None, None
170
+ # preprocess routing information and ptr lookup table
171
+ M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0]
172
+ return x, w, writeback_idxs, writeback_size, finalize_scatter_idxs
173
+
174
+
175
+ # ---------------------
176
+ # Postprocessing
177
+ # ---------------------
178
+
179
+
180
+ @dataclass(frozen=True)
181
+ class PostprocessingFeatures:
182
+ finalize: bool
183
+
184
+ def init_postprocessing_features(routing_data, scatter_indx, opt_flags):
185
+ finalize = (scatter_indx is not None and routing_data.n_expts_act > 1) or opt_flags.split_k > 1
186
+ return PostprocessingFeatures(finalize)
187
+
188
+ def apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_offs, num_indx, precision_config, routing_data,
189
+ postprocess_features, memory, fused_activation, epilogue):
190
+ out = memory["output"]
191
+ flex_ctx = precision_config.flex_ctx
192
+ if postprocess_features.finalize:
193
+ has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1
194
+ if has_fused_scatter_scratchpad:
195
+ inp = memory["output"]
196
+ else:
197
+ inp = memory["scratchpad"]["matmul"]
198
+ if scatter_indx is not None:
199
+ assert inp.shape[1] == 1, "batched finalize scatter not supported"
200
+ n_final_rows = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act
201
+ scatter_src_indx = scatter_indx.src_indx
202
+ EXPT_PER_TOK = routing_data.n_expts_act
203
+ num_rows = None
204
+ else:
205
+ n_final_rows = inp.shape[1] * inp.shape[2]
206
+ scatter_src_indx = None
207
+ EXPT_PER_TOK = 1
208
+ num_rows = num_indx or (None if expt_offs is None else expt_offs[-1])
209
+
210
+ if inp.dtype == torch.float32:
211
+ inp_flex = OutFlexData()
212
+ else:
213
+ inp_flex = precision_config.flex_ctx.out_data
214
+
215
+ out_scatter = memory["output"]
216
+ out_scatter_flex = precision_config.flex_ctx.out_data
217
+
218
+ N = inp.shape[3]
219
+ M = n_final_rows
220
+ warps_per_sm = 32 if target_info.is_hip() else 128
221
+
222
+ def compute_grid(BLOCK_N, num_warps):
223
+ num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
224
+ if M < num_pid or target_info.is_hip():
225
+ grid_n = triton.cdiv(N, BLOCK_N)
226
+ grid_m = min(M, max(1, triton.cdiv(num_pid, grid_n)))
227
+ else:
228
+ grid_m = min(M, num_pid)
229
+ grid_n = 1
230
+ return (grid_m, grid_n)
231
+
232
+ if inp.dtype.itemsize == 1:
233
+ candidates = [(1024, 1)]
234
+ else:
235
+ if target_info.is_hip():
236
+ candidates = [(4096 // inp.dtype.itemsize, 2)]
237
+ else:
238
+ if inp.dtype.itemsize == 2:
239
+ candidates = [
240
+ (4096 // inp.dtype.itemsize, 4),
241
+ (1024 // inp.dtype.itemsize, 1),
242
+ ]
243
+ else:
244
+ candidates = [
245
+ (2048 // inp.dtype.itemsize, 4),
246
+ (1024 // inp.dtype.itemsize, 1),
247
+ ]
248
+ if precision_config.enforce_bitwise_invariance:
249
+ candidates = [candidates[0]]
250
+
251
+ # sort by smallest grid_n so we share compute across a row
252
+ grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
253
+ STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5)
254
+
255
+ out_scale = precision_config.out_scale
256
+ out_has_mx = out_scale is not None
257
+ out_scale_strides = (None, None) if out_scale is None else out_scale.stride()[-2:]
258
+ mx_a_scale = memory["scratchpad"].get("mx_out_scale", None)
259
+ if mx_a_scale is not None:
260
+ mx_a_scale_stride_k, mx_a_scale_stride_m = [mx_a_scale.stride(i) for i in (0, 2)]
261
+ else:
262
+ mx_a_scale_stride_k, mx_a_scale_stride_m = None, None
263
+
264
+ kernels = get_kernels(epilogue.specs, fused_activation.specs)
265
+ kernels._finalize_matmul[grid](
266
+ flex_ctx.out_data.reinterpret(out_scatter),
267
+ *((None, out_scale, None) if out_has_mx else out_scatter_flex),
268
+ *out_scale_strides,
269
+ flex_ctx.out_data.reinterpret(inp), inp.stride(0), inp.stride(2),
270
+ inp_flex.expected_scale if mx_a_scale is None else mx_a_scale,
271
+ mx_a_scale_stride_k, mx_a_scale_stride_m,
272
+ scatter_src_indx, finalize_scatter_idxs,
273
+ inp.shape[0], M, N, num_rows,
274
+ *fused_activation.fn_args, fused_activation.reduction_n,
275
+ *epilogue.fn_arg_values_finalize,
276
+ EXPT_PER_TOK=EXPT_PER_TOK,
277
+ BLOCK_N=BLOCK_N,
278
+ STAGES=STAGES,
279
+ num_warps=num_warps,
280
+ flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
281
+ HAS_FUSED_SCRATCHPAD=has_fused_scatter_scratchpad,
282
+ )
283
+ out = out_scatter
284
+ # trim unnecessary part of output
285
+ if has_fused_scatter_scratchpad:
286
+ # Discard scratchpad part.
287
+ # This still gives a contiguous tensor, because shape[0] > 1 only when
288
+ # batch mode is enabled, in which case this is a no-op (there's no scratchpad).
289
+ out = out[:, :, :n_final_rows, :]
290
+ return out
291
+
292
+
293
+ # ---------------------
294
+ # Allocation
295
+ # ---------------------
296
+
297
+ @dataclass
298
+ class MatmulAllocation:
299
+ device: str
300
+ output: tuple[tuple[int], torch.dtype]
301
+ scratchpads: dict[str, tuple]
302
+
303
+ def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags,
304
+ preprocessing_features, postprocessing_features):
305
+ # ---- output ------
306
+ N = w.shape[-1]
307
+ # by default - M is number of rows in the activations
308
+ M = x.shape[-2]
309
+ # if the activations are gathered, then M is number of gather indices
310
+ if gather_indx is not None:
311
+ M = gather_indx.src_indx.shape[0]
312
+ # final output
313
+ if routing_data.n_expts_act == 1 or scatter_indx is None:
314
+ y_rows = M
315
+ elif opt_flags.fused_scatter:
316
+ # we need the scratchpad and the output to be contiguous in memory
317
+ Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
318
+ y_rows = M + Mc
319
+ else:
320
+ Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
321
+ y_rows = Mc
322
+ batch_dim = x.shape[0] if x.ndim == 3 else 1
323
+ y_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
324
+ out_dtype = precision_config.out_dtype or x.dtype
325
+ output = (y_shape, out_dtype)
326
+ # ---- scratchpad -----#
327
+ scratchpad = dict()
328
+ # if we need either standalone scatter or split-k, the matmul output will need post-processing
329
+ if postprocessing_features.finalize:
330
+ if opt_flags.split_k > 1 or not opt_flags.fused_scatter:
331
+ dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
332
+ scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), dtype)
333
+ if precision_config.out_scale is not None and not (scratchpad.get("matmul", None) is not None and scratchpad["matmul"][1].itemsize > 1):
334
+ scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
335
+ return MatmulAllocation(x.device, output, scratchpad)
336
+
337
+ def apply_allocation(allocation: MatmulAllocation, output):
338
+ ret = dict()
339
+ if output is None:
340
+ output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
341
+ else:
342
+ assert output.shape == allocation.output[0]
343
+ ret["output"] = output[None, :, :]
344
+ ret["scratchpad"] = {
345
+ k: torch.empty(v[0], device=allocation.device, dtype=v[1])
346
+ for k, v in allocation.scratchpads.items()
347
+ }
348
+ return ret
349
+
350
+ # -----------------------------------------------------------------------------
351
+ # Canonicalize
352
+ # -----------------------------------------------------------------------------
353
+ # the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
354
+ # we can canonicalize storages to make the implementation more uniform
355
+
356
+ def _canonicalize_storage(storage, out_ndim, flex_data):
357
+ assert out_ndim >= storage.data.ndim
358
+ new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
359
+ new_storage_data = storage.data.view(new_storage_shape)
360
+ if flex_data is not None:
361
+ new_storage_data = flex_data.reinterpret(new_storage_data)
362
+ return Storage(new_storage_data, storage.layout)
363
+
364
+
365
+ # -----------------------------------------------------------------------------
366
+ # Triton Implementation
367
+ # -----------------------------------------------------------------------------
368
+
369
+ def matmul_ogs_set_idle_sms(num_idle_sms):
370
+ """
371
+ persistent kernels will leave `num_idle_sms` idle
372
+ """
373
+ update_opt_flags_constraints({"idle_sms": num_idle_sms})
374
+
375
+ def matmul_ogs(x, w, bias,
376
+ routing_data: RoutingData | None = None,
377
+ gather_indx: GatherIndx | None = None,
378
+ scatter_indx: ScatterIndx | None = None,
379
+ precision_config: PrecisionConfig | None = None,
380
+ betas: torch.Tensor | None = None,
381
+ gammas: torch.Tensor | None = None,
382
+ out_alpha: float | None = None,
383
+ y: torch.Tensor | None = None,
384
+ fused_activation: FusedActivation | None = None,
385
+ epilogue: Epilogue | None = None,
386
+ ):
387
+ """
388
+ Y[:, :] = 0.
389
+ for e in num_experts:
390
+ Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
391
+ """
392
+ is_input_batched = x.ndim == 3
393
+ if is_input_batched:
394
+ assert gather_indx is None, "gather not supported in batched mode"
395
+ assert scatter_indx is None, "scatter not supported in batched mode"
396
+ assert routing_data is None, "routing not supported in batched mode"
397
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
398
+ # canonicalize inputs
399
+ if precision_config is None:
400
+ precision_config = PrecisionConfig()
401
+ if fused_activation is None:
402
+ fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
403
+ if epilogue is None:
404
+ epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
405
+ if routing_data is None:
406
+ routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
407
+ # unpack scales
408
+ w_scale = precision_config.weight_scale
409
+ w_has_mx = w_scale is not None
410
+ is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
411
+ if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
412
+ if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
413
+ if not isinstance(w, Tensor):
414
+ # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
415
+ dtype = FP4 if w.dtype == torch.uint8 else w.dtype
416
+ w = wrap_torch_tensor(w, dtype=dtype)
417
+ if w_scale is not None and not isinstance(w_scale, Tensor):
418
+ w_scale = Tensor(w_scale)
419
+ if w_scale is not None:
420
+ w_scale.storage.data = w_scale.data.view(torch.uint8)
421
+ w_scale.dtype = torch.uint8
422
+ x_scale = precision_config.act_scale
423
+ x_has_mx = x_scale is not None
424
+ if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
425
+ if x_scale is not None and not isinstance(x_scale, Tensor):
426
+ x_scale = Tensor(x_scale)
427
+ if not isinstance(x, Tensor):
428
+ x = Tensor(x, dtype=x.dtype)
429
+ # determine shapes
430
+ M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
431
+ batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
432
+ K, N = w.shape[-2:]
433
+ assert K == x.shape[-1]
434
+ if x.ndim == 3 and w.ndim == 3:
435
+ assert x.shape[0] == w.shape[0]
436
+ # compute optimization flags
437
+ out_dtype = precision_config.out_dtype or x.dtype
438
+ can_use_tma = x.storage.is_tma_compliant() and \
439
+ w.storage.is_tma_compliant() and \
440
+ (w_scale is None or w_scale.storage.is_tma_compliant())
441
+ # hopper w/ mxfp4 doesn't support TMA
442
+ can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
443
+ can_use_fused_scatter = scatter_indx is not None and fused_activation.specs.fn is None
444
+ opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
445
+ M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
446
+ )
447
+ if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
448
+ raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
449
+ if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
450
+ raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
451
+ # determine necessary pre/post processing
452
+ preprocessing_features = init_preprocessing_features(w, precision_config, opt_flags)
453
+ postprocessing_features = init_postprocessing_features(routing_data, scatter_indx, opt_flags)
454
+ # allocate output/scratchpad memory
455
+ allocation = init_allocation(x, w, precision_config, fused_activation,
456
+ routing_data, gather_indx, scatter_indx,
457
+ opt_flags, preprocessing_features, postprocessing_features
458
+ )
459
+ memory = apply_allocation(allocation, y)
460
+ # TMA descriptors require a global memory allocation
461
+ if opt_flags.is_persistent:
462
+ triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
463
+ # Intermediate tensors and postprocess kernels for each situation
464
+ out0, out0_flex = memory["output"], precision_config.flex_ctx.out_data
465
+ fused_postprocess_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
466
+ out_scale = None if precision_config.out_scale is None else precision_config.out_scale.data.view(torch.uint8)
467
+ if postprocessing_features.finalize:
468
+ if opt_flags.fused_scatter:
469
+ out0 = memory["output"]
470
+ else:
471
+ out0 = memory["scratchpad"]["matmul"]
472
+ if "mx_out_scale" in memory["scratchpad"]:
473
+ assert out_scale is not None
474
+ out_scale = memory["scratchpad"]["mx_out_scale"]
475
+ out0_flex = OutFlexData() if out0.dtype == torch.float32 else precision_config.flex_ctx.out_data
476
+
477
+ fused_activation, fused_postprocess_activation = fused_postprocess_activation, fused_activation
478
+ out_has_mx = out_scale is not None and out0.element_size() == 1
479
+ if out_has_mx:
480
+ if isinstance(out_scale, Tensor):
481
+ out_scale = Tensor(out_scale)
482
+ else:
483
+ out_scale = None
484
+ # pre-processing
485
+ x, w, writeback_idxs, writeback_size, finalize_scatter_idxs = apply_preprocessing_features(
486
+ x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features
487
+ )
488
+ # matrix multiplication
489
+ flex = precision_config.flex_ctx
490
+ bias_stride = None if bias is None else bias.stride(0)
491
+ num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
492
+ # moe metadata
493
+ expt_data = routing_data.expt_data
494
+ block_m = opt_flags.block_m
495
+ expt_hist = None if expt_data is None else expt_data.hist
496
+ expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
497
+ expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
498
+ expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
499
+ # spmd grid
500
+ grid_m = triton.cdiv(M, opt_flags.block_m)
501
+ if expt_block_pid_map is not None:
502
+ grid_m = routing_data.n_blocks(M, opt_flags.block_m)
503
+ grid_n = triton.cdiv(N, opt_flags.block_n)
504
+ max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
505
+ grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
506
+ # canonicalize storage
507
+ has_gather = gather_indx is not None
508
+ x_storage = _canonicalize_storage(x.storage, 2 if has_gather else 3, flex.lhs_data)
509
+ w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
510
+ # create tma descriptor for x
511
+ x_has_tma = ((not has_gather) or (has_gather and target_info.has_tma_gather())) and opt_flags.is_persistent
512
+ x_block_tma = ([1] if has_gather else [1, opt_flags.block_m]) + [opt_flags.block_k]
513
+ x_tensor_or_tma = x_storage.make_tma(x_block_tma) if x_has_tma else x_storage.data
514
+ # create tma descriptor for w
515
+ w_has_tma = opt_flags.is_persistent
516
+ w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n]) if w_has_tma else w_storage.data
517
+ # create tma descriptor for w_scale
518
+ w_scale_tensor_or_tma = w_scale
519
+ w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
520
+ w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k]) if w_scale_has_tma else w_scale
521
+ # canonicalize strides
522
+ x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
523
+ x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
524
+ x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
525
+ w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
526
+ w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
527
+ out_scale_strides = out_scale.stride() if out_has_mx else (None, None, None, None)
528
+ out_scale_strides = (0, ) * (3 - len(out_scale_strides)) + out_scale_strides
529
+ # launch kernel
530
+ kernels = get_kernels(epilogue.specs, fused_activation.specs)
531
+ (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
532
+ flex.out_data.reinterpret(memory["output"]),
533
+ flex.out_data.reinterpret(out0), *out0.stride(),
534
+ *((None, out_scale, None) if out_has_mx else out0_flex),
535
+ *out_scale_strides[-3:],
536
+ x_tensor_or_tma, x_storage.data, *x_strides,
537
+ flex.lhs_data.scale,
538
+ None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
539
+ w_tensor_or_tma, *w_storage.data.stride(), w_storage.data.stride()[-1] != 1,
540
+ flex.rhs_data.scale,
541
+ w_scale_tensor_or_tma, *w_scale_strides,
542
+ bias, bias_stride,
543
+ x.shape[-2],
544
+ x.shape[-2] if routing_data.expt_hist is None else None,
545
+ N, K,
546
+ betas, gammas,
547
+ None if gather_indx is None else gather_indx.src_indx,
548
+ None if scatter_indx is None else scatter_indx.src_indx,
549
+ num_indx,
550
+ writeback_idxs, writeback_size,
551
+ expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
552
+ batch_size, grid_m, grid_n,
553
+ out_alpha,
554
+ *fused_activation.fn_args, fused_activation.reduction_n,
555
+ *epilogue.fn_arg_values_matmul,
556
+ routing_data.n_expts_tot, routing_data.n_expts_act,
557
+ precision_config.max_num_imprecise_acc,
558
+ precision_config.allow_tf32,
559
+ precision_config.flexpoint_saturate_inf,
560
+ flex.rhs_data.is_per_batch,
561
+ opt_flags.block_m,
562
+ opt_flags.block_n,
563
+ opt_flags.block_k,
564
+ opt_flags.group_m,
565
+ XCD_SWIZZLE=opt_flags.xcd_swizzle,
566
+ SWIZZLE_MX_VALUE=w.storage.layout.name,
567
+ SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
568
+ EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
569
+ SPLIT_K=opt_flags.split_k,
570
+ EVEN_K=K % opt_flags.block_k == 0,
571
+ W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
572
+ TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
573
+ num_warps=opt_flags.num_warps,
574
+ num_stages=opt_flags.num_stages,
575
+ arch=opt_flags.arch,
576
+ UPCAST_INDICES=should_upcast_indices(x, w, out0),
577
+ DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0,
578
+ SWAP_XW=preprocessing_features.swap_xw,
579
+ IS_EPILOGUE_DEQUANT_MXFP8=epilogue.specs.name == FnName.DEQUANTIZE_MXFP8.name,
580
+ NUM_SMS = grid if opt_flags.is_persistent else 0,
581
+ **opt_flags.target_kernel_kwargs)
582
+ # post-processing
583
+ out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
584
+ num_indx, precision_config, routing_data,
585
+ postprocessing_features, memory, fused_postprocess_activation, epilogue)
586
+ # remove split-k
587
+ out = out.squeeze(0)
588
+ if not is_input_batched:
589
+ out = out.view(out.shape[-2], out.shape[-1])
590
+ return out
591
+
592
+
593
+ # -----------------------------------------------------------------------------
594
+ # Reference Implementation
595
+ # -----------------------------------------------------------------------------
596
+
597
+ def matmul_ogs_torch(x, w, bias,
598
+ routing_data: RoutingData = None,
599
+ gather_indx: GatherIndx = None,
600
+ scatter_indx: ScatterIndx = None,
601
+ precision_config: PrecisionConfig = None,
602
+ betas = None,
603
+ gammas = None,
604
+ round_x = None, round_y = None,
605
+ ):
606
+ is_input_batched = x.ndim == 3
607
+ assert x.dtype.itemsize > 1
608
+ assert w.dtype.itemsize > 1
609
+ if is_input_batched:
610
+ assert gather_indx is None, "gather not supported in batched mode"
611
+ assert scatter_indx is None, "scatter not supported in batched mode"
612
+ assert routing_data is None, "routing not supported in batched mode"
613
+ assert w.ndim == 3 and w.shape[0] == x.shape[0]
614
+ if round_x is None:
615
+ round_x = lambda x: x
616
+ if round_y is None:
617
+ round_y = lambda x: x
618
+ if bias.ndim == 1:
619
+ bias = bias.view(1, *bias.shape)
620
+ if w.ndim == 2:
621
+ w = w.view(1, *w.shape)
622
+ if x.ndim == 2:
623
+ x = x.view(1, *x.shape)
624
+ if routing_data is None:
625
+ routing_data = RoutingData(None, None, w.shape[0], 1)
626
+ n_expts_act = routing_data.n_expts_act
627
+ # memory offsets
628
+ if routing_data.n_expts_tot > 1 and not is_input_batched:
629
+ sizes = routing_data.expt_hist
630
+ off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
631
+ off[1:] = torch.cumsum(sizes, 0)
632
+ offs = list(itertools.pairwise(off))
633
+ else:
634
+ offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
635
+ # compute
636
+ n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
637
+ y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
638
+ for i, (lo, hi) in enumerate(offs):
639
+ if gather_indx is None:
640
+ idx = torch.arange(lo, hi, device=x.device)
641
+ else:
642
+ idx = gather_indx.src_indx[lo:hi] // n_expts_act
643
+ batch = i if is_input_batched else 0
644
+ out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
645
+ w[i].float())
646
+ if bias is not None:
647
+ out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
648
+ if gammas is not None:
649
+ out *= gammas[lo:hi, None]
650
+ y[batch, lo:hi, :] = round_y(out)
651
+ if not is_input_batched:
652
+ y = y.view(y.shape[1], y.shape[2])
653
+ if scatter_indx is None:
654
+ return y
655
+ # accumulate output from all experts
656
+ n_rows = y.shape[0] // n_expts_act
657
+ out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
658
+ for i, (lo, hi) in enumerate(offs):
659
+ dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
660
+ msk = dst_idx != -1
661
+ out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
662
+ return out
build/torch-universal/triton_kernels/matmul_ogs_details/_common.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import triton
4
+ import triton.language as tl
5
+ from triton.tools.tensor_descriptor import TensorDescriptor
6
+
7
+ # -----------------------------------------------------------------------------
8
+ # Utilities
9
+ # -----------------------------------------------------------------------------
10
+
11
+
12
+ @tl.constexpr_function
13
+ def get_scaled_dot_format_string(dtype: tl.dtype):
14
+ mapping = {
15
+ tl.float16: "fp16",
16
+ tl.bfloat16: "bf16",
17
+ tl.uint8: "e2m1",
18
+ tl.float8e4nv: "e4m3",
19
+ tl.float8e5: "e5m2",
20
+ }
21
+ return mapping[dtype]
22
+
23
+
24
+ @triton.jit
25
+ def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
26
+ """
27
+ Swizzle the program id based on integer XCD_SWIZZLE.
28
+ This is useful for reording how blocks are ordered. A scheduler may, for example,
29
+ assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
30
+ This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
31
+ becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
32
+ the same hardware unit.
33
+ """
34
+ # Number of pids per group in the new arrangement
35
+ pids_per_group = domain_size // XCD_SWIZZLE
36
+ extra_pid_groups = domain_size % XCD_SWIZZLE
37
+
38
+ # Compute current current and local pid within the group
39
+ group = pid % XCD_SWIZZLE
40
+ local_pid = pid // XCD_SWIZZLE
41
+
42
+ # Calculate new pid based on the new grouping
43
+ new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
44
+ return new_pid
45
+
46
+
47
+ @triton.jit
48
+ def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
49
+ width = GROUP_M * grid_n
50
+ group_id = pid // width
51
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
52
+ pid_m = group_id * GROUP_M + (pid % group_size)
53
+ pid_n = (pid % width) // (group_size)
54
+ return pid_m, pid_n
55
+
56
+
57
+ def make_matmul_repr(base_name, order):
58
+
59
+ def matmul_repr(specialization):
60
+ signature = specialization.signature
61
+ constants = specialization.constants
62
+ reorder = lambda L: [L[i] for i in order]
63
+ layout = lambda stride: "N" if stride in constants else "T"
64
+
65
+ def convert_dtype(dtype):
66
+ if "tensordesc" in dtype:
67
+ ret = convert_dtype(dtype.split("<")[1].split("[")[0])
68
+ return ret
69
+ elif "u8" in dtype:
70
+ return "mxfp4"
71
+ elif dtype[0] == "*":
72
+ return dtype[1:]
73
+ else:
74
+ return dtype
75
+
76
+ dtypes = "x".join([convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])])
77
+ layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])])
78
+ blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]])
79
+ # mode = []
80
+ # if "GatherIndx" not in constants:
81
+ # mode += ['g']
82
+ # if "ScatterSrcIndx" not in constants:
83
+ # mode += ['s']
84
+ # suffix = "" if not mode else "_o" + (''.join(mode))
85
+ # if base_name.startswith("_p"):
86
+ # suffix += "_ptma"
87
+ return f"{base_name}_{layouts}_{dtypes}_{blocks}"
88
+
89
+ return matmul_repr
90
+
91
+
92
+ def matmul_launch_metadata(grid, kernel, args):
93
+ from ..proton_opts import launch_metadata_allow_sync
94
+
95
+ ret = dict()
96
+ M, N, K = args["M"], args["N"], args["K"]
97
+ Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]]
98
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
99
+ hist = args["ExptHist"]
100
+ if hist is not None:
101
+ # If annotation is given, use that to generate name for profiling.
102
+ if tokens_per_expt is not None:
103
+ n_rows = f"{tokens_per_expt}*"
104
+ elif launch_metadata_allow_sync():
105
+ n_rows = int(hist.float().mean())
106
+ else:
107
+ n_rows = "unknown"
108
+
109
+ if launch_metadata_allow_sync():
110
+ n_tokens = float(hist.sum())
111
+ n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
112
+ elif tokens_per_expt is not None:
113
+ n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
114
+ # This may not be totally correct (e.g., we might not be using all experts)
115
+ # but it's better than nothing.
116
+ n_w_bytes = W.numel() * W.element_size()
117
+ else:
118
+ n_tokens = None
119
+ n_w_bytes = 0
120
+
121
+ # If annotation is given, use that to generate name for profiling.
122
+ tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
123
+ n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
124
+ else:
125
+ n_tokens = None
126
+ n_w_bytes = W.numel() * W.element_size()
127
+ repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
128
+ nbits = X.dtype.itemsize * 8
129
+ batch_repr = ""
130
+ if "batch_size" in args and args["batch_size"] > 1:
131
+ batch_repr = repr("B", args["batch_size"]) + ", "
132
+ ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
133
+ ep_subtile = args["EPILOGUE_SUBTILE"]
134
+ if ep_subtile is not None and ep_subtile > 1:
135
+ ret["name"] += f" ep/{ep_subtile}"
136
+
137
+ if hist is not None and n_tokens is None:
138
+ return ret # Don't fill metadata because we can't compute them properly.
139
+
140
+ fM = M if M is not None else n_tokens
141
+ fK = K if K is not None else n_tokens
142
+ ret[f"flops{nbits}"] = 2.0 * fM * N * fK
143
+
144
+ gindx = args.get("GatherIndx", None)
145
+ # sindx = args.get("WriteBackIndx", None)
146
+ n_x_bytes = X.numel() * X.element_size()
147
+ n_y_bytes = Y.numel() * Y.element_size()
148
+ if hist is not None:
149
+ assert n_tokens is not None
150
+ n_expts_act = args["N_EXPTS_ACT"]
151
+
152
+ if (gindx is not None) and launch_metadata_allow_sync():
153
+ # recreate inverse GatherIndx.
154
+ dst = torch.full_like(gindx, -1)
155
+ idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
156
+ mask = (gindx != -1)
157
+ dst[gindx[mask]] = idx[mask]
158
+ n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
159
+ else:
160
+ n_read_rows = n_tokens
161
+ n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
162
+ n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
163
+ ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
164
+
165
+ return ret
build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, update_scale
4
+ from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5
+ from triton_kernels.target_info import cuda_capability_geq as _cuda_capability_geq
6
+ from triton_kernels.target_info import is_hip as _is_hip
7
+
8
+
9
+ # fmt: off
10
+ @tl.constexpr_function
11
+ def is_hip():
12
+ return _is_hip()
13
+
14
+
15
+ @tl.constexpr_function
16
+ def cuda_capability_geq(x, y):
17
+ return _cuda_capability_geq(x, y)
18
+
19
+
20
+ @tl.constexpr_function
21
+ def log2(n):
22
+ return len(bin(n)) - 3
23
+
24
+
25
+ @tl.constexpr_function
26
+ def _permute_to_end_order(n: int, axis: int):
27
+ """
28
+ Returns the order of the axes of a tensor to permute `axis` to the end.
29
+ """
30
+ order = tuple(range(n))
31
+ return order[:axis] + order[(axis + 1):] + (axis, )
32
+
33
+
34
+ @triton.jit
35
+ def permute_to_end(x, axis: tl.constexpr):
36
+ """
37
+ Permutes `x` so that `axis` is the last axis.
38
+ """
39
+ N: tl.constexpr = len(x.shape)
40
+ return tl.permute(x, _permute_to_end_order(N, axis).value)
41
+
42
+
43
+ @triton.jit
44
+ def split_n(x, N: tl.constexpr):
45
+ """
46
+ Given `x`, a tensor of shape AxB...x2x2...x2, split it N times.
47
+ Return a tuple of the results.
48
+ """
49
+ xs = (x, )
50
+ for i in tl.static_range(N):
51
+ next = tl.split(xs[0])
52
+ for j in tl.static_range(2**i - 1):
53
+ next = next + tl.split(xs[j + 1])
54
+ xs = next
55
+ return xs
56
+
57
+
58
+ @triton.jit
59
+ def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr = None, NUM_THREADS: tl.constexpr = None):
60
+ N: tl.constexpr = tl.extra.cuda.num_threads() if NUM_THREADS is None else NUM_THREADS
61
+ BS: tl.constexpr = x.numel if BLOCK_SIZE is None else BLOCK_SIZE
62
+ tl.static_assert(BS % N == 0, "BLOCK_SIZE must be divisible by NUM_THREADS")
63
+ return tl.max(tl.reshape(tl.abs(x), [N, BS // N], can_reorder=True), axis=1)
64
+
65
+
66
+ def _finalize_matmul_launch_metadata(grid, kernel, args):
67
+ ret = dict()
68
+ Out, A, ScatterSrcIndx, FinalizeScatterIdxs, K, M, N, EXPT_PER_TOK, NumRows = [
69
+ args[name]
70
+ for name in ["Out", "A", "ScatterSrcIndx", "FinalizeScatterIdxs", "K", "M", "N", "EXPT_PER_TOK", "NumRows"]
71
+ ]
72
+ ret["name"] = f"{kernel.name} [M={M}x{EXPT_PER_TOK} {N=} {K=}]"
73
+
74
+ if FinalizeScatterIdxs is not None:
75
+ M = FinalizeScatterIdxs[-1].item()
76
+
77
+ if ScatterSrcIndx is not None:
78
+ is_active = (ScatterSrcIndx != -1).view((-1, EXPT_PER_TOK))
79
+ n_active = is_active.sum(dim=1)
80
+ need_accum = n_active >= (1 if K > 1 else 2)
81
+ is_active &= need_accum[:, None]
82
+ active_input_rows = is_active.sum()
83
+ active_output_rows = need_accum.sum()
84
+ if EXPT_PER_TOK > 1:
85
+ # Masked rows are set to zero.
86
+ active_output_rows += (n_active == 0).sum()
87
+ else:
88
+ if NumRows is not None:
89
+ if isinstance(NumRows, int):
90
+ active_input_rows = NumRows
91
+ else:
92
+ active_input_rows = NumRows.item()
93
+ else:
94
+ active_input_rows = M
95
+ active_output_rows = M
96
+
97
+ ret["bytes"] = (active_input_rows * K * A.shape[-1] * A.element_size() +
98
+ active_output_rows * Out.shape[-1] * Out.element_size())
99
+ if FinalizeScatterIdxs is not None:
100
+ ret["bytes"] += FinalizeScatterIdxs.numel() * FinalizeScatterIdxs.element_size()
101
+ elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1:
102
+ ret["bytes"] += ScatterSrcIndx.numel() * ScatterSrcIndx.element_size()
103
+ nbits = Out.dtype.itemsize * 8
104
+ ret[f"flops{nbits}"] = active_input_rows * K * A.shape[-1]
105
+ return ret
106
+
107
+
108
+ @tl.constexpr_function
109
+ def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
110
+ """
111
+ Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
112
+ adds. If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside
113
+ that register.
114
+
115
+ Generates code something like:
116
+
117
+ add.f32.f16 $0, $2, $1;
118
+ add.f32.f16 $0, $3, $0;
119
+ add.f32.f16 $0, $4, $0;
120
+ add.f32.f16 $0, $5, $0;
121
+
122
+ .reg .f32 b;
123
+ abs.f32 b, $0;
124
+ max.f32 my_abs_max, my_abs_max, b;
125
+ """
126
+ # Add the first f16 value to the input $1, store into the output $0.
127
+ ptx = f"\nadd.f32.{src_type} $0, $2, $1;"
128
+ # Accumulate the rest of the inputs into the output $0.
129
+ for i in range(1, n_inputs):
130
+ ptx += f"\nadd.f32.{src_type} $0, ${2 + i}, $0;"
131
+ if absmax_reg_name is not None:
132
+ # Update `absmax_reg_name` with the absolute maximum value seen so far.
133
+ ptx += f"""
134
+ .reg .f32 b;
135
+ abs.f32 b, $0;
136
+ max.f32 {absmax_reg_name}, {absmax_reg_name}, b;
137
+ """
138
+ # Return the PTX snippet, brace-enclosed so we don't pollute the global namespace.
139
+ return f"{{{ptx}}}"
140
+
141
+
142
+ @triton.jit
143
+ def _mixed_precision_accumulate_and_track_absmax(acc, x, axis: tl.constexpr, absmax_reg_name: tl.constexpr = None):
144
+ """Given an fp8/bf16/fp16 tensor, accumulate into `acc` along `axis`.
145
+ Values are first converted to bf16/fp16, packed into 32-bit registers, and then accumulated using
146
+ mixed-precision adds.
147
+
148
+ If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside that
149
+ register.
150
+ """
151
+ REDUCTION_SIZE: tl.constexpr = x.shape[axis]
152
+ tl.static_assert(2**log2(REDUCTION_SIZE) == REDUCTION_SIZE,
153
+ f"Reduction size must be a power of 2, was {REDUCTION_SIZE}")
154
+ # move `axis` to the last axis and reshape for iterative splitting.
155
+ x = permute_to_end(x, axis)
156
+ x = tl.reshape(x, x.shape[:-1] + (2, ) * log2(REDUCTION_SIZE))
157
+ # Split into a tuple of AxB tensors.
158
+ xs = split_n(x, log2(REDUCTION_SIZE))
159
+ if (tl.constexpr(x.dtype == tl.float8e4nv) or tl.constexpr(x.dtype == tl.float8e5)):
160
+ # Convert fp8 to fp16.
161
+ fp16_xs = ()
162
+ for i in tl.static_range(len(xs)):
163
+ fp16_xs += (xs[i].to(tl.float16), )
164
+ xs = fp16_xs
165
+ src_type: tl.constexpr = "f16"
166
+ elif x.dtype == tl.float16:
167
+ src_type: tl.constexpr = "f16"
168
+ elif x.dtype == tl.bfloat16:
169
+ src_type: tl.constexpr = "bf16"
170
+ else:
171
+ tl.static_assert(False, f"Unsupported dtype: {x.dtype}")
172
+ return tl.inline_asm_elementwise(
173
+ _accumulate_f16_into_f32_and_track_absmax_ptx(REDUCTION_SIZE, src_type, absmax_reg_name),
174
+ "=r,r" + (",h" * len(xs)),
175
+ (acc, ) + xs,
176
+ dtype=tl.float32,
177
+ is_pure=True,
178
+ pack=1,
179
+ )
180
+
181
+
182
+ def _finalize_matmul_repr(specialization):
183
+ signature = specialization.signature
184
+ suffix = "" if "ScatterSrcIndx" in specialization.constants else "_scatter"
185
+ return f"_finalize_matmul{suffix}_{signature['A'][1:]}"
186
+
187
+
188
+ @triton.jit(repr=_finalize_matmul_repr, launch_metadata=_finalize_matmul_launch_metadata)
189
+ def _finalize_matmul(
190
+ Out,
191
+ OutExpectedScale,
192
+ OutActualScale,
193
+ OutChecksumScale,
194
+ stride_out_mx_m, stride_out_mx_n,
195
+ A,
196
+ stride_a_k,
197
+ stride_a_m,
198
+ AScale,
199
+ stride_a_mx_k,
200
+ stride_a_mx_m,
201
+ ScatterSrcIndx,
202
+ FinalizeScatterIdxs,
203
+ K: tl.constexpr,
204
+ M,
205
+ N,
206
+ NumRows,
207
+ # fused activation function
208
+ ACTIVATION_FN: tl.constexpr,
209
+ activation_fn_args,
210
+ ACTIVATION_REDUCTION_N: tl.constexpr,
211
+ # epilogue transform
212
+ EPILOGUE_FN: tl.constexpr,
213
+ epilogue_fn_args,
214
+ EXPT_PER_TOK: tl.constexpr,
215
+ flexpoint_saturate_inf: tl.constexpr,
216
+ BLOCK_N: tl.constexpr,
217
+ STAGES: tl.constexpr,
218
+ HAS_FUSED_SCRATCHPAD: tl.constexpr,
219
+ ):
220
+ IN_MXFP8: tl.constexpr = stride_a_mx_k is not None
221
+ OUT_MXFP8: tl.constexpr = stride_out_mx_m is not None
222
+ if HAS_FUSED_SCRATCHPAD:
223
+ # Bump A to the scratchpad region.
224
+ A += tl.cast(M, tl.int64) * stride_a_m
225
+
226
+ USE_FUSED_MIXED_PREC_ACC: tl.constexpr = (cuda_capability_geq(10, 0)
227
+ and tl.constexpr(A.dtype.element_ty != tl.float32))
228
+ USE_FUSED_ABSMAX: tl.constexpr = (USE_FUSED_MIXED_PREC_ACC and OutActualScale is not None) and ACTIVATION_FN is None
229
+
230
+ THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
231
+ local_max = tl.full([THREADS_PER_BLOCK], 0.0, tl.float32)
232
+ if USE_FUSED_ABSMAX:
233
+ local_max = tl.inline_asm_elementwise(
234
+ """
235
+ .reg .f32 my_abs_max;
236
+ mov.b32 my_abs_max, 0;
237
+ mov.b32 $0, 0;
238
+ """, "=r,r", [local_max], dtype=tl.float32, is_pure=False, pack=1)
239
+
240
+ out_scale = load_scale(OutExpectedScale)
241
+ a_scale = load_scale(AScale)
242
+
243
+ if FinalizeScatterIdxs is not None:
244
+ MBound = tl.load(FinalizeScatterIdxs + M + M * EXPT_PER_TOK)
245
+ if tl.program_id(0) >= MBound:
246
+ return
247
+ else:
248
+ MBound = M
249
+
250
+ if NumRows is not None:
251
+ NumRows = NumRows # remove constexpr
252
+ if NumRows.dtype.is_ptr():
253
+ NumRows = tl.load(NumRows)
254
+
255
+ if FinalizeScatterIdxs is not None or (ScatterSrcIndx is not None and EXPT_PER_TOK > 1):
256
+ n_active_experts = 0
257
+ else:
258
+ n_active_experts: tl.constexpr = EXPT_PER_TOK
259
+
260
+ OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
261
+ outN = N // ACTIVATION_REDUCTION_N
262
+
263
+ for pid_m in tl.range(tl.program_id(0), MBound, tl.num_programs(0)):
264
+ src_offs = pid_m * EXPT_PER_TOK + tl.arange(0, EXPT_PER_TOK)
265
+ if FinalizeScatterIdxs is not None:
266
+ row = tl.load(FinalizeScatterIdxs + pid_m)
267
+ src_idxs = tl.load(FinalizeScatterIdxs + M + src_offs)
268
+ n_active_experts = tl.sum((src_idxs != -1).to(tl.int32))
269
+ elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1:
270
+ row = pid_m
271
+ src_idxs = tl.load(ScatterSrcIndx + src_offs)
272
+ n_active_experts = tl.sum((src_idxs != -1).to(tl.int32))
273
+ else:
274
+ row = pid_m
275
+ src_idxs = src_offs
276
+ if NumRows is not None:
277
+ src_idxs = tl.where(src_idxs < NumRows, src_idxs, -1)
278
+
279
+ if n_active_experts == 0:
280
+ for off_n in tl.range(tl.program_id(1) * OUT_BLOCK_N, outN, tl.num_programs(1) * OUT_BLOCK_N):
281
+ offs_n = off_n + tl.arange(0, OUT_BLOCK_N)
282
+ n_mask = offs_n < outN
283
+ tl.store(Out + row * outN + offs_n, tl.zeros([OUT_BLOCK_N], dtype=Out.dtype.element_ty), mask=n_mask)
284
+ else:
285
+ for off_n in tl.range(tl.program_id(1) * BLOCK_N, N, tl.num_programs(1) * BLOCK_N, num_stages=STAGES):
286
+ offs_n = off_n + tl.arange(0, BLOCK_N)
287
+ n_mask = offs_n < N
288
+ if IN_MXFP8:
289
+ MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
290
+ N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
291
+ offs_n_scale = off_n // BLOCK_N * MX_SCALE_BLOCK_N + tl.arange(0, MX_SCALE_BLOCK_N)[None, :]
292
+ n_mask_scale = offs_n_scale < N_MX_BLOCK
293
+
294
+ acc = tl.zeros([BLOCK_N], dtype=tl.float32)
295
+ if is_hip():
296
+ if EXPT_PER_TOK > 1:
297
+ src_idxs_tup = split_n(tl.reshape(src_idxs, (2, ) * log2(EXPT_PER_TOK)), log2(EXPT_PER_TOK))
298
+ else:
299
+ # Convert 1D tensor to 1D tuple.
300
+ src_idxs_tup = tl.split(tl.reshape(tl.join(src_idxs, src_idxs), (2, )))[:1]
301
+ for i in tl.static_range(0, EXPT_PER_TOK, 1):
302
+ src_idx = src_idxs_tup[i]
303
+ if src_idx != -1:
304
+ As = A + src_idx.to(tl.int64) * stride_a_m + offs_n
305
+ for ki in tl.static_range(K):
306
+ acc += tl.load(As, mask=n_mask, other=0.0)
307
+ As += stride_a_k
308
+ else:
309
+ As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :]
310
+ if IN_MXFP8:
311
+ AScales = AScale + src_idxs.to(tl.int64)[:, None] * stride_a_mx_m + offs_n_scale[None, :]
312
+ for ki in tl.static_range(K):
313
+ a = tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0)
314
+ As += stride_a_k
315
+ if IN_MXFP8:
316
+ a_mx_scale = tl.load(AScales, mask=(src_idxs != -1)[:, None] & n_mask_scale[None, :])
317
+ AScales += stride_a_mx_k
318
+ a_mx_scale = (a_mx_scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
319
+ a_mx_scale = a_mx_scale.reshape([EXPT_PER_TOK, MX_SCALE_BLOCK_N, 1])
320
+ a = a.to(tl.float32).reshape([EXPT_PER_TOK, MX_SCALE_BLOCK_N, MXFP_BLOCK_SIZE])
321
+ a = (a_mx_scale * a).reshape([EXPT_PER_TOK, BLOCK_N])
322
+ acc += tl.sum(a, dtype=tl.float32, axis=0)
323
+ elif USE_FUSED_MIXED_PREC_ACC:
324
+ acc = _mixed_precision_accumulate_and_track_absmax(
325
+ acc, a, axis=0,
326
+ absmax_reg_name="my_abs_max" if USE_FUSED_ABSMAX and ki == K - 1 else None)
327
+ else:
328
+ acc += tl.sum(a, dtype=tl.float32, axis=0)
329
+ if not IN_MXFP8:
330
+ acc = acc * a_scale
331
+ if ACTIVATION_FN is not None:
332
+ out = ACTIVATION_FN(tl.reshape(acc, (1, BLOCK_N)), *activation_fn_args)
333
+ out = tl.reshape(out, (OUT_BLOCK_N, ))
334
+ else:
335
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1,
336
+ "Activation reduction must be 1 if no activation fn is provided")
337
+ out = acc
338
+ if not USE_FUSED_ABSMAX and OutActualScale is not None:
339
+ local_max = tl.maximum(local_max, thread_local_absmax(out))
340
+ if OUT_MXFP8:
341
+ OUT_MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE
342
+ OUT_N_MX_BLOCK: tl.constexpr = (outN + MXFP_BLOCK_SIZE - 1) // MXFP_BLOCK_SIZE
343
+ offs_n_scale = off_n // BLOCK_N * OUT_MX_SCALE_BLOCK_N + tl.arange(0, OUT_MX_SCALE_BLOCK_N)[None, :]
344
+ n_mask_scale = offs_n_scale < OUT_N_MX_BLOCK
345
+ acc, acc_scale = EPILOGUE_FN(acc[None, :], n_mask[None, :], *epilogue_fn_args,
346
+ pid=row * tl.num_programs(1) + tl.program_id(1))
347
+ tl.static_assert(OUT_BLOCK_N % OUT_MX_SCALE_BLOCK_N == 0, "")
348
+ tl.store(OutActualScale + row * stride_out_mx_m + offs_n_scale * stride_out_mx_n, acc_scale, mask=n_mask_scale)
349
+ tl.store(Out + row * outN + offs_n[None, :], acc, mask=n_mask[None, :])
350
+ else:
351
+ out = float_to_flex(out, out_scale if OutExpectedScale is not None else None, None, OutChecksumScale,
352
+ None, Out, flexpoint_saturate_inf)
353
+ if EPILOGUE_FN is not None:
354
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=Out.dtype.element_ty,
355
+ pid=row * tl.num_programs(1) + tl.program_id(1))
356
+ offs_n = off_n // ACTIVATION_REDUCTION_N + tl.arange(0, OUT_BLOCK_N)
357
+ n_mask = offs_n < outN
358
+ tl.store(Out + row * outN + offs_n, out, mask=n_mask)
359
+
360
+ persisent_m = tl.num_programs(0) < MBound
361
+ if not persisent_m and n_active_experts == 0:
362
+ # Skip updating the scale if there were no active experts and this is a non-persistent launch.
363
+ # The loop ran only once, and inside it we only stored zeros.
364
+ return
365
+
366
+ if USE_FUSED_ABSMAX:
367
+ local_max = tl.inline_asm_elementwise(
368
+ "mov.b32 $0, my_abs_max;",
369
+ "=r,r",
370
+ [local_max],
371
+ dtype=tl.float32,
372
+ is_pure=True,
373
+ pack=1,
374
+ )
375
+ local_max *= a_scale
376
+ if not OUT_MXFP8:
377
+ update_scale(local_max, OutActualScale, Out)
build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ # fmt: off
3
+ import triton
4
+ import triton.language as tl
5
+ from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
6
+ from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
7
+ from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
8
+ from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
9
+ from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
10
+ from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
11
+
12
+
13
+ @triton.jit
14
+ def _zero_masked_rows(
15
+ pid_m, pid_n,
16
+ Y, stride_y_m, stride_y_n,
17
+ N,
18
+ ScatterSrcIndx, num_idxs,
19
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
20
+ offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
21
+ offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
22
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
23
+ YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
24
+ mask_n = offs_n < N
25
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
26
+ tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
27
+
28
+
29
+ _matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
30
+ @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
31
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
32
+ def _matmul_ogs(
33
+ Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
34
+ YExpectedScale, YActualScale, YChecksumScale,
35
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
36
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
37
+ XScale,
38
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
39
+ W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
40
+ WScale,
41
+ WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
42
+ B, stride_b_e, # Bias
43
+ NRows, M, N, K, # shapes
44
+ # expt data
45
+ Betas, Gammas,
46
+ GatherIndx,
47
+ ScatterSrcIndx, num_idxs,
48
+ WriteBackIndx, writeback_size,
49
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
50
+ # true grid size
51
+ batch_size, grid_m, grid_n,
52
+ # Out scale
53
+ out_alpha,
54
+ # fused activation function
55
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
56
+ # epilogue transform
57
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
58
+ # MoE config
59
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
60
+ # precision config
61
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
62
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
63
+ PER_BATCH_SCALE: tl.constexpr,
64
+ # optimization config
65
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
66
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
67
+ # One of ["HOPPER", "BLACKWELL", None]
68
+ SWIZZLE_MX_VALUE: tl.constexpr,
69
+ # One of ["HOPPER", "BLACKWELL", None]
70
+ SWIZZLE_MX_SCALE: tl.constexpr,
71
+ EPILOGUE_SUBTILE: tl.constexpr,
72
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
73
+ W_CACHE_MODIFIER: tl.constexpr,
74
+ NUM_SMS: tl.constexpr,
75
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
76
+ UPCAST_INDICES: tl.constexpr = False,
77
+ DISABLE_Y_TMA: tl.constexpr = True,
78
+ SWAP_XW: tl.constexpr = False,
79
+ IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False):
80
+
81
+ Y = Out # Y is passed for the purposes of annotation; replace it with Out
82
+ is_w_microscaled: tl.constexpr = WMxScale is not None
83
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
84
+ if is_w_microscaled:
85
+ w_type: tl.constexpr = W.dtype.element_ty
86
+ is_mxfp4: tl.constexpr = w_type == tl.uint8
87
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
88
+ "mx_weight_ptr must be uint8 or fp8")
89
+ tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
90
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
91
+ tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
92
+ else:
93
+ tl.static_assert(SWIZZLE_MX_VALUE is None)
94
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
95
+ is_x_microscaled: tl.constexpr = XMxScale is not None
96
+ if is_x_microscaled:
97
+ x_type: tl.constexpr = X.dtype.element_ty
98
+ tl.static_assert(is_w_microscaled)
99
+ tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
100
+ tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
101
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
102
+ is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
103
+
104
+ OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
105
+ yN = N // ACTIVATION_REDUCTION_N
106
+
107
+ pid = tl.program_id(0)
108
+ if ExptOffsSum is not None and XCD_SWIZZLE > 1:
109
+ # Determine how much padding there is on the expert data. This allows us to
110
+ # know the true grid size and avoid processing padding tiles.
111
+ padding_m = grid_m - tl.load(ExptOffsSum)
112
+ else:
113
+ padding_m: tl.constexpr = 0
114
+
115
+ HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
116
+ index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
117
+
118
+ total_actual_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
119
+ if padding_m > 0 and pid >= total_actual_tiles:
120
+ tl.device_assert(batch_size == 0)
121
+ pid_mn = pid - total_actual_tiles
122
+ if pid_mn < padding_m * grid_n:
123
+ pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
124
+
125
+ # set masked out rows to 0
126
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
127
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
128
+ return
129
+
130
+ # swizzle program ids
131
+ pid_emnk = pid
132
+ if XCD_SWIZZLE != 1:
133
+ pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
134
+ pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
135
+ pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
136
+ pid_k = pid_mnk % SPLIT_K
137
+ pid_mn = pid_mnk // SPLIT_K
138
+ pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
139
+ # For split-k, advance to the output k slice
140
+ if SPLIT_K > 1:
141
+ Y += pid_k.to( index_type) * stride_y_k
142
+ if is_out_microscaled:
143
+ YActualScale += pid_k.to(index_type) * stride_x_mx_k
144
+ # set masked out rows to 0
145
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
146
+ _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
147
+ # unpack expert data
148
+ if ExptData is None:
149
+ tl.static_assert(M is not None)
150
+ expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
151
+ else:
152
+ tl.static_assert(M is None)
153
+ expt_data = tl.load(ExptData + pid_m)
154
+ if expt_data == -1:
155
+ return
156
+ expt_id = expt_data & 0x0000FFFF
157
+ block_id = expt_data >> 16
158
+ M = tl.load(ExptHist + expt_id)
159
+ start_m = tl.load(ExptOffs + expt_id)
160
+ start_z = 0
161
+ expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
162
+ start_m, start_z = start_m.to(index_type), start_z.to(index_type)
163
+ pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
164
+ # A pointers
165
+ offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
166
+ offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
167
+ X += start_z * stride_x_z
168
+ if GatherIndx is None:
169
+ X += start_m * stride_x_m
170
+ else:
171
+ GatherIndx += start_m
172
+ # no needs to bounds-check here because `offs_x_m` wraps around M dim
173
+ offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
174
+ offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
175
+ XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
176
+
177
+ # TODO: refactor if/else when triton front end improves
178
+ if is_w_microscaled:
179
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
180
+ tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
181
+ tl.static_assert(not is_x_microscaled)
182
+ # We have pack 2 fp4 values in a byte but we divide the dimension by 2
183
+ # when swizzling
184
+ W_K_DIVISOR: tl.constexpr = 1
185
+ W_K_MULTIPLIER: tl.constexpr = 2
186
+ W_N_DIVISOR: tl.constexpr = 4
187
+ else:
188
+ # We have pack 2 fp4 values in a byte
189
+ W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
190
+ W_K_MULTIPLIER: tl.constexpr = 1
191
+ W_N_DIVISOR: tl.constexpr = 1
192
+
193
+ PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
194
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
195
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
196
+
197
+ WMxScale += expt_id * stride_w_mx_e
198
+
199
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
200
+ tl.static_assert(BLOCK_N % 128 == 0)
201
+ tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
202
+ PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
203
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
204
+ stride_scale_k: tl.constexpr = 1
205
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
206
+ n_warps: tl.constexpr = tl.extra.cuda.num_warps()
207
+ tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
208
+ tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
209
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
210
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
211
+ stride_scale_k = stride_w_mx_k
212
+ else:
213
+ PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
214
+ SCALE_BLOCK_N: tl.constexpr = BLOCK_N
215
+ stride_scale_k = stride_w_mx_k
216
+ offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
217
+ offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
218
+ # K dimension must be the last dimension for the scales
219
+ offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
220
+ WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
221
+ else:
222
+ WMxScalePtrs = None
223
+ offs_k_scale = None
224
+ W_K_DIVISOR: tl.constexpr = 1
225
+ W_K_MULTIPLIER: tl.constexpr = 1
226
+ W_N_DIVISOR: tl.constexpr = 1
227
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
228
+ PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
229
+
230
+ # B pointers
231
+ offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
232
+ offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
233
+
234
+ if is_x_microscaled:
235
+ XMxScale += start_z.to(index_type) * stride_x_mx_z
236
+ if GatherIndx is None:
237
+ XMxScale += start_m * stride_x_mx_m
238
+ offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
239
+ XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
240
+ else:
241
+ XMxScalePtrs = None
242
+
243
+ offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
244
+ W += expt_id * stride_w_e
245
+ WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
246
+ # compute output
247
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
248
+ for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
249
+ if EVEN_K:
250
+ mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
251
+ mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
252
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
253
+ mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
254
+ if is_x_microscaled:
255
+ mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
256
+ else:
257
+ mask_k = offs_k < k
258
+ mask_k_w = offs_w_k < ((k // W_K_DIVISOR) * W_K_MULTIPLIER)
259
+ if is_w_microscaled and SWIZZLE_MX_SCALE is None:
260
+ mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
261
+ if is_x_microscaled:
262
+ mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
263
+
264
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
265
+ w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
266
+ if is_w_microscaled:
267
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
268
+ w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
269
+
270
+ if is_x_microscaled:
271
+ x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
272
+ elif x_format == "fp16" or x_format == "bf16":
273
+ x_scales: tl.constexpr = None
274
+ else:
275
+ # Scale of 1 in E8M0 format
276
+ x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
277
+
278
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
279
+ w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
280
+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
281
+ # Handshake with the swizzling code
282
+ num_warps: tl.constexpr = tl.extra.cuda.num_warps()
283
+ w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
284
+ else:
285
+ w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
286
+
287
+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
288
+ # Handshake with the swizzling code
289
+ tl.static_assert(x_format == "bf16")
290
+ tl.static_assert(w_format == "e2m1")
291
+ w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
292
+ tl.static_assert(w.dtype == tl.bfloat16)
293
+ acc = acc.trans()
294
+ x = x.trans()
295
+ # w = w.trans()
296
+ acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
297
+ acc = acc.trans()
298
+ else:
299
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True)
300
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
301
+ WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
302
+ else:
303
+ WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
304
+ if is_x_microscaled:
305
+ XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
306
+ else:
307
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
308
+ XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
309
+ WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
310
+ # bias + scale
311
+ offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
312
+ offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
313
+ mask_m = offs_m < M
314
+ mask_n = offs_y_n < N
315
+ if B is not None:
316
+ BPtrs = B + expt_id * stride_b_e + offs_y_n
317
+ if pid_k == 0:
318
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
319
+ else:
320
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
321
+ else:
322
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
323
+ if Betas is not None:
324
+ betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
325
+ else:
326
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
327
+ if Gammas is not None:
328
+ gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
329
+ else:
330
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
331
+ # flexpoint
332
+ x_scale = load_scale(XScale)
333
+ if PER_BATCH_SCALE:
334
+ w_scale = load_scale(WScale + expt_id)
335
+ else:
336
+ w_scale = load_scale(WScale)
337
+ acc *= x_scale * w_scale
338
+ acc = acc + bias[None, :] * betas[:, None]
339
+ if out_alpha is not None:
340
+ acc *= out_alpha
341
+ if ACTIVATION_FN is not None:
342
+ out = ACTIVATION_FN(acc, *activation_fn_args)
343
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
344
+ offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
345
+ mask_n = offs_y_n < yN
346
+ else:
347
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
348
+ out = acc
349
+ out *= gammas[:, None]
350
+ # write-back
351
+ Y += start_z.to(index_type) * stride_y_z
352
+ if WriteBackIndx is not None:
353
+ WriteBackIndx += start_m
354
+ dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
355
+ mask_m = mask_m & (dst_idx != -1)
356
+ offs_y_m = dst_idx
357
+ else:
358
+ Y += start_m * stride_y_m
359
+ offs_y_m = offs_m
360
+
361
+ YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
362
+ mask = mask_m[:, None] & mask_n[None, :]
363
+ if is_out_microscaled:
364
+ MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
365
+ N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
366
+ tl.static_assert(EPILOGUE_FN is not None)
367
+ out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
368
+ tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
369
+ offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
370
+ mask_n_scale = offs_y_n_scale < N_MX_BLOCK
371
+ YActualScale += start_z.to(index_type) * stride_y_mx_z
372
+ if WriteBackIndx is None:
373
+ YActualScale += start_m * stride_y_mx_m
374
+ YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
375
+ else:
376
+ YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
377
+ tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
378
+ else:
379
+ out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
380
+ if EPILOGUE_FN is not None and not IS_EPILOGUE_DEQUANT_MXFP8:
381
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
382
+ tl.store(YPtrs, out, mask=mask)
383
+
384
+
385
+ # Imagine N_EXPTS_ACT = 4, n_final_rows = 5, and n_scratchpad_rows = 8.
386
+ # Also imagine scatter_indx.src_indx is:
387
+ # (number of active experts per final row)
388
+ # -1 -1 0 -1 1
389
+ # -1 2 -1 -1 1
390
+ # 1 3 -1 -1 2
391
+ # -1 4 5 6 3
392
+ # -1 -1 -1 -1 0 (this row is unused)
393
+ #
394
+ # Then, row 0 and 1 can be written directly to the final tensor.
395
+ # In this case, WriteBackIndx looks like:
396
+ # [0] = 0 : intermediate row 0 is written directly to final row 0
397
+ # [1] = 5+1=6 : scratchpad starts at offset 5
398
+ # [2] = 1 : intermediate row 2 is written directly to final row 1
399
+ # [3] = 5+3=8
400
+ # [4] = 5+4=9
401
+ # [5] = 5+5=10
402
+ # [6] = 5+6=11
403
+ # [7] = -1 : unused (there are only seven intermediate rows)
404
+ @triton.jit
405
+ def _compute_writeback_idx(
406
+ WriteBackIndx,
407
+ FinalizeScatterIdxs,
408
+ ScatterDstIndx, ScatterSrcIndx,
409
+ n_final_rows, n_scratchpad_rows,
410
+ BLOCK_M: tl.constexpr,
411
+ N_EXPTS_ACT: tl.constexpr,
412
+ ):
413
+ tl.static_assert(N_EXPTS_ACT > 1)
414
+
415
+ pid_m = tl.program_id(0)
416
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
417
+ mask_m = offs_m < n_scratchpad_rows
418
+ dst_idxs = tl.load(ScatterDstIndx + offs_m, mask=mask_m, other=-1)
419
+ # Load corresponding rows in ScatterSrcIndx.
420
+ mask = dst_idxs != -1
421
+ src_offs = (dst_idxs // N_EXPTS_ACT) * N_EXPTS_ACT
422
+ src_offs = src_offs[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :]
423
+ src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask[:, None], other=-1)
424
+ # Compute the number of actually active experts.
425
+ is_src_active = (src_idxs != -1).to(tl.int32)
426
+ has_one_active = tl.sum(is_src_active, axis=1) == 1
427
+ # Compute the writeback index.
428
+ wb_idx = tl.where(has_one_active, dst_idxs // N_EXPTS_ACT, n_final_rows + offs_m)
429
+ wb_idx = tl.where(mask, wb_idx, -1)
430
+ tl.store(WriteBackIndx + offs_m, wb_idx, mask=mask_m)
431
+
432
+ if pid_m >= ((n_final_rows + BLOCK_M - 1) // BLOCK_M):
433
+ return
434
+
435
+ mask_m = offs_m < n_final_rows
436
+ src_offs = offs_m[:, None] * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT)[None, :]
437
+ src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask_m[:, None], other=-1)
438
+ is_src_active = (src_idxs != -1).to(tl.int32)
439
+ num_src_active = tl.sum(is_src_active, axis=1)
440
+
441
+ need_finalize_scatter = mask_m & (num_src_active != 1)
442
+ finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32))
443
+ if finalize_scatter_count == 0:
444
+ return
445
+ pp_off = tl.atomic_add(FinalizeScatterIdxs + n_final_rows + n_scratchpad_rows, finalize_scatter_count)
446
+
447
+ # need_finalize_scatter = [1, 0, 0, 1, 1, 0, 1, 0, 1]
448
+ # arange = [0, 1, 2, 3, 4, 5, 6, 7, 8]
449
+ arange = tl.arange(0, BLOCK_M)
450
+ # idxs = [0, _, _, 3, 4, _, 6, _, 8]
451
+ last = BLOCK_M - 1
452
+ idxs = tl.where(need_finalize_scatter, arange, last)
453
+ # idxs = [0, 3, 4, 6, 8, _, _, _, _]
454
+ idxs = tl.sort(idxs)
455
+ # r = offs_m
456
+ # d = [r[0], r[3], r[4], r[6], r[8], r[-1], r[-1], r[-1], r[-1]]
457
+ d = tl.gather(offs_m, idxs, axis=0)
458
+ s = tl.gather(src_idxs, idxs.expand_dims(1).broadcast_to(src_idxs.shape), axis=0)
459
+ # store destination indices
460
+ Ptr = FinalizeScatterIdxs + pp_off
461
+ tl.store(Ptr + arange, d, mask=arange < finalize_scatter_count)
462
+ # store src indices
463
+ Ptr = FinalizeScatterIdxs + n_final_rows + pp_off * N_EXPTS_ACT
464
+ tl.store(Ptr + N_EXPTS_ACT * arange[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :], s, mask=(arange < finalize_scatter_count)[:, None])
build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ # fmt: off
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+ from triton_kernels import target_info
7
+ from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
8
+ from triton_kernels.numerics_details.flexpoint import (
9
+ float_to_flex,
10
+ load_scale,
11
+ nan_propagating_absmax_reduce,
12
+ compute_scale,
13
+ )
14
+ from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
15
+ from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
16
+
17
+
18
+ @tl.constexpr_function
19
+ def cuda_capability_geq(major, minor):
20
+ return target_info.cuda_capability_geq(major, minor)
21
+
22
+ @tl.constexpr_function
23
+ def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
24
+ if isinstance(tensor_or_desc, tl.tensor):
25
+ return tensor_or_desc.dtype.element_ty
26
+ elif isinstance(tensor_or_desc, tl.tensor_descriptor):
27
+ return tensor_or_desc.dtype
28
+ else:
29
+ raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
30
+
31
+
32
+ @triton.jit
33
+ def _tma_load_2d(desc, offs, transpose: tl.constexpr = False):
34
+ if len(desc.shape) == 2 and len(offs) == 3:
35
+ tl.device_assert(offs[0] == 0, "2D TMA load requires Z offset to be 0")
36
+ offs = offs[1:]
37
+ if transpose:
38
+ offs = offs[:-2] + [offs[-1], offs[-2]]
39
+ res = desc.load(offs)
40
+ res = tl.reshape(res, desc.block_shape[-2:])
41
+ if transpose:
42
+ res = tl.trans(res)
43
+ return res
44
+
45
+
46
+ # Helper function to recreate a TMA desc with the same fields, but with a new pointer and optional new shape.
47
+ @triton.jit
48
+ def _update_tensor_desc(desc, ptr, shape=None):
49
+ return tl.make_tensor_descriptor(
50
+ ptr,
51
+ shape=shape or desc.shape,
52
+ # last dim must be constexpr 1; reflecting the old descriptor drops the constexpr
53
+ strides=desc.strides[:-1] + [tl.constexpr(1)],
54
+ block_shape=desc.block_shape,
55
+ )
56
+
57
+
58
+ @triton.jit
59
+ def _load_tile_attrs(
60
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
61
+ M, ExptData, ExptHist, ExptOffs,
62
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
63
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
64
+ # unpack and swizzle program ids
65
+ pid_emnk = tile_id
66
+ if XCD_SWIZZLE != 1:
67
+ pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
68
+ pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
69
+ pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
70
+ if SPLIT_K > 1:
71
+ pid_k = pid_mnk % SPLIT_K
72
+ pid_mn = pid_mnk // SPLIT_K
73
+ else:
74
+ pid_k: tl.constexpr = 0
75
+ pid_mn = pid_mnk
76
+ pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
77
+
78
+ # unpack expert data
79
+ if ExptData is None:
80
+ tl.static_assert(M is not None)
81
+ expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
82
+ else:
83
+ tl.static_assert(M is None)
84
+ expt_data = tl.load(ExptData + pid_m)
85
+ expt_id = expt_data & 0x0000FFFF
86
+ block_id = expt_data >> 16
87
+ eM = tl.load(ExptHist + expt_id)
88
+ start_m = tl.load(ExptOffs + expt_id)
89
+ start_z = 0
90
+
91
+ off_m = BLOCK_M * block_id
92
+ off_n = BLOCK_N * pid_n
93
+
94
+ return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
95
+
96
+
97
+ @triton.jit
98
+ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
99
+ mask = mask & (offs < writeback_size)
100
+ offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
101
+ mask = offs != -1
102
+ return (offs, mask)
103
+
104
+
105
+ _matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
106
+ @triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
107
+ repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
108
+ def _p_matmul_ogs(
109
+ Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
110
+ YExpectedScale, YActualScale, YChecksumScale,
111
+ stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
112
+ X, XPtr, stride_x_z, stride_x_m, stride_x_k,
113
+ XScale,
114
+ XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
115
+ W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
116
+ WScale,
117
+ MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
118
+ B, stride_b_e, # Bias
119
+ NRows, M, N, K, # shapes
120
+ # expt data
121
+ Betas, Gammas,
122
+ GatherIndx,
123
+ ScatterSrcIndx, num_idxs,
124
+ WriteBackIndx, writeback_size,
125
+ ExptHist, ExptOffs, ExptOffsSum, ExptData,
126
+ # true grid size
127
+ batch_size, grid_m, grid_n,
128
+ # Out scale
129
+ out_alpha,
130
+ # fused activation function
131
+ ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
132
+ # epilogue transform
133
+ EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
134
+ # MoE config
135
+ N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
136
+ # precision config
137
+ MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
138
+ FLEXPOINT_SATURATE_INF: tl.constexpr,
139
+ PER_BATCH_SCALE: tl.constexpr,
140
+ # optimization config
141
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
142
+ GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
143
+ # NYI: Must be None
144
+ SWIZZLE_MX_VALUE: tl.constexpr,
145
+ # One of ["BLACKWELL", None]
146
+ SWIZZLE_MX_SCALE: tl.constexpr,
147
+ EPILOGUE_SUBTILE: tl.constexpr,
148
+ EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
149
+ W_CACHE_MODIFIER: tl.constexpr,
150
+ NUM_SMS: tl.constexpr,
151
+ TOKENS_PER_EXPT_FOR_ANNOTATION=None,
152
+ UPCAST_INDICES:tl.constexpr=False,
153
+ DISABLE_Y_TMA: tl.constexpr=False,
154
+ SWAP_XW: tl.constexpr = False,
155
+ IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False):
156
+ tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
157
+ Y = Out # Y is passed for the purposes of annotation; replace it with Out
158
+
159
+ is_microscaled_format: tl.constexpr = MxScale is not None
160
+ MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
161
+ if is_microscaled_format:
162
+ w_type: tl.constexpr = get_dtype(W)
163
+ tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
164
+ "mx_weight_ptr must be uint8")
165
+ tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
166
+ tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
167
+ tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
168
+
169
+ # We have pack 2 fp4 values in a byte
170
+ W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
171
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
172
+ MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
173
+ else:
174
+ W_PACK_DIVISOR: tl.constexpr = 1
175
+ MX_SCALE_BLOCK_K: tl.constexpr = 1
176
+ PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
177
+ tl.static_assert(SWIZZLE_MX_SCALE is None)
178
+
179
+ if ExptOffsSum is not None:
180
+ # Determine how much padding there is on the expert data. This allows us to
181
+ # know the true grid size and avoid processing padding tiles.
182
+ padding_m = grid_m - tl.load(ExptOffsSum)
183
+ else:
184
+ padding_m: tl.constexpr = 0
185
+
186
+ HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
187
+ index_type: tl.constexpr = tl.int64
188
+
189
+ if EPILOGUE_SUBTILE is None:
190
+ SUBTILE_FACTOR: tl.constexpr = 1
191
+ else:
192
+ SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
193
+ EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
194
+ OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
195
+ yN = N // ACTIVATION_REDUCTION_N
196
+
197
+ # set masked out rows to 0
198
+ if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
199
+ # Iterate with reversed pids so that later pids will get more tiles if the number of
200
+ # tiles isn't evenly divisible by the number of SMs.
201
+ # The main loop after this iterates in the forward direction such that earlier
202
+ # pids get more tiles if the number of tiles isn't evenly divisible.
203
+ # This helps balance the work across the SMs.
204
+ for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
205
+ pid_k = pid_mnk % SPLIT_K
206
+ pid_mn = pid_mnk // SPLIT_K
207
+ pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
208
+
209
+ z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
210
+ offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
211
+ offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
212
+ src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
213
+ YPtrs = Y + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
214
+ mask_n = offs_n < yN
215
+ mask = (src_idx == -1)[:, None] & mask_n[None, :]
216
+ tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
217
+
218
+ USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
219
+
220
+ USE_GATHER_TMA: tl.constexpr = GatherIndx is not None and cuda_capability_geq(10, 0)
221
+ X_USE_LOAD_TMA: tl.constexpr = GatherIndx is None and isinstance(X, tl.tensor_descriptor)
222
+ USE_SCATTER_TMA: tl.constexpr = (cuda_capability_geq(10, 0) and HAS_FUSED_SCATTER) and not DISABLE_Y_TMA
223
+ INT_MAX: tl.constexpr = 2147483647
224
+
225
+ if USE_SCATTER_TMA:
226
+ y_desc = tl.make_tensor_descriptor(
227
+ Y,
228
+ # No masking on the M dimension because we manually mask by setting indices to INT_MAX
229
+ shape=[INT_MAX - 1, yN],
230
+ strides=[stride_y_m, stride_y_n],
231
+ block_shape=[1, OUT_BLOCK_N],
232
+ )
233
+
234
+ k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
235
+ num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
236
+
237
+ # If true, do not share loop-carried variables between the prologue and the
238
+ # epilogue to enable better pipelining with mmav5
239
+ INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
240
+
241
+ # start negative; will be incremented at the top of the loop
242
+ if INDEPENDENT_EPILOGUE:
243
+ tile_id1 = tl.program_id(0) - NUM_SMS
244
+
245
+ # Keep track of local max for updating flexpoint scales.
246
+ THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
247
+ local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
248
+
249
+ DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
250
+ # Enable warp specialization when all loads are TMA loads.
251
+ WARP_SPECIALIZE: tl.constexpr = (USE_GATHER_TMA or X_USE_LOAD_TMA)
252
+
253
+ for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=WARP_SPECIALIZE):
254
+ expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
255
+ tile_id, num_tiles, grid_m, grid_n, padding_m,
256
+ M, ExptData, ExptHist, ExptOffs,
257
+ BLOCK_M, BLOCK_N, SPLIT_K,
258
+ GROUP_M, XCD_SWIZZLE)
259
+
260
+ # Base pointers and offsets.
261
+ if not USE_GATHER_TMA and not X_USE_LOAD_TMA:
262
+ XBase = X + start_z.to(index_type) * stride_x_z
263
+ offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
264
+ if SPLIT_K > 1:
265
+ offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
266
+
267
+ if not X_USE_LOAD_TMA:
268
+ offs_m = off_m + tl.arange(0, BLOCK_M)
269
+ mask_m = offs_m < (M if M is not None else eM)
270
+ if USE_GATHER_TMA:
271
+ # Mask the gather indices and load -1 instead. TMA will handle OOB accesses.
272
+ if ExptData is None:
273
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
274
+ # Bump rows to account for the Z offset.
275
+ offs_x_m += start_z * (stride_x_z // stride_x_m)
276
+ offs_x_m = tl.where(mask_m, offs_x_m, -1)
277
+ else:
278
+ offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
279
+ mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
280
+ else:
281
+ if M is not None:
282
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
283
+ else:
284
+ offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
285
+ # no needs to bounds-check here because `offs_m` wraps around M dim
286
+ offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
287
+ offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
288
+
289
+ acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
290
+ for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
291
+ off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
292
+ off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
293
+ off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
294
+
295
+ if USE_GATHER_TMA:
296
+ x = X.gather(offs_x_m, off_k)
297
+ elif X_USE_LOAD_TMA:
298
+ x = _tma_load_2d(X, [start_z, start_m + off_m, off_k])
299
+ else:
300
+ XPtrs = XBase + offs_x_m + offs_x_k
301
+ XBase += BLOCK_K * SPLIT_K * stride_x_k
302
+ mask_k = tl.arange(0, BLOCK_K) < K - off_k
303
+ if EVEN_K:
304
+ if SPLIT_K > 1:
305
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
306
+ else:
307
+ x = tl.load(XPtrs)
308
+ else:
309
+ x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
310
+
311
+ w = _tma_load_2d(W, [expt_id, off_k_w, off_n], transpose=W_TRANSPOSE)
312
+
313
+ if is_microscaled_format:
314
+ x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
315
+ mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
316
+ if x_format == "fp16" or x_format == "bf16":
317
+ x_scales: tl.constexpr = None
318
+ else:
319
+ x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
320
+ if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
321
+ flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
322
+ w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
323
+ w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
324
+ w_scales = unswizzle_mx_scale_bw(w_scales)
325
+ else:
326
+ w_scales = _tma_load_2d(MxScale, [expt_id, off_k_mx, off_n]).T
327
+ if SWAP_XW:
328
+ acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
329
+ else:
330
+ acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
331
+ else:
332
+ if SWAP_XW:
333
+ acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
334
+ else:
335
+ acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
336
+
337
+ if INDEPENDENT_EPILOGUE:
338
+ tile_id1 += NUM_SMS
339
+ expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
340
+ tile_id1, num_tiles, grid_m, grid_n, padding_m,
341
+ M, ExptData, ExptHist, ExptOffs,
342
+ BLOCK_M, BLOCK_N, SPLIT_K,
343
+ GROUP_M, XCD_SWIZZLE)
344
+ else:
345
+ tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
346
+ off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
347
+
348
+ # Determine output row offsets and mask
349
+ offs_m = off_m1 + tl.arange(0, BLOCK_M)
350
+ mask_m = offs_m < M if M is not None else offs_m < eM1
351
+ if HAS_FUSED_SCATTER:
352
+ offs_y_m, mask_m = _load_writeback_idx_and_mask(
353
+ WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
354
+ # Later, mask out the acc for computing flexpoint scales.
355
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
356
+
357
+ if USE_SCATTER_TMA and SPLIT_K > 1:
358
+ # Compute the split k offset in number of rows, and add it to offs_y_m.
359
+ # This allows us to write to the correct slice in the output tensor while using
360
+ # a 2D TMA scatter.
361
+ tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
362
+ split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
363
+ offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
364
+ else:
365
+ offs_y_m = start_m1 + offs_m
366
+
367
+ if USE_GATHER_TMA:
368
+ MASK_ACC: tl.constexpr = False
369
+ else:
370
+ # Later, mask out the acc for computing flexpoint scales.
371
+ MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
372
+
373
+ # TMA is faster on Blackwell if a SWAP_XW transpose is not needed, or when we need registers to mask out the acc.
374
+ # Contrary to the SWAP_XW case, having a fused activation function tends to make TMA faster again.
375
+ # For the ideal optimization, this would depend on what the activation function is doing.
376
+ Y_USE_TMA: tl.constexpr = (MASK_ACC or cuda_capability_geq(10, 0)) and not (
377
+ DISABLE_Y_TMA or (SWAP_XW and ACTIVATION_FN is None))
378
+
379
+ YBase = Y + start_z1.to(index_type) * stride_y_z + start_m1.to(index_type) * stride_y_m
380
+ if USE_SCATTER_TMA:
381
+ if ExptData is None: # start_z1 may change; update the descriptor
382
+ y_desc = _update_tensor_desc(y_desc, YBase)
383
+ elif not HAS_FUSED_SCATTER and Y_USE_TMA:
384
+ y_desc = tl.make_tensor_descriptor(
385
+ YBase + pid_k1.to(index_type) * stride_y_k,
386
+ shape=[M if M is not None else eM1, yN],
387
+ strides=[stride_y_m, stride_y_n],
388
+ block_shape=[BLOCK_M, OUT_BLOCK_N],
389
+ )
390
+
391
+ # bias + scale
392
+ offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
393
+ mask_n = offs_y_n < N
394
+ if B is not None:
395
+ BPtrs = B + expt_id1 * stride_b_e + offs_y_n
396
+ if pid_k1 == 0:
397
+ bias = tl.load(BPtrs, mask=mask_n, other=0)
398
+ else:
399
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
400
+ else:
401
+ bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
402
+ if Betas is not None:
403
+ betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
404
+ else:
405
+ betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
406
+ if Gammas is not None:
407
+ gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
408
+ else:
409
+ gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
410
+ x_scale = load_scale(XScale)
411
+ if PER_BATCH_SCALE:
412
+ w_scale = load_scale(WScale + expt_id1)
413
+ else:
414
+ w_scale = load_scale(WScale)
415
+
416
+ accs = (acc,)
417
+ biases = (bias,)
418
+
419
+ if SUBTILE_FACTOR >= 2:
420
+ acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
421
+ accs = (acc0, acc1)
422
+ bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
423
+ biases = (bias0, bias1)
424
+
425
+ if SUBTILE_FACTOR >= 4:
426
+ acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
427
+ acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
428
+ accs = (acc00, acc01, acc10, acc11)
429
+ bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
430
+ bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
431
+ biases = (bias00, bias01, bias10, bias11)
432
+
433
+ tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
434
+ tl.static_assert(len(accs) == SUBTILE_FACTOR)
435
+
436
+ for a_i in tl.static_range(len(accs)):
437
+ acc_tile = accs[a_i]
438
+ acc_tile *= x_scale * w_scale
439
+
440
+ if SWAP_XW:
441
+ acc_tile = acc_tile.T
442
+
443
+ acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
444
+ if out_alpha is not None:
445
+ acc_tile *= out_alpha
446
+
447
+ if ACTIVATION_FN is not None:
448
+ out = ACTIVATION_FN(acc_tile, *activation_fn_args)
449
+ tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
450
+ else:
451
+ tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
452
+ out = acc_tile
453
+
454
+ out *= gammas[:, None]
455
+
456
+ if MASK_ACC:
457
+ out = tl.where(mask_m[:, None], out, 0.0)
458
+ # Flexpoint
459
+ out_view = tl.reshape(
460
+ out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
461
+ local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
462
+ out = float_to_flex(
463
+ out, YExpectedScale,
464
+ None, # ActualScale: local absmax is tracked and updated after the loop
465
+ YChecksumScale,
466
+ None, # mask: out is manually masked to 0
467
+ Y, FLEXPOINT_SATURATE_INF)
468
+ if EPILOGUE_FN is not None:
469
+ out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=Y.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
470
+
471
+ out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
472
+ if USE_SCATTER_TMA:
473
+ # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
474
+ # there shouldn't be any other negative values.
475
+ offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
476
+ y_desc.scatter(out.to(Y.dtype.element_ty), offs_y_m, out_off_n)
477
+ elif not HAS_FUSED_SCATTER and Y_USE_TMA:
478
+ y_desc.store([off_m1, out_off_n], out.to(Y.dtype.element_ty))
479
+ else:
480
+ offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
481
+ mask_n = offs_y_n < yN
482
+
483
+ YPtrs = Y + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
484
+ mask = mask_m[:, None] & mask_n[None, :]
485
+ tl.store(YPtrs, out, mask=mask)
486
+
487
+
488
+ # Update the flexpoint scales
489
+ if YActualScale is not None:
490
+ tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), Y), sem="relaxed")
491
+
492
+
493
+ _per_device_alloc_fns = {}
494
+ def get_per_device_per_stream_alloc_fn(device):
495
+ if device not in _per_device_alloc_fns:
496
+ _per_stream_tensors = {}
497
+ def alloc_fn(size: int, alignment: int, stream):
498
+ assert alignment == 128
499
+ if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
500
+ _per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
501
+ _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
502
+ return _per_stream_tensors[stream]
503
+
504
+ _per_device_alloc_fns[device] = alloc_fn
505
+ return _per_device_alloc_fns[device]
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ # fmt: off
3
+ from dataclasses import dataclass
4
+ import triton
5
+ from triton_kernels.target_info import get_cdna_version
6
+ import torch
7
+ from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
8
+
9
+
10
+ @dataclass
11
+ class OptFlags:
12
+ block_m: int
13
+ block_n: int
14
+ block_k: int
15
+ num_warps: int
16
+ num_stages: int
17
+ group_m: int
18
+ xcd_swizzle: int
19
+ w_cache_modifier: str
20
+ split_k: int
21
+ fused_scatter: bool
22
+ is_persistent: bool
23
+ idle_sms: int
24
+ epilogue_subtile: int | None
25
+ arch: str
26
+ target_kernel_kwargs: dict
27
+
28
+ def __post_init__(self):
29
+ if self.fused_scatter and self.split_k != 1:
30
+ raise ValueError("Not supported")
31
+
32
+
33
+
34
+ def make_default_opt_flags_amd(
35
+ out_dtype,
36
+ lhs_dtype,
37
+ rhs_dtype,
38
+ precision_config,
39
+ m,
40
+ n,
41
+ k,
42
+ routing_data,
43
+ can_use_persistent_tma,
44
+ can_use_fused_scatter,
45
+ enforce_bitwise_invariance,
46
+ epilogue_effective_itemsize,
47
+ constraints,
48
+ ):
49
+ constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
50
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
51
+ # tokens per expert
52
+ if routing_data is None:
53
+ tokens_per_expt = m
54
+ elif routing_data.expected_tokens_per_expt is None:
55
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
56
+ else:
57
+ tokens_per_expt = routing_data.expected_tokens_per_expt
58
+
59
+ is_cdna4 = get_cdna_version() == 4
60
+ # block_m
61
+ if constraints.get("block_m", None):
62
+ block_m = constraints["block_m"]
63
+ elif enforce_bitwise_invariance:
64
+ block_m = 256 if is_cdna4 else 128
65
+ elif tokens_per_expt >= 512 and n >= 2048:
66
+ block_m = 256 if is_cdna4 else 128
67
+ elif is_cdna4 and m >= 512:
68
+ block_m = 128
69
+ else:
70
+ block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
71
+
72
+ if routing_data is not None:
73
+ grid_m = routing_data.n_blocks(m, block_m)
74
+ else:
75
+ grid_m = triton.cdiv(m, block_m)
76
+ # group_m:
77
+ group_m = 4
78
+ # number of xcds
79
+ num_xcds = 8
80
+ xcd_swizzle = num_xcds
81
+ # block_nk:
82
+ block_n, block_k = opt_flags_amd.compute_block_nk(
83
+ n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
84
+ )
85
+ # Replace block_k if provided in constraints.
86
+ # TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
87
+ if constraints.get("block_k", None) is not None:
88
+ block_k = constraints["block_k"]
89
+ is_persistent = constraints.get("is_persistent", False)
90
+ # split_k:
91
+ if constraints.get("split_k", None) is not None:
92
+ split_k = constraints["split_k"]
93
+ elif is_persistent or enforce_bitwise_invariance:
94
+ split_k = 1
95
+ else:
96
+ grid_size = grid_m * ((n + block_n - 1) // block_n)
97
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
98
+ split_k = max(1, n_cu // grid_size)
99
+ # w_cache_modifier:
100
+ w_cache_modifier = ".cg" if block_m <= 32 else None
101
+ # num_warps, num_stages
102
+ num_warps = 2 if (m is not None and m <= 16) else 8
103
+ num_stages = 2
104
+ # AMD-specific
105
+ target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
106
+ ret = OptFlags(
107
+ block_m=block_m,
108
+ block_n=block_n,
109
+ block_k=block_k,
110
+ num_warps=num_warps,
111
+ num_stages=num_stages,
112
+ group_m=group_m,
113
+ xcd_swizzle=xcd_swizzle,
114
+ w_cache_modifier=w_cache_modifier,
115
+ split_k=split_k,
116
+ fused_scatter=constraints.get('fused_scatter', False),
117
+ is_persistent=is_persistent,
118
+ idle_sms=0,
119
+ epilogue_subtile=constraints.get('epilogue_subtile', None),
120
+ arch=None,
121
+ target_kernel_kwargs=target_kernel_kwargs,
122
+ )
123
+ # check constraints
124
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
125
+ return ret
126
+
127
+ def make_default_opt_flags_nvidia(
128
+ out_dtype,
129
+ lhs_dtype,
130
+ rhs_dtype,
131
+ precision_config,
132
+ m,
133
+ n,
134
+ k,
135
+ routing_data,
136
+ can_use_persistent_tma,
137
+ can_use_fused_scatter,
138
+ enforce_bitwise_invariance,
139
+ epilogue_effective_itemsize,
140
+ constraints,
141
+ ):
142
+ constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms"]
143
+ assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
144
+ # tokens per expert
145
+ if routing_data is None:
146
+ tokens_per_expt = m
147
+ elif routing_data.expected_tokens_per_expt is None:
148
+ tokens_per_expt = max(1, m // routing_data.n_expts_tot)
149
+ else:
150
+ tokens_per_expt = routing_data.expected_tokens_per_expt
151
+ # pid swizzling
152
+ group_m = 8
153
+ xcd_swizzle = 1
154
+ # block_m
155
+ if constraints.get("block_m", None):
156
+ block_m = constraints["block_m"]
157
+ elif enforce_bitwise_invariance:
158
+ block_m = 128
159
+ else:
160
+ min_block_m = 64 if torch.cuda.get_device_capability()[0] == 10 else 16
161
+ block_m = max(min_block_m, min(triton.next_power_of_2(tokens_per_expt), 128))
162
+ # block n
163
+ arch = None
164
+ block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
165
+ # is_persistent
166
+ grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
167
+ n_sms = torch.cuda.get_device_properties(0).multi_processor_count
168
+ tiles_per_sm = grid_size / n_sms
169
+ supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
170
+ if constraints.get("is_persistent", None) is not None:
171
+ is_persistent = constraints["is_persistent"]
172
+ else:
173
+ has_simple_epilogue = precision_config.max_num_imprecise_acc is None
174
+ is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
175
+ # TEMP CHANGE
176
+ if precision_config.act_scale is not None or precision_config.out_scale is not None:
177
+ is_persistent = False
178
+ # block k
179
+ if constraints.get("block_k", None) is not None:
180
+ block_k = constraints["block_k"]
181
+ else:
182
+ block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
183
+ # split_k
184
+ if constraints.get("split_k", None) is not None:
185
+ split_k = constraints["split_k"]
186
+ elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
187
+ split_k = 1
188
+ else:
189
+ estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
190
+ split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
191
+ if split_k > 1:
192
+ # With split_k, results are written in f32. Use that for the following computations.
193
+ out_dtype = torch.float32
194
+ compute_num_stages_args = (
195
+ precision_config,
196
+ is_persistent,
197
+ block_m,
198
+ block_n,
199
+ block_k,
200
+ out_dtype,
201
+ lhs_dtype,
202
+ rhs_dtype,
203
+ )
204
+
205
+ if constraints.get("epilogue_subtile", None) is not None:
206
+ subtiles_to_check = [constraints["epilogue_subtile"]]
207
+ else:
208
+ subtiles_to_check = [1, 2, 4]
209
+ num_stages = -1
210
+ for ep in subtiles_to_check:
211
+ ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
212
+ if ns > num_stages:
213
+ epilogue_subtile, num_stages = ep, ns
214
+ assert num_stages >= 1
215
+ if constraints.get("num_stages", None):
216
+ num_stages = constraints["num_stages"]
217
+
218
+ # fused scatter scratchpad
219
+ if constraints.get("fused_scatter", None) is not None:
220
+ fused_scatter = constraints["fused_scatter"]
221
+ else:
222
+ fused_scatter = can_use_fused_scatter and split_k == 1
223
+ # Handshake with the HBM swizzling
224
+ num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
225
+ ret = OptFlags(
226
+ block_m=block_m,
227
+ block_n=block_n,
228
+ block_k=block_k,
229
+ num_warps=num_warps,
230
+ num_stages=num_stages,
231
+ group_m=group_m,
232
+ xcd_swizzle=xcd_swizzle,
233
+ w_cache_modifier=None,
234
+ split_k=split_k,
235
+ fused_scatter=fused_scatter,
236
+ is_persistent=is_persistent,
237
+ epilogue_subtile=epilogue_subtile,
238
+ arch=arch,
239
+ target_kernel_kwargs=dict(),
240
+ idle_sms=constraints.get("idle_sms", 0),
241
+ )
242
+ # check constraints
243
+ assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
244
+ return ret
245
+
246
+ # --------------
247
+ # User Interface
248
+ # --------------
249
+
250
+ _opt_flags_constraints: dict = dict()
251
+ _opt_flags: OptFlags | None = None
252
+
253
+ def update_opt_flags_constraints(constraints: dict[str, int]):
254
+ global _opt_flags_constraints
255
+ _opt_flags_constraints.update(constraints)
256
+
257
+ def reset_opt_flags_constraints():
258
+ global _opt_flags_constraints
259
+ _opt_flags_constraints = dict()
260
+
261
+ def set_opt_flags(opt_flags: OptFlags):
262
+ global _opt_flags
263
+ assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
264
+ assert not _opt_flags, "opt_flags already set; please reset to None first"
265
+ _opt_flags = opt_flags
266
+
267
+ class InapplicableConstraint(Exception):
268
+ pass
269
+
270
+ def make_opt_flags(
271
+ out_dtype,
272
+ lhs_dtype,
273
+ rhs_dtype,
274
+ precision_config,
275
+ m,
276
+ n,
277
+ k,
278
+ routing_data,
279
+ can_use_persistent_tma,
280
+ can_use_fused_scatter,
281
+ epilogue_effective_itemsize,
282
+ ):
283
+ if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
284
+ raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
285
+ enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
286
+ if _opt_flags is not None:
287
+ assert not _opt_flags_constraints
288
+ return _opt_flags
289
+ args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
290
+ routing_data, can_use_persistent_tma, can_use_fused_scatter,
291
+ enforce_bitwise_invariance, epilogue_effective_itemsize,
292
+ _opt_flags_constraints]
293
+ backend = triton.runtime.driver.active.get_current_target().backend
294
+ if backend == "hip":
295
+ return make_default_opt_flags_amd(*args)
296
+ if backend == "cuda":
297
+ return make_default_opt_flags_nvidia(*args)
298
+ assert False
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ from triton_kernels.target_info import get_cdna_version
4
+ from triton_kernels.tensor import bitwidth
5
+
6
+
7
+ def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
8
+ lhs_width = bitwidth(lhs_dtype) / 8
9
+ rhs_width = bitwidth(rhs_dtype) / 8
10
+
11
+ # block_n:
12
+ n_cu = torch.cuda.get_device_properties(0).multi_processor_count
13
+ if n is not None:
14
+ if n <= 128 and (n & (n - 1)) == 0:
15
+ block_n = n
16
+ else:
17
+ block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu)))
18
+ elif block_m > 64:
19
+ block_n = 256
20
+ else:
21
+ block_n = 128
22
+
23
+ if get_cdna_version() == 4 and block_m == 128:
24
+ block_n = 512
25
+
26
+ # block_k needs to match the cacheline size (128B)
27
+ block_k = int(128 // min(lhs_width, rhs_width))
28
+
29
+ # TODO: block_k = 128 seems to work better for now.
30
+ # perhaps due to increased number of k loops to pipeline
31
+ if precision_config.weight_scale is not None and get_cdna_version() != 4:
32
+ block_k = 128
33
+ return block_n, block_k
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ from triton_kernels import target_info
4
+ from triton_kernels.tensor import get_layout, bitwidth, FP4
5
+ from triton_kernels.tensor_details.layout import HopperMXScaleLayout
6
+ from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
7
+
8
+
9
+ def compute_grid_size(routing_data, m, n, block_m, block_n):
10
+ if routing_data is not None:
11
+ grid_m = routing_data.n_blocks(m, block_m)
12
+ else:
13
+ grid_m = triton.cdiv(m, block_m)
14
+ grid_n = (n + block_n - 1) // block_n
15
+ return grid_m * grid_n
16
+
17
+
18
+ def compute_block_n(n: int, arch, precision_config):
19
+ # block_n:
20
+ layout = get_layout(precision_config.weight_scale)
21
+ if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
22
+ return 128
23
+ elif precision_config.max_num_imprecise_acc is None and n > 128:
24
+ return 256
25
+ else:
26
+ return max(16, min(128, triton.next_power_of_2(n)))
27
+
28
+
29
+ def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
30
+ lhs_width = bitwidth(lhs_dtype)
31
+ rhs_width = bitwidth(rhs_dtype)
32
+ # block_k needs to match the cacheline size (1024 bits)
33
+ block_k = int(1024 // min(lhs_width, rhs_width))
34
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
35
+ if rhs_width == 4 and not has_native_mxfp:
36
+ block_k = 128
37
+ elif k is not None:
38
+ block_k = max(32, min(triton.next_power_of_2(k), block_k))
39
+ has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
40
+ if has_native_mxfp and is_persistent and has_mx_weight_scale:
41
+ block_k = min(block_k, 128)
42
+ return block_k
43
+
44
+
45
+ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
46
+ device_props = torch.cuda.get_device_properties(0)
47
+ n_sms = device_props.multi_processor_count
48
+ split_k = n_sms // grid_size
49
+ if k is not None:
50
+ # avoid split_k for small k
51
+ num_block_k = triton.cdiv(k, block_k)
52
+ split_k = min(split_k, num_block_k // 4)
53
+ split_k = max(split_k, 1)
54
+ return split_k
55
+
56
+
57
+ def compute_num_warps(block_m, block_n, precision_config):
58
+ layout = get_layout(precision_config.weight_scale)
59
+ if isinstance(layout, HopperMXScaleLayout):
60
+ return layout.num_warps
61
+ return max(block_m * block_n // 4096, 4)
62
+
63
+
64
+ def compute_num_stages(
65
+ precision_config,
66
+ is_persistent,
67
+ block_m,
68
+ block_n,
69
+ block_k,
70
+ out_dtype,
71
+ lhs_dtype,
72
+ rhs_dtype,
73
+ epilogue_subtile,
74
+ epilogue_effective_itemsize,
75
+ ):
76
+ if precision_config.max_num_imprecise_acc is not None:
77
+ return 3
78
+ weight_size = bitwidth(rhs_dtype) / 8
79
+ stage_size = block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
80
+ device_props = torch.cuda.get_device_properties(0)
81
+ smem_capacity = device_props.shared_memory_per_block_optin
82
+ has_native_mxfp = target_info.cuda_capability_geq(10, 0)
83
+ if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None:
84
+ if rhs_dtype == FP4:
85
+ # 4-bit e2m1 weights are padded 2x
86
+ # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
87
+ stage_size += block_k * block_n * weight_size
88
+
89
+ if is_persistent:
90
+ # Per-stage wait barrier
91
+ stage_size += 8
92
+ if target_info.cuda_capability_geq(10, 0):
93
+ acc_size = epilogue_effective_itemsize or out_dtype.itemsize
94
+ else:
95
+ acc_size = out_dtype.itemsize
96
+ if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
97
+ acc_block_n = block_n // epilogue_subtile
98
+ else:
99
+ acc_block_n = block_n
100
+ # pipelined TMA store local to global, or
101
+ # pipelined layout conversion before store of the accumulator
102
+ # note: layout conversion has some padding
103
+ smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
104
+ if precision_config.weight_scale is not None:
105
+ # mx scales
106
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
107
+ elif has_native_mxfp:
108
+ # mx scales
109
+ stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
110
+ num_stages = min(4, smem_capacity // int(stage_size))
111
+ return num_stages
build/torch-universal/triton_kernels/numerics.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass
3
+
4
+ MAX_FINITE_FLOAT8E5 = 57344.0
5
+ MAX_FINITE_FLOAT8E4NV = 448.0
6
+ MAX_FINITE_FLOAT8E4B8 = 240.0
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class BaseFlexData:
11
+ dtype: torch.dtype | None = None
12
+
13
+ def view(self, x: torch.Tensor):
14
+ if self.dtype is None:
15
+ return x
16
+ return x.view(self.dtype)
17
+
18
+ def reinterpret(self, x):
19
+ if self.dtype is None or x.dtype.itemsize > 1:
20
+ return x
21
+ return x.view(self.dtype)
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class InFlexData(BaseFlexData):
26
+ scale: torch.Tensor | None = None
27
+
28
+ @property
29
+ def is_per_batch(self):
30
+ return False if self.scale is None else len(self.scale) > 1
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class OutFlexData(BaseFlexData):
35
+ expected_scale: torch.Tensor | None = None
36
+ actual_scale: torch.Tensor | None = None
37
+ checksum_scale: torch.Tensor | None = None
38
+
39
+ def __iter__(self):
40
+ yield self.expected_scale
41
+ yield self.actual_scale
42
+ yield self.checksum_scale
build/torch-universal/triton_kernels/numerics_details/__init__.py ADDED
File without changes
build/torch-universal/triton_kernels/numerics_details/flexpoint.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
2
+ from triton_kernels import target_info
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ # -------------------------------
7
+ # Kernels stuff
8
+ # -------------------------------
9
+
10
+ TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5)
11
+ TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV)
12
+ TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8)
13
+ TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750)
14
+ TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0)
15
+
16
+ TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16
17
+ TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9
18
+ TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8
19
+ TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1
20
+ TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16
21
+
22
+
23
+ @triton.jit
24
+ def max_finite(dtype):
25
+ if dtype == tl.constexpr(tl.float8e5):
26
+ return TL_MAX_FINITE_FLOAT8E5
27
+ elif dtype == tl.constexpr(tl.float8e4nv):
28
+ return TL_MAX_FINITE_FLOAT8E4NV
29
+ elif dtype == tl.constexpr(tl.float8e4b8):
30
+ return TL_MAX_FINITE_FLOAT8E4B8
31
+ elif dtype == tl.constexpr(tl.float8e4b15):
32
+ return TL_MAX_FINITE_FLOAT8E4B15
33
+ elif dtype == tl.constexpr(tl.float16):
34
+ return TL_MAX_FINITE_FLOAT16
35
+ else:
36
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
37
+
38
+
39
+ @triton.jit
40
+ def rcp_max_finite(dtype):
41
+ if dtype == tl.constexpr(tl.float8e5):
42
+ return TL_RCP_MAX_FINITE_FLOAT8E5
43
+ elif dtype == tl.constexpr(tl.float8e4nv):
44
+ return TL_RCP_MAX_FINITE_FLOAT8E4NV
45
+ elif dtype == tl.constexpr(tl.float8e4b8):
46
+ return TL_RCP_MAX_FINITE_FLOAT8E4B8
47
+ elif dtype == tl.constexpr(tl.float8e4b15):
48
+ return TL_RCP_MAX_FINITE_FLOAT8E4B15
49
+ elif dtype == tl.constexpr(tl.float16):
50
+ return TL_RCP_MAX_FINITE_FLOAT16
51
+ else:
52
+ tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
53
+
54
+
55
+ @tl.constexpr_function
56
+ def cuda_capability_geq(major, minor):
57
+ return target_info.cuda_capability_geq(major, minor)
58
+
59
+
60
+ @triton.jit
61
+ def sm86_min_nan_xorsign_abs_f32(a, b):
62
+ """Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
63
+
64
+ Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
65
+ NaN inputs are propagated to the output.
66
+
67
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
68
+ """
69
+ tl.static_assert(cuda_capability_geq(8, 6), "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+")
70
+ tl.static_assert(a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs")
71
+ tl.static_assert(b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs")
72
+
73
+ return tl.inline_asm_elementwise(
74
+ """{
75
+ min.NaN.xorsign.abs.f32 $0, $1, $2;
76
+ }""",
77
+ "=r,r,r",
78
+ [a, b],
79
+ dtype=tl.float32,
80
+ is_pure=True,
81
+ pack=1,
82
+ )
83
+
84
+
85
+ @triton.jit
86
+ def sm86_max_nan_xorsign_abs_f32(a, b):
87
+ """Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
88
+
89
+ Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
90
+ NaN inputs are propagated to the output.
91
+
92
+ Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
93
+ """
94
+ tl.static_assert(cuda_capability_geq(8, 6), "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+")
95
+ tl.static_assert(a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs")
96
+ tl.static_assert(b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs")
97
+
98
+ return tl.inline_asm_elementwise(
99
+ """{
100
+ max.NaN.xorsign.abs.f32 $0, $1, $2;
101
+ }""",
102
+ "=r,r,r",
103
+ [a, b],
104
+ dtype=tl.float32,
105
+ is_pure=True,
106
+ pack=1,
107
+ )
108
+
109
+
110
+ @triton.jit
111
+ def load_scale(scale_ptr):
112
+ return 1.0 if scale_ptr is None else tl.load(scale_ptr)
113
+
114
+
115
+ @triton.jit
116
+ def flex_to_float(x, scale_ptr):
117
+ scale = load_scale(scale_ptr)
118
+ return x.to(tl.float32) * scale
119
+
120
+
121
+ @triton.jit
122
+ def clip(x, limit):
123
+ res = tl.minimum(x, limit)
124
+ res = tl.maximum(-limit, res)
125
+ return res
126
+
127
+
128
+ @triton.jit
129
+ def nan_propagating_absmax_reduce(x, axis=None):
130
+ if cuda_capability_geq(8, 6):
131
+ # abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
132
+ x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32)
133
+ # Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
134
+ x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
135
+ else:
136
+ # Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
137
+ masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
138
+ x_absmax = tl.max(masked_abs_x, axis)
139
+
140
+ return x_absmax
141
+
142
+
143
+ @triton.jit
144
+ def compute_scale(x, Out):
145
+ x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True))
146
+
147
+ # atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
148
+ # We use integer minimum because NaNs are above +inf in integer representation.
149
+ x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True)
150
+ RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty)
151
+ return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30)
152
+
153
+
154
+ @triton.jit
155
+ def update_scale(x, scale_ptr, Out) -> None:
156
+ if scale_ptr is not None:
157
+ scale = compute_scale(x, Out)
158
+ tl.atomic_max(scale_ptr, scale, sem="relaxed")
159
+
160
+
161
+ @triton.jit
162
+ def float_to_flex(
163
+ x,
164
+ expected_scale_ptr_or_val,
165
+ actual_scale_ptr,
166
+ checksum_scale_ptr,
167
+ mask,
168
+ Out,
169
+ saturate_infs: tl.constexpr,
170
+ ):
171
+ if expected_scale_ptr_or_val is not None:
172
+ if expected_scale_ptr_or_val.dtype.is_ptr():
173
+ invscale = 1.0 / tl.load(expected_scale_ptr_or_val)
174
+ else:
175
+ invscale = 1.0 / expected_scale_ptr_or_val
176
+ else:
177
+ invscale = 1.0
178
+ if checksum_scale_ptr is not None:
179
+ x_int32 = x.to(tl.int32, bitcast=True)
180
+ zero = tl.cast(0.0, tl.int32)
181
+ if mask is not None:
182
+ x_int32 = tl.where(mask, x_int32, zero)
183
+ checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0)
184
+ tl.atomic_add(checksum_scale_ptr, checksum_local)
185
+ if mask is not None:
186
+ if actual_scale_ptr is not None:
187
+ x = tl.where(mask, x, 0.0)
188
+ update_scale(x, actual_scale_ptr, Out)
189
+ x = x * invscale
190
+ # if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
191
+ if expected_scale_ptr_or_val is not None:
192
+ if saturate_infs:
193
+ CLIP_VALUE = max_finite(Out.dtype.element_ty)
194
+ x = clip(x, CLIP_VALUE)
195
+ return x
build/torch-universal/triton_kernels/numerics_details/mxfp.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # isort: off
2
+ # fmt: off
3
+ from enum import Enum
4
+ import triton
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
8
+ from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE
9
+
10
+ # -----------------------------------------------------------------------------
11
+ # Dequantization / Quantization Utilities
12
+ # -----------------------------------------------------------------------------
13
+
14
+
15
+ class DequantScaleRoundingMode(Enum):
16
+ ROUND_UP = 0
17
+ ROUND_DOWN = 1
18
+
19
+
20
+ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
21
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
22
+ """
23
+ Convert the src weights to mx format. The src weight is quantized along the axis dimension.
24
+
25
+ If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
26
+ Note that this means the k_dim of the tensor will be half of the logical k_dim.
27
+
28
+ If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
29
+ in their respective formats.
30
+ """
31
+ ndim = src_tensor.ndim
32
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
33
+ axis = axis if axis >= 0 else axis + ndim
34
+ # downcast
35
+ src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
36
+ is_fp4 = out_quant_type == torch.uint8
37
+ is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
38
+ assert is_fp4 or is_fp8
39
+ divisor = 2 if is_fp4 else 1
40
+ L = src_tensor.shape[-1]
41
+ if is_fp4:
42
+ assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
43
+ out_shape = src_tensor.shape[:-1] + (L // divisor, )
44
+ out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
45
+
46
+ out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
47
+ out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
48
+
49
+ if src_tensor.numel() > 0:
50
+ kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
51
+ kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
52
+ kernel_scale = out_scale.view(-1, out_scale.shape[-1])
53
+
54
+ BLOCK_OUT_DIM = 128
55
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
56
+ grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
57
+ grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
58
+
59
+ _downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
60
+ *kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
61
+ *kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
62
+ DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
63
+
64
+ out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
65
+ out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
66
+ return out_quant_tensor, out_scale
67
+
68
+
69
+ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int):
70
+ """
71
+ Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
72
+
73
+ The function assumes that the tensors were quantized along the given axis.
74
+ It permutes the tensor so that the quantized axis is last, reshapes to 2D,
75
+ launches the Triton upcast kernel, and then unpermutes back to the original order.
76
+ """
77
+ ndim = tensor.ndim
78
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
79
+ axis = axis if axis >= 0 else axis + ndim
80
+ assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
81
+ f"Got {tensor.ndim=} and {scale.ndim=}")
82
+ # dtype checks
83
+ assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
84
+ f"Invalid tensor dtype {tensor.dtype=}"
85
+ assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
86
+ assert dtype in (torch.float16, torch.bfloat16), f"Invalid output dtype {dtype=}"
87
+ # upcast
88
+ logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
89
+ tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
90
+ scale = scale.transpose(axis, scale.ndim - 1).contiguous()
91
+ out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device)
92
+ reshaped_out = out.view(-1, out.shape[-1])
93
+ reshaped_tensor = tensor.view(-1, tensor.shape[-1])
94
+ reshaped_scale = scale.view(-1, scale.shape[-1])
95
+ BLOCK_OUT_DIM = 128
96
+ BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
97
+ blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
98
+ blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
99
+ _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
100
+ *reshaped_scale.stride(), reshaped_tensor,
101
+ *reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
102
+ BLOCK_QUANT_DIM, num_warps=8)
103
+ out = out.transpose(axis, scale.ndim - 1).contiguous()
104
+ return out
105
+
106
+
107
+ # ------------
108
+
109
+
110
+ def right_shift_unsigned(x, shift):
111
+ # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
112
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
113
+
114
+
115
+ def get_max_quant_val(dtype: torch.dtype):
116
+ d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
117
+ assert dtype in d
118
+ return d[dtype]
119
+
120
+
121
+ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
122
+ DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
123
+ """
124
+ Converts the src tensor to the output format specified by out_quant_type.
125
+ axis: The axis along which the tensors are contiguous and quantization is applied.
126
+ DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
127
+
128
+ Returns:
129
+ out_quant_tensor: Quantized tensor in mx format.
130
+ • For mxfp8, the output has the same shape as src_tensor.
131
+ • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
132
+ scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
133
+ Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
134
+ where L is the original length along that axis.
135
+ """
136
+ # This should probably be packed into its own tiny class
137
+ ndim = src_tensor.ndim
138
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
139
+ assert src_tensor.dtype in {torch.float32, torch.bfloat16,
140
+ torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
141
+
142
+ axis = axis if axis >= 0 else axis + ndim
143
+ is_fp4 = out_quant_type == torch.uint8
144
+ is_fp8 = "float8" in str(out_quant_type)
145
+ assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
146
+
147
+ device = src_tensor.device
148
+
149
+ # For mxfp4 conversion, we assume the contiguous axis length is even.
150
+ if is_fp4:
151
+ axis_shape = src_tensor.size(axis)
152
+ assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
153
+
154
+ # Permute the tensor so that the contiguous axis becomes the last dimension.
155
+ src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
156
+ axis_shape = src.shape[-1]
157
+
158
+ # Pad the axis to be divisible by 32, in case it is not.
159
+ next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
160
+ pad_amount = next_multiple - axis_shape
161
+ padded_src = F.pad(src, (0, pad_amount))
162
+ valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
163
+ padded_axis_shape = padded_src.size(-1) # now divisible by 32
164
+
165
+ # --- Compute per-group maximums for scale ---
166
+ # Set padded entries to -1 so they don’t affect the max.
167
+ abs_f = torch.abs(padded_src)
168
+ abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
169
+ # Reshape the last dimension into groups of 32.
170
+ new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
171
+ abs_groups = abs_f.view(*new_shape)
172
+ # Compute maximum along the group dimension (of size 32).
173
+ max_val, _ = abs_groups.max(dim=-1, keepdim=True)
174
+
175
+ # Choose a max quantization value depending on type.
176
+ max_quant_val = get_max_quant_val(out_quant_type)
177
+ dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
178
+
179
+ # Convert to int to round the FP32 scale, prior to quantization!
180
+ ds_int = dequant_scale.view(torch.int32)
181
+ if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
182
+ ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
183
+ else:
184
+ ds_int_rounded = ds_int & 0x7F800000
185
+ # Reinterpret back as float32.
186
+ dequant_scale_rounded = ds_int_rounded.view(torch.float32)
187
+
188
+ # Compute the quantization scale.
189
+ quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
190
+
191
+ # Quantize the tensor
192
+ orig_padded_shape = padded_src.shape
193
+ padded_src_groups = padded_src.view(*new_shape)
194
+ quant_tensor = padded_src_groups * quant_scale
195
+ # Reshape back to the original shape and trim padding
196
+ quant_tensor = quant_tensor.view(orig_padded_shape)
197
+ quant_tensor = quant_tensor[..., :axis_shape]
198
+
199
+ # Finally, convert the quantized tensor to the target format
200
+ if is_fp8:
201
+ # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
202
+ quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
203
+ out_weight = quant_tensor.to(out_quant_type)
204
+ else:
205
+ assert is_fp4, f"Invalid output quantization type {out_quant_type}"
206
+ # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
207
+ # First, reinterpret the quantized tensor bits.
208
+ q_int = quant_tensor.contiguous().view(torch.int32)
209
+ # Extract sign, exponent, and mantissa.
210
+ signs = q_int & 0x80000000
211
+ exponents = right_shift_unsigned(q_int, 23) & 0xFF
212
+ mantissas = q_int & 0x7FFFFF
213
+
214
+ E8_BIAS = 127
215
+ E2_BIAS = 1
216
+ # Adjust mantissas for subnormals.
217
+ mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
218
+ (E8_BIAS - exponents - 1), mantissas)
219
+ exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
220
+ e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
221
+ e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
222
+ e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
223
+
224
+ # Pack pairs of 4-bit values along the last dimension.
225
+ e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
226
+ evens = e2m1_value[..., 0]
227
+ odds = e2m1_value[..., 1]
228
+ out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
229
+
230
+ # --- Process and output the scale ---
231
+ dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
232
+ dq_scale = dq_scale.squeeze(-1)
233
+ out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
234
+ dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
235
+ return out_weight, dq_scale
236
+
237
+
238
+ def cvt_e2m1_to_fp32(input_tensor):
239
+ assert input_tensor.dtype == torch.uint8
240
+
241
+ input_tensor = input_tensor.to(torch.int32)
242
+ evens = input_tensor & 0xF
243
+ odds = (input_tensor >> 4) & 0xF
244
+
245
+ vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
246
+ outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
247
+ outputs = torch.cat([outputs, -outputs])
248
+
249
+ even_floats = outputs[evens]
250
+ odd_floats = outputs[odds]
251
+ output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
252
+ output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
253
+ return output_tensor
254
+
255
+
256
+ def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
257
+ """
258
+ Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
259
+ axis: The axis along which dequantization is applied.
260
+
261
+ Returns:
262
+ out_weight: Tensor in the target format.
263
+ """
264
+
265
+ ndim = tensor.ndim
266
+ assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
267
+ is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
268
+ assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
269
+
270
+ # Permute the tensor and scale so that the quantization axis becomes the last dimension
271
+ axis = axis if axis >= 0 else axis + ndim
272
+ scale = scale.transpose(axis, scale.ndim - 1)
273
+ tensor = tensor.transpose(axis, tensor.ndim - 1)
274
+
275
+ dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
276
+ if tensor.dtype == torch.uint8:
277
+ fp32_tensor = cvt_e2m1_to_fp32(tensor)
278
+ else:
279
+ fp32_tensor = tensor.to(torch.float32)
280
+
281
+ logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
282
+ axis_shape = fp32_tensor.size(-1)
283
+ padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
284
+ pad_size = padded_axis_shape - axis_shape
285
+ padded_tensor = F.pad(fp32_tensor, (0, pad_size))
286
+
287
+ new_axis_shape = padded_tensor.shape[-1]
288
+ new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
289
+ padded_tensor = padded_tensor.view(*new_shape)
290
+ dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
291
+ out_padded = padded_tensor * dq_scale_padded
292
+
293
+ # Flatten back and remove the padded tail
294
+ out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
295
+ out_tensor = out_padded[..., :axis_shape]
296
+
297
+ out_tensor = out_tensor.to(target_dtype).contiguous()
298
+ out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
299
+
300
+ return out_tensor
301
+
302
+
303
+ dequantize_mxfp8_fn = _dequantize_mxfp8_fn
build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ # fmt: off
5
+
6
+
7
+ MXFP_BLOCK_SIZE = tl.constexpr(32)
8
+
9
+
10
+ @triton.jit
11
+ def _get_max_quant_val(dtype: tl.constexpr):
12
+ if dtype == tl.uint8:
13
+ return 6.0
14
+ elif dtype == tl.float8e5:
15
+ return 57344.0
16
+ elif dtype == tl.float8e4nv:
17
+ return 448.0
18
+ else:
19
+ tl.static_assert(False, f"Invalid {dtype=}")
20
+
21
+ @triton.jit
22
+ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
23
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
24
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
25
+ BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
26
+ BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
27
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
28
+
29
+ # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
30
+ f32_tensor = src_tensor.to(tl.float32)
31
+ abs_tensor = tl.abs(f32_tensor)
32
+ abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
33
+ abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
34
+ max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
35
+ dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
36
+ if DEQUANT_SCALE_ROUNDING_MODE == 0:
37
+ # DequantScaleRoundingMode.ROUND_UP
38
+ # compute 2 ** ceil(log2(dequant_scale))
39
+ # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
40
+ # A corner case: exponent is 0xFF that will overflow but that's already
41
+ # NaN so assume we don't care.
42
+ dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
43
+ else:
44
+ # DequantScaleRoundingMode.ROUND_DOWN
45
+ # compute 2 ** floor(log2(dequant_scale))
46
+ assert DEQUANT_SCALE_ROUNDING_MODE == 1
47
+ dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
48
+ dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
49
+ quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
50
+
51
+ f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
52
+ quant_tensor = f32_tensor * quant_scale
53
+
54
+ # Reshape the tensors after scaling
55
+ quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
56
+ # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
57
+ quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
58
+ dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
59
+
60
+ # First, we simply extract the exponent part of the scales and store the result
61
+ dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
62
+ # Now we must convert the tensors to the mx format.
63
+ if is_fp8:
64
+ out_tensor = quant_tensor.to(mx_tensor_dtype)
65
+ else:
66
+ quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
67
+ signs = quant_tensor & 0x80000000
68
+ exponents = (quant_tensor >> 23) & 0xFF
69
+ mantissas = (quant_tensor & 0x7FFFFF)
70
+
71
+ # 0.25 <= x < 0.75 maps to 0.5, a denormal number
72
+ E8_BIAS = 127
73
+ E2_BIAS = 1
74
+ # Move implicit bit 1 at the beginning to mantissa for denormals
75
+ adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
76
+ mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
77
+
78
+ # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
79
+ exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
80
+
81
+ # Combine sign, exponent, and mantissa, while saturating
82
+ # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
83
+ e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
84
+ e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
85
+
86
+ e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
87
+ evens, odds = tl.split(e2m1_value)
88
+ out_tensor = evens | (odds << 4)
89
+
90
+ return out_tensor, dequant_scale_exponent
91
+
92
+ @triton.jit
93
+ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
94
+ mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
95
+ src_ptr, stride_src_outer, stride_src_quant,
96
+ outer_dim, quant_dim,
97
+ BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
98
+ DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
99
+
100
+ tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
101
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
102
+
103
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
104
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
105
+ tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
106
+ f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
107
+
108
+ src_dtype: tl.constexpr = src_ptr.dtype.element_ty
109
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
110
+ tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16")
111
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
112
+
113
+ outer_block = tl.program_id(0).to(tl.int64)
114
+ quant_block = tl.program_id(1).to(tl.int64)
115
+
116
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
117
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
118
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
119
+
120
+ start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
121
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
122
+ start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
123
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
124
+
125
+ src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
126
+ mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
127
+ mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
128
+
129
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
130
+ offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
131
+ offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
132
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
133
+
134
+ mask_src_quant = start_src_quant + offs_src_quant < quant_dim
135
+ mask_n = start_out + offs_outer < outer_dim
136
+ full_mask_src = mask_src_quant & mask_n
137
+
138
+ mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
139
+ full_mask_mxt = mask_mxt_quant & mask_n
140
+
141
+ scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
142
+ full_scale_mask = scale_mask_k & mask_n
143
+
144
+ src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
145
+ mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
146
+ mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
147
+ src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
148
+
149
+ out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
150
+ DEQUANT_SCALE_ROUNDING_MODE)
151
+
152
+ tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
153
+ tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
154
+
155
+
156
+ @triton.jit(repr=lambda _: "_dequantize_mxfp8")
157
+ def _dequantize_mxfp8_fn(input, mask, pid=None):
158
+ return _compute_quant_and_scale(input, mask, tl.float8e4nv)
build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+ from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
4
+
5
+
6
+ # fmt: off
7
+ @triton.jit
8
+ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
9
+ stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
10
+ outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
11
+
12
+ tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
13
+ tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
14
+ # uint8 signifies two fp4 e2m1 values packed into a single byte
15
+ mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
16
+ dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
17
+ tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16)
18
+ tl.static_assert(
19
+ mx_tensor_dtype == tl.uint8
20
+ or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
21
+ "mx_tensor_ptr must be uint8 or float8 or dst_dtype")
22
+ tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
23
+
24
+ # Determine if we are dealing with fp8 types.
25
+ is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
26
+ is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
27
+ K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
28
+ BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
29
+ BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
30
+
31
+ # Compute starting indices for the quantized (packed) dimension and the outer dimension.
32
+ outer_block = tl.program_id(0).to(tl.int64)
33
+ quant_block = tl.program_id(1).to(tl.int64)
34
+
35
+ start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
36
+ start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
37
+ start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
38
+ start_out = outer_block * BLOCK_SIZE_OUT_DIM
39
+
40
+ mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
41
+ mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
42
+ out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
43
+
44
+ # Compute offsets and masks.
45
+ offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
46
+ offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
47
+ offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
48
+ offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
49
+
50
+ mask_outer = start_out + offs_outer < outer_dim
51
+ mask_out_quant = start_out_quant + offs_out_quant < quant_dim
52
+ full_mask_out = mask_out_quant & mask_outer
53
+
54
+ mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
55
+ full_mask_src = mask_src_quant & mask_outer
56
+
57
+ mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
58
+ full_scale_mask = mask_scale & mask_outer
59
+
60
+ tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
61
+ scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
62
+ out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
63
+
64
+ # Load the packed tensor and scale.
65
+ tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
66
+ scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
67
+
68
+ # Upcast the scale to the destination type.
69
+ if dst_dtype == tl.bfloat16:
70
+ dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
71
+ else:
72
+ tl.static_assert(dst_dtype == tl.float16)
73
+ dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
74
+ dst_scale = dst_scale.to(tl.float16)
75
+
76
+ # Now upcast the tensor.
77
+ if is_fp8:
78
+ dst_tensor = tensor.to(dst_dtype)
79
+ if tensor.dtype == tl.float8e5:
80
+ from_e_bits: tl.constexpr = 5
81
+ from_m_bits: tl.constexpr = 2
82
+ to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5
83
+ to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
84
+
85
+ # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
86
+ non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
87
+ non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
88
+ dst_tensor = tl.where(
89
+ (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
90
+ (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True),
91
+ dst_tensor,
92
+ )
93
+ else:
94
+ assert is_fp4
95
+ dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15
96
+ dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800
97
+ dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
98
+ # e2m1
99
+ em0 = tensor & 0x07
100
+ em1 = tensor & 0x70
101
+ x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
102
+ x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
103
+ # Three cases:
104
+ # 1) x is normal and non-zero: Correct bias
105
+ x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
106
+ x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
107
+ # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
108
+ x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
109
+ x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
110
+ # 3) x is zero, do nothing
111
+ dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True)
112
+
113
+ # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
114
+ dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
115
+ dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
116
+ scale = scale.reshape(dst_scale.shape)
117
+
118
+ out_tensor = dst_tensor * dst_scale
119
+ # Correct any NaNs encoded via the scale.
120
+ out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
121
+ out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
122
+ tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
build/torch-universal/triton_kernels/proton_opts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # proton options
2
+
3
+ import os
4
+
5
+ _launch_metadata_allow_sync = None
6
+
7
+
8
+ def launch_metadata_allow_sync():
9
+ global _launch_metadata_allow_sync
10
+ if _launch_metadata_allow_sync is None:
11
+ _launch_metadata_allow_sync = not (os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1")
12
+ return _launch_metadata_allow_sync
13
+
14
+
15
+ def set_launch_metadata_allow_sync(allow_sync: bool):
16
+ global _launch_metadata_allow_sync
17
+ _launch_metadata_allow_sync = allow_sync
build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def vpopc(x):
8
+ """
9
+ Vertical popcount
10
+ Input x : uint32[..., N]
11
+ Output y : uint32[..., 32]
12
+ semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
13
+ credits: @apgoucher
14
+ """
15
+
16
+ tl.static_assert(x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers")
17
+
18
+ BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
19
+ BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
20
+ if BLOCK_N >= 8:
21
+ sa1: tl.constexpr = 8
22
+ else:
23
+ sa1: tl.constexpr = BLOCK_N
24
+ # create 8-way sums in 4-bit fields:
25
+ y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
26
+ y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
27
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
28
+ if BLOCK_N >= 128:
29
+ sa2: tl.constexpr = 16
30
+ else:
31
+ sa2: tl.constexpr = BLOCK_N // sa1
32
+ # create 128-way sums in 8-bit fields:
33
+ y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
34
+ y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0f0f0f0f
35
+ y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
36
+ sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
37
+ # create N-way sums in 32-bit fields:
38
+ y = tl.reshape(y, [BATCHES, 1, sa3, 8])
39
+ y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000ff
40
+ y = tl.sum(y, 2) # [BATCHES, 4, 8]
41
+ y = tl.reshape(y, x.shape[:-1] + [32])
42
+ return y
43
+
44
+
45
+ @triton.jit
46
+ def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
47
+ pid = tl.program_id(0)
48
+ offs = pid * BLOCK + tl.arange(0, BLOCK)
49
+ tl.store(Ret + offs, 0)
50
+
51
+
52
+ @triton.jit
53
+ def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix
54
+ Ret, Partials, stride_pm: tl.constexpr, stride_pn, shape_pn, # outputs
55
+ BLOCK_MM: tl.constexpr, BLOCK_M: tl.constexpr):
56
+
57
+ tl.static_assert(BLOCK_MM % BLOCK_M == 0)
58
+ TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
59
+ if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
60
+ shape_bm = tl.load(shape_bm)
61
+ pid_m = tl.program_id(0)
62
+ pid_n = tl.program_id(1)
63
+ offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
64
+ offs_n = pid_n * 32 + tl.arange(0, 32)
65
+ n_rows = shape_bm
66
+ bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0)
67
+ bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
68
+ ret = vpopc(bits) # [TILE_SIZE, 32]
69
+
70
+ offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
71
+
72
+ tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
73
+ tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
74
+
75
+
76
+ def clear_sums(n_cols, device, MEMSET_BLOCK=512):
77
+ cdiv = triton.cdiv
78
+ blocks = cdiv(n_cols, MEMSET_BLOCK)
79
+ out_ret = torch.empty((blocks * MEMSET_BLOCK, ), device=device, dtype=torch.int32)
80
+ _sum_bitmatrix_memset[(blocks, )](out_ret, MEMSET_BLOCK)
81
+ return out_ret
82
+
83
+
84
+ def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
85
+ assert partials_block_size is not None
86
+ cdiv = triton.cdiv
87
+ PARTIALS_BLOCK_M = partials_block_size
88
+ n_rows, n_cols = x.shape
89
+ n_rows_max = x.shape_max[0]
90
+ assert out_ret.shape == (n_cols, )
91
+
92
+ TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
93
+ BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
94
+
95
+ pids_x = cdiv(n_rows_max, BLOCK_MM)
96
+ pids_y = cdiv(n_cols, 32)
97
+ out_partials = torch.empty((pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32)
98
+ out_partials = torch.transpose(out_partials, 0, 1)
99
+
100
+ # output tensors
101
+ _sum_bitmatrix_rows[(pids_x, pids_y)](
102
+ x.storage.data, n_rows, x.stride(0), x.stride(1), # input
103
+ out_ret, # output [final reduction]
104
+ out_partials, out_partials.stride(0), out_partials.stride(1),
105
+ out_partials.shape[1], # output [partial reductions]
106
+ BLOCK_M=PARTIALS_BLOCK_M, BLOCK_MM=BLOCK_MM, # constants
107
+ num_warps=8)
108
+
109
+ out_partials = out_partials[:cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
110
+
111
+ return out_ret, out_partials
build/torch-universal/triton_kernels/routing.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ from dataclasses import dataclass, field
4
+ from .routing_details._routing_compute import _combined_routing_compute
5
+ from .routing_details._routing_compute import _combined_routing_memset
6
+ from .routing_details._routing_compute import _routing_clear_bitmatrix
7
+ from .routing_details._expt_data import _expt_data_memset
8
+ from .routing_details._expt_data import _expt_data_compute
9
+ from .target_info import is_hip
10
+
11
+
12
+ @dataclass
13
+ class GatherIndx:
14
+ """
15
+ Indices for an operation that performs:
16
+ Y = X[src_idx, :]
17
+ """
18
+ # array such that `dst_idx[src_idx] = arange(0, N)`
19
+ src_indx: torch.Tensor
20
+ dst_indx: torch.Tensor
21
+
22
+
23
+ @dataclass
24
+ class ScatterIndx:
25
+ """
26
+ Indices for an operation that performs:
27
+ Y[dst_idx, :] = X
28
+ """
29
+ # array such that `dst_idx[src_idx] = arange(0, N)`
30
+ src_indx: torch.Tensor
31
+ dst_indx: torch.Tensor
32
+
33
+
34
+ @dataclass
35
+ class ExptData:
36
+ # hist[i] is the number of tokens routed to expert i
37
+ hist: torch.Tensor
38
+ # token_offs_raw[i] is the offset of the first token routed
39
+ # to expert i in an expert-sorted array
40
+ token_offs_raw: torch.Tensor
41
+ # token_offs_pad[block][i] is the offset of the first token routed
42
+ # to expert i in an expert-sorted array, assuming histogram
43
+ # rounded to the next multiple of `block`
44
+ token_offs_pad: dict[int, torch.Tensor]
45
+ # block_id_map[block] contain one value for each `pid`` launched by
46
+ # the matrix multiplication kernel launched with BLOCK_M=block:
47
+ # - the value is -1 if the `pid` has no work to do
48
+ # - otherwise, the value is two int16 (packed as an int32) that
49
+ # correspond respectively to (1) the expert assigned to
50
+ # the tokens processed by this pid; (2) the block assigned to the
51
+ # tokens processed by this pid (think `pid_m` in a regular matmul)
52
+ # see `test_routing.py` for a reference implementation and more details
53
+ block_pid_map: dict[int, torch.Tensor]
54
+
55
+ def __post_init__(self):
56
+ if self.hist is not None:
57
+ assert self.hist.dtype == torch.int32
58
+ if self.token_offs_raw is not None:
59
+ assert self.token_offs_raw.dtype == torch.int32
60
+ if self.token_offs_pad is not None:
61
+ for v in self.token_offs_pad.values():
62
+ assert v.dtype == torch.int32
63
+ if self.block_pid_map is not None:
64
+ for v in self.block_pid_map.values():
65
+ assert v.dtype == torch.int32
66
+
67
+
68
+ @dataclass
69
+ class RoutingData:
70
+ gate_scal: torch.Tensor = field()
71
+ expt_hist: torch.Tensor = field()
72
+ n_expts_tot: int = field()
73
+ n_expts_act: int = field()
74
+ expt_data: ExptData = None
75
+
76
+ # Used to make perf annotation cleaner: when we use expert sharding, we can
77
+ # use this to tell the "expected" number of local tokens per expert, because
78
+ # the actual number can vary per each input.
79
+ expected_tokens_per_expt: int = field(default=None)
80
+
81
+ def n_blocks(self, n_rows, block_m):
82
+ if n_rows <= self.n_expts_tot:
83
+ return n_rows
84
+ else:
85
+ return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1
86
+
87
+
88
+ # --------------------------
89
+ # sort tokens by expert
90
+ # --------------------------
91
+
92
+
93
+ class SortTokens(torch.autograd.Function):
94
+
95
+ @staticmethod
96
+ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
97
+ HIST_BLOCK_M = 32
98
+ INDX_OFFS_BLOCK_M = 512
99
+ MEMSET_BLOCK = 1024
100
+ cdiv = triton.cdiv
101
+
102
+ device = expt_scal.device
103
+ dtype = expt_scal.dtype
104
+ n_tokens_raw, _ = bitmatrix.shape
105
+ n_tokens_pad, n_expts_act = expt_scal.shape
106
+ n_gates_pad = n_tokens_pad * n_expts_act
107
+
108
+ hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
109
+ hist = hist[:n_expts_tot]
110
+ assert hist.dtype == torch.int32
111
+ # scratchpad
112
+ expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
113
+ combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
114
+ # output
115
+ topk_indx = combined_indx[:n_gates_pad]
116
+ gate_indx = combined_indx[n_gates_pad:]
117
+ gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
118
+
119
+ token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
120
+ hist, n_expts_tot, n_gates_pad)
121
+
122
+ blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
123
+ blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
124
+
125
+ _combined_routing_memset[(blocks1a + blocks1b, )](
126
+ combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, #
127
+ expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs
128
+ partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
129
+ token_offs_combined, token_offs_combined.stride(0), #
130
+ blocks1a, block_pid_map, #
131
+ block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
132
+ BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
133
+ )
134
+
135
+ indx_offs = partial_hist
136
+
137
+ _combined_routing_compute[(blocks2a + blocks2b, )](
138
+ topk_indx, gate_indx, gate_scal, # outputs
139
+ expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs
140
+ expt_offs, n_tokens_raw, # input shape
141
+ HIST_BLOCK_M, n_expts_act, # constants
142
+ hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
143
+ block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc.
144
+ )
145
+
146
+ ctx.n_tokens_raw = n_tokens_raw
147
+ ctx.n_tokens_pad = n_tokens_pad
148
+ ctx.n_expts_act = n_expts_act
149
+ ctx.save_for_backward(gate_indx)
150
+ return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map
151
+
152
+ @staticmethod
153
+ def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
154
+ (gate_indx, ) = ctx.saved_tensors
155
+ dgate_scal = dgate_scal[gate_indx]
156
+ dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
157
+ return dgate_scal, None, None, None
158
+
159
+
160
+ def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
161
+ return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
162
+
163
+
164
+ # --------------------------
165
+ # prune routing
166
+ # --------------------------
167
+
168
+
169
+ class PruneRouting(torch.autograd.Function):
170
+
171
+ @staticmethod
172
+ def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
173
+ from .compaction import compaction
174
+ n_tokens_pad = expt_scal.shape[0]
175
+ assert n_expts_tot % simulated_ep == 0
176
+ _routing_clear_bitmatrix[(n_tokens_pad, )](
177
+ bitmatrix.storage.data,
178
+ bitmatrix.storage.data.stride(0),
179
+ bitmatrix.storage.data.stride(1),
180
+ bitmatrix.storage.data.shape[1],
181
+ n_expts_tot // simulated_ep,
182
+ BLOCK_N=512,
183
+ )
184
+ # perform compaction to update expt_scal / expt_indx
185
+ expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
186
+ n_expts_tot = n_expts_tot // simulated_ep
187
+ bitmatrix.shape[-1] = n_expts_tot
188
+ return expt_scal, expt_indx, bitmatrix
189
+
190
+
191
+ def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
192
+ return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep)
193
+
194
+
195
+ # --------------------------
196
+ # expt_data
197
+ # --------------------------
198
+
199
+
200
+ def log2_power_of_two(x):
201
+ assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
202
+ return x.bit_length() - 1
203
+
204
+
205
+ block_m_log2_start = 4
206
+
207
+
208
+ def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
209
+
210
+ MEMSET_BLOCK = 512
211
+ HIST2_BLOCK_M = 512
212
+ device = expt_hist.device
213
+ n_expts_tot = n_expts_tot
214
+ cdiv = triton.cdiv
215
+ # block_ms are all powers-of-two between 16 and 128 (inclusive)
216
+ block_m_log2_end = 9 if is_hip() else 8
217
+ block_m_num = block_m_log2_end - block_m_log2_start
218
+ if n_gates <= n_expts_tot:
219
+ max_n_tiles = n_gates
220
+ else:
221
+ max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
222
+ # allocate memory
223
+ pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
224
+ dtype = torch.int32
225
+
226
+ token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device)
227
+
228
+ token_offs_raw = token_offs_combined[0][:n_expts_tot + 1]
229
+ token_offs_pad = token_offs_combined[1:]
230
+
231
+ block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device)
232
+ memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
233
+ # compute outputs
234
+ token_offs_pad = token_offs_pad[:, :n_expts_tot + 1]
235
+ block_pid_map = block_pid_map[:, :max_n_tiles]
236
+
237
+ blocks1 = memset_grid + block_m_num + 1
238
+ blocks2 = n_expts_tot * block_m_num
239
+ return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num
240
+
241
+
242
+ def _unpack_into_dict(x):
243
+
244
+ block_m_log2_end = block_m_log2_start + x.shape[0]
245
+ x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
246
+ return x
247
+
248
+
249
+ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
250
+
251
+ if expt_hist is None:
252
+ return ExptData(None, None, None, None)
253
+
254
+ # this just computes the kernel arguments:
255
+ token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
256
+ expt_hist, n_expts_tot, n_gates)
257
+
258
+ _expt_data_memset[(blocks1, )](
259
+ expt_hist, n_expts_tot, #
260
+ token_offs_combined, token_offs_combined.stride(0), #
261
+ block_pid_map, #
262
+ block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters
263
+ num_warps=4)
264
+ _expt_data_compute[(blocks2, )](
265
+ expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
266
+ block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters
267
+ num_warps=4)
268
+
269
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
270
+ block_pid_map = _unpack_into_dict(block_pid_map)
271
+ return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
272
+
273
+
274
+ # --------------------------
275
+ # routing
276
+ # --------------------------
277
+
278
+
279
+ def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
280
+ hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
281
+ expt_scal, expt_indx, n_expts_tot, bitmatrix)
282
+ token_offs_pad = _unpack_into_dict(token_offs_pad)
283
+ block_pid_map = _unpack_into_dict(block_pid_map)
284
+ expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
285
+
286
+ # pack the matmul data structure
287
+ gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
288
+ scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
289
+ return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
290
+
291
+
292
+ def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None):
293
+ from .topk import topk
294
+ if sm_first:
295
+ logits = torch.softmax(logits, dim=-1)
296
+ expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, #
297
+ apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows)
298
+ n_expts_tot = logits.shape[-1] // simulated_ep
299
+ # mutate bitmatrix
300
+ if simulated_ep > 1:
301
+ expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep)
302
+
303
+ return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act)
304
+
305
+
306
+ # --------------------------
307
+ # torch reference
308
+ # --------------------------
309
+
310
+
311
+ def compute_expt_data_torch(hist, n_expts_tot, n_gates):
312
+ # offset for each experts
313
+ device = hist.device
314
+ token_offs_raw = torch.cumsum(hist, dim=0)
315
+ token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
316
+ token_offs_raw = token_offs_raw.int()
317
+ # maximum number of tiles for all values of `block_m` considered
318
+ block_ms = [16, 32, 64, 128]
319
+ if is_hip():
320
+ block_ms.append(256)
321
+ if n_gates <= n_expts_tot:
322
+ max_n_tiles = n_gates
323
+ else:
324
+ # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
325
+ # ceil_div(x, y): -(-x // y)
326
+ max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
327
+ # fill up tile offset/infos for each block
328
+ token_offs_pad = dict()
329
+ block_pid_map = dict()
330
+ for block_m in block_ms:
331
+ n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
332
+ token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
333
+ token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m]))
334
+ token_offs_pad[block_m] = token_offs_pad[block_m].int()
335
+ # compute data required to drive ragged batch matmul
336
+ block_pid_map[block_m] = -torch.ones(max_n_tiles, device=device)
337
+ for e in range(n_expts_tot):
338
+ offset = token_offs_pad[block_m][e]
339
+ for b in range(n_tiles[e]):
340
+ block_pid_map[block_m][offset + b] = (b << 16) + e
341
+ block_pid_map[block_m] = block_pid_map[block_m].int()
342
+ return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
343
+
344
+
345
+ def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
346
+ has_user_provided_indx = expt_indx is not None
347
+ n_gates_pad = logits.shape[0] * n_expts_act
348
+
349
+ if n_rows is not None:
350
+ logits = logits[:n_rows, :]
351
+
352
+ def topk(vals, k, expt_indx):
353
+ # topk of experts
354
+ if has_user_provided_indx:
355
+ tk_indx = expt_indx
356
+ else:
357
+ tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
358
+ tk_indx = tk_indx.long()
359
+ tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
360
+ tk_indx = tk_indx.int()
361
+ return tk_val, tk_indx
362
+
363
+ _, n_expts_tot = logits.shape
364
+ if sm_first:
365
+ logits = torch.softmax(logits, dim=-1)
366
+ expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx)
367
+ if not sm_first:
368
+ expt_scal = torch.softmax(expt_scal, dim=-1)
369
+ # sort each token's selections by expert
370
+ if not has_user_provided_indx:
371
+ expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
372
+ expt_scal = torch.gather(expt_scal, 1, sort_indices)
373
+ # flatten topk data
374
+ expt_scal = expt_scal.reshape(-1)
375
+ expt_indx = expt_indx.reshape(-1).to(torch.int32)
376
+ # sort by expert_id so experts are contiguous for the matmul
377
+ topk_indx = torch.argsort(expt_indx, stable=True)
378
+ gate_indx = torch.argsort(topk_indx, stable=True)
379
+ gate_scal = expt_scal[topk_indx]
380
+ hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() # histogram of tokens over experts
381
+ # pack the matmul data structure
382
+ gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
383
+ scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
384
+ # compute expt_data
385
+ expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
386
+ return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
build/torch-universal/triton_kernels/routing_details/_expt_data.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def _cdiv_pow2(n, log2_k):
7
+ return (n + ((1 << log2_k) - 1)) >> log2_k
8
+
9
+
10
+ @triton.jit
11
+ def _expt_data_memset(Hist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
12
+ SIZES: tl.constexpr, BLOCK: tl.constexpr):
13
+
14
+ pid = tl.program_id(0)
15
+
16
+ if pid <= SIZES:
17
+
18
+ MDStarts += pid * tile_starts_stridem
19
+ x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
20
+ Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
21
+ tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
22
+
23
+ for i in range(0, n_expts_tot + 1, BLOCK):
24
+
25
+ offs_n = tl.arange(0, BLOCK) + i
26
+ mask_n0 = offs_n < n_expts_tot
27
+ hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
28
+ hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
29
+
30
+ tile_starts = tl.cumsum(hist_tile, 0) + x_tile
31
+ x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
32
+ tl.store(Tile_ptrs, tile_starts - hist_tile)
33
+ Tile_ptrs += BLOCK
34
+
35
+ else:
36
+
37
+ pid -= (SIZES + 1)
38
+ TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
39
+ tl.store(TileInfoOut, 0xffffffff)
40
+
41
+
42
+ @triton.jit
43
+ def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
44
+ SIZES: tl.constexpr, BLOCK: tl.constexpr):
45
+
46
+ pid = tl.program_id(0)
47
+
48
+ expt_id = pid // SIZES
49
+ buff_id = pid % SIZES
50
+
51
+ MDTileStarts += buff_id * tile_starts_stridem
52
+ MDTileInfo += buff_id * tile_info_stridem
53
+
54
+ n_tokens = tl.load(Hist + expt_id)
55
+ tile_dim_log2 = first_tile_dim_log2 + buff_id
56
+ n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
57
+
58
+ tile_off = tl.load(MDTileStarts + expt_id)
59
+ MDTileInfo += tile_off
60
+
61
+ for block_off in range(0, n_blocks, BLOCK):
62
+ block_offs = block_off + tl.arange(0, BLOCK)
63
+ data = (block_offs << 16) + expt_id
64
+ tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
build/torch-universal/triton_kernels/routing_details/_routing_compute.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+ from ._expt_data import _expt_data_compute, _expt_data_memset
5
+
6
+
7
+ @triton.jit
8
+ def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histogram
9
+ BLOCK_N: tl.constexpr):
10
+ loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
11
+ x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
12
+ for i in range(loop_iterations):
13
+ offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
14
+ mask_n = offs_n < hist_size
15
+ hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
16
+ tok_starts = tl.cumsum(hist2, 0) - hist2 + x
17
+ x += tl.sum(hist2, 0)
18
+ tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
19
+ offs_n += BLOCK_N
20
+
21
+
22
+ @triton.jit
23
+ def _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id):
24
+ offs_m = tl.arange(0, BLOCK_M)
25
+ # iterate over input data
26
+ curr_sum = 0
27
+ for _ in range(0, shape_pm, BLOCK_M):
28
+ offs = offs_m * stride_pm + expt_id * stride_pn
29
+ curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
30
+ out = tl.cumsum(curr, 0) + curr_sum
31
+ curr_sum += tl.sum(curr, 0)
32
+ tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
33
+ offs_m += BLOCK_M
34
+
35
+
36
+ @triton.jit
37
+ def _keyed_add(x, y):
38
+
39
+ # we keep the key in the upper 16 bits of a uint32:
40
+ key_mask: tl.constexpr = 0xffff0000
41
+
42
+ kx = x & key_mask
43
+ ky = y & key_mask
44
+ z = tl.where(kx == ky, x + y - kx, y)
45
+ return z
46
+
47
+
48
+ @triton.jit
49
+ def _routing_compute_indx(pid_m, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
50
+ stride_pn, TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
51
+
52
+ if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
53
+ n_tokens = tl.load(n_tokens)
54
+ n_gates = n_tokens * N_EXPTS_ACT
55
+
56
+ tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
57
+
58
+ local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
59
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
60
+ expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
61
+
62
+ # stable-sort by expert ID:
63
+ kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
64
+ kv_pairs = tl.sort(kv_pairs, 0)
65
+ expert = kv_pairs >> 16
66
+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff)
67
+ mask = expert != 0xffff
68
+ gate_scal = tl.load(ExptScal + offs, mask=mask)
69
+
70
+ # compute run lengths in expert-sorted order:
71
+ x = (kv_pairs & 0xffff0000 | 0x00000001)
72
+ expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
73
+ exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
74
+
75
+ gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
76
+ gates += tl.load(TokensStart + expert, mask=mask)
77
+ gates += exclusive_run_lengths
78
+
79
+ tl.store(ScatterIndx + offs, gates, mask=mask)
80
+ tl.store(GatherIndx + gates, offs, mask=mask)
81
+ tl.store(GateScal + gates, gate_scal, mask=mask)
82
+
83
+
84
+ @triton.jit
85
+ def _combined_routing_compute(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn,
86
+ TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr, Hist,
87
+ MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
88
+ SIZES: tl.constexpr, BLOCK: tl.constexpr, blocks2a):
89
+
90
+ pid = tl.program_id(0)
91
+ if pid < blocks2a:
92
+ _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
93
+ SIZES, BLOCK)
94
+ else:
95
+ pid -= blocks2a
96
+ _routing_compute_indx(pid, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
97
+ stride_pn, TokensStart, n_tokens, BLOCK_M, N_EXPTS_ACT)
98
+
99
+
100
+ @triton.jit
101
+ def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr):
102
+ pid_m = tl.program_id(0)
103
+ cutoff_word = cutoff // 32
104
+ cutoff_bit = cutoff % 32
105
+ cutoff_mask = (1 << (cutoff_bit)) - 1
106
+ for start_n in range(0, shape_bn, BLOCK_N):
107
+ offs_n = start_n + tl.arange(0, BLOCK_N)
108
+ values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn)
109
+ values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
110
+ values = tl.where(offs_n > cutoff_word, 0, values)
111
+ tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn)
112
+
113
+
114
+ @triton.jit
115
+ def _combined_routing_memset(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
116
+ n_expts_tot, PartialHist, shape_pm, stride_pm, stride_pn, MDStarts, tile_starts_stridem,
117
+ blocks1a, MDTileInfo, first_tile_dim_log2, SIZES: tl.constexpr, BLOCK_A: tl.constexpr,
118
+ BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr):
119
+ """
120
+ This kernel essentially combines 6 different pieces of functionality,
121
+ statically branching on the value of tl.program_id(0) to decide which
122
+ codepath to take.
123
+
124
+ pid == 0: create the token cumsum
125
+ 1 <= pid <= SIZES: create a tile cumsum
126
+ SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
127
+ blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
128
+ pid == blocks1a + n_expts_tot: compute_expt_offs
129
+ pid > blocks1a + n_expts_tot: initialise Indx to sentinel
130
+
131
+ As each of these is a relatively trivial workload, launching them from
132
+ this single trampoline is beneficial as they can execute on different
133
+ streaming multiprocesses in parallel.
134
+ """
135
+
136
+ pid = tl.program_id(0)
137
+
138
+ if pid < blocks1a:
139
+ _expt_data_memset(ExpertHist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
140
+ SIZES, BLOCK_A)
141
+ elif pid == n_expts_tot + blocks1a:
142
+ _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
143
+ elif pid < n_expts_tot + blocks1a:
144
+ _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a)
145
+ else:
146
+ offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
147
+ mask = offs < size
148
+ tl.store(Indx + offs, sentinel, mask=mask)
build/torch-universal/triton_kernels/specialize.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import re
3
+ import textwrap
4
+ import types
5
+ import triton
6
+
7
+
8
+ def cacheable(f):
9
+ """
10
+ A decorator that allow you to write something of the form:
11
+
12
+ @cacheable
13
+ def my_kernel(): return (expression dynamically defining a kernel)
14
+
15
+ such that it interacts gracefully with triton cache and preload.
16
+ """
17
+
18
+ g = f()
19
+ g.fn.__name__ = f.__name__
20
+ g.fn.__module__ = f.__module__
21
+ g.fn.__qualname__ = f.__qualname__
22
+ g._fn_name = f"{f.__module__}.{f.__qualname__}"
23
+ return g
24
+
25
+
26
+ def define_kernel(src, module, attrs=None, **extra_globals):
27
+ """
28
+ Dynamically create a Triton function or kernel from a src string,
29
+ linking any symbols in the kernel to objects specified by extra_globals.
30
+ """
31
+
32
+ # create templace function
33
+ def _empty_fn():
34
+ pass
35
+
36
+ gdict = dict(**(_empty_fn.__globals__))
37
+ gdict.update(extra_globals)
38
+ f = types.FunctionType(_empty_fn.__code__, gdict)
39
+ f.__module__ = module.__name__
40
+
41
+ src = textwrap.dedent(src)
42
+ src = src[src.find("def "):]
43
+
44
+ stored_functions = []
45
+ function_name = src[4:].split("(")[0].strip()
46
+
47
+ exec_globals = gdict
48
+ exec_globals.update({"stored_functions": stored_functions})
49
+ exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
50
+
51
+ f.__signature__ = inspect.signature(stored_functions[0])
52
+ f.__name__ = function_name
53
+ f.__doc__ = stored_functions[0].__doc__
54
+
55
+ if attrs is None:
56
+ attrs = dict()
57
+ f = triton.JITFunction(f, **attrs)
58
+ f._unsafe_update_src(src)
59
+ return f
60
+
61
+
62
+ def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
63
+ assert isinstance(fn, triton.runtime.jit.JITFunction)
64
+ if name is None:
65
+ name = f"{fn.__name__}"
66
+ # Get original source code
67
+ src = inspect.getsource(fn.fn)
68
+ src = textwrap.dedent(src)
69
+ lines = src.split("\n")
70
+ # Skip decorator and def line
71
+ def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
72
+ # separate header vs body LOC
73
+ header_end = def_idx
74
+ while not lines[header_end].rstrip().endswith(":"):
75
+ header_end += 1
76
+ body_lines = lines[header_end + 1:]
77
+ header_lines = lines[def_idx:header_end + 1]
78
+ # clean-up header
79
+ header_clean = [
80
+ l.split("#", 1)[0].strip() # keep code, discard comment
81
+ for l in header_lines
82
+ if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines
83
+ ]
84
+ # decompose arguments
85
+ header_src = " ".join(header_clean) # turn it into a single line
86
+ m = re.search(r"\((.*)\)\s*:", header_src)
87
+ if not m:
88
+ raise ValueError("Could not parse function header")
89
+ args_str = m.group(1)
90
+ args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
91
+ non_specialized_args = []
92
+ for arg in args:
93
+ arg_key = arg.split(":")[0].split("=")[0].strip()
94
+ new_args = tuples.get(arg_key, [arg])
95
+ if arg_key not in constants:
96
+ non_specialized_args += new_args
97
+ # add global symbols
98
+ spec_fns = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)}
99
+ globals = spec_fns | fn.get_capture_scope()
100
+ # build new source code and define kernel dynamically
101
+ new_signature = f"def {name}({', '.join(non_specialized_args)}):"
102
+ constexpr_lines = [
103
+ f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items()
104
+ ]
105
+ tuple_lines = [
106
+ f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items()
107
+ ]
108
+ new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines)
109
+ # find function parameters
110
+ sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
111
+ params = list(sig.parameters.values())[2:]
112
+ attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
113
+
114
+ # make a new repr which appends the repr of the specialized functions.
115
+ base_repr = attrs["repr"]
116
+
117
+ def new_repr(specialization):
118
+ ret = base_repr(specialization)
119
+ for spec_fn in spec_fns.values():
120
+ spec_repr = spec_fn.repr(None)
121
+ if spec_repr:
122
+ spec_repr = spec_repr.strip("_")
123
+ if spec_repr:
124
+ ret += f"_{spec_repr}"
125
+ return ret
126
+
127
+ attrs["repr"] = new_repr
128
+
129
+ if do_not_specialize:
130
+ attrs["do_not_specialize"] = do_not_specialize
131
+ ret = define_kernel(new_src, module, attrs, **globals)
132
+ return ret
build/torch-universal/triton_kernels/swiglu.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from triton_kernels.numerics import InFlexData, OutFlexData
3
+ import torch
4
+ import triton
5
+ from .swiglu_details._swiglu import _swiglu, _swiglu_fn
6
+ from triton_kernels import target_info
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class FlexCtx:
11
+ out_data: OutFlexData = OutFlexData()
12
+ inp_data: InFlexData = InFlexData()
13
+ saturate_inf: bool = False
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class PrecisionConfig:
18
+ limit: float
19
+ flex_ctx: FlexCtx = FlexCtx()
20
+
21
+
22
+ swiglu_fn = _swiglu_fn
23
+
24
+
25
+ class SwiGLU(torch.autograd.Function):
26
+
27
+ @staticmethod
28
+ def forward(ctx, a, alpha, precision_config, routing_data):
29
+ N = a.shape[-1]
30
+ M = a.numel() // N
31
+ assert a.stride()[-1] == 1
32
+ assert a.shape[-1] % 2 == 0
33
+ out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
34
+ flex_ctx = precision_config.flex_ctx
35
+ # optimization hyperparameters
36
+ BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
37
+ num_warps = 4
38
+ kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
39
+ # launch semi-persistent kernel
40
+ N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
41
+ num_sms = target_info.num_sms()
42
+ if routing_data is not None:
43
+ waves_per_sm = 32 if target_info.is_hip() else 128
44
+ num_pid = num_sms * (waves_per_sm // num_warps)
45
+ M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
46
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
47
+ else:
48
+ M_BLOCKS = triton.cdiv(M, BLOCK_M)
49
+ if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
50
+ grid = (8 * num_sms, )
51
+ else:
52
+ grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
53
+ n_tokens = None
54
+ if routing_data is not None:
55
+ n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
56
+ _swiglu[grid](
57
+ flex_ctx.out_data.reinterpret(out),
58
+ flex_ctx.out_data.expected_scale,
59
+ flex_ctx.out_data.actual_scale,
60
+ flex_ctx.out_data.checksum_scale,
61
+ flex_ctx.inp_data.reinterpret(a),
62
+ flex_ctx.inp_data.scale,
63
+ alpha,
64
+ M,
65
+ N // 2,
66
+ a.shape[-1],
67
+ 1,
68
+ out.shape[-1],
69
+ 1,
70
+ precision_config.limit,
71
+ n_tokens,
72
+ BLOCK_M=BLOCK_M,
73
+ BLOCK_N=BLOCK_N,
74
+ EVEN_N=(N // 2) % BLOCK_N == 0,
75
+ M_BLOCKS=M_BLOCKS,
76
+ N_BLOCKS=N_BLOCKS,
77
+ flexpoint_saturate_inf=flex_ctx.saturate_inf,
78
+ num_warps=num_warps,
79
+ **kwargs,
80
+ )
81
+ out = out.view(a.shape[:-1] + out.shape[-1:])
82
+ return out
83
+
84
+
85
+ def swiglu(a, alpha, precision_config, routing_data=None):
86
+ return SwiGLU.apply(a, alpha, precision_config, routing_data)
87
+
88
+
89
+ def swiglu_torch(a, alpha, precision_config):
90
+ limit = precision_config.limit
91
+ a_gelu = a[..., ::2]
92
+ if limit is not None:
93
+ a_gelu = a_gelu.clamp(max=limit)
94
+ a_linear = a[..., 1::2]
95
+ if limit is not None:
96
+ a_linear = a_linear.clamp(min=-limit, max=limit)
97
+
98
+ out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
99
+ out = out_gelu * (a_linear + 1)
100
+ return out
build/torch-universal/triton_kernels/swiglu_details/_swiglu.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale
2
+ import triton
3
+ import triton.language as tl
4
+
5
+
6
+ @triton.jit
7
+ def clip(x, limit, clip_lower: tl.constexpr):
8
+ res = tl.minimum(x, limit)
9
+ if clip_lower:
10
+ res = tl.maximum(-limit, res)
11
+ return res
12
+
13
+
14
+ @triton.jit
15
+ def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
16
+ return tl.max(tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True), axis=1)
17
+
18
+
19
+ def swiglu_repr(specialization):
20
+ signature = specialization.signature
21
+ constants = specialization.constants
22
+ convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
23
+ dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
24
+ blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
25
+ return f"_swiglu_{dtypes}_{blocks}"
26
+
27
+
28
+ def swiglu_launch_metadata(grid, kernel, args):
29
+ M, N = args["M"], args["N"]
30
+ ret = dict()
31
+ ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
32
+ A, Out = args["A"], args["Out"]
33
+ ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
34
+ return ret
35
+
36
+
37
+ @triton.jit
38
+ def compute_swiglu(gelu, linear, scale, alpha, limit):
39
+ gelu = gelu.to(tl.float32) * scale
40
+ if limit is not None:
41
+ gelu = clip(gelu, limit, clip_lower=False)
42
+ linear = linear.to(tl.float32) * scale
43
+ if limit is not None:
44
+ linear = clip(linear, limit, clip_lower=True)
45
+ s = gelu / (1 + tl.exp(-alpha * gelu))
46
+ return tl.fma(s, linear, s) # (s * (linear + 1))
47
+
48
+
49
+ @triton.jit(repr=lambda _: "_swiglu")
50
+ def _swiglu_fn(input, alpha, limit):
51
+ gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
52
+ return compute_swiglu(gelu, linear, 1.0, alpha, limit)
53
+
54
+
55
+ @triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
56
+ def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an,
57
+ stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
58
+ EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
59
+ if NTokens is not None:
60
+ M = tl.load(NTokens)
61
+ M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
62
+
63
+ local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
64
+
65
+ a_scale = load_scale(AScale)
66
+ out_expected_scale = load_scale(OutExpectedScale)
67
+
68
+ for pid in tl.range(tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2):
69
+ pid_m = (pid // N_BLOCKS)
70
+ pid_n = (pid % N_BLOCKS)
71
+ off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
72
+ off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
73
+ mask_m = off_m < M
74
+ mask_n = off_n < N
75
+ packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
76
+ packed_mask_n = packed_off_n < N
77
+ packed_mask_n = tl.max_constancy(packed_mask_n, [16])
78
+ # load a
79
+ packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
80
+ packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
81
+ if EVEN_N:
82
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
83
+ else:
84
+ if pid_n * BLOCK_N + BLOCK_N <= N:
85
+ a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
86
+ else:
87
+ packed_mask = mask_m[:, None] & packed_mask_n[None, :]
88
+ a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.)
89
+ a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
90
+ out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
91
+ # update flexpoint stats and divide by scale
92
+ # we don't need masking because of the `other` when loading `A`
93
+ if OutActualScale is not None:
94
+ absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
95
+ local_max = tl.maximum(local_max, absmax)
96
+ out = float_to_flex(out, out_expected_scale,
97
+ None, # ActualScale: local absmax is tracked and updated after the loop
98
+ OutChecksumScale, None, Out, flexpoint_saturate_inf)
99
+ mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
100
+ tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask)
101
+
102
+ update_scale(local_max, OutActualScale, Out)
build/torch-universal/triton_kernels/target_info.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+
4
+ cached_capabilities = {}
5
+
6
+
7
+ def is_cuda():
8
+ if "is_cuda" not in cached_capabilities:
9
+ target = triton.runtime.driver.active.get_current_target()
10
+ cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
11
+ return cached_capabilities["is_cuda"]
12
+
13
+
14
+ def is_hip():
15
+ if "is_hip" not in cached_capabilities:
16
+ cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip)
17
+ return cached_capabilities["is_hip"]
18
+
19
+
20
+ def is_hip_cdna3():
21
+ if "is_hip_cdna3" not in cached_capabilities:
22
+ target = triton.runtime.driver.active.get_current_target()
23
+ cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
24
+ and target.arch == 'gfx942')
25
+ return cached_capabilities["is_hip_cdna3"]
26
+
27
+
28
+ def is_hip_cdna4():
29
+ if "is_hip_cdna4" not in cached_capabilities:
30
+ target = triton.runtime.driver.active.get_current_target()
31
+ cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
32
+ and target.arch == 'gfx950')
33
+ return cached_capabilities["is_hip_cdna4"]
34
+
35
+
36
+ def cuda_capability_geq(major, minor=0):
37
+ """
38
+ Determines whether we have compute capability >= (major, minor) and
39
+ returns this as a constexpr boolean. This can be used for guarding
40
+ inline asm implementations that require a certain compute capability.
41
+ """
42
+ if is_hip():
43
+ return False
44
+ if "cuda" not in cached_capabilities:
45
+ if torch.cuda.is_available():
46
+ cached_capabilities["cuda"] = torch.cuda.get_device_capability()
47
+ else:
48
+ cached_capabilities["cuda"] = (0, 0)
49
+ return cached_capabilities["cuda"] >= (major, minor)
50
+
51
+
52
+ def get_cdna_version():
53
+ """
54
+ Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
55
+ only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
56
+ hardware or unsupported architecture
57
+ """
58
+ target = triton.runtime.driver.active.get_current_target()
59
+ if target.backend != 'hip':
60
+ return -1
61
+ if target.arch == 'gfx942':
62
+ return 3
63
+ if target.arch == 'gfx950':
64
+ return 4
65
+ return -1
66
+
67
+
68
+ def has_tma_gather():
69
+ return cuda_capability_geq(10, 0)
70
+
71
+
72
+ def has_native_mxfp():
73
+ return cuda_capability_geq(10, 0)
74
+
75
+
76
+ def num_sms():
77
+ return torch.cuda.get_device_properties(0).multi_processor_count
build/torch-universal/triton_kernels/tensor.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, fields
2
+ from typing import Type
3
+
4
+ import torch
5
+ from triton.tools.tensor_descriptor import TensorDescriptor
6
+
7
+ from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
8
+ from .target_info import cuda_capability_geq
9
+ from .tensor_details.layout import Layout, StridedLayout
10
+
11
+
12
+ @dataclass
13
+ class Storage:
14
+ data: torch.Tensor
15
+ layout: Layout = None
16
+
17
+ def __post_init__(self):
18
+ assert isinstance(self.data, torch.Tensor)
19
+ if self.layout is None:
20
+ self.layout = StridedLayout(self.data.shape)
21
+
22
+ @property
23
+ def device(self):
24
+ return self.data.device
25
+
26
+ def is_tma_compliant(self):
27
+ # TMAs didn't exist until Hopper
28
+ if not cuda_capability_geq(9, 0):
29
+ return False
30
+ # TMAs only exist for 2D, 3D, 5D inputs
31
+ if len(self.data.shape) not in [2, 3, 5]:
32
+ return False
33
+ # TMAs need at most one stride equal to 1
34
+ # and all other strides divisble by 16
35
+ strides = list(self.data.stride())
36
+ try:
37
+ major_dim = strides.index(1)
38
+ except ValueError:
39
+ major_dim = -1
40
+ ndim = self.data.ndim
41
+ bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
42
+ compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
43
+ return all(compliant)
44
+
45
+ def make_tma(self, block_shape, transpose=False):
46
+ strides = list(self.data.stride())
47
+ shape = list(self.data.shape)
48
+ # TODO
49
+ # there is an issue w/ column-major TMA; we transpose instead
50
+ transpose = self.data.stride()[-1] != 1
51
+ if transpose:
52
+ block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
53
+ shape = shape[:-2] + [shape[-1], shape[-2]]
54
+ strides = strides[:-2] + [strides[-1], strides[-2]]
55
+ if self.data.dtype == torch.uint8 and self.layout.name is None:
56
+ # physical block size is half logical block size along packed dimension
57
+ indx = strides.index(1)
58
+ block_shape[indx] = block_shape[indx] // 2
59
+ # Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
60
+ # CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
61
+ pad = 128
62
+ shape[-1] = (shape[-1] + pad - 1) // pad * pad
63
+ block_shape = self.layout.swizzle_block_shape(block_shape)
64
+ return TensorDescriptor(self.data, shape, strides, block_shape)
65
+
66
+
67
+ @dataclass
68
+ class IntegerType:
69
+ bitwidth: int
70
+
71
+
72
+ @dataclass
73
+ class FloatType:
74
+ bitwidth_exponent: int
75
+ bitwidth_mantissa: int
76
+ is_signed: bool
77
+
78
+ def __post_init__(self):
79
+ self.bitwidth = int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
80
+
81
+
82
+ BIT = IntegerType(1)
83
+ FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
84
+
85
+
86
+ def bitwidth(type: IntegerType | FloatType | torch.dtype):
87
+ if isinstance(type, torch.dtype):
88
+ return type.itemsize * 8
89
+ return type.bitwidth
90
+
91
+
92
+ @dataclass
93
+ class Tensor:
94
+ storage: Storage | torch.Tensor
95
+ dtype: IntegerType | FloatType | torch.dtype = None
96
+ shape: list[int] | None = None
97
+ shape_max: list[int] | None = None
98
+
99
+ def __post_init__(self):
100
+ # set storage
101
+ if isinstance(self.storage, torch.Tensor):
102
+ self.storage = Storage(self.storage)
103
+ # initialize dtype
104
+ if self.dtype is None:
105
+ self.dtype = self.storage.data.dtype
106
+ if bitwidth(self.dtype) < 8 and self.shape is None:
107
+ raise ValueError("shape must be provided for sub-byte types")
108
+ # initialize shape
109
+ if self.shape is None:
110
+ self.shape = list(self.storage.data.shape)
111
+ # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
112
+ is_int = lambda s: isinstance(s, int)
113
+ is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
114
+ assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
115
+ # initialize shape_max
116
+ if self.shape_max is None:
117
+ self.shape_max = [None] * len(self.shape)
118
+ for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
119
+ if smax is not None and not is_int(smax):
120
+ raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
121
+ if smax is None:
122
+ self.shape_max[i] = s
123
+ # validate shape_max: all elements must be `int`
124
+ assert all(map(is_int, self.shape_max))
125
+
126
+ # torch compatibility layer
127
+ @property
128
+ def ndim(self):
129
+ return len(self.shape)
130
+
131
+ @property
132
+ def device(self):
133
+ return self.storage.device
134
+
135
+ def stride(self, i=None):
136
+ return self.storage.data.stride() if i is None else self.storage.data.stride(i)
137
+
138
+ def data_ptr(self):
139
+ return self.storage.data.data_ptr()
140
+
141
+ def numel(self):
142
+ return self.storage.data.numel()
143
+
144
+ def element_size(self):
145
+ return bitwidth(self.dtype) // 8
146
+
147
+ @property
148
+ def data(self):
149
+ t = self.storage
150
+ return t.data if isinstance(t, Storage) else t
151
+
152
+ def dim(self):
153
+ return self.ndim
154
+
155
+ def size(self, i=None):
156
+ if i is None:
157
+ return self.shape
158
+ return self.shape[i]
159
+
160
+
161
+ @dataclass
162
+ class Bitmatrix(Tensor):
163
+ """
164
+ Represents a boolean matrix in a packed format where each element occupies
165
+ a single bit of memory.
166
+
167
+ _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
168
+ with the actual bitmatrix to avoid having to launch a separate memset
169
+ kernel when we call Bitmatrix::sum().
170
+ """
171
+
172
+ scratchpad: torch.Tensor = None
173
+
174
+ def __init__(self, storage, shape, shape_max=None, scratchpad=None):
175
+ super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
176
+ self.scratchpad = scratchpad
177
+
178
+ def sum(self, partials_block_size):
179
+ _, n_cols = self.shape
180
+ dev = self.device
181
+ if self.scratchpad is None:
182
+ self.scratchpad = clear_sums(n_cols, dev)
183
+ out_ret = self.scratchpad[:n_cols]
184
+ self.scratchpad = None # throw error if we try to sum again
185
+ return sum_bitmatrix_rows(self, out_ret, partials_block_size)
186
+
187
+
188
+ def get_layout(tensor: torch.Tensor | Tensor | None):
189
+ if tensor is None:
190
+ return None
191
+ if isinstance(tensor, Tensor):
192
+ return tensor.storage.layout
193
+ return StridedLayout
194
+
195
+
196
+ def wrap_torch_tensor(torch_tensor, dtype=None):
197
+ if dtype is None:
198
+ dtype = torch_tensor.dtype
199
+ shape = list(torch_tensor.shape)
200
+ shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype)
201
+ return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
202
+
203
+
204
+ def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
205
+ assert isinstance(tensor, Tensor)
206
+ old_storage = tensor.storage
207
+ old_data = old_storage.layout.unswizzle_data(old_storage.data)
208
+ new_layout = layout_cls(old_data.shape, **layout_kwargs)
209
+ new_data = new_layout.swizzle_data(old_data)
210
+ attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"}
211
+ return Tensor(Storage(new_data, new_layout), **attrs)
build/torch-universal/triton_kernels/tensor_details/layout.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .layout_details.base import Layout
2
+ from .layout_details.blackwell_scale import BlackwellMXScaleLayout
3
+ from .layout_details.hopper_scale import HopperMXScaleLayout
4
+ from .layout_details.hopper_value import HopperMXValueLayout
5
+ from .layout_details.strided import StridedLayout
6
+ from ..target_info import cuda_capability_geq
7
+
8
+ __all__ = [
9
+ "Layout",
10
+ "BlackwellMXScaleLayout",
11
+ "HopperMXScaleLayout",
12
+ "HopperMXValueLayout",
13
+ "StridedLayout",
14
+ ]
15
+
16
+
17
+ def make_default_matmul_mxfp4_w_layout(mx_axis: int):
18
+ if cuda_capability_geq(10):
19
+ return StridedLayout, dict()
20
+ elif cuda_capability_geq(9):
21
+ return HopperMXValueLayout, {"mx_axis": mx_axis}
22
+ else:
23
+ return StridedLayout, dict()
24
+
25
+
26
+ def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
27
+ if cuda_capability_geq(10):
28
+ return BlackwellMXScaleLayout, dict()
29
+ elif cuda_capability_geq(9):
30
+ return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
31
+ else:
32
+ return StridedLayout, dict()
build/torch-universal/triton_kernels/tensor_details/layout_details/base.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class Layout(ABC):
5
+
6
+ def __init__(self, shape) -> None:
7
+ self.initial_shape = shape
8
+
9
+ @abstractmethod
10
+ def swizzle_data(self, data):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def unswizzle_data(self, data):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def swizzle_block_shape(self, block_shape):
19
+ pass
build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import triton
3
+ import triton.language as tl
4
+ import torch
5
+ from .base import Layout
6
+
7
+ SWIZZLE_ALIGN_INNER = 8
8
+ SWIZZLE_SIZE_INNER = 4
9
+ SWIZZLE_SIZE_OUTER = 128
10
+
11
+
12
+ class BlackwellMXScaleLayout(Layout):
13
+ name: str = "BLACKWELL_SCALE"
14
+
15
+ def __init__(self, shape) -> None:
16
+ super().__init__(shape)
17
+ *self.leading_shape, self.K, self.N, = shape
18
+ self.B = math.prod(self.leading_shape)
19
+ self.ALIGN_K = 8
20
+ self.ALIGN_N = 128
21
+ self.SWIZZLE_K = 4
22
+ self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
23
+ self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
24
+
25
+ def swizzle_data(self, data):
26
+ data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_pad - self.K))
27
+ data = data.transpose(-1, -2).contiguous()
28
+ data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K,
29
+ self.SWIZZLE_K)
30
+ data = data.transpose(2, 4).contiguous()
31
+ data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
32
+ return data
33
+
34
+ def unswizzle_data(self, data):
35
+ data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32,
36
+ self.SWIZZLE_K)
37
+ data = data.transpose(2, 4)
38
+ data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
39
+ data = data.transpose(-1, -2)
40
+ return data[..., :self.K, :self.N]
41
+
42
+ def swizzle_block_shape(self, block_shape):
43
+ MX_PACK_DIVISOR = 32
44
+ MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
45
+ return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
46
+
47
+
48
+ @triton.jit
49
+ def unswizzle_mx_scale_bw(x, SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
50
+ SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
51
+ ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER):
52
+ shape_0: tl.constexpr = x.shape[0]
53
+ shape_1: tl.constexpr = x.shape[1]
54
+ tl.static_assert(shape_1 % SIZE_OUTER == 0)
55
+ tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
56
+ x = x.reshape(shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER)
57
+ x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
58
+ return x
build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from .base import Layout
5
+
6
+
7
+ class HopperMXScaleLayout(Layout):
8
+ name: str = "HOPPER_SCALE"
9
+
10
+ def __init__(self, shape, mx_axis, num_warps=8) -> None:
11
+ assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
12
+ super().__init__(shape)
13
+ self.mx_axis = mx_axis
14
+ self.num_warps = num_warps
15
+ *self.leading_shape, _, _ = shape
16
+
17
+ def _maybe_mT(self, data):
18
+ if self.mx_axis == len(self.leading_shape):
19
+ return data.contiguous().mT
20
+ return data
21
+
22
+ def swizzle_data(self, data):
23
+ data = self._maybe_mT(data).contiguous()
24
+ *batch, M, K = data.shape
25
+ SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
26
+ SWIZZLE_ALIGN_K = 2
27
+ pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
28
+ pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
29
+ data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
30
+ *batch, M, K = data.shape
31
+ assert data.is_contiguous()
32
+ assert M % (
33
+ 2 * self.num_warps * 2 *
34
+ 8) == 0 and K % 2 == 0, f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
35
+ b = len(batch)
36
+ data = data.reshape(*batch, M // (2 * self.num_warps * 2 * 8), 2, self.num_warps, 2, 8, K // 2, 2)
37
+ perm = [0, 2, 5, 1, 4, 6, 3]
38
+ perm = list(range(b)) + [b + p for p in perm]
39
+ data = data.permute(*perm)
40
+ data = data.flatten(-5, -1)
41
+ data = data.flatten(-3, -2)
42
+ assert data.shape[-2] == M // 32
43
+ assert data.shape[-1] == K * 32
44
+ data = self._maybe_mT(data)
45
+ return data
46
+
47
+ def unswizzle_data(self, data):
48
+ data = self._maybe_mT(data)
49
+ *batch, M, K = data.shape
50
+ b = len(batch)
51
+ data = data.reshape(*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2)
52
+ perm = [0, 3, 1, 6, 4, 2, 5]
53
+ perm = list(range(b)) + [b + p for p in perm]
54
+ data = data.permute(*perm)
55
+ data = data.reshape(*batch, M * 32, K // 32)
56
+ data = self._maybe_mT(data)
57
+ return data
58
+
59
+ def swizzle_block_shape(self, block_shape):
60
+ return block_shape
61
+
62
+
63
+ @triton.jit
64
+ def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
65
+ """
66
+ Triton inverse of swizzle_mxfp4_scale_hopper
67
+ """
68
+ tl.static_assert(len(x.shape) == 2, "NYI")
69
+ # implementation assumes mxfp data is packed along the last dimension
70
+ x = x.trans() if mx_axis == 0 else x
71
+ M: tl.constexpr = x.shape[0]
72
+ K: tl.constexpr = x.shape[1]
73
+ tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
74
+ tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
75
+ x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
76
+ x = x.trans(0, 3, 1, 6, 4, 2, 5)
77
+ x = x.reshape(M * 32, K // 32)
78
+ # implementation assumed mxfp data is packed along the last dimension
79
+ x = x.trans() if mx_axis == 0 else x
80
+ return x
build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ from .base import Layout
5
+
6
+
7
+ def right_shift_unsigned(x, shift):
8
+ return (x >> shift) & ((1 << (32 - shift)) - 1)
9
+
10
+
11
+ # -----------------------------------------------------------------------
12
+ # Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
13
+ # 1000000111000000 (first fp4)
14
+ # 1000000111000000 (second fp4)
15
+ # 1000000111000000 (third fp4)
16
+ # 0110110000000000 (fourth fp4)
17
+ # This is done so that dequantization can be done in 14 SASS instructions
18
+ # -----------------------------------------------------------------------
19
+
20
+
21
+ def _compress_fp4(x):
22
+ x = x.to(torch.int32)
23
+ return ((x & 0x8) << 12) | ((x & 0x7) << 6)
24
+
25
+
26
+ def _compress_fourth(x):
27
+ x = x.to(torch.int32)
28
+ return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
29
+
30
+
31
+ def _pack_bits(x: torch.Tensor, mx_axis: int):
32
+ x = x.contiguous()
33
+ assert x.shape[-1] % 4 == 0, "Input tensor must have a last dimension divisible by 4"
34
+ x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
35
+ first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
36
+ second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
37
+ third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
38
+ fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
39
+ x = first | right_shift_unsigned(second, 3) | right_shift_unsigned(third, 6) | fourth
40
+ assert x.is_contiguous()
41
+ x = x.view(torch.uint8)
42
+ return x
43
+
44
+
45
+ # -----------------------------------------------------------------------
46
+ # inverse operation of _pack_bits
47
+ # -----------------------------------------------------------------------
48
+
49
+
50
+ def _bf16_to_fp4e2m1(x):
51
+ # 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
52
+ assert x.dtype == torch.int16
53
+ s = (right_shift_unsigned(x, 15) & 0x1) << 3
54
+ em = right_shift_unsigned(x, 6) & 0x7
55
+ return (s | em).to(torch.uint8)
56
+
57
+
58
+ def _bf16x2_to_fp4e2m1x2(x):
59
+ # 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
60
+ assert x.dtype == torch.int32
61
+ lo = (x & 0xFFFF).to(torch.int16)
62
+ hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
63
+ ret_lo = _bf16_to_fp4e2m1(lo)
64
+ ret_hi = _bf16_to_fp4e2m1(hi)
65
+ return ret_lo | (ret_hi << 4)
66
+
67
+
68
+ def _unpack_bits(x, mx_axis: int):
69
+ x = x.view(torch.int32)
70
+ m = 0b10000001110000001000000111000000
71
+ a = (x << 1) & 0b10000000000000001000000000000000
72
+ b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
73
+ c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
74
+ unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
75
+ x = torch.stack(unpacked, dim=-1)
76
+ x = x.flatten(-2, -1)
77
+ x = _bf16x2_to_fp4e2m1x2(x)
78
+ return x
79
+
80
+
81
+ # -----------------------------------------------------------------------
82
+
83
+
84
+ class HopperMXValueLayout(Layout):
85
+ name: str = "HOPPER_VALUE"
86
+
87
+ def __init__(self, shape, mx_axis, mma_version=3):
88
+ super().__init__(shape)
89
+ assert mx_axis in range(len(shape))
90
+ self.mx_axis = mx_axis
91
+ self.mma_version = mma_version
92
+ *self.leading_shape, self.K, self.N, = shape
93
+
94
+ def _maybe_mT(self, data):
95
+ if self.mx_axis == len(self.leading_shape):
96
+ return data.mT
97
+ return data
98
+
99
+ def swizzle_data(self, data):
100
+ """
101
+ Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
102
+ (*, M // 4, K * 4) such that:
103
+
104
+ 1) Groups contiguously all the elements owned by the same thread of 4
105
+ mma tiles along the K axis. The following animation shows a similar
106
+ grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
107
+ as done here:
108
+ https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
109
+
110
+ 2) Moves the elements belonging to thread 4-7 to be contiguous with those
111
+ from thread 0-3. This is done to get a full cache line when loading them
112
+ from HBM.
113
+
114
+ mx_axis selects the lhs or rhs of the matmul.
115
+
116
+ WARNING: Assumes that the matmul will be done in bf16 or fp16!
117
+ Implementing it for fp8 is as easy as making the tile size (8, 8)
118
+ """
119
+ batch = data.ndim - 2
120
+ assert batch >= 0
121
+ assert self.mma_version in (2, 3)
122
+ data = self._maybe_mT(data)
123
+ init_shape = data.shape
124
+
125
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
126
+ # Every u8 represents 2 mxfp4 elements
127
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
128
+
129
+ # Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
130
+ contig = (1, u8_kwidth)
131
+ scott_trick = (2, 1)
132
+ threads = (4, 4)
133
+ warp_tile = (2, 2)
134
+ k_tile = (1, 4 // u8_kwidth)
135
+
136
+ sizes = list(data.shape[:-2])
137
+ pads = []
138
+ # [rest, K, tile, threads] per dimension
139
+ for i, (a, b, c, s, d) in enumerate(zip(k_tile, warp_tile, threads, scott_trick, contig)):
140
+ pack = a * b * c * s * d
141
+ size = data.shape[batch + i]
142
+ pad = (pack - size % pack) % pack
143
+ pads += [(0, pad)]
144
+ sizes.append((size + pad) // pack)
145
+ sizes += [a, b, c, s, d]
146
+
147
+ pads = tuple(x for t in pads[::-1] for x in t)
148
+ data = torch.nn.functional.pad(data, pads)
149
+ init_shape = data.shape
150
+ # 0: rest[0]
151
+ # 1: k_tile[0]
152
+ # 2: warp_tile[0]
153
+ # 3: threads[0]
154
+ # 4: scott_trick[0]
155
+ # 5: contig[0]
156
+ # 6: rest[1]
157
+ # 7: k_tile[1]
158
+ # 8: warp_tile[1]
159
+ # 9: threads[1]
160
+ # 10: scott_trick[1]
161
+ # 11: contig[1]
162
+ data = data.view(*sizes)
163
+ # Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
164
+ perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
165
+ perm = list(range(batch)) + [batch + p for p in perm]
166
+ data = data.permute(*perm).contiguous()
167
+ # These are views
168
+ data = data.flatten(-10, -1)
169
+ data = data.flatten(-3, -2)
170
+ assert data.is_contiguous()
171
+ assert data.shape[-2] == init_shape[-2] // 4
172
+ assert data.shape[-1] == init_shape[-1] * 4
173
+ # twiddle the bits
174
+ data = _pack_bits(data, self.mx_axis)
175
+ data = self._maybe_mT(data)
176
+ return data
177
+
178
+ def unswizzle_data(self, data):
179
+ data = self._maybe_mT(data)
180
+ data = _unpack_bits(data, self.mx_axis)
181
+ *batch, M, K = data.shape
182
+ # We have two times the elements if we already upcasted to bfloat16
183
+ mult = 2 if data.dtype == torch.bfloat16 else 1
184
+ assert M % 4 == 0, "M must be divisible by 4"
185
+ assert K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
186
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
187
+ # Every u8 represents 2 mxfp4 elements
188
+ u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
189
+ data = data.reshape(*batch, M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult)
190
+ b = len(batch)
191
+ perm = [0, 6, 1, 3, 2, 5, 4, 7]
192
+ perm = list(range(b)) + [b + p for p in perm]
193
+ data = data.permute(*perm)
194
+ data = data.reshape(*batch, M * 4, K // 4)
195
+ data = self._maybe_mT(data)
196
+ return data[..., :self.K, :self.N]
197
+
198
+ def swizzle_block_shape(self, block_shape):
199
+ return block_shape
200
+
201
+
202
+ @triton.jit
203
+ def _unshuffle_triton(x, mma_version: tl.constexpr):
204
+ """
205
+ Triton inverse of swizzle_mxfp4_value_hopper
206
+ """
207
+ tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
208
+ # if mx_axis == 0:
209
+ # x = x.trans()
210
+
211
+ # We have two times the elements if we already upcasted to bfloat16
212
+ mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
213
+ M: tl.constexpr = x.shape[0]
214
+ K: tl.constexpr = x.shape[1]
215
+ tl.static_assert(M % 4 == 0, "M must be divisible by 4")
216
+ tl.static_assert(K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}")
217
+
218
+ # We are loading 8 bf16 elements per thread to use ld.global.v4
219
+ # Every u8 represents 2 mxfp4 elements
220
+ u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
221
+ x = x.reshape(M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult)
222
+ x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
223
+ x = x.reshape(M * 4, K // 4)
224
+ # if mx_axis == 0:
225
+ # x = x.trans()
226
+ return x
227
+
228
+
229
+ @triton.jit
230
+ def _unpack_fp4_to_bf16_triton(x):
231
+ # For now we implement just H100 support (mul.bf16x2)
232
+ # A100 support is possible via fma
233
+ r0, r1 = tl.inline_asm_elementwise(
234
+ r"""
235
+ {
236
+ .reg .b32 b, c, d<7>, scale;
237
+ .reg .b32 bias;
238
+ mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
239
+ // We add the missing bias to the scale directly
240
+ and.b32 $0, $4, 0b10000001110000001000000111000000;
241
+ mul.bf16x2 $0, $0, bias;
242
+ shl.b32 b, $4, 3;
243
+ and.b32 $1, b, 0b10000001110000001000000111000000;
244
+ mul.bf16x2 $1, $1, bias;
245
+ shl.b32 c, $4, 6;
246
+ and.b32 $2, c, 0b10000001110000001000000111000000;
247
+ mul.bf16x2 $2, $2, bias;
248
+ // Unpack last two elements
249
+ shl.b32 d0, $4, 1;
250
+ and.b32 d1, d0, 0b10000000000000001000000000000000;
251
+ shr.b32 d2, $4, 3;
252
+ and.b32 d3, d2, 0b00000001100000000000000110000000;
253
+ or.b32 d4, d1, d3;
254
+ shr.b32 d5, $4, 7;
255
+ and.b32 d6, d5, 0b00000000010000000000000001000000;
256
+ or.b32 $3, d4, d6;
257
+ mul.bf16x2 $3, $3, bias;
258
+ }
259
+ """,
260
+ constraints="=r,=r,=r,=r,r",
261
+ args=[x],
262
+ dtype=(tl.bfloat16, tl.bfloat16),
263
+ is_pure=True,
264
+ pack=4,
265
+ )
266
+ # Concat each pack of 4
267
+ x = tl.join(r0, r1)
268
+ x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
269
+ x = x.trans(0, 1, 3, 2)
270
+ x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
271
+ return x
272
+
273
+
274
+ @triton.jit
275
+ def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
276
+ """
277
+ Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
278
+ (x << 0) & 0b1000000111000000
279
+ (x << 3) & 0b1000000111000000
280
+ (x << 6) & 0b1000000111000000
281
+ ((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
282
+ """
283
+ # upcast values to bfloat16
284
+ tl.static_assert(len(x.shape) == 2)
285
+ tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
286
+ tl.static_assert(x.shape[1] % 4 == 0)
287
+ tl.static_assert(x.dtype == tl.uint8)
288
+ if mx_axis == 0:
289
+ x = x.trans()
290
+ x = _unpack_fp4_to_bf16_triton(x)
291
+ x = _unshuffle_triton(x, mma_version=3)
292
+ if mx_axis == 0:
293
+ x = x.trans()
294
+
295
+ # upcast scale to bfloat16
296
+ # Add bias missing from the bf16 upcasting sequence
297
+ # triton / LLVM generates terrible code for this sequence
298
+ # scale = scale.to(tl.uint16)
299
+ # scale = scale << 7
300
+ # scale = scale.to(tl.bfloat16, bitcast=True)
301
+ scale = tl.inline_asm_elementwise(
302
+ r"""
303
+ {
304
+ prmt.b32 $0, $2, 0, 0x5140;
305
+ shl.b32 $0, $0, 7;
306
+ prmt.b32 $1, $2, 0, 0x7362;
307
+ shl.b32 $1, $1, 7;
308
+ }
309
+ """,
310
+ constraints="=r,=r,r",
311
+ args=[scale],
312
+ dtype=tl.bfloat16,
313
+ is_pure=True,
314
+ pack=4,
315
+ )
316
+ # Broadcast scale
317
+ scale = scale.expand_dims(mx_axis + 1)
318
+ scale = scale.broadcast_to(scale.shape[:mx_axis + 1] + [32] + scale.shape[mx_axis + 2:])
319
+ scale = scale.reshape(x.shape)
320
+
321
+ # Combine scale and x
322
+ x = x * scale
323
+ return x
build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import Layout
2
+
3
+
4
+ class StridedLayout(Layout):
5
+ name: str = None
6
+
7
+ def __init__(self, shape) -> None:
8
+ super().__init__(shape)
9
+
10
+ def swizzle_data(self, data):
11
+ return data
12
+
13
+ def unswizzle_data(self, data):
14
+ return data
15
+
16
+ def swizzle_block_shape(self, block_shape):
17
+ return block_shape
build/torch-universal/triton_kernels/testing.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import enum
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import sys
6
+ import torch
7
+ from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
8
+
9
+
10
+ def assert_equal(ref, tri):
11
+ if isinstance(ref, torch.Tensor):
12
+ assert torch.all(ref == tri)
13
+ else:
14
+ assert ref == tri
15
+
16
+
17
+ def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
18
+ if tri.dtype.itemsize == 1:
19
+ ref_as_type = ref.to(tri.dtype)
20
+ if ref.dtype == tri.dtype:
21
+ assert torch.all(ref_as_type == tri)
22
+ return
23
+ ref = ref_as_type
24
+
25
+ if maxtol is None:
26
+ maxtol = 2e-2
27
+ if rmstol is None:
28
+ rmstol = 4e-3
29
+ """
30
+ Compare reference values against obtained values.
31
+ """
32
+
33
+ # cast to float32:
34
+ ref = ref.to(torch.float32).detach()
35
+ tri = tri.to(torch.float32).detach()
36
+ assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
37
+
38
+ # deal with infinite elements:
39
+ inf_mask_ref = torch.isinf(ref)
40
+ inf_mask_tri = torch.isinf(tri)
41
+ assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
42
+ refn = torch.where(inf_mask_ref, 0, ref)
43
+ trin = torch.where(inf_mask_tri, 0, tri)
44
+
45
+ # normalise so that RMS calculation doesn't overflow:
46
+ eps = 1.0e-30
47
+ multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
48
+ refn *= multiplier
49
+ trin *= multiplier
50
+
51
+ ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
52
+
53
+ rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
54
+ max_err = torch.max(rel_err).item()
55
+ rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
56
+
57
+ if verbose:
58
+ print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
59
+ print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
60
+
61
+ if max_err > maxtol:
62
+ bad_idxs = torch.nonzero(rel_err > maxtol)
63
+ num_nonzero = bad_idxs.size(0)
64
+ bad_idxs = bad_idxs[:1000]
65
+ print("%d / %d mismatched elements (shape = %s) at coords %s" %
66
+ (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
67
+
68
+ bad_idxs = bad_idxs.unbind(-1)
69
+ print("ref values: ", ref[tuple(bad_idxs)].cpu())
70
+ print("tri values: ", tri[tuple(bad_idxs)].cpu())
71
+
72
+ assert max_err <= maxtol
73
+ assert rms_err <= rmstol
74
+
75
+
76
+ class ComputeSanitizerTool(enum.Enum):
77
+ MEMCHECK = "memcheck"
78
+ RACECHECK = "racecheck"
79
+ SYNCCHECK = "synccheck"
80
+ INITCHECK = "initcheck"
81
+
82
+
83
+ def compute_sanitizer(**target_kwargs):
84
+ """
85
+ Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
86
+ to expose potential memory access errors.
87
+ This decorator requires the `request` fixture to be present.
88
+ If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
89
+ Running tests under compute sanitizer requires launching subprocess and is slow,
90
+ so use sparingly
91
+ """
92
+
93
+ def decorator(test_fn):
94
+
95
+ @functools.wraps(test_fn)
96
+ def wrapper(*args, **kwargs):
97
+ if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
98
+ test_fn(*args, **kwargs)
99
+ return
100
+
101
+ import psutil
102
+
103
+ if target_kwargs.pop("clear_torch_cache", False):
104
+ # If we don't pop clear_torch_cache, it won't pass
105
+ # target_kwargs.items() <= kwargs.items() condition below.
106
+ torch.cuda.empty_cache()
107
+ tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK])
108
+ assert isinstance(tools_to_check, list), f"{tools_to_check=}"
109
+ assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
110
+ f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}")
111
+
112
+ ppid_name = psutil.Process(os.getppid()).exe()
113
+ run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
114
+ if "run_sanitizer" in kwargs:
115
+ run_compute_sanitizer &= kwargs["run_sanitizer"]
116
+ if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
117
+ for tool in tools_to_check:
118
+ path = os.path.realpath(test_fn.__globals__["__file__"])
119
+ # get path of current file
120
+ env = {
121
+ "PATH": os.environ["PATH"],
122
+ "PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
123
+ "TORCH_SHOW_CPP_STACKTRACES": "1",
124
+ "CUDA_LAUNCH_BLOCKING": "1",
125
+ }
126
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
127
+ env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
128
+ assert "request_fixture" in kwargs, (
129
+ "memcheck'ed test must have a (possibly unused) `request` fixture")
130
+ test_id = kwargs["request_fixture"].node.callspec.id
131
+ cmd = f"{path}::{test_fn.__name__}[{test_id}]"
132
+ cmd = [
133
+ "compute-sanitizer",
134
+ "--target-processes=application-only",
135
+ "--destroy-on-device-error=context",
136
+ f"--tool={tool.value}",
137
+ sys.executable,
138
+ "-m",
139
+ "pytest",
140
+ "-vsx",
141
+ cmd,
142
+ ]
143
+ for opt in ["--update_checksum", "--ignore_checksum_error"]:
144
+ if opt in sys.argv:
145
+ cmd.append(opt)
146
+ out = subprocess.run(
147
+ cmd,
148
+ stdout=subprocess.PIPE,
149
+ stderr=subprocess.STDOUT,
150
+ env=env,
151
+ )
152
+ sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
153
+ out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
154
+ test_output = out.stdout
155
+ if type(test_output) is bytes:
156
+ test_output = test_output.decode()
157
+
158
+ fail = False
159
+ if not sanitizer_ok:
160
+ print("compute-sanitizer returned an error")
161
+ fail = True
162
+ elif out.returncode != 0:
163
+ print(
164
+ "The test failed due to some other reason: consider running without compute-sanitizer to verify."
165
+ )
166
+ print(f"{out.returncode=}")
167
+ fail = True
168
+
169
+ if fail:
170
+ print("*****************************************************")
171
+ print("******************** TEST OUTPUT ********************")
172
+ print("*****************************************************")
173
+ print(test_output)
174
+ print("*****************************************************")
175
+ print("****************** TEST OUTPUT END ******************")
176
+ print("*****************************************************")
177
+ assert None
178
+ else:
179
+ test_fn(*args, **kwargs)
180
+
181
+ return wrapper
182
+
183
+ return decorator
184
+
185
+
186
+ def compute_actual_scale(x, dtype):
187
+ max_finite = {
188
+ torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
189
+ torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
190
+ torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
191
+ }[dtype]
192
+ return x.abs().max() / max_finite
build/torch-universal/triton_kernels/topk.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ from triton_kernels.topk_details._topk_forward import _topk_forward
4
+ from triton_kernels.topk_details._topk_backward import _topk_backward
5
+ from triton_kernels.tensor import Tensor, Bitmatrix
6
+
7
+
8
+ def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
9
+ if not isinstance(x, Tensor):
10
+ x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
11
+ x_shape_max = [x.shape[0], x.shape[1]]
12
+ x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
13
+ cdiv = lambda a, b: (a + b - 1) // b
14
+ BLOCK_M = 32
15
+ BLOCK_N = 32
16
+ BLOCK_S = 128
17
+ assert len(x.shape) == 2
18
+ assert x.shape_max[-1] < 32768
19
+ assert dim == 1
20
+ assert return_bitmatrix
21
+ n_rows, n_cols = x.shape
22
+ n_rows_max, _ = x.shape_max
23
+ dev = x.device
24
+ # scratchpad tensors
25
+ # NOTE: these are not returned
26
+ y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
27
+ if y_indx is not None:
28
+ use_provided_indx = True
29
+ else:
30
+ y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
31
+ use_provided_indx = False
32
+ # create bitmatrix in transposed memory layout:
33
+ n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
34
+ n_cols_words = n_cols_pad // 32
35
+ bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev)
36
+ bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
37
+ s_blocks = cdiv(n_cols, BLOCK_S)
38
+ s_cols = s_blocks * BLOCK_S
39
+ scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev)
40
+ pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
41
+ _topk_forward[(pids, )](
42
+ x, x.stride(0), # inputs
43
+ y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk]
44
+ bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix]
45
+ n_rows, n_cols, # shapes
46
+ scratchpad, BLOCK_S, s_blocks, # thing to memset to zero
47
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
48
+ APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
49
+ )
50
+ bitmatrix_shape = [n_rows, n_cols_words * 32]
51
+ bitmatrix_shape_max = [n_rows_max, None]
52
+ bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad)
53
+ return y_vals, y_indx, bitmatrix
54
+
55
+
56
+ def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
57
+ assert dy_vals.shape[-1] == k
58
+ n_expts_pad = triton.next_power_of_2(x.shape[-1])
59
+ dx = torch.empty_like(x)
60
+ _topk_backward[(dy_vals.shape[0], )](
61
+ y_indx, y_indx.stride(0), dy_vals, dy_vals.stride(0), x, x.stride(0), # inputs
62
+ dx, # outputs
63
+ dx.stride(0), x.shape[0], n_rows, x.shape[-1], APPLY_SOFTMAX=apply_softmax, N_EXPTS_ACT=k,
64
+ N_EXPTS_PAD=n_expts_pad)
65
+ return dx
66
+
67
+
68
+ class TopK(torch.autograd.Function):
69
+
70
+ @staticmethod
71
+ def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
72
+ y_vals, y_indx, bitmatrix = topk_forward(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
73
+ ctx.save_for_backward(x, y_indx)
74
+ ctx.apply_softmax = apply_softmax
75
+ ctx.k = k
76
+ ctx.n_rows = n_rows
77
+ return y_vals, y_indx, bitmatrix
78
+
79
+ @staticmethod
80
+ def backward(ctx, dy_vals, _0, _1):
81
+ x, y_indx = ctx.saved_tensors
82
+ dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
83
+ return dx, None, None, None, None, None, None
84
+
85
+
86
+ def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
87
+ ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
88
+ return ret
89
+
90
+
91
+ # x = torch.randn((32, 32), dtype=torch.float16, device="cuda")
92
+ # print(topk(x, 4))
build/torch-universal/triton_kernels/topk_details/__init__.py ADDED
File without changes
build/torch-universal/triton_kernels/topk_details/_topk_backward.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def _topk_backward(
7
+ Yi,
8
+ stride_ym, # topk indices
9
+ DY,
10
+ stride_dym, # output gradient values
11
+ X,
12
+ stride_xm, # input values
13
+ DX,
14
+ stride_dxm, # input gradient values
15
+ n_rows,
16
+ NRows,
17
+ n_expts_tot,
18
+ APPLY_SOFTMAX: tl.constexpr,
19
+ N_EXPTS_ACT: tl.constexpr,
20
+ N_EXPTS_PAD: tl.constexpr,
21
+ ):
22
+ pid_m = tl.program_id(0)
23
+ if NRows is not None:
24
+ n_rows = tl.load(NRows)
25
+ if pid_m >= n_rows:
26
+ return
27
+ Yi += pid_m * stride_ym
28
+ DY += pid_m * stride_dym
29
+ X += pid_m * stride_xm
30
+ DX += pid_m * stride_dxm
31
+ # --
32
+ offs_xn = tl.arange(0, N_EXPTS_PAD)
33
+ offs_yn = tl.arange(0, N_EXPTS_ACT)
34
+ mask_xn = offs_xn < n_expts_tot
35
+ # recompute softmax
36
+ y_indx = tl.load(Yi + offs_yn)
37
+ x = tl.load(X + y_indx)
38
+ x = x.to(tl.float32)
39
+ y = tl.softmax(x)
40
+ # compute input-gradient
41
+ dy = tl.load(DY + offs_yn)
42
+ dy = dy.to(tl.float32)
43
+ s = tl.sum(y * dy, 0)
44
+ # write-back input gradient
45
+ tl.store(DX + offs_xn, 0, mask=mask_xn)
46
+ tl.debug_barrier()
47
+ if APPLY_SOFTMAX:
48
+ dx = y * (dy - s)
49
+ else:
50
+ dx = dy
51
+ tl.store(DX + y_indx, dx)
build/torch-universal/triton_kernels/topk_details/_topk_forward.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import triton
2
+ import triton.language as tl
3
+
4
+
5
+ @triton.jit
6
+ def get_topmask_and_fullmask(x):
7
+ tl.static_assert(x.dtype.is_int_unsigned(), "floating-point value must be passed as bits")
8
+ tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
9
+ fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
10
+ tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
11
+ fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
12
+ return tm_arr, fm_arr
13
+
14
+
15
+ @triton.jit
16
+ def fpval_to_key(x):
17
+ tm, fm = get_topmask_and_fullmask(x)
18
+ return x ^ tl.where((x & tm) != 0, fm, tm)
19
+
20
+
21
+ @triton.jit
22
+ def key_to_fpval(x):
23
+ tm, fm = get_topmask_and_fullmask(x)
24
+ return x ^ tl.where((x & tm) == 0, fm, tm)
25
+
26
+
27
+ # stable top-k tie-breaks to value with smaller index
28
+ @triton.jit
29
+ def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
30
+ return N_EXPTS_PAD - indx
31
+
32
+
33
+ @triton.jit
34
+ def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
35
+ return N_EXPTS_PAD - indx
36
+
37
+
38
+ @triton.jit
39
+ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
40
+ BLOCK_N: tl.constexpr):
41
+ x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
42
+ x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
43
+ if x_nbits < 16:
44
+ # this ensures that we leave at least 16 bits for expert index
45
+ # even if the input dtype is smaller than 16 bits:
46
+ y_nbits: tl.constexpr = 32
47
+ else:
48
+ y_nbits: tl.constexpr = x_nbits * 2
49
+ x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
50
+ x_dtype: tl.constexpr = X.dtype.element_ty
51
+
52
+ # subtract 1 from loop iterations because we peel the first (masked) iteration:
53
+ loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
54
+ offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
55
+ mask_n = offs_x_n[None, :] < n_expts_tot
56
+
57
+ # first iteration:
58
+ X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
59
+ x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
60
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
61
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
62
+ acc = tl.topk(x, N_EXPTS_ACT, dim=1)
63
+
64
+ # subsequent iterations:
65
+ for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
66
+ acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
67
+ X_ptrs -= BLOCK_N
68
+ offs_x_n -= BLOCK_N
69
+ x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
70
+ x = fpval_to_key(x.to(x_utype, bitcast=True))
71
+ x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
72
+ acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
73
+
74
+ # rotate expert index into upper 16 bits:
75
+ # 0000vvvvvvvviiii --> iiii0000vvvvvvvv
76
+ acc = (acc << (y_nbits - 16)) | (acc >> 16)
77
+ # sort in ascending order of expert (descending order of key)
78
+ acc = tl.sort(acc, dim=1, descending=True)
79
+ # iiii0000vvvvvvvv --> 0000iiii:
80
+ y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
81
+ y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
82
+ # iiii0000vvvvvvvv --> vvvvvvvv:
83
+ y_values_raw = acc.to(x_utype)
84
+ y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
85
+
86
+ return y_values, y_indices
87
+
88
+
89
+ @triton.jit
90
+ def _topk_forward(X, stride_xm, # inputs
91
+ Yv, Yi, stride_ym, # topk values/indices
92
+ USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, # bitmatrix
93
+ n_rows, n_expts_tot, # shape
94
+ S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset
95
+ APPLY_SOFTMAX: tl.constexpr, # constant
96
+ BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr):
97
+
98
+ pid = tl.program_id(0)
99
+ if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
100
+ n_rows = tl.load(n_rows)
101
+
102
+ if pid < s_blocks:
103
+ tl.store(S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32))
104
+
105
+ if pid * BLOCK_M >= n_rows:
106
+ # early exit:
107
+ return
108
+
109
+ tl.static_assert(BLOCK_N % 32 == 0)
110
+ tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
111
+ x_dtype: tl.constexpr = X.dtype.element_ty
112
+
113
+ # load logits
114
+ offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
115
+ offs_y_n = tl.arange(0, N_EXPTS_ACT)
116
+ mask_m = offs_m[:, None] < n_rows
117
+ if USE_PROVIDED_INDX:
118
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
119
+ y_indices = tl.load(Yi_ptrs, mask=mask_m)
120
+ Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
121
+ y_values = tl.load(Xv_ptrs, mask=mask_m)
122
+ else:
123
+ y_values, y_indices = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, #
124
+ N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N)
125
+
126
+ # normalize selected values
127
+ if APPLY_SOFTMAX:
128
+ y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
129
+
130
+ # write back
131
+ Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
132
+ tl.store(Yv_ptrs, y_values, mask=mask_m)
133
+ if not USE_PROVIDED_INDX:
134
+ Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
135
+ tl.store(Yi_ptrs, y_indices, mask=mask_m)
136
+
137
+ # pack into bitmatrix
138
+ y_div = y_indices // 32
139
+ y_rem = y_indices % 32
140
+ loop_iterations = N_EXPTS_PAD // BLOCK_N
141
+ for i in range(loop_iterations):
142
+ offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
143
+ y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0)
144
+ r = tl.reduce_or(y2, axis=1)
145
+ BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
146
+ tl.store(BitsPtrs, r, mask=mask_m)
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1751968576,
77
+ "narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1754384793,
102
+ "narHash": "sha256-pmsetd7e4HtJxduQlHOuQNwdnZYjUtcEDguZ0VGHhoE=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "ea0b13b8a53e53c900181242840640e27f169484",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }