Upload folder using huggingface_hub
Browse files- build/torch-universal/triton_kernels/__init__.py +0 -0
- build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc +0 -0
- build/torch-universal/triton_kernels/_ops.py +8 -0
- build/torch-universal/triton_kernels/compaction.py +69 -0
- build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py +20 -0
- build/torch-universal/triton_kernels/matmul_ogs.py +662 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/_common.py +165 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py +377 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py +464 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py +505 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py +298 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py +33 -0
- build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py +111 -0
- build/torch-universal/triton_kernels/numerics.py +42 -0
- build/torch-universal/triton_kernels/numerics_details/__init__.py +0 -0
- build/torch-universal/triton_kernels/numerics_details/flexpoint.py +195 -0
- build/torch-universal/triton_kernels/numerics_details/mxfp.py +303 -0
- build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py +158 -0
- build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py +122 -0
- build/torch-universal/triton_kernels/proton_opts.py +17 -0
- build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py +111 -0
- build/torch-universal/triton_kernels/routing.py +386 -0
- build/torch-universal/triton_kernels/routing_details/_expt_data.py +64 -0
- build/torch-universal/triton_kernels/routing_details/_routing_compute.py +148 -0
- build/torch-universal/triton_kernels/specialize.py +132 -0
- build/torch-universal/triton_kernels/swiglu.py +100 -0
- build/torch-universal/triton_kernels/swiglu_details/_swiglu.py +102 -0
- build/torch-universal/triton_kernels/target_info.py +77 -0
- build/torch-universal/triton_kernels/tensor.py +211 -0
- build/torch-universal/triton_kernels/tensor_details/layout.py +32 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/base.py +19 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py +58 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py +80 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py +323 -0
- build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py +17 -0
- build/torch-universal/triton_kernels/testing.py +192 -0
- build/torch-universal/triton_kernels/topk.py +92 -0
- build/torch-universal/triton_kernels/topk_details/__init__.py +0 -0
- build/torch-universal/triton_kernels/topk_details/_topk_backward.py +51 -0
- build/torch-universal/triton_kernels/topk_details/_topk_forward.py +146 -0
- flake.lock +168 -0
build/torch-universal/triton_kernels/__init__.py
ADDED
File without changes
|
build/torch-universal/triton_kernels/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (220 Bytes). View file
|
|
build/torch-universal/triton_kernels/_ops.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
ops = torch.ops._triton_kernels_de70d68_dirty
|
3 |
+
|
4 |
+
def add_op_namespace_prefix(op_name: str):
|
5 |
+
"""
|
6 |
+
Prefix op by namespace.
|
7 |
+
"""
|
8 |
+
return f"_triton_kernels_de70d68_dirty::{op_name}"
|
build/torch-universal/triton_kernels/compaction.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .compaction_details._masked_compaction import _masked_compaction
|
3 |
+
from .tensor import Bitmatrix
|
4 |
+
|
5 |
+
|
6 |
+
def compaction(yv, yi, bitmask, sentinel=-1):
|
7 |
+
"""
|
8 |
+
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
|
9 |
+
|
10 |
+
Only the elements whose index appears among the active bits of *bitmask*
|
11 |
+
are kept; the rest are replaced by *sentinel*. Kept elements preserve
|
12 |
+
their original left-to-right order.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
yv : torch.Tensor, shape (B, K)
|
17 |
+
Values tensor.
|
18 |
+
yi : torch.Tensor, shape (B, K), dtype torch.long
|
19 |
+
Integer indices (0 ≤ index < 32) associated with *yv*.
|
20 |
+
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
|
21 |
+
Per-row mask of active indices. See the in-place version for details.
|
22 |
+
sentinel : int, default -1
|
23 |
+
Value written into dropped positions of the returned tensors.
|
24 |
+
|
25 |
+
Returns
|
26 |
+
-------
|
27 |
+
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
|
28 |
+
New tensors with the same dtype/device as the inputs.
|
29 |
+
|
30 |
+
"""
|
31 |
+
|
32 |
+
n_rows, n_cols = yi.shape
|
33 |
+
ret_yv = torch.empty_like(yv)
|
34 |
+
ret_yi = torch.empty_like(yi)
|
35 |
+
if isinstance(bitmask, Bitmatrix):
|
36 |
+
bitmask = bitmask.storage.data
|
37 |
+
|
38 |
+
_masked_compaction[(n_rows, )](
|
39 |
+
yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs
|
40 |
+
ret_yv, ret_yi, # outputs
|
41 |
+
sentinel, # sentinel
|
42 |
+
K=n_cols # constants
|
43 |
+
)
|
44 |
+
return ret_yv, ret_yi
|
45 |
+
|
46 |
+
|
47 |
+
def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
|
48 |
+
"""
|
49 |
+
reference implementation of `masked_compact`
|
50 |
+
"""
|
51 |
+
B, K = yi.shape
|
52 |
+
device = yi.device
|
53 |
+
# Expand bitmask to a boolean matrix of active bits (B, 32)
|
54 |
+
w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype))
|
55 |
+
bits = (bitmask.unsqueeze(-1) & w) != 0
|
56 |
+
mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1)
|
57 |
+
# For every yi element decide whether it should be kept
|
58 |
+
keep = mask.gather(1, yi.long())
|
59 |
+
# Build a stable permutation that brings all "keep" items forward
|
60 |
+
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
|
61 |
+
order = (~keep).to(torch.int).argsort(dim=1, stable=True)
|
62 |
+
# Re‑order tensors according to above permutation
|
63 |
+
yi_sorted = yi.gather(1, order)
|
64 |
+
yv_sorted = yv.gather(1, order)
|
65 |
+
# fill relevant positions with sentinel
|
66 |
+
keep_sorted = keep.gather(1, order)
|
67 |
+
yi_sorted[~keep_sorted] = sentinel
|
68 |
+
yv_sorted[~keep_sorted] = sentinel
|
69 |
+
return yv_sorted, yi_sorted
|
build/torch-universal/triton_kernels/compaction_details/_masked_compaction.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
|
4 |
+
|
5 |
+
@triton.jit
|
6 |
+
def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr):
|
7 |
+
pid_m = tl.program_id(0)
|
8 |
+
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
|
9 |
+
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
|
10 |
+
div = yi // 32
|
11 |
+
rem = yi % 32
|
12 |
+
active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1
|
13 |
+
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
|
14 |
+
active_flags = active_bits.to(tl.int1)
|
15 |
+
rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K))
|
16 |
+
write_indx = exc_cumsum + rev_arange
|
17 |
+
yv = tl.where(active_flags, yv, sentinel)
|
18 |
+
yi = tl.where(active_flags, yi, sentinel)
|
19 |
+
tl.store(RetYv + pid_m * K + write_indx, yv)
|
20 |
+
tl.store(RetYi + pid_m * K + write_indx, yi)
|
build/torch-universal/triton_kernels/matmul_ogs.py
ADDED
@@ -0,0 +1,662 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# isort: off
|
2 |
+
# fmt: off
|
3 |
+
from dataclasses import dataclass
|
4 |
+
import itertools
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
import triton
|
8 |
+
from enum import Enum, auto
|
9 |
+
# utilities
|
10 |
+
from triton_kernels import target_info
|
11 |
+
from triton_kernels.numerics import InFlexData, OutFlexData
|
12 |
+
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
|
13 |
+
from triton_kernels.target_info import is_cuda
|
14 |
+
# details
|
15 |
+
from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
|
16 |
+
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
|
17 |
+
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
|
18 |
+
from .matmul_ogs_details._finalize_matmul import _finalize_matmul
|
19 |
+
from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints
|
20 |
+
from .numerics_details.mxfp import MXFP_BLOCK_SIZE
|
21 |
+
from .specialize import specialize
|
22 |
+
from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass(frozen=True)
|
26 |
+
class FnSpecs:
|
27 |
+
name: str
|
28 |
+
fn: "triton.runtime.jit.JITFunction"
|
29 |
+
fn_arg_names: tuple[str]
|
30 |
+
fn_arg_do_not_specialize: tuple[str] = tuple()
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
def default():
|
34 |
+
return FnSpecs("dflt", None, tuple())
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass(frozen=True)
|
38 |
+
class FusedActivation:
|
39 |
+
specs: FnSpecs = FnSpecs.default()
|
40 |
+
fn_args: tuple[object] = tuple()
|
41 |
+
reduction_n: int = 1
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass(frozen=True)
|
45 |
+
class Epilogue:
|
46 |
+
specs: FnSpecs = FnSpecs.default()
|
47 |
+
fn_arg_values_matmul: tuple[object] = tuple()
|
48 |
+
fn_arg_values_finalize: tuple[object] = tuple()
|
49 |
+
effective_itemsize: float = None
|
50 |
+
|
51 |
+
class FnName(Enum):
|
52 |
+
DEQUANTIZE_MXFP8 = auto()
|
53 |
+
|
54 |
+
|
55 |
+
EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
|
56 |
+
|
57 |
+
_kernels = dict()
|
58 |
+
|
59 |
+
|
60 |
+
def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
|
61 |
+
global _kernels
|
62 |
+
key = (fused_activation.name, epilogue.name)
|
63 |
+
if key in _kernels:
|
64 |
+
return _kernels[key]
|
65 |
+
spec_constants = {
|
66 |
+
"ACTIVATION_FN": fused_activation.fn,
|
67 |
+
"EPILOGUE_FN": epilogue.fn,
|
68 |
+
}
|
69 |
+
spec_tuples = {
|
70 |
+
"activation_fn_args": fused_activation.fn_arg_names,
|
71 |
+
"epilogue_fn_args": epilogue.fn_arg_names,
|
72 |
+
}
|
73 |
+
do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
|
74 |
+
import types
|
75 |
+
|
76 |
+
module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
|
77 |
+
sys.modules[module.__name__] = module
|
78 |
+
module._finalize_matmul = specialize(_finalize_matmul, module, spec_constants, spec_tuples,
|
79 |
+
do_not_specialize=do_not_specialize)
|
80 |
+
module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
|
81 |
+
do_not_specialize=do_not_specialize)
|
82 |
+
module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
|
83 |
+
do_not_specialize=do_not_specialize)
|
84 |
+
_kernels[key] = module
|
85 |
+
return module
|
86 |
+
|
87 |
+
|
88 |
+
# -----------------------------------------------------------------------------
|
89 |
+
# Matrix Multiplication + Outer Gather/Scatter
|
90 |
+
# -----------------------------------------------------------------------------
|
91 |
+
|
92 |
+
|
93 |
+
def can_overflow_int32(tensor: torch.Tensor):
|
94 |
+
max_int32 = (1 << 31) - 1
|
95 |
+
offset = 0
|
96 |
+
for i in range(tensor.ndim):
|
97 |
+
offset += (tensor.shape[i] - 1) * tensor.stride(i)
|
98 |
+
return offset > max_int32
|
99 |
+
|
100 |
+
|
101 |
+
def should_upcast_indices(*args):
|
102 |
+
return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
|
103 |
+
|
104 |
+
|
105 |
+
# ---------------------
|
106 |
+
# Numerics
|
107 |
+
# ---------------------
|
108 |
+
|
109 |
+
# fmt: off
|
110 |
+
|
111 |
+
@dataclass(frozen=True)
|
112 |
+
class FlexCtx:
|
113 |
+
lhs_data: InFlexData = InFlexData()
|
114 |
+
rhs_data: InFlexData = InFlexData()
|
115 |
+
out_data: OutFlexData = OutFlexData()
|
116 |
+
|
117 |
+
@dataclass
|
118 |
+
class PrecisionConfig:
|
119 |
+
max_num_imprecise_acc: int = None
|
120 |
+
allow_tf32: bool = True
|
121 |
+
flex_ctx: FlexCtx = FlexCtx()
|
122 |
+
acc_scale: int = 1.0
|
123 |
+
flexpoint_saturate_inf: bool = False
|
124 |
+
report_quantization_err_fn: callable = None
|
125 |
+
act_scale: Tensor | None = None
|
126 |
+
weight_scale: Tensor| None = None
|
127 |
+
out_scale: Tensor | None = None
|
128 |
+
out_dtype: torch.dtype = None
|
129 |
+
enforce_bitwise_invariance: bool = False
|
130 |
+
|
131 |
+
# ---------------------
|
132 |
+
# Preprocessing
|
133 |
+
# ---------------------
|
134 |
+
|
135 |
+
@dataclass(frozen=True)
|
136 |
+
class PreprocessingFeatures:
|
137 |
+
swap_xw: bool
|
138 |
+
|
139 |
+
|
140 |
+
def init_preprocessing_features(w, precision_config, opt_flags):
|
141 |
+
swap_xw = False # Whether or not to swap X and W operands to the tl.dot
|
142 |
+
if target_info.cuda_capability_geq(10, 0):
|
143 |
+
swap_xw = precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent
|
144 |
+
return PreprocessingFeatures(swap_xw)
|
145 |
+
|
146 |
+
def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features):
|
147 |
+
has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1
|
148 |
+
if has_fused_scatter_scratchpad:
|
149 |
+
M = scatter_indx.src_indx.shape[0]
|
150 |
+
writeback_idxs = torch.zeros((M,), dtype=torch.int32, device=x.device)
|
151 |
+
writeback_size = writeback_idxs.shape[0]
|
152 |
+
finalize_scatter_idxs = torch.zeros((M // routing_data.n_expts_act + M + 1,), dtype=torch.int32, device=x.device)
|
153 |
+
BLOCK_M=256
|
154 |
+
_compute_writeback_idx[(triton.cdiv(M, BLOCK_M),)](
|
155 |
+
writeback_idxs,
|
156 |
+
finalize_scatter_idxs,
|
157 |
+
scatter_indx.dst_indx,
|
158 |
+
scatter_indx.src_indx,
|
159 |
+
M // routing_data.n_expts_act,
|
160 |
+
M,
|
161 |
+
BLOCK_M=BLOCK_M,
|
162 |
+
N_EXPTS_ACT=routing_data.n_expts_act,
|
163 |
+
)
|
164 |
+
elif scatter_indx is not None and routing_data.n_expts_act == 1:
|
165 |
+
writeback_idxs = scatter_indx.dst_indx
|
166 |
+
writeback_size = scatter_indx.dst_indx.shape[0]
|
167 |
+
finalize_scatter_idxs = None
|
168 |
+
else:
|
169 |
+
writeback_idxs, writeback_size, finalize_scatter_idxs = None, None, None
|
170 |
+
# preprocess routing information and ptr lookup table
|
171 |
+
M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0]
|
172 |
+
return x, w, writeback_idxs, writeback_size, finalize_scatter_idxs
|
173 |
+
|
174 |
+
|
175 |
+
# ---------------------
|
176 |
+
# Postprocessing
|
177 |
+
# ---------------------
|
178 |
+
|
179 |
+
|
180 |
+
@dataclass(frozen=True)
|
181 |
+
class PostprocessingFeatures:
|
182 |
+
finalize: bool
|
183 |
+
|
184 |
+
def init_postprocessing_features(routing_data, scatter_indx, opt_flags):
|
185 |
+
finalize = (scatter_indx is not None and routing_data.n_expts_act > 1) or opt_flags.split_k > 1
|
186 |
+
return PostprocessingFeatures(finalize)
|
187 |
+
|
188 |
+
def apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_offs, num_indx, precision_config, routing_data,
|
189 |
+
postprocess_features, memory, fused_activation, epilogue):
|
190 |
+
out = memory["output"]
|
191 |
+
flex_ctx = precision_config.flex_ctx
|
192 |
+
if postprocess_features.finalize:
|
193 |
+
has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1
|
194 |
+
if has_fused_scatter_scratchpad:
|
195 |
+
inp = memory["output"]
|
196 |
+
else:
|
197 |
+
inp = memory["scratchpad"]["matmul"]
|
198 |
+
if scatter_indx is not None:
|
199 |
+
assert inp.shape[1] == 1, "batched finalize scatter not supported"
|
200 |
+
n_final_rows = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act
|
201 |
+
scatter_src_indx = scatter_indx.src_indx
|
202 |
+
EXPT_PER_TOK = routing_data.n_expts_act
|
203 |
+
num_rows = None
|
204 |
+
else:
|
205 |
+
n_final_rows = inp.shape[1] * inp.shape[2]
|
206 |
+
scatter_src_indx = None
|
207 |
+
EXPT_PER_TOK = 1
|
208 |
+
num_rows = num_indx or (None if expt_offs is None else expt_offs[-1])
|
209 |
+
|
210 |
+
if inp.dtype == torch.float32:
|
211 |
+
inp_flex = OutFlexData()
|
212 |
+
else:
|
213 |
+
inp_flex = precision_config.flex_ctx.out_data
|
214 |
+
|
215 |
+
out_scatter = memory["output"]
|
216 |
+
out_scatter_flex = precision_config.flex_ctx.out_data
|
217 |
+
|
218 |
+
N = inp.shape[3]
|
219 |
+
M = n_final_rows
|
220 |
+
warps_per_sm = 32 if target_info.is_hip() else 128
|
221 |
+
|
222 |
+
def compute_grid(BLOCK_N, num_warps):
|
223 |
+
num_pid = target_info.num_sms() * (warps_per_sm // num_warps)
|
224 |
+
if M < num_pid or target_info.is_hip():
|
225 |
+
grid_n = triton.cdiv(N, BLOCK_N)
|
226 |
+
grid_m = min(M, max(1, triton.cdiv(num_pid, grid_n)))
|
227 |
+
else:
|
228 |
+
grid_m = min(M, num_pid)
|
229 |
+
grid_n = 1
|
230 |
+
return (grid_m, grid_n)
|
231 |
+
|
232 |
+
if inp.dtype.itemsize == 1:
|
233 |
+
candidates = [(1024, 1)]
|
234 |
+
else:
|
235 |
+
if target_info.is_hip():
|
236 |
+
candidates = [(4096 // inp.dtype.itemsize, 2)]
|
237 |
+
else:
|
238 |
+
if inp.dtype.itemsize == 2:
|
239 |
+
candidates = [
|
240 |
+
(4096 // inp.dtype.itemsize, 4),
|
241 |
+
(1024 // inp.dtype.itemsize, 1),
|
242 |
+
]
|
243 |
+
else:
|
244 |
+
candidates = [
|
245 |
+
(2048 // inp.dtype.itemsize, 4),
|
246 |
+
(1024 // inp.dtype.itemsize, 1),
|
247 |
+
]
|
248 |
+
if precision_config.enforce_bitwise_invariance:
|
249 |
+
candidates = [candidates[0]]
|
250 |
+
|
251 |
+
# sort by smallest grid_n so we share compute across a row
|
252 |
+
grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
|
253 |
+
STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5)
|
254 |
+
|
255 |
+
out_scale = precision_config.out_scale
|
256 |
+
out_has_mx = out_scale is not None
|
257 |
+
out_scale_strides = (None, None) if out_scale is None else out_scale.stride()[-2:]
|
258 |
+
mx_a_scale = memory["scratchpad"].get("mx_out_scale", None)
|
259 |
+
if mx_a_scale is not None:
|
260 |
+
mx_a_scale_stride_k, mx_a_scale_stride_m = [mx_a_scale.stride(i) for i in (0, 2)]
|
261 |
+
else:
|
262 |
+
mx_a_scale_stride_k, mx_a_scale_stride_m = None, None
|
263 |
+
|
264 |
+
kernels = get_kernels(epilogue.specs, fused_activation.specs)
|
265 |
+
kernels._finalize_matmul[grid](
|
266 |
+
flex_ctx.out_data.reinterpret(out_scatter),
|
267 |
+
*((None, out_scale, None) if out_has_mx else out_scatter_flex),
|
268 |
+
*out_scale_strides,
|
269 |
+
flex_ctx.out_data.reinterpret(inp), inp.stride(0), inp.stride(2),
|
270 |
+
inp_flex.expected_scale if mx_a_scale is None else mx_a_scale,
|
271 |
+
mx_a_scale_stride_k, mx_a_scale_stride_m,
|
272 |
+
scatter_src_indx, finalize_scatter_idxs,
|
273 |
+
inp.shape[0], M, N, num_rows,
|
274 |
+
*fused_activation.fn_args, fused_activation.reduction_n,
|
275 |
+
*epilogue.fn_arg_values_finalize,
|
276 |
+
EXPT_PER_TOK=EXPT_PER_TOK,
|
277 |
+
BLOCK_N=BLOCK_N,
|
278 |
+
STAGES=STAGES,
|
279 |
+
num_warps=num_warps,
|
280 |
+
flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf,
|
281 |
+
HAS_FUSED_SCRATCHPAD=has_fused_scatter_scratchpad,
|
282 |
+
)
|
283 |
+
out = out_scatter
|
284 |
+
# trim unnecessary part of output
|
285 |
+
if has_fused_scatter_scratchpad:
|
286 |
+
# Discard scratchpad part.
|
287 |
+
# This still gives a contiguous tensor, because shape[0] > 1 only when
|
288 |
+
# batch mode is enabled, in which case this is a no-op (there's no scratchpad).
|
289 |
+
out = out[:, :, :n_final_rows, :]
|
290 |
+
return out
|
291 |
+
|
292 |
+
|
293 |
+
# ---------------------
|
294 |
+
# Allocation
|
295 |
+
# ---------------------
|
296 |
+
|
297 |
+
@dataclass
|
298 |
+
class MatmulAllocation:
|
299 |
+
device: str
|
300 |
+
output: tuple[tuple[int], torch.dtype]
|
301 |
+
scratchpads: dict[str, tuple]
|
302 |
+
|
303 |
+
def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags,
|
304 |
+
preprocessing_features, postprocessing_features):
|
305 |
+
# ---- output ------
|
306 |
+
N = w.shape[-1]
|
307 |
+
# by default - M is number of rows in the activations
|
308 |
+
M = x.shape[-2]
|
309 |
+
# if the activations are gathered, then M is number of gather indices
|
310 |
+
if gather_indx is not None:
|
311 |
+
M = gather_indx.src_indx.shape[0]
|
312 |
+
# final output
|
313 |
+
if routing_data.n_expts_act == 1 or scatter_indx is None:
|
314 |
+
y_rows = M
|
315 |
+
elif opt_flags.fused_scatter:
|
316 |
+
# we need the scratchpad and the output to be contiguous in memory
|
317 |
+
Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
|
318 |
+
y_rows = M + Mc
|
319 |
+
else:
|
320 |
+
Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
|
321 |
+
y_rows = Mc
|
322 |
+
batch_dim = x.shape[0] if x.ndim == 3 else 1
|
323 |
+
y_shape = (batch_dim, y_rows, N // fused_activation.reduction_n)
|
324 |
+
out_dtype = precision_config.out_dtype or x.dtype
|
325 |
+
output = (y_shape, out_dtype)
|
326 |
+
# ---- scratchpad -----#
|
327 |
+
scratchpad = dict()
|
328 |
+
# if we need either standalone scatter or split-k, the matmul output will need post-processing
|
329 |
+
if postprocessing_features.finalize:
|
330 |
+
if opt_flags.split_k > 1 or not opt_flags.fused_scatter:
|
331 |
+
dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
|
332 |
+
scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), dtype)
|
333 |
+
if precision_config.out_scale is not None and not (scratchpad.get("matmul", None) is not None and scratchpad["matmul"][1].itemsize > 1):
|
334 |
+
scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
|
335 |
+
return MatmulAllocation(x.device, output, scratchpad)
|
336 |
+
|
337 |
+
def apply_allocation(allocation: MatmulAllocation, output):
|
338 |
+
ret = dict()
|
339 |
+
if output is None:
|
340 |
+
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
|
341 |
+
else:
|
342 |
+
assert output.shape == allocation.output[0]
|
343 |
+
ret["output"] = output[None, :, :]
|
344 |
+
ret["scratchpad"] = {
|
345 |
+
k: torch.empty(v[0], device=allocation.device, dtype=v[1])
|
346 |
+
for k, v in allocation.scratchpads.items()
|
347 |
+
}
|
348 |
+
return ret
|
349 |
+
|
350 |
+
# -----------------------------------------------------------------------------
|
351 |
+
# Canonicalize
|
352 |
+
# -----------------------------------------------------------------------------
|
353 |
+
# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
|
354 |
+
# we can canonicalize storages to make the implementation more uniform
|
355 |
+
|
356 |
+
def _canonicalize_storage(storage, out_ndim, flex_data):
|
357 |
+
assert out_ndim >= storage.data.ndim
|
358 |
+
new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape)
|
359 |
+
new_storage_data = storage.data.view(new_storage_shape)
|
360 |
+
if flex_data is not None:
|
361 |
+
new_storage_data = flex_data.reinterpret(new_storage_data)
|
362 |
+
return Storage(new_storage_data, storage.layout)
|
363 |
+
|
364 |
+
|
365 |
+
# -----------------------------------------------------------------------------
|
366 |
+
# Triton Implementation
|
367 |
+
# -----------------------------------------------------------------------------
|
368 |
+
|
369 |
+
def matmul_ogs_set_idle_sms(num_idle_sms):
|
370 |
+
"""
|
371 |
+
persistent kernels will leave `num_idle_sms` idle
|
372 |
+
"""
|
373 |
+
update_opt_flags_constraints({"idle_sms": num_idle_sms})
|
374 |
+
|
375 |
+
def matmul_ogs(x, w, bias,
|
376 |
+
routing_data: RoutingData | None = None,
|
377 |
+
gather_indx: GatherIndx | None = None,
|
378 |
+
scatter_indx: ScatterIndx | None = None,
|
379 |
+
precision_config: PrecisionConfig | None = None,
|
380 |
+
betas: torch.Tensor | None = None,
|
381 |
+
gammas: torch.Tensor | None = None,
|
382 |
+
out_alpha: float | None = None,
|
383 |
+
y: torch.Tensor | None = None,
|
384 |
+
fused_activation: FusedActivation | None = None,
|
385 |
+
epilogue: Epilogue | None = None,
|
386 |
+
):
|
387 |
+
"""
|
388 |
+
Y[:, :] = 0.
|
389 |
+
for e in num_experts:
|
390 |
+
Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
|
391 |
+
"""
|
392 |
+
is_input_batched = x.ndim == 3
|
393 |
+
if is_input_batched:
|
394 |
+
assert gather_indx is None, "gather not supported in batched mode"
|
395 |
+
assert scatter_indx is None, "scatter not supported in batched mode"
|
396 |
+
assert routing_data is None, "routing not supported in batched mode"
|
397 |
+
assert w.ndim == 3 and w.shape[0] == x.shape[0]
|
398 |
+
# canonicalize inputs
|
399 |
+
if precision_config is None:
|
400 |
+
precision_config = PrecisionConfig()
|
401 |
+
if fused_activation is None:
|
402 |
+
fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
|
403 |
+
if epilogue is None:
|
404 |
+
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
|
405 |
+
if routing_data is None:
|
406 |
+
routing_data = RoutingData(None, None, max(1, w.shape[0]), 1)
|
407 |
+
# unpack scales
|
408 |
+
w_scale = precision_config.weight_scale
|
409 |
+
w_has_mx = w_scale is not None
|
410 |
+
is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8
|
411 |
+
if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"
|
412 |
+
if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10"
|
413 |
+
if not isinstance(w, Tensor):
|
414 |
+
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
|
415 |
+
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
|
416 |
+
w = wrap_torch_tensor(w, dtype=dtype)
|
417 |
+
if w_scale is not None and not isinstance(w_scale, Tensor):
|
418 |
+
w_scale = Tensor(w_scale)
|
419 |
+
if w_scale is not None:
|
420 |
+
w_scale.storage.data = w_scale.data.view(torch.uint8)
|
421 |
+
w_scale.dtype = torch.uint8
|
422 |
+
x_scale = precision_config.act_scale
|
423 |
+
x_has_mx = x_scale is not None
|
424 |
+
if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp"
|
425 |
+
if x_scale is not None and not isinstance(x_scale, Tensor):
|
426 |
+
x_scale = Tensor(x_scale)
|
427 |
+
if not isinstance(x, Tensor):
|
428 |
+
x = Tensor(x, dtype=x.dtype)
|
429 |
+
# determine shapes
|
430 |
+
M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0]
|
431 |
+
batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1
|
432 |
+
K, N = w.shape[-2:]
|
433 |
+
assert K == x.shape[-1]
|
434 |
+
if x.ndim == 3 and w.ndim == 3:
|
435 |
+
assert x.shape[0] == w.shape[0]
|
436 |
+
# compute optimization flags
|
437 |
+
out_dtype = precision_config.out_dtype or x.dtype
|
438 |
+
can_use_tma = x.storage.is_tma_compliant() and \
|
439 |
+
w.storage.is_tma_compliant() and \
|
440 |
+
(w_scale is None or w_scale.storage.is_tma_compliant())
|
441 |
+
# hopper w/ mxfp4 doesn't support TMA
|
442 |
+
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
|
443 |
+
can_use_fused_scatter = scatter_indx is not None and fused_activation.specs.fn is None
|
444 |
+
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
|
445 |
+
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
|
446 |
+
)
|
447 |
+
if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
|
448 |
+
raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
|
449 |
+
if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp():
|
450 |
+
raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP")
|
451 |
+
# determine necessary pre/post processing
|
452 |
+
preprocessing_features = init_preprocessing_features(w, precision_config, opt_flags)
|
453 |
+
postprocessing_features = init_postprocessing_features(routing_data, scatter_indx, opt_flags)
|
454 |
+
# allocate output/scratchpad memory
|
455 |
+
allocation = init_allocation(x, w, precision_config, fused_activation,
|
456 |
+
routing_data, gather_indx, scatter_indx,
|
457 |
+
opt_flags, preprocessing_features, postprocessing_features
|
458 |
+
)
|
459 |
+
memory = apply_allocation(allocation, y)
|
460 |
+
# TMA descriptors require a global memory allocation
|
461 |
+
if opt_flags.is_persistent:
|
462 |
+
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
|
463 |
+
# Intermediate tensors and postprocess kernels for each situation
|
464 |
+
out0, out0_flex = memory["output"], precision_config.flex_ctx.out_data
|
465 |
+
fused_postprocess_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
|
466 |
+
out_scale = None if precision_config.out_scale is None else precision_config.out_scale.data.view(torch.uint8)
|
467 |
+
if postprocessing_features.finalize:
|
468 |
+
if opt_flags.fused_scatter:
|
469 |
+
out0 = memory["output"]
|
470 |
+
else:
|
471 |
+
out0 = memory["scratchpad"]["matmul"]
|
472 |
+
if "mx_out_scale" in memory["scratchpad"]:
|
473 |
+
assert out_scale is not None
|
474 |
+
out_scale = memory["scratchpad"]["mx_out_scale"]
|
475 |
+
out0_flex = OutFlexData() if out0.dtype == torch.float32 else precision_config.flex_ctx.out_data
|
476 |
+
|
477 |
+
fused_activation, fused_postprocess_activation = fused_postprocess_activation, fused_activation
|
478 |
+
out_has_mx = out_scale is not None and out0.element_size() == 1
|
479 |
+
if out_has_mx:
|
480 |
+
if isinstance(out_scale, Tensor):
|
481 |
+
out_scale = Tensor(out_scale)
|
482 |
+
else:
|
483 |
+
out_scale = None
|
484 |
+
# pre-processing
|
485 |
+
x, w, writeback_idxs, writeback_size, finalize_scatter_idxs = apply_preprocessing_features(
|
486 |
+
x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features
|
487 |
+
)
|
488 |
+
# matrix multiplication
|
489 |
+
flex = precision_config.flex_ctx
|
490 |
+
bias_stride = None if bias is None else bias.stride(0)
|
491 |
+
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
|
492 |
+
# moe metadata
|
493 |
+
expt_data = routing_data.expt_data
|
494 |
+
block_m = opt_flags.block_m
|
495 |
+
expt_hist = None if expt_data is None else expt_data.hist
|
496 |
+
expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
|
497 |
+
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
|
498 |
+
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
|
499 |
+
# spmd grid
|
500 |
+
grid_m = triton.cdiv(M, opt_flags.block_m)
|
501 |
+
if expt_block_pid_map is not None:
|
502 |
+
grid_m = routing_data.n_blocks(M, opt_flags.block_m)
|
503 |
+
grid_n = triton.cdiv(N, opt_flags.block_n)
|
504 |
+
max_grid = batch_size * grid_m * grid_n * opt_flags.split_k
|
505 |
+
grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid
|
506 |
+
# canonicalize storage
|
507 |
+
has_gather = gather_indx is not None
|
508 |
+
x_storage = _canonicalize_storage(x.storage, 2 if has_gather else 3, flex.lhs_data)
|
509 |
+
w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data)
|
510 |
+
# create tma descriptor for x
|
511 |
+
x_has_tma = ((not has_gather) or (has_gather and target_info.has_tma_gather())) and opt_flags.is_persistent
|
512 |
+
x_block_tma = ([1] if has_gather else [1, opt_flags.block_m]) + [opt_flags.block_k]
|
513 |
+
x_tensor_or_tma = x_storage.make_tma(x_block_tma) if x_has_tma else x_storage.data
|
514 |
+
# create tma descriptor for w
|
515 |
+
w_has_tma = opt_flags.is_persistent
|
516 |
+
w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n]) if w_has_tma else w_storage.data
|
517 |
+
# create tma descriptor for w_scale
|
518 |
+
w_scale_tensor_or_tma = w_scale
|
519 |
+
w_scale_has_tma = opt_flags.is_persistent and w_scale is not None
|
520 |
+
w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k]) if w_scale_has_tma else w_scale
|
521 |
+
# canonicalize strides
|
522 |
+
x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride())
|
523 |
+
x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None)
|
524 |
+
x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides
|
525 |
+
w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None)
|
526 |
+
w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides
|
527 |
+
out_scale_strides = out_scale.stride() if out_has_mx else (None, None, None, None)
|
528 |
+
out_scale_strides = (0, ) * (3 - len(out_scale_strides)) + out_scale_strides
|
529 |
+
# launch kernel
|
530 |
+
kernels = get_kernels(epilogue.specs, fused_activation.specs)
|
531 |
+
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)](
|
532 |
+
flex.out_data.reinterpret(memory["output"]),
|
533 |
+
flex.out_data.reinterpret(out0), *out0.stride(),
|
534 |
+
*((None, out_scale, None) if out_has_mx else out0_flex),
|
535 |
+
*out_scale_strides[-3:],
|
536 |
+
x_tensor_or_tma, x_storage.data, *x_strides,
|
537 |
+
flex.lhs_data.scale,
|
538 |
+
None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides,
|
539 |
+
w_tensor_or_tma, *w_storage.data.stride(), w_storage.data.stride()[-1] != 1,
|
540 |
+
flex.rhs_data.scale,
|
541 |
+
w_scale_tensor_or_tma, *w_scale_strides,
|
542 |
+
bias, bias_stride,
|
543 |
+
x.shape[-2],
|
544 |
+
x.shape[-2] if routing_data.expt_hist is None else None,
|
545 |
+
N, K,
|
546 |
+
betas, gammas,
|
547 |
+
None if gather_indx is None else gather_indx.src_indx,
|
548 |
+
None if scatter_indx is None else scatter_indx.src_indx,
|
549 |
+
num_indx,
|
550 |
+
writeback_idxs, writeback_size,
|
551 |
+
expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
|
552 |
+
batch_size, grid_m, grid_n,
|
553 |
+
out_alpha,
|
554 |
+
*fused_activation.fn_args, fused_activation.reduction_n,
|
555 |
+
*epilogue.fn_arg_values_matmul,
|
556 |
+
routing_data.n_expts_tot, routing_data.n_expts_act,
|
557 |
+
precision_config.max_num_imprecise_acc,
|
558 |
+
precision_config.allow_tf32,
|
559 |
+
precision_config.flexpoint_saturate_inf,
|
560 |
+
flex.rhs_data.is_per_batch,
|
561 |
+
opt_flags.block_m,
|
562 |
+
opt_flags.block_n,
|
563 |
+
opt_flags.block_k,
|
564 |
+
opt_flags.group_m,
|
565 |
+
XCD_SWIZZLE=opt_flags.xcd_swizzle,
|
566 |
+
SWIZZLE_MX_VALUE=w.storage.layout.name,
|
567 |
+
SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name,
|
568 |
+
EPILOGUE_SUBTILE=opt_flags.epilogue_subtile,
|
569 |
+
SPLIT_K=opt_flags.split_k,
|
570 |
+
EVEN_K=K % opt_flags.block_k == 0,
|
571 |
+
W_CACHE_MODIFIER=opt_flags.w_cache_modifier,
|
572 |
+
TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt,
|
573 |
+
num_warps=opt_flags.num_warps,
|
574 |
+
num_stages=opt_flags.num_stages,
|
575 |
+
arch=opt_flags.arch,
|
576 |
+
UPCAST_INDICES=should_upcast_indices(x, w, out0),
|
577 |
+
DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0,
|
578 |
+
SWAP_XW=preprocessing_features.swap_xw,
|
579 |
+
IS_EPILOGUE_DEQUANT_MXFP8=epilogue.specs.name == FnName.DEQUANTIZE_MXFP8.name,
|
580 |
+
NUM_SMS = grid if opt_flags.is_persistent else 0,
|
581 |
+
**opt_flags.target_kernel_kwargs)
|
582 |
+
# post-processing
|
583 |
+
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
|
584 |
+
num_indx, precision_config, routing_data,
|
585 |
+
postprocessing_features, memory, fused_postprocess_activation, epilogue)
|
586 |
+
# remove split-k
|
587 |
+
out = out.squeeze(0)
|
588 |
+
if not is_input_batched:
|
589 |
+
out = out.view(out.shape[-2], out.shape[-1])
|
590 |
+
return out
|
591 |
+
|
592 |
+
|
593 |
+
# -----------------------------------------------------------------------------
|
594 |
+
# Reference Implementation
|
595 |
+
# -----------------------------------------------------------------------------
|
596 |
+
|
597 |
+
def matmul_ogs_torch(x, w, bias,
|
598 |
+
routing_data: RoutingData = None,
|
599 |
+
gather_indx: GatherIndx = None,
|
600 |
+
scatter_indx: ScatterIndx = None,
|
601 |
+
precision_config: PrecisionConfig = None,
|
602 |
+
betas = None,
|
603 |
+
gammas = None,
|
604 |
+
round_x = None, round_y = None,
|
605 |
+
):
|
606 |
+
is_input_batched = x.ndim == 3
|
607 |
+
assert x.dtype.itemsize > 1
|
608 |
+
assert w.dtype.itemsize > 1
|
609 |
+
if is_input_batched:
|
610 |
+
assert gather_indx is None, "gather not supported in batched mode"
|
611 |
+
assert scatter_indx is None, "scatter not supported in batched mode"
|
612 |
+
assert routing_data is None, "routing not supported in batched mode"
|
613 |
+
assert w.ndim == 3 and w.shape[0] == x.shape[0]
|
614 |
+
if round_x is None:
|
615 |
+
round_x = lambda x: x
|
616 |
+
if round_y is None:
|
617 |
+
round_y = lambda x: x
|
618 |
+
if bias.ndim == 1:
|
619 |
+
bias = bias.view(1, *bias.shape)
|
620 |
+
if w.ndim == 2:
|
621 |
+
w = w.view(1, *w.shape)
|
622 |
+
if x.ndim == 2:
|
623 |
+
x = x.view(1, *x.shape)
|
624 |
+
if routing_data is None:
|
625 |
+
routing_data = RoutingData(None, None, w.shape[0], 1)
|
626 |
+
n_expts_act = routing_data.n_expts_act
|
627 |
+
# memory offsets
|
628 |
+
if routing_data.n_expts_tot > 1 and not is_input_batched:
|
629 |
+
sizes = routing_data.expt_hist
|
630 |
+
off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32)
|
631 |
+
off[1:] = torch.cumsum(sizes, 0)
|
632 |
+
offs = list(itertools.pairwise(off))
|
633 |
+
else:
|
634 |
+
offs = [[0, x.shape[1]] for _ in range(w.shape[0])]
|
635 |
+
# compute
|
636 |
+
n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0]
|
637 |
+
y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype)
|
638 |
+
for i, (lo, hi) in enumerate(offs):
|
639 |
+
if gather_indx is None:
|
640 |
+
idx = torch.arange(lo, hi, device=x.device)
|
641 |
+
else:
|
642 |
+
idx = gather_indx.src_indx[lo:hi] // n_expts_act
|
643 |
+
batch = i if is_input_batched else 0
|
644 |
+
out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(),
|
645 |
+
w[i].float())
|
646 |
+
if bias is not None:
|
647 |
+
out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None]
|
648 |
+
if gammas is not None:
|
649 |
+
out *= gammas[lo:hi, None]
|
650 |
+
y[batch, lo:hi, :] = round_y(out)
|
651 |
+
if not is_input_batched:
|
652 |
+
y = y.view(y.shape[1], y.shape[2])
|
653 |
+
if scatter_indx is None:
|
654 |
+
return y
|
655 |
+
# accumulate output from all experts
|
656 |
+
n_rows = y.shape[0] // n_expts_act
|
657 |
+
out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device)
|
658 |
+
for i, (lo, hi) in enumerate(offs):
|
659 |
+
dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act
|
660 |
+
msk = dst_idx != -1
|
661 |
+
out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float()
|
662 |
+
return out
|
build/torch-universal/triton_kernels/matmul_ogs_details/_common.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import triton
|
4 |
+
import triton.language as tl
|
5 |
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
6 |
+
|
7 |
+
# -----------------------------------------------------------------------------
|
8 |
+
# Utilities
|
9 |
+
# -----------------------------------------------------------------------------
|
10 |
+
|
11 |
+
|
12 |
+
@tl.constexpr_function
|
13 |
+
def get_scaled_dot_format_string(dtype: tl.dtype):
|
14 |
+
mapping = {
|
15 |
+
tl.float16: "fp16",
|
16 |
+
tl.bfloat16: "bf16",
|
17 |
+
tl.uint8: "e2m1",
|
18 |
+
tl.float8e4nv: "e4m3",
|
19 |
+
tl.float8e5: "e5m2",
|
20 |
+
}
|
21 |
+
return mapping[dtype]
|
22 |
+
|
23 |
+
|
24 |
+
@triton.jit
|
25 |
+
def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr):
|
26 |
+
"""
|
27 |
+
Swizzle the program id based on integer XCD_SWIZZLE.
|
28 |
+
This is useful for reording how blocks are ordered. A scheduler may, for example,
|
29 |
+
assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
|
30 |
+
This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
|
31 |
+
becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
|
32 |
+
the same hardware unit.
|
33 |
+
"""
|
34 |
+
# Number of pids per group in the new arrangement
|
35 |
+
pids_per_group = domain_size // XCD_SWIZZLE
|
36 |
+
extra_pid_groups = domain_size % XCD_SWIZZLE
|
37 |
+
|
38 |
+
# Compute current current and local pid within the group
|
39 |
+
group = pid % XCD_SWIZZLE
|
40 |
+
local_pid = pid // XCD_SWIZZLE
|
41 |
+
|
42 |
+
# Calculate new pid based on the new grouping
|
43 |
+
new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid
|
44 |
+
return new_pid
|
45 |
+
|
46 |
+
|
47 |
+
@triton.jit
|
48 |
+
def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr):
|
49 |
+
width = GROUP_M * grid_n
|
50 |
+
group_id = pid // width
|
51 |
+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
|
52 |
+
pid_m = group_id * GROUP_M + (pid % group_size)
|
53 |
+
pid_n = (pid % width) // (group_size)
|
54 |
+
return pid_m, pid_n
|
55 |
+
|
56 |
+
|
57 |
+
def make_matmul_repr(base_name, order):
|
58 |
+
|
59 |
+
def matmul_repr(specialization):
|
60 |
+
signature = specialization.signature
|
61 |
+
constants = specialization.constants
|
62 |
+
reorder = lambda L: [L[i] for i in order]
|
63 |
+
layout = lambda stride: "N" if stride in constants else "T"
|
64 |
+
|
65 |
+
def convert_dtype(dtype):
|
66 |
+
if "tensordesc" in dtype:
|
67 |
+
ret = convert_dtype(dtype.split("<")[1].split("[")[0])
|
68 |
+
return ret
|
69 |
+
elif "u8" in dtype:
|
70 |
+
return "mxfp4"
|
71 |
+
elif dtype[0] == "*":
|
72 |
+
return dtype[1:]
|
73 |
+
else:
|
74 |
+
return dtype
|
75 |
+
|
76 |
+
dtypes = "x".join([convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])])
|
77 |
+
layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])])
|
78 |
+
blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]])
|
79 |
+
# mode = []
|
80 |
+
# if "GatherIndx" not in constants:
|
81 |
+
# mode += ['g']
|
82 |
+
# if "ScatterSrcIndx" not in constants:
|
83 |
+
# mode += ['s']
|
84 |
+
# suffix = "" if not mode else "_o" + (''.join(mode))
|
85 |
+
# if base_name.startswith("_p"):
|
86 |
+
# suffix += "_ptma"
|
87 |
+
return f"{base_name}_{layouts}_{dtypes}_{blocks}"
|
88 |
+
|
89 |
+
return matmul_repr
|
90 |
+
|
91 |
+
|
92 |
+
def matmul_launch_metadata(grid, kernel, args):
|
93 |
+
from ..proton_opts import launch_metadata_allow_sync
|
94 |
+
|
95 |
+
ret = dict()
|
96 |
+
M, N, K = args["M"], args["N"], args["K"]
|
97 |
+
Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]]
|
98 |
+
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
|
99 |
+
hist = args["ExptHist"]
|
100 |
+
if hist is not None:
|
101 |
+
# If annotation is given, use that to generate name for profiling.
|
102 |
+
if tokens_per_expt is not None:
|
103 |
+
n_rows = f"{tokens_per_expt}*"
|
104 |
+
elif launch_metadata_allow_sync():
|
105 |
+
n_rows = int(hist.float().mean())
|
106 |
+
else:
|
107 |
+
n_rows = "unknown"
|
108 |
+
|
109 |
+
if launch_metadata_allow_sync():
|
110 |
+
n_tokens = float(hist.sum())
|
111 |
+
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
|
112 |
+
elif tokens_per_expt is not None:
|
113 |
+
n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
|
114 |
+
# This may not be totally correct (e.g., we might not be using all experts)
|
115 |
+
# but it's better than nothing.
|
116 |
+
n_w_bytes = W.numel() * W.element_size()
|
117 |
+
else:
|
118 |
+
n_tokens = None
|
119 |
+
n_w_bytes = 0
|
120 |
+
|
121 |
+
# If annotation is given, use that to generate name for profiling.
|
122 |
+
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
|
123 |
+
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
|
124 |
+
else:
|
125 |
+
n_tokens = None
|
126 |
+
n_w_bytes = W.numel() * W.element_size()
|
127 |
+
repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
|
128 |
+
nbits = X.dtype.itemsize * 8
|
129 |
+
batch_repr = ""
|
130 |
+
if "batch_size" in args and args["batch_size"] > 1:
|
131 |
+
batch_repr = repr("B", args["batch_size"]) + ", "
|
132 |
+
ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
|
133 |
+
ep_subtile = args["EPILOGUE_SUBTILE"]
|
134 |
+
if ep_subtile is not None and ep_subtile > 1:
|
135 |
+
ret["name"] += f" ep/{ep_subtile}"
|
136 |
+
|
137 |
+
if hist is not None and n_tokens is None:
|
138 |
+
return ret # Don't fill metadata because we can't compute them properly.
|
139 |
+
|
140 |
+
fM = M if M is not None else n_tokens
|
141 |
+
fK = K if K is not None else n_tokens
|
142 |
+
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
|
143 |
+
|
144 |
+
gindx = args.get("GatherIndx", None)
|
145 |
+
# sindx = args.get("WriteBackIndx", None)
|
146 |
+
n_x_bytes = X.numel() * X.element_size()
|
147 |
+
n_y_bytes = Y.numel() * Y.element_size()
|
148 |
+
if hist is not None:
|
149 |
+
assert n_tokens is not None
|
150 |
+
n_expts_act = args["N_EXPTS_ACT"]
|
151 |
+
|
152 |
+
if (gindx is not None) and launch_metadata_allow_sync():
|
153 |
+
# recreate inverse GatherIndx.
|
154 |
+
dst = torch.full_like(gindx, -1)
|
155 |
+
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
|
156 |
+
mask = (gindx != -1)
|
157 |
+
dst[gindx[mask]] = idx[mask]
|
158 |
+
n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
|
159 |
+
else:
|
160 |
+
n_read_rows = n_tokens
|
161 |
+
n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
|
162 |
+
n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
|
163 |
+
ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
|
164 |
+
|
165 |
+
return ret
|
build/torch-universal/triton_kernels/matmul_ogs_details/_finalize_matmul.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, update_scale
|
4 |
+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
5 |
+
from triton_kernels.target_info import cuda_capability_geq as _cuda_capability_geq
|
6 |
+
from triton_kernels.target_info import is_hip as _is_hip
|
7 |
+
|
8 |
+
|
9 |
+
# fmt: off
|
10 |
+
@tl.constexpr_function
|
11 |
+
def is_hip():
|
12 |
+
return _is_hip()
|
13 |
+
|
14 |
+
|
15 |
+
@tl.constexpr_function
|
16 |
+
def cuda_capability_geq(x, y):
|
17 |
+
return _cuda_capability_geq(x, y)
|
18 |
+
|
19 |
+
|
20 |
+
@tl.constexpr_function
|
21 |
+
def log2(n):
|
22 |
+
return len(bin(n)) - 3
|
23 |
+
|
24 |
+
|
25 |
+
@tl.constexpr_function
|
26 |
+
def _permute_to_end_order(n: int, axis: int):
|
27 |
+
"""
|
28 |
+
Returns the order of the axes of a tensor to permute `axis` to the end.
|
29 |
+
"""
|
30 |
+
order = tuple(range(n))
|
31 |
+
return order[:axis] + order[(axis + 1):] + (axis, )
|
32 |
+
|
33 |
+
|
34 |
+
@triton.jit
|
35 |
+
def permute_to_end(x, axis: tl.constexpr):
|
36 |
+
"""
|
37 |
+
Permutes `x` so that `axis` is the last axis.
|
38 |
+
"""
|
39 |
+
N: tl.constexpr = len(x.shape)
|
40 |
+
return tl.permute(x, _permute_to_end_order(N, axis).value)
|
41 |
+
|
42 |
+
|
43 |
+
@triton.jit
|
44 |
+
def split_n(x, N: tl.constexpr):
|
45 |
+
"""
|
46 |
+
Given `x`, a tensor of shape AxB...x2x2...x2, split it N times.
|
47 |
+
Return a tuple of the results.
|
48 |
+
"""
|
49 |
+
xs = (x, )
|
50 |
+
for i in tl.static_range(N):
|
51 |
+
next = tl.split(xs[0])
|
52 |
+
for j in tl.static_range(2**i - 1):
|
53 |
+
next = next + tl.split(xs[j + 1])
|
54 |
+
xs = next
|
55 |
+
return xs
|
56 |
+
|
57 |
+
|
58 |
+
@triton.jit
|
59 |
+
def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr = None, NUM_THREADS: tl.constexpr = None):
|
60 |
+
N: tl.constexpr = tl.extra.cuda.num_threads() if NUM_THREADS is None else NUM_THREADS
|
61 |
+
BS: tl.constexpr = x.numel if BLOCK_SIZE is None else BLOCK_SIZE
|
62 |
+
tl.static_assert(BS % N == 0, "BLOCK_SIZE must be divisible by NUM_THREADS")
|
63 |
+
return tl.max(tl.reshape(tl.abs(x), [N, BS // N], can_reorder=True), axis=1)
|
64 |
+
|
65 |
+
|
66 |
+
def _finalize_matmul_launch_metadata(grid, kernel, args):
|
67 |
+
ret = dict()
|
68 |
+
Out, A, ScatterSrcIndx, FinalizeScatterIdxs, K, M, N, EXPT_PER_TOK, NumRows = [
|
69 |
+
args[name]
|
70 |
+
for name in ["Out", "A", "ScatterSrcIndx", "FinalizeScatterIdxs", "K", "M", "N", "EXPT_PER_TOK", "NumRows"]
|
71 |
+
]
|
72 |
+
ret["name"] = f"{kernel.name} [M={M}x{EXPT_PER_TOK} {N=} {K=}]"
|
73 |
+
|
74 |
+
if FinalizeScatterIdxs is not None:
|
75 |
+
M = FinalizeScatterIdxs[-1].item()
|
76 |
+
|
77 |
+
if ScatterSrcIndx is not None:
|
78 |
+
is_active = (ScatterSrcIndx != -1).view((-1, EXPT_PER_TOK))
|
79 |
+
n_active = is_active.sum(dim=1)
|
80 |
+
need_accum = n_active >= (1 if K > 1 else 2)
|
81 |
+
is_active &= need_accum[:, None]
|
82 |
+
active_input_rows = is_active.sum()
|
83 |
+
active_output_rows = need_accum.sum()
|
84 |
+
if EXPT_PER_TOK > 1:
|
85 |
+
# Masked rows are set to zero.
|
86 |
+
active_output_rows += (n_active == 0).sum()
|
87 |
+
else:
|
88 |
+
if NumRows is not None:
|
89 |
+
if isinstance(NumRows, int):
|
90 |
+
active_input_rows = NumRows
|
91 |
+
else:
|
92 |
+
active_input_rows = NumRows.item()
|
93 |
+
else:
|
94 |
+
active_input_rows = M
|
95 |
+
active_output_rows = M
|
96 |
+
|
97 |
+
ret["bytes"] = (active_input_rows * K * A.shape[-1] * A.element_size() +
|
98 |
+
active_output_rows * Out.shape[-1] * Out.element_size())
|
99 |
+
if FinalizeScatterIdxs is not None:
|
100 |
+
ret["bytes"] += FinalizeScatterIdxs.numel() * FinalizeScatterIdxs.element_size()
|
101 |
+
elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1:
|
102 |
+
ret["bytes"] += ScatterSrcIndx.numel() * ScatterSrcIndx.element_size()
|
103 |
+
nbits = Out.dtype.itemsize * 8
|
104 |
+
ret[f"flops{nbits}"] = active_input_rows * K * A.shape[-1]
|
105 |
+
return ret
|
106 |
+
|
107 |
+
|
108 |
+
@tl.constexpr_function
|
109 |
+
def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None):
|
110 |
+
"""
|
111 |
+
Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision
|
112 |
+
adds. If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside
|
113 |
+
that register.
|
114 |
+
|
115 |
+
Generates code something like:
|
116 |
+
|
117 |
+
add.f32.f16 $0, $2, $1;
|
118 |
+
add.f32.f16 $0, $3, $0;
|
119 |
+
add.f32.f16 $0, $4, $0;
|
120 |
+
add.f32.f16 $0, $5, $0;
|
121 |
+
|
122 |
+
.reg .f32 b;
|
123 |
+
abs.f32 b, $0;
|
124 |
+
max.f32 my_abs_max, my_abs_max, b;
|
125 |
+
"""
|
126 |
+
# Add the first f16 value to the input $1, store into the output $0.
|
127 |
+
ptx = f"\nadd.f32.{src_type} $0, $2, $1;"
|
128 |
+
# Accumulate the rest of the inputs into the output $0.
|
129 |
+
for i in range(1, n_inputs):
|
130 |
+
ptx += f"\nadd.f32.{src_type} $0, ${2 + i}, $0;"
|
131 |
+
if absmax_reg_name is not None:
|
132 |
+
# Update `absmax_reg_name` with the absolute maximum value seen so far.
|
133 |
+
ptx += f"""
|
134 |
+
.reg .f32 b;
|
135 |
+
abs.f32 b, $0;
|
136 |
+
max.f32 {absmax_reg_name}, {absmax_reg_name}, b;
|
137 |
+
"""
|
138 |
+
# Return the PTX snippet, brace-enclosed so we don't pollute the global namespace.
|
139 |
+
return f"{{{ptx}}}"
|
140 |
+
|
141 |
+
|
142 |
+
@triton.jit
|
143 |
+
def _mixed_precision_accumulate_and_track_absmax(acc, x, axis: tl.constexpr, absmax_reg_name: tl.constexpr = None):
|
144 |
+
"""Given an fp8/bf16/fp16 tensor, accumulate into `acc` along `axis`.
|
145 |
+
Values are first converted to bf16/fp16, packed into 32-bit registers, and then accumulated using
|
146 |
+
mixed-precision adds.
|
147 |
+
|
148 |
+
If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside that
|
149 |
+
register.
|
150 |
+
"""
|
151 |
+
REDUCTION_SIZE: tl.constexpr = x.shape[axis]
|
152 |
+
tl.static_assert(2**log2(REDUCTION_SIZE) == REDUCTION_SIZE,
|
153 |
+
f"Reduction size must be a power of 2, was {REDUCTION_SIZE}")
|
154 |
+
# move `axis` to the last axis and reshape for iterative splitting.
|
155 |
+
x = permute_to_end(x, axis)
|
156 |
+
x = tl.reshape(x, x.shape[:-1] + (2, ) * log2(REDUCTION_SIZE))
|
157 |
+
# Split into a tuple of AxB tensors.
|
158 |
+
xs = split_n(x, log2(REDUCTION_SIZE))
|
159 |
+
if (tl.constexpr(x.dtype == tl.float8e4nv) or tl.constexpr(x.dtype == tl.float8e5)):
|
160 |
+
# Convert fp8 to fp16.
|
161 |
+
fp16_xs = ()
|
162 |
+
for i in tl.static_range(len(xs)):
|
163 |
+
fp16_xs += (xs[i].to(tl.float16), )
|
164 |
+
xs = fp16_xs
|
165 |
+
src_type: tl.constexpr = "f16"
|
166 |
+
elif x.dtype == tl.float16:
|
167 |
+
src_type: tl.constexpr = "f16"
|
168 |
+
elif x.dtype == tl.bfloat16:
|
169 |
+
src_type: tl.constexpr = "bf16"
|
170 |
+
else:
|
171 |
+
tl.static_assert(False, f"Unsupported dtype: {x.dtype}")
|
172 |
+
return tl.inline_asm_elementwise(
|
173 |
+
_accumulate_f16_into_f32_and_track_absmax_ptx(REDUCTION_SIZE, src_type, absmax_reg_name),
|
174 |
+
"=r,r" + (",h" * len(xs)),
|
175 |
+
(acc, ) + xs,
|
176 |
+
dtype=tl.float32,
|
177 |
+
is_pure=True,
|
178 |
+
pack=1,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
def _finalize_matmul_repr(specialization):
|
183 |
+
signature = specialization.signature
|
184 |
+
suffix = "" if "ScatterSrcIndx" in specialization.constants else "_scatter"
|
185 |
+
return f"_finalize_matmul{suffix}_{signature['A'][1:]}"
|
186 |
+
|
187 |
+
|
188 |
+
@triton.jit(repr=_finalize_matmul_repr, launch_metadata=_finalize_matmul_launch_metadata)
|
189 |
+
def _finalize_matmul(
|
190 |
+
Out,
|
191 |
+
OutExpectedScale,
|
192 |
+
OutActualScale,
|
193 |
+
OutChecksumScale,
|
194 |
+
stride_out_mx_m, stride_out_mx_n,
|
195 |
+
A,
|
196 |
+
stride_a_k,
|
197 |
+
stride_a_m,
|
198 |
+
AScale,
|
199 |
+
stride_a_mx_k,
|
200 |
+
stride_a_mx_m,
|
201 |
+
ScatterSrcIndx,
|
202 |
+
FinalizeScatterIdxs,
|
203 |
+
K: tl.constexpr,
|
204 |
+
M,
|
205 |
+
N,
|
206 |
+
NumRows,
|
207 |
+
# fused activation function
|
208 |
+
ACTIVATION_FN: tl.constexpr,
|
209 |
+
activation_fn_args,
|
210 |
+
ACTIVATION_REDUCTION_N: tl.constexpr,
|
211 |
+
# epilogue transform
|
212 |
+
EPILOGUE_FN: tl.constexpr,
|
213 |
+
epilogue_fn_args,
|
214 |
+
EXPT_PER_TOK: tl.constexpr,
|
215 |
+
flexpoint_saturate_inf: tl.constexpr,
|
216 |
+
BLOCK_N: tl.constexpr,
|
217 |
+
STAGES: tl.constexpr,
|
218 |
+
HAS_FUSED_SCRATCHPAD: tl.constexpr,
|
219 |
+
):
|
220 |
+
IN_MXFP8: tl.constexpr = stride_a_mx_k is not None
|
221 |
+
OUT_MXFP8: tl.constexpr = stride_out_mx_m is not None
|
222 |
+
if HAS_FUSED_SCRATCHPAD:
|
223 |
+
# Bump A to the scratchpad region.
|
224 |
+
A += tl.cast(M, tl.int64) * stride_a_m
|
225 |
+
|
226 |
+
USE_FUSED_MIXED_PREC_ACC: tl.constexpr = (cuda_capability_geq(10, 0)
|
227 |
+
and tl.constexpr(A.dtype.element_ty != tl.float32))
|
228 |
+
USE_FUSED_ABSMAX: tl.constexpr = (USE_FUSED_MIXED_PREC_ACC and OutActualScale is not None) and ACTIVATION_FN is None
|
229 |
+
|
230 |
+
THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
|
231 |
+
local_max = tl.full([THREADS_PER_BLOCK], 0.0, tl.float32)
|
232 |
+
if USE_FUSED_ABSMAX:
|
233 |
+
local_max = tl.inline_asm_elementwise(
|
234 |
+
"""
|
235 |
+
.reg .f32 my_abs_max;
|
236 |
+
mov.b32 my_abs_max, 0;
|
237 |
+
mov.b32 $0, 0;
|
238 |
+
""", "=r,r", [local_max], dtype=tl.float32, is_pure=False, pack=1)
|
239 |
+
|
240 |
+
out_scale = load_scale(OutExpectedScale)
|
241 |
+
a_scale = load_scale(AScale)
|
242 |
+
|
243 |
+
if FinalizeScatterIdxs is not None:
|
244 |
+
MBound = tl.load(FinalizeScatterIdxs + M + M * EXPT_PER_TOK)
|
245 |
+
if tl.program_id(0) >= MBound:
|
246 |
+
return
|
247 |
+
else:
|
248 |
+
MBound = M
|
249 |
+
|
250 |
+
if NumRows is not None:
|
251 |
+
NumRows = NumRows # remove constexpr
|
252 |
+
if NumRows.dtype.is_ptr():
|
253 |
+
NumRows = tl.load(NumRows)
|
254 |
+
|
255 |
+
if FinalizeScatterIdxs is not None or (ScatterSrcIndx is not None and EXPT_PER_TOK > 1):
|
256 |
+
n_active_experts = 0
|
257 |
+
else:
|
258 |
+
n_active_experts: tl.constexpr = EXPT_PER_TOK
|
259 |
+
|
260 |
+
OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
|
261 |
+
outN = N // ACTIVATION_REDUCTION_N
|
262 |
+
|
263 |
+
for pid_m in tl.range(tl.program_id(0), MBound, tl.num_programs(0)):
|
264 |
+
src_offs = pid_m * EXPT_PER_TOK + tl.arange(0, EXPT_PER_TOK)
|
265 |
+
if FinalizeScatterIdxs is not None:
|
266 |
+
row = tl.load(FinalizeScatterIdxs + pid_m)
|
267 |
+
src_idxs = tl.load(FinalizeScatterIdxs + M + src_offs)
|
268 |
+
n_active_experts = tl.sum((src_idxs != -1).to(tl.int32))
|
269 |
+
elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1:
|
270 |
+
row = pid_m
|
271 |
+
src_idxs = tl.load(ScatterSrcIndx + src_offs)
|
272 |
+
n_active_experts = tl.sum((src_idxs != -1).to(tl.int32))
|
273 |
+
else:
|
274 |
+
row = pid_m
|
275 |
+
src_idxs = src_offs
|
276 |
+
if NumRows is not None:
|
277 |
+
src_idxs = tl.where(src_idxs < NumRows, src_idxs, -1)
|
278 |
+
|
279 |
+
if n_active_experts == 0:
|
280 |
+
for off_n in tl.range(tl.program_id(1) * OUT_BLOCK_N, outN, tl.num_programs(1) * OUT_BLOCK_N):
|
281 |
+
offs_n = off_n + tl.arange(0, OUT_BLOCK_N)
|
282 |
+
n_mask = offs_n < outN
|
283 |
+
tl.store(Out + row * outN + offs_n, tl.zeros([OUT_BLOCK_N], dtype=Out.dtype.element_ty), mask=n_mask)
|
284 |
+
else:
|
285 |
+
for off_n in tl.range(tl.program_id(1) * BLOCK_N, N, tl.num_programs(1) * BLOCK_N, num_stages=STAGES):
|
286 |
+
offs_n = off_n + tl.arange(0, BLOCK_N)
|
287 |
+
n_mask = offs_n < N
|
288 |
+
if IN_MXFP8:
|
289 |
+
MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
|
290 |
+
N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
|
291 |
+
offs_n_scale = off_n // BLOCK_N * MX_SCALE_BLOCK_N + tl.arange(0, MX_SCALE_BLOCK_N)[None, :]
|
292 |
+
n_mask_scale = offs_n_scale < N_MX_BLOCK
|
293 |
+
|
294 |
+
acc = tl.zeros([BLOCK_N], dtype=tl.float32)
|
295 |
+
if is_hip():
|
296 |
+
if EXPT_PER_TOK > 1:
|
297 |
+
src_idxs_tup = split_n(tl.reshape(src_idxs, (2, ) * log2(EXPT_PER_TOK)), log2(EXPT_PER_TOK))
|
298 |
+
else:
|
299 |
+
# Convert 1D tensor to 1D tuple.
|
300 |
+
src_idxs_tup = tl.split(tl.reshape(tl.join(src_idxs, src_idxs), (2, )))[:1]
|
301 |
+
for i in tl.static_range(0, EXPT_PER_TOK, 1):
|
302 |
+
src_idx = src_idxs_tup[i]
|
303 |
+
if src_idx != -1:
|
304 |
+
As = A + src_idx.to(tl.int64) * stride_a_m + offs_n
|
305 |
+
for ki in tl.static_range(K):
|
306 |
+
acc += tl.load(As, mask=n_mask, other=0.0)
|
307 |
+
As += stride_a_k
|
308 |
+
else:
|
309 |
+
As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :]
|
310 |
+
if IN_MXFP8:
|
311 |
+
AScales = AScale + src_idxs.to(tl.int64)[:, None] * stride_a_mx_m + offs_n_scale[None, :]
|
312 |
+
for ki in tl.static_range(K):
|
313 |
+
a = tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0)
|
314 |
+
As += stride_a_k
|
315 |
+
if IN_MXFP8:
|
316 |
+
a_mx_scale = tl.load(AScales, mask=(src_idxs != -1)[:, None] & n_mask_scale[None, :])
|
317 |
+
AScales += stride_a_mx_k
|
318 |
+
a_mx_scale = (a_mx_scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
|
319 |
+
a_mx_scale = a_mx_scale.reshape([EXPT_PER_TOK, MX_SCALE_BLOCK_N, 1])
|
320 |
+
a = a.to(tl.float32).reshape([EXPT_PER_TOK, MX_SCALE_BLOCK_N, MXFP_BLOCK_SIZE])
|
321 |
+
a = (a_mx_scale * a).reshape([EXPT_PER_TOK, BLOCK_N])
|
322 |
+
acc += tl.sum(a, dtype=tl.float32, axis=0)
|
323 |
+
elif USE_FUSED_MIXED_PREC_ACC:
|
324 |
+
acc = _mixed_precision_accumulate_and_track_absmax(
|
325 |
+
acc, a, axis=0,
|
326 |
+
absmax_reg_name="my_abs_max" if USE_FUSED_ABSMAX and ki == K - 1 else None)
|
327 |
+
else:
|
328 |
+
acc += tl.sum(a, dtype=tl.float32, axis=0)
|
329 |
+
if not IN_MXFP8:
|
330 |
+
acc = acc * a_scale
|
331 |
+
if ACTIVATION_FN is not None:
|
332 |
+
out = ACTIVATION_FN(tl.reshape(acc, (1, BLOCK_N)), *activation_fn_args)
|
333 |
+
out = tl.reshape(out, (OUT_BLOCK_N, ))
|
334 |
+
else:
|
335 |
+
tl.static_assert(ACTIVATION_REDUCTION_N == 1,
|
336 |
+
"Activation reduction must be 1 if no activation fn is provided")
|
337 |
+
out = acc
|
338 |
+
if not USE_FUSED_ABSMAX and OutActualScale is not None:
|
339 |
+
local_max = tl.maximum(local_max, thread_local_absmax(out))
|
340 |
+
if OUT_MXFP8:
|
341 |
+
OUT_MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE
|
342 |
+
OUT_N_MX_BLOCK: tl.constexpr = (outN + MXFP_BLOCK_SIZE - 1) // MXFP_BLOCK_SIZE
|
343 |
+
offs_n_scale = off_n // BLOCK_N * OUT_MX_SCALE_BLOCK_N + tl.arange(0, OUT_MX_SCALE_BLOCK_N)[None, :]
|
344 |
+
n_mask_scale = offs_n_scale < OUT_N_MX_BLOCK
|
345 |
+
acc, acc_scale = EPILOGUE_FN(acc[None, :], n_mask[None, :], *epilogue_fn_args,
|
346 |
+
pid=row * tl.num_programs(1) + tl.program_id(1))
|
347 |
+
tl.static_assert(OUT_BLOCK_N % OUT_MX_SCALE_BLOCK_N == 0, "")
|
348 |
+
tl.store(OutActualScale + row * stride_out_mx_m + offs_n_scale * stride_out_mx_n, acc_scale, mask=n_mask_scale)
|
349 |
+
tl.store(Out + row * outN + offs_n[None, :], acc, mask=n_mask[None, :])
|
350 |
+
else:
|
351 |
+
out = float_to_flex(out, out_scale if OutExpectedScale is not None else None, None, OutChecksumScale,
|
352 |
+
None, Out, flexpoint_saturate_inf)
|
353 |
+
if EPILOGUE_FN is not None:
|
354 |
+
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=Out.dtype.element_ty,
|
355 |
+
pid=row * tl.num_programs(1) + tl.program_id(1))
|
356 |
+
offs_n = off_n // ACTIVATION_REDUCTION_N + tl.arange(0, OUT_BLOCK_N)
|
357 |
+
n_mask = offs_n < outN
|
358 |
+
tl.store(Out + row * outN + offs_n, out, mask=n_mask)
|
359 |
+
|
360 |
+
persisent_m = tl.num_programs(0) < MBound
|
361 |
+
if not persisent_m and n_active_experts == 0:
|
362 |
+
# Skip updating the scale if there were no active experts and this is a non-persistent launch.
|
363 |
+
# The loop ran only once, and inside it we only stored zeros.
|
364 |
+
return
|
365 |
+
|
366 |
+
if USE_FUSED_ABSMAX:
|
367 |
+
local_max = tl.inline_asm_elementwise(
|
368 |
+
"mov.b32 $0, my_abs_max;",
|
369 |
+
"=r,r",
|
370 |
+
[local_max],
|
371 |
+
dtype=tl.float32,
|
372 |
+
is_pure=True,
|
373 |
+
pack=1,
|
374 |
+
)
|
375 |
+
local_max *= a_scale
|
376 |
+
if not OUT_MXFP8:
|
377 |
+
update_scale(local_max, OutActualScale, Out)
|
build/torch-universal/triton_kernels/matmul_ogs_details/_matmul_ogs.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# isort: off
|
2 |
+
# fmt: off
|
3 |
+
import triton
|
4 |
+
import triton.language as tl
|
5 |
+
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
|
6 |
+
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
|
7 |
+
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
|
8 |
+
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
|
9 |
+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
10 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
11 |
+
|
12 |
+
|
13 |
+
@triton.jit
|
14 |
+
def _zero_masked_rows(
|
15 |
+
pid_m, pid_n,
|
16 |
+
Y, stride_y_m, stride_y_n,
|
17 |
+
N,
|
18 |
+
ScatterSrcIndx, num_idxs,
|
19 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
20 |
+
offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M)
|
21 |
+
offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
|
22 |
+
src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
|
23 |
+
YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
|
24 |
+
mask_n = offs_n < N
|
25 |
+
mask = (src_idx == -1)[:, None] & mask_n[None, :]
|
26 |
+
tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask)
|
27 |
+
|
28 |
+
|
29 |
+
_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
|
30 |
+
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
|
31 |
+
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
|
32 |
+
def _matmul_ogs(
|
33 |
+
Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
|
34 |
+
YExpectedScale, YActualScale, YChecksumScale,
|
35 |
+
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
|
36 |
+
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
|
37 |
+
XScale,
|
38 |
+
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
|
39 |
+
W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
|
40 |
+
WScale,
|
41 |
+
WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n,
|
42 |
+
B, stride_b_e, # Bias
|
43 |
+
NRows, M, N, K, # shapes
|
44 |
+
# expt data
|
45 |
+
Betas, Gammas,
|
46 |
+
GatherIndx,
|
47 |
+
ScatterSrcIndx, num_idxs,
|
48 |
+
WriteBackIndx, writeback_size,
|
49 |
+
ExptHist, ExptOffs, ExptOffsSum, ExptData,
|
50 |
+
# true grid size
|
51 |
+
batch_size, grid_m, grid_n,
|
52 |
+
# Out scale
|
53 |
+
out_alpha,
|
54 |
+
# fused activation function
|
55 |
+
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
|
56 |
+
# epilogue transform
|
57 |
+
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
|
58 |
+
# MoE config
|
59 |
+
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
|
60 |
+
# precision config
|
61 |
+
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
|
62 |
+
FLEXPOINT_SATURATE_INF: tl.constexpr,
|
63 |
+
PER_BATCH_SCALE: tl.constexpr,
|
64 |
+
# optimization config
|
65 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
66 |
+
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
|
67 |
+
# One of ["HOPPER", "BLACKWELL", None]
|
68 |
+
SWIZZLE_MX_VALUE: tl.constexpr,
|
69 |
+
# One of ["HOPPER", "BLACKWELL", None]
|
70 |
+
SWIZZLE_MX_SCALE: tl.constexpr,
|
71 |
+
EPILOGUE_SUBTILE: tl.constexpr,
|
72 |
+
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
|
73 |
+
W_CACHE_MODIFIER: tl.constexpr,
|
74 |
+
NUM_SMS: tl.constexpr,
|
75 |
+
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
|
76 |
+
UPCAST_INDICES: tl.constexpr = False,
|
77 |
+
DISABLE_Y_TMA: tl.constexpr = True,
|
78 |
+
SWAP_XW: tl.constexpr = False,
|
79 |
+
IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False):
|
80 |
+
|
81 |
+
Y = Out # Y is passed for the purposes of annotation; replace it with Out
|
82 |
+
is_w_microscaled: tl.constexpr = WMxScale is not None
|
83 |
+
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
|
84 |
+
if is_w_microscaled:
|
85 |
+
w_type: tl.constexpr = W.dtype.element_ty
|
86 |
+
is_mxfp4: tl.constexpr = w_type == tl.uint8
|
87 |
+
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
|
88 |
+
"mx_weight_ptr must be uint8 or fp8")
|
89 |
+
tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
|
90 |
+
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
|
91 |
+
tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values")
|
92 |
+
else:
|
93 |
+
tl.static_assert(SWIZZLE_MX_VALUE is None)
|
94 |
+
tl.static_assert(SWIZZLE_MX_SCALE is None)
|
95 |
+
is_x_microscaled: tl.constexpr = XMxScale is not None
|
96 |
+
if is_x_microscaled:
|
97 |
+
x_type: tl.constexpr = X.dtype.element_ty
|
98 |
+
tl.static_assert(is_w_microscaled)
|
99 |
+
tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv")
|
100 |
+
tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
|
101 |
+
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
|
102 |
+
is_out_microscaled: tl.constexpr = stride_y_mx_z is not None
|
103 |
+
|
104 |
+
OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N
|
105 |
+
yN = N // ACTIVATION_REDUCTION_N
|
106 |
+
|
107 |
+
pid = tl.program_id(0)
|
108 |
+
if ExptOffsSum is not None and XCD_SWIZZLE > 1:
|
109 |
+
# Determine how much padding there is on the expert data. This allows us to
|
110 |
+
# know the true grid size and avoid processing padding tiles.
|
111 |
+
padding_m = grid_m - tl.load(ExptOffsSum)
|
112 |
+
else:
|
113 |
+
padding_m: tl.constexpr = 0
|
114 |
+
|
115 |
+
HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
|
116 |
+
index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32
|
117 |
+
|
118 |
+
total_actual_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
|
119 |
+
if padding_m > 0 and pid >= total_actual_tiles:
|
120 |
+
tl.device_assert(batch_size == 0)
|
121 |
+
pid_mn = pid - total_actual_tiles
|
122 |
+
if pid_mn < padding_m * grid_n:
|
123 |
+
pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M)
|
124 |
+
|
125 |
+
# set masked out rows to 0
|
126 |
+
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
|
127 |
+
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
|
128 |
+
return
|
129 |
+
|
130 |
+
# swizzle program ids
|
131 |
+
pid_emnk = pid
|
132 |
+
if XCD_SWIZZLE != 1:
|
133 |
+
pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE)
|
134 |
+
pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
|
135 |
+
pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
|
136 |
+
pid_k = pid_mnk % SPLIT_K
|
137 |
+
pid_mn = pid_mnk // SPLIT_K
|
138 |
+
pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
|
139 |
+
# For split-k, advance to the output k slice
|
140 |
+
if SPLIT_K > 1:
|
141 |
+
Y += pid_k.to( index_type) * stride_y_k
|
142 |
+
if is_out_microscaled:
|
143 |
+
YActualScale += pid_k.to(index_type) * stride_x_mx_k
|
144 |
+
# set masked out rows to 0
|
145 |
+
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
|
146 |
+
_zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N)
|
147 |
+
# unpack expert data
|
148 |
+
if ExptData is None:
|
149 |
+
tl.static_assert(M is not None)
|
150 |
+
expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m
|
151 |
+
else:
|
152 |
+
tl.static_assert(M is None)
|
153 |
+
expt_data = tl.load(ExptData + pid_m)
|
154 |
+
if expt_data == -1:
|
155 |
+
return
|
156 |
+
expt_id = expt_data & 0x0000FFFF
|
157 |
+
block_id = expt_data >> 16
|
158 |
+
M = tl.load(ExptHist + expt_id)
|
159 |
+
start_m = tl.load(ExptOffs + expt_id)
|
160 |
+
start_z = 0
|
161 |
+
expt_id, block_id = expt_id.to(index_type), block_id.to(index_type)
|
162 |
+
start_m, start_z = start_m.to(index_type), start_z.to(index_type)
|
163 |
+
pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type)
|
164 |
+
# A pointers
|
165 |
+
offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
|
166 |
+
offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M)
|
167 |
+
X += start_z * stride_x_z
|
168 |
+
if GatherIndx is None:
|
169 |
+
X += start_m * stride_x_m
|
170 |
+
else:
|
171 |
+
GatherIndx += start_m
|
172 |
+
# no needs to bounds-check here because `offs_x_m` wraps around M dim
|
173 |
+
offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT
|
174 |
+
offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K)
|
175 |
+
XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k
|
176 |
+
|
177 |
+
# TODO: refactor if/else when triton front end improves
|
178 |
+
if is_w_microscaled:
|
179 |
+
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
|
180 |
+
tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling")
|
181 |
+
tl.static_assert(not is_x_microscaled)
|
182 |
+
# We have pack 2 fp4 values in a byte but we divide the dimension by 2
|
183 |
+
# when swizzling
|
184 |
+
W_K_DIVISOR: tl.constexpr = 1
|
185 |
+
W_K_MULTIPLIER: tl.constexpr = 2
|
186 |
+
W_N_DIVISOR: tl.constexpr = 4
|
187 |
+
else:
|
188 |
+
# We have pack 2 fp4 values in a byte
|
189 |
+
W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1
|
190 |
+
W_K_MULTIPLIER: tl.constexpr = 1
|
191 |
+
W_N_DIVISOR: tl.constexpr = 1
|
192 |
+
|
193 |
+
PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER
|
194 |
+
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR
|
195 |
+
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
|
196 |
+
|
197 |
+
WMxScale += expt_id * stride_w_mx_e
|
198 |
+
|
199 |
+
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
|
200 |
+
tl.static_assert(BLOCK_N % 128 == 0)
|
201 |
+
tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0)
|
202 |
+
PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4
|
203 |
+
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128
|
204 |
+
stride_scale_k: tl.constexpr = 1
|
205 |
+
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
|
206 |
+
n_warps: tl.constexpr = tl.extra.cuda.num_warps()
|
207 |
+
tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0)
|
208 |
+
tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0)
|
209 |
+
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32
|
210 |
+
SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32
|
211 |
+
stride_scale_k = stride_w_mx_k
|
212 |
+
else:
|
213 |
+
PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K
|
214 |
+
SCALE_BLOCK_N: tl.constexpr = BLOCK_N
|
215 |
+
stride_scale_k = stride_w_mx_k
|
216 |
+
offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N
|
217 |
+
offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N)
|
218 |
+
# K dimension must be the last dimension for the scales
|
219 |
+
offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK)
|
220 |
+
WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n
|
221 |
+
else:
|
222 |
+
WMxScalePtrs = None
|
223 |
+
offs_k_scale = None
|
224 |
+
W_K_DIVISOR: tl.constexpr = 1
|
225 |
+
W_K_MULTIPLIER: tl.constexpr = 1
|
226 |
+
W_N_DIVISOR: tl.constexpr = 1
|
227 |
+
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
|
228 |
+
PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N
|
229 |
+
|
230 |
+
# B pointers
|
231 |
+
offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W)
|
232 |
+
offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W)
|
233 |
+
|
234 |
+
if is_x_microscaled:
|
235 |
+
XMxScale += start_z.to(index_type) * stride_x_mx_z
|
236 |
+
if GatherIndx is None:
|
237 |
+
XMxScale += start_m * stride_x_mx_m
|
238 |
+
offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K)
|
239 |
+
XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k
|
240 |
+
else:
|
241 |
+
XMxScalePtrs = None
|
242 |
+
|
243 |
+
offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W)
|
244 |
+
W += expt_id * stride_w_e
|
245 |
+
WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n)
|
246 |
+
# compute output
|
247 |
+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
248 |
+
for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)):
|
249 |
+
if EVEN_K:
|
250 |
+
mask_k = tl.full([BLOCK_K], True, dtype=tl.int1)
|
251 |
+
mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1)
|
252 |
+
if is_w_microscaled and SWIZZLE_MX_SCALE is None:
|
253 |
+
mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1)
|
254 |
+
if is_x_microscaled:
|
255 |
+
mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1)
|
256 |
+
else:
|
257 |
+
mask_k = offs_k < k
|
258 |
+
mask_k_w = offs_w_k < ((k // W_K_DIVISOR) * W_K_MULTIPLIER)
|
259 |
+
if is_w_microscaled and SWIZZLE_MX_SCALE is None:
|
260 |
+
mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k
|
261 |
+
if is_x_microscaled:
|
262 |
+
mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k
|
263 |
+
|
264 |
+
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
|
265 |
+
w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER)
|
266 |
+
if is_w_microscaled:
|
267 |
+
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
|
268 |
+
w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
|
269 |
+
|
270 |
+
if is_x_microscaled:
|
271 |
+
x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :])
|
272 |
+
elif x_format == "fp16" or x_format == "bf16":
|
273 |
+
x_scales: tl.constexpr = None
|
274 |
+
else:
|
275 |
+
# Scale of 1 in E8M0 format
|
276 |
+
x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8)
|
277 |
+
|
278 |
+
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
|
279 |
+
w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs))
|
280 |
+
elif SWIZZLE_MX_SCALE == "HOPPER_SCALE":
|
281 |
+
# Handshake with the swizzling code
|
282 |
+
num_warps: tl.constexpr = tl.extra.cuda.num_warps()
|
283 |
+
w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps)
|
284 |
+
else:
|
285 |
+
w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :])
|
286 |
+
|
287 |
+
if SWIZZLE_MX_VALUE == "HOPPER_VALUE":
|
288 |
+
# Handshake with the swizzling code
|
289 |
+
tl.static_assert(x_format == "bf16")
|
290 |
+
tl.static_assert(w_format == "e2m1")
|
291 |
+
w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1)
|
292 |
+
tl.static_assert(w.dtype == tl.bfloat16)
|
293 |
+
acc = acc.trans()
|
294 |
+
x = x.trans()
|
295 |
+
# w = w.trans()
|
296 |
+
acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
|
297 |
+
acc = acc.trans()
|
298 |
+
else:
|
299 |
+
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True)
|
300 |
+
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
|
301 |
+
WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k
|
302 |
+
else:
|
303 |
+
WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k
|
304 |
+
if is_x_microscaled:
|
305 |
+
XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k
|
306 |
+
else:
|
307 |
+
acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
|
308 |
+
XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k
|
309 |
+
WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k
|
310 |
+
# bias + scale
|
311 |
+
offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M)
|
312 |
+
offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N)
|
313 |
+
mask_m = offs_m < M
|
314 |
+
mask_n = offs_y_n < N
|
315 |
+
if B is not None:
|
316 |
+
BPtrs = B + expt_id * stride_b_e + offs_y_n
|
317 |
+
if pid_k == 0:
|
318 |
+
bias = tl.load(BPtrs, mask=mask_n, other=0)
|
319 |
+
else:
|
320 |
+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
|
321 |
+
else:
|
322 |
+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
|
323 |
+
if Betas is not None:
|
324 |
+
betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0)
|
325 |
+
else:
|
326 |
+
betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
|
327 |
+
if Gammas is not None:
|
328 |
+
gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0)
|
329 |
+
else:
|
330 |
+
gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
|
331 |
+
# flexpoint
|
332 |
+
x_scale = load_scale(XScale)
|
333 |
+
if PER_BATCH_SCALE:
|
334 |
+
w_scale = load_scale(WScale + expt_id)
|
335 |
+
else:
|
336 |
+
w_scale = load_scale(WScale)
|
337 |
+
acc *= x_scale * w_scale
|
338 |
+
acc = acc + bias[None, :] * betas[:, None]
|
339 |
+
if out_alpha is not None:
|
340 |
+
acc *= out_alpha
|
341 |
+
if ACTIVATION_FN is not None:
|
342 |
+
out = ACTIVATION_FN(acc, *activation_fn_args)
|
343 |
+
tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
|
344 |
+
offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N)
|
345 |
+
mask_n = offs_y_n < yN
|
346 |
+
else:
|
347 |
+
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
|
348 |
+
out = acc
|
349 |
+
out *= gammas[:, None]
|
350 |
+
# write-back
|
351 |
+
Y += start_z.to(index_type) * stride_y_z
|
352 |
+
if WriteBackIndx is not None:
|
353 |
+
WriteBackIndx += start_m
|
354 |
+
dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1)
|
355 |
+
mask_m = mask_m & (dst_idx != -1)
|
356 |
+
offs_y_m = dst_idx
|
357 |
+
else:
|
358 |
+
Y += start_m * stride_y_m
|
359 |
+
offs_y_m = offs_m
|
360 |
+
|
361 |
+
YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n
|
362 |
+
mask = mask_m[:, None] & mask_n[None, :]
|
363 |
+
if is_out_microscaled:
|
364 |
+
MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE
|
365 |
+
N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE)
|
366 |
+
tl.static_assert(EPILOGUE_FN is not None)
|
367 |
+
out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args)
|
368 |
+
tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "")
|
369 |
+
offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N)
|
370 |
+
mask_n_scale = offs_y_n_scale < N_MX_BLOCK
|
371 |
+
YActualScale += start_z.to(index_type) * stride_y_mx_z
|
372 |
+
if WriteBackIndx is None:
|
373 |
+
YActualScale += start_m * stride_y_mx_m
|
374 |
+
YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
|
375 |
+
else:
|
376 |
+
YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n
|
377 |
+
tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :])
|
378 |
+
else:
|
379 |
+
out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF)
|
380 |
+
if EPILOGUE_FN is not None and not IS_EPILOGUE_DEQUANT_MXFP8:
|
381 |
+
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty)
|
382 |
+
tl.store(YPtrs, out, mask=mask)
|
383 |
+
|
384 |
+
|
385 |
+
# Imagine N_EXPTS_ACT = 4, n_final_rows = 5, and n_scratchpad_rows = 8.
|
386 |
+
# Also imagine scatter_indx.src_indx is:
|
387 |
+
# (number of active experts per final row)
|
388 |
+
# -1 -1 0 -1 1
|
389 |
+
# -1 2 -1 -1 1
|
390 |
+
# 1 3 -1 -1 2
|
391 |
+
# -1 4 5 6 3
|
392 |
+
# -1 -1 -1 -1 0 (this row is unused)
|
393 |
+
#
|
394 |
+
# Then, row 0 and 1 can be written directly to the final tensor.
|
395 |
+
# In this case, WriteBackIndx looks like:
|
396 |
+
# [0] = 0 : intermediate row 0 is written directly to final row 0
|
397 |
+
# [1] = 5+1=6 : scratchpad starts at offset 5
|
398 |
+
# [2] = 1 : intermediate row 2 is written directly to final row 1
|
399 |
+
# [3] = 5+3=8
|
400 |
+
# [4] = 5+4=9
|
401 |
+
# [5] = 5+5=10
|
402 |
+
# [6] = 5+6=11
|
403 |
+
# [7] = -1 : unused (there are only seven intermediate rows)
|
404 |
+
@triton.jit
|
405 |
+
def _compute_writeback_idx(
|
406 |
+
WriteBackIndx,
|
407 |
+
FinalizeScatterIdxs,
|
408 |
+
ScatterDstIndx, ScatterSrcIndx,
|
409 |
+
n_final_rows, n_scratchpad_rows,
|
410 |
+
BLOCK_M: tl.constexpr,
|
411 |
+
N_EXPTS_ACT: tl.constexpr,
|
412 |
+
):
|
413 |
+
tl.static_assert(N_EXPTS_ACT > 1)
|
414 |
+
|
415 |
+
pid_m = tl.program_id(0)
|
416 |
+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
417 |
+
mask_m = offs_m < n_scratchpad_rows
|
418 |
+
dst_idxs = tl.load(ScatterDstIndx + offs_m, mask=mask_m, other=-1)
|
419 |
+
# Load corresponding rows in ScatterSrcIndx.
|
420 |
+
mask = dst_idxs != -1
|
421 |
+
src_offs = (dst_idxs // N_EXPTS_ACT) * N_EXPTS_ACT
|
422 |
+
src_offs = src_offs[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :]
|
423 |
+
src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask[:, None], other=-1)
|
424 |
+
# Compute the number of actually active experts.
|
425 |
+
is_src_active = (src_idxs != -1).to(tl.int32)
|
426 |
+
has_one_active = tl.sum(is_src_active, axis=1) == 1
|
427 |
+
# Compute the writeback index.
|
428 |
+
wb_idx = tl.where(has_one_active, dst_idxs // N_EXPTS_ACT, n_final_rows + offs_m)
|
429 |
+
wb_idx = tl.where(mask, wb_idx, -1)
|
430 |
+
tl.store(WriteBackIndx + offs_m, wb_idx, mask=mask_m)
|
431 |
+
|
432 |
+
if pid_m >= ((n_final_rows + BLOCK_M - 1) // BLOCK_M):
|
433 |
+
return
|
434 |
+
|
435 |
+
mask_m = offs_m < n_final_rows
|
436 |
+
src_offs = offs_m[:, None] * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT)[None, :]
|
437 |
+
src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask_m[:, None], other=-1)
|
438 |
+
is_src_active = (src_idxs != -1).to(tl.int32)
|
439 |
+
num_src_active = tl.sum(is_src_active, axis=1)
|
440 |
+
|
441 |
+
need_finalize_scatter = mask_m & (num_src_active != 1)
|
442 |
+
finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32))
|
443 |
+
if finalize_scatter_count == 0:
|
444 |
+
return
|
445 |
+
pp_off = tl.atomic_add(FinalizeScatterIdxs + n_final_rows + n_scratchpad_rows, finalize_scatter_count)
|
446 |
+
|
447 |
+
# need_finalize_scatter = [1, 0, 0, 1, 1, 0, 1, 0, 1]
|
448 |
+
# arange = [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
449 |
+
arange = tl.arange(0, BLOCK_M)
|
450 |
+
# idxs = [0, _, _, 3, 4, _, 6, _, 8]
|
451 |
+
last = BLOCK_M - 1
|
452 |
+
idxs = tl.where(need_finalize_scatter, arange, last)
|
453 |
+
# idxs = [0, 3, 4, 6, 8, _, _, _, _]
|
454 |
+
idxs = tl.sort(idxs)
|
455 |
+
# r = offs_m
|
456 |
+
# d = [r[0], r[3], r[4], r[6], r[8], r[-1], r[-1], r[-1], r[-1]]
|
457 |
+
d = tl.gather(offs_m, idxs, axis=0)
|
458 |
+
s = tl.gather(src_idxs, idxs.expand_dims(1).broadcast_to(src_idxs.shape), axis=0)
|
459 |
+
# store destination indices
|
460 |
+
Ptr = FinalizeScatterIdxs + pp_off
|
461 |
+
tl.store(Ptr + arange, d, mask=arange < finalize_scatter_count)
|
462 |
+
# store src indices
|
463 |
+
Ptr = FinalizeScatterIdxs + n_final_rows + pp_off * N_EXPTS_ACT
|
464 |
+
tl.store(Ptr + N_EXPTS_ACT * arange[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :], s, mask=(arange < finalize_scatter_count)[:, None])
|
build/torch-universal/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
ADDED
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# isort: off
|
2 |
+
# fmt: off
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
import triton.language as tl
|
6 |
+
from triton_kernels import target_info
|
7 |
+
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
|
8 |
+
from triton_kernels.numerics_details.flexpoint import (
|
9 |
+
float_to_flex,
|
10 |
+
load_scale,
|
11 |
+
nan_propagating_absmax_reduce,
|
12 |
+
compute_scale,
|
13 |
+
)
|
14 |
+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
15 |
+
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
|
16 |
+
|
17 |
+
|
18 |
+
@tl.constexpr_function
|
19 |
+
def cuda_capability_geq(major, minor):
|
20 |
+
return target_info.cuda_capability_geq(major, minor)
|
21 |
+
|
22 |
+
@tl.constexpr_function
|
23 |
+
def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype:
|
24 |
+
if isinstance(tensor_or_desc, tl.tensor):
|
25 |
+
return tensor_or_desc.dtype.element_ty
|
26 |
+
elif isinstance(tensor_or_desc, tl.tensor_descriptor):
|
27 |
+
return tensor_or_desc.dtype
|
28 |
+
else:
|
29 |
+
raise ValueError(f"Invalid type: {type(tensor_or_desc)}")
|
30 |
+
|
31 |
+
|
32 |
+
@triton.jit
|
33 |
+
def _tma_load_2d(desc, offs, transpose: tl.constexpr = False):
|
34 |
+
if len(desc.shape) == 2 and len(offs) == 3:
|
35 |
+
tl.device_assert(offs[0] == 0, "2D TMA load requires Z offset to be 0")
|
36 |
+
offs = offs[1:]
|
37 |
+
if transpose:
|
38 |
+
offs = offs[:-2] + [offs[-1], offs[-2]]
|
39 |
+
res = desc.load(offs)
|
40 |
+
res = tl.reshape(res, desc.block_shape[-2:])
|
41 |
+
if transpose:
|
42 |
+
res = tl.trans(res)
|
43 |
+
return res
|
44 |
+
|
45 |
+
|
46 |
+
# Helper function to recreate a TMA desc with the same fields, but with a new pointer and optional new shape.
|
47 |
+
@triton.jit
|
48 |
+
def _update_tensor_desc(desc, ptr, shape=None):
|
49 |
+
return tl.make_tensor_descriptor(
|
50 |
+
ptr,
|
51 |
+
shape=shape or desc.shape,
|
52 |
+
# last dim must be constexpr 1; reflecting the old descriptor drops the constexpr
|
53 |
+
strides=desc.strides[:-1] + [tl.constexpr(1)],
|
54 |
+
block_shape=desc.block_shape,
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
@triton.jit
|
59 |
+
def _load_tile_attrs(
|
60 |
+
tile_id, num_tiles, grid_m, grid_n, padding_m,
|
61 |
+
M, ExptData, ExptHist, ExptOffs,
|
62 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr,
|
63 |
+
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr):
|
64 |
+
# unpack and swizzle program ids
|
65 |
+
pid_emnk = tile_id
|
66 |
+
if XCD_SWIZZLE != 1:
|
67 |
+
pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE)
|
68 |
+
pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K)
|
69 |
+
pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K)
|
70 |
+
if SPLIT_K > 1:
|
71 |
+
pid_k = pid_mnk % SPLIT_K
|
72 |
+
pid_mn = pid_mnk // SPLIT_K
|
73 |
+
else:
|
74 |
+
pid_k: tl.constexpr = 0
|
75 |
+
pid_mn = pid_mnk
|
76 |
+
pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M)
|
77 |
+
|
78 |
+
# unpack expert data
|
79 |
+
if ExptData is None:
|
80 |
+
tl.static_assert(M is not None)
|
81 |
+
expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1
|
82 |
+
else:
|
83 |
+
tl.static_assert(M is None)
|
84 |
+
expt_data = tl.load(ExptData + pid_m)
|
85 |
+
expt_id = expt_data & 0x0000FFFF
|
86 |
+
block_id = expt_data >> 16
|
87 |
+
eM = tl.load(ExptHist + expt_id)
|
88 |
+
start_m = tl.load(ExptOffs + expt_id)
|
89 |
+
start_z = 0
|
90 |
+
|
91 |
+
off_m = BLOCK_M * block_id
|
92 |
+
off_n = BLOCK_N * pid_n
|
93 |
+
|
94 |
+
return expt_id, start_z, start_m, eM, off_m, off_n, pid_k
|
95 |
+
|
96 |
+
|
97 |
+
@triton.jit
|
98 |
+
def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
|
99 |
+
mask = mask & (offs < writeback_size)
|
100 |
+
offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1)
|
101 |
+
mask = offs != -1
|
102 |
+
return (offs, mask)
|
103 |
+
|
104 |
+
|
105 |
+
_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
|
106 |
+
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
|
107 |
+
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
|
108 |
+
def _p_matmul_ogs(
|
109 |
+
Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
|
110 |
+
YExpectedScale, YActualScale, YChecksumScale,
|
111 |
+
stride_y_mx_z, stride_y_mx_m, stride_y_mx_n,
|
112 |
+
X, XPtr, stride_x_z, stride_x_m, stride_x_k,
|
113 |
+
XScale,
|
114 |
+
XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k,
|
115 |
+
W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr,
|
116 |
+
WScale,
|
117 |
+
MxScale, stride_mx_e, stride_mx_k, stride_mx_n,
|
118 |
+
B, stride_b_e, # Bias
|
119 |
+
NRows, M, N, K, # shapes
|
120 |
+
# expt data
|
121 |
+
Betas, Gammas,
|
122 |
+
GatherIndx,
|
123 |
+
ScatterSrcIndx, num_idxs,
|
124 |
+
WriteBackIndx, writeback_size,
|
125 |
+
ExptHist, ExptOffs, ExptOffsSum, ExptData,
|
126 |
+
# true grid size
|
127 |
+
batch_size, grid_m, grid_n,
|
128 |
+
# Out scale
|
129 |
+
out_alpha,
|
130 |
+
# fused activation function
|
131 |
+
ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr,
|
132 |
+
# epilogue transform
|
133 |
+
EPILOGUE_FN: tl.constexpr, epilogue_fn_args,
|
134 |
+
# MoE config
|
135 |
+
N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
|
136 |
+
# precision config
|
137 |
+
MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr,
|
138 |
+
FLEXPOINT_SATURATE_INF: tl.constexpr,
|
139 |
+
PER_BATCH_SCALE: tl.constexpr,
|
140 |
+
# optimization config
|
141 |
+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
142 |
+
GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr,
|
143 |
+
# NYI: Must be None
|
144 |
+
SWIZZLE_MX_VALUE: tl.constexpr,
|
145 |
+
# One of ["BLACKWELL", None]
|
146 |
+
SWIZZLE_MX_SCALE: tl.constexpr,
|
147 |
+
EPILOGUE_SUBTILE: tl.constexpr,
|
148 |
+
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr,
|
149 |
+
W_CACHE_MODIFIER: tl.constexpr,
|
150 |
+
NUM_SMS: tl.constexpr,
|
151 |
+
TOKENS_PER_EXPT_FOR_ANNOTATION=None,
|
152 |
+
UPCAST_INDICES:tl.constexpr=False,
|
153 |
+
DISABLE_Y_TMA: tl.constexpr=False,
|
154 |
+
SWAP_XW: tl.constexpr = False,
|
155 |
+
IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False):
|
156 |
+
tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
|
157 |
+
Y = Out # Y is passed for the purposes of annotation; replace it with Out
|
158 |
+
|
159 |
+
is_microscaled_format: tl.constexpr = MxScale is not None
|
160 |
+
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
|
161 |
+
if is_microscaled_format:
|
162 |
+
w_type: tl.constexpr = get_dtype(W)
|
163 |
+
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),
|
164 |
+
"mx_weight_ptr must be uint8")
|
165 |
+
tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8")
|
166 |
+
tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR")
|
167 |
+
tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales")
|
168 |
+
|
169 |
+
# We have pack 2 fp4 values in a byte
|
170 |
+
W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1
|
171 |
+
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR
|
172 |
+
MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR
|
173 |
+
else:
|
174 |
+
W_PACK_DIVISOR: tl.constexpr = 1
|
175 |
+
MX_SCALE_BLOCK_K: tl.constexpr = 1
|
176 |
+
PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K
|
177 |
+
tl.static_assert(SWIZZLE_MX_SCALE is None)
|
178 |
+
|
179 |
+
if ExptOffsSum is not None:
|
180 |
+
# Determine how much padding there is on the expert data. This allows us to
|
181 |
+
# know the true grid size and avoid processing padding tiles.
|
182 |
+
padding_m = grid_m - tl.load(ExptOffsSum)
|
183 |
+
else:
|
184 |
+
padding_m: tl.constexpr = 0
|
185 |
+
|
186 |
+
HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
|
187 |
+
index_type: tl.constexpr = tl.int64
|
188 |
+
|
189 |
+
if EPILOGUE_SUBTILE is None:
|
190 |
+
SUBTILE_FACTOR: tl.constexpr = 1
|
191 |
+
else:
|
192 |
+
SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
|
193 |
+
EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
|
194 |
+
OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
|
195 |
+
yN = N // ACTIVATION_REDUCTION_N
|
196 |
+
|
197 |
+
# set masked out rows to 0
|
198 |
+
if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1:
|
199 |
+
# Iterate with reversed pids so that later pids will get more tiles if the number of
|
200 |
+
# tiles isn't evenly divisible by the number of SMs.
|
201 |
+
# The main loop after this iterates in the forward direction such that earlier
|
202 |
+
# pids get more tiles if the number of tiles isn't evenly divisible.
|
203 |
+
# This helps balance the work across the SMs.
|
204 |
+
for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS):
|
205 |
+
pid_k = pid_mnk % SPLIT_K
|
206 |
+
pid_mn = pid_mnk // SPLIT_K
|
207 |
+
pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M)
|
208 |
+
|
209 |
+
z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32)
|
210 |
+
offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0])
|
211 |
+
offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1])
|
212 |
+
src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0)
|
213 |
+
YPtrs = Y + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n
|
214 |
+
mask_n = offs_n < yN
|
215 |
+
mask = (src_idx == -1)[:, None] & mask_n[None, :]
|
216 |
+
tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask)
|
217 |
+
|
218 |
+
USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None
|
219 |
+
|
220 |
+
USE_GATHER_TMA: tl.constexpr = GatherIndx is not None and cuda_capability_geq(10, 0)
|
221 |
+
X_USE_LOAD_TMA: tl.constexpr = GatherIndx is None and isinstance(X, tl.tensor_descriptor)
|
222 |
+
USE_SCATTER_TMA: tl.constexpr = (cuda_capability_geq(10, 0) and HAS_FUSED_SCATTER) and not DISABLE_Y_TMA
|
223 |
+
INT_MAX: tl.constexpr = 2147483647
|
224 |
+
|
225 |
+
if USE_SCATTER_TMA:
|
226 |
+
y_desc = tl.make_tensor_descriptor(
|
227 |
+
Y,
|
228 |
+
# No masking on the M dimension because we manually mask by setting indices to INT_MAX
|
229 |
+
shape=[INT_MAX - 1, yN],
|
230 |
+
strides=[stride_y_m, stride_y_n],
|
231 |
+
block_shape=[1, OUT_BLOCK_N],
|
232 |
+
)
|
233 |
+
|
234 |
+
k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K)
|
235 |
+
num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K
|
236 |
+
|
237 |
+
# If true, do not share loop-carried variables between the prologue and the
|
238 |
+
# epilogue to enable better pipelining with mmav5
|
239 |
+
INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0)
|
240 |
+
|
241 |
+
# start negative; will be incremented at the top of the loop
|
242 |
+
if INDEPENDENT_EPILOGUE:
|
243 |
+
tile_id1 = tl.program_id(0) - NUM_SMS
|
244 |
+
|
245 |
+
# Keep track of local max for updating flexpoint scales.
|
246 |
+
THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads()
|
247 |
+
local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32)
|
248 |
+
|
249 |
+
DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256
|
250 |
+
# Enable warp specialization when all loads are TMA loads.
|
251 |
+
WARP_SPECIALIZE: tl.constexpr = (USE_GATHER_TMA or X_USE_LOAD_TMA)
|
252 |
+
|
253 |
+
for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=WARP_SPECIALIZE):
|
254 |
+
expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs(
|
255 |
+
tile_id, num_tiles, grid_m, grid_n, padding_m,
|
256 |
+
M, ExptData, ExptHist, ExptOffs,
|
257 |
+
BLOCK_M, BLOCK_N, SPLIT_K,
|
258 |
+
GROUP_M, XCD_SWIZZLE)
|
259 |
+
|
260 |
+
# Base pointers and offsets.
|
261 |
+
if not USE_GATHER_TMA and not X_USE_LOAD_TMA:
|
262 |
+
XBase = X + start_z.to(index_type) * stride_x_z
|
263 |
+
offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k
|
264 |
+
if SPLIT_K > 1:
|
265 |
+
offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k
|
266 |
+
|
267 |
+
if not X_USE_LOAD_TMA:
|
268 |
+
offs_m = off_m + tl.arange(0, BLOCK_M)
|
269 |
+
mask_m = offs_m < (M if M is not None else eM)
|
270 |
+
if USE_GATHER_TMA:
|
271 |
+
# Mask the gather indices and load -1 instead. TMA will handle OOB accesses.
|
272 |
+
if ExptData is None:
|
273 |
+
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m)
|
274 |
+
# Bump rows to account for the Z offset.
|
275 |
+
offs_x_m += start_z * (stride_x_z // stride_x_m)
|
276 |
+
offs_x_m = tl.where(mask_m, offs_x_m, -1)
|
277 |
+
else:
|
278 |
+
offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m,
|
279 |
+
mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT
|
280 |
+
else:
|
281 |
+
if M is not None:
|
282 |
+
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M)
|
283 |
+
else:
|
284 |
+
offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M)
|
285 |
+
# no needs to bounds-check here because `offs_m` wraps around M dim
|
286 |
+
offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT
|
287 |
+
offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m
|
288 |
+
|
289 |
+
acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32)
|
290 |
+
for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER):
|
291 |
+
off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K
|
292 |
+
off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K
|
293 |
+
off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K
|
294 |
+
|
295 |
+
if USE_GATHER_TMA:
|
296 |
+
x = X.gather(offs_x_m, off_k)
|
297 |
+
elif X_USE_LOAD_TMA:
|
298 |
+
x = _tma_load_2d(X, [start_z, start_m + off_m, off_k])
|
299 |
+
else:
|
300 |
+
XPtrs = XBase + offs_x_m + offs_x_k
|
301 |
+
XBase += BLOCK_K * SPLIT_K * stride_x_k
|
302 |
+
mask_k = tl.arange(0, BLOCK_K) < K - off_k
|
303 |
+
if EVEN_K:
|
304 |
+
if SPLIT_K > 1:
|
305 |
+
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
|
306 |
+
else:
|
307 |
+
x = tl.load(XPtrs)
|
308 |
+
else:
|
309 |
+
x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0)
|
310 |
+
|
311 |
+
w = _tma_load_2d(W, [expt_id, off_k_w, off_n], transpose=W_TRANSPOSE)
|
312 |
+
|
313 |
+
if is_microscaled_format:
|
314 |
+
x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype)
|
315 |
+
mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype)
|
316 |
+
if x_format == "fp16" or x_format == "bf16":
|
317 |
+
x_scales: tl.constexpr = None
|
318 |
+
else:
|
319 |
+
x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8)
|
320 |
+
if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE":
|
321 |
+
flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128)
|
322 |
+
w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0])
|
323 |
+
w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1]))
|
324 |
+
w_scales = unswizzle_mx_scale_bw(w_scales)
|
325 |
+
else:
|
326 |
+
w_scales = _tma_load_2d(MxScale, [expt_id, off_k_mx, off_n]).T
|
327 |
+
if SWAP_XW:
|
328 |
+
acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True)
|
329 |
+
else:
|
330 |
+
acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True)
|
331 |
+
else:
|
332 |
+
if SWAP_XW:
|
333 |
+
acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
|
334 |
+
else:
|
335 |
+
acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32)
|
336 |
+
|
337 |
+
if INDEPENDENT_EPILOGUE:
|
338 |
+
tile_id1 += NUM_SMS
|
339 |
+
expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs(
|
340 |
+
tile_id1, num_tiles, grid_m, grid_n, padding_m,
|
341 |
+
M, ExptData, ExptHist, ExptOffs,
|
342 |
+
BLOCK_M, BLOCK_N, SPLIT_K,
|
343 |
+
GROUP_M, XCD_SWIZZLE)
|
344 |
+
else:
|
345 |
+
tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM
|
346 |
+
off_m1, off_n1, pid_k1 = off_m, off_n, pid_k
|
347 |
+
|
348 |
+
# Determine output row offsets and mask
|
349 |
+
offs_m = off_m1 + tl.arange(0, BLOCK_M)
|
350 |
+
mask_m = offs_m < M if M is not None else offs_m < eM1
|
351 |
+
if HAS_FUSED_SCATTER:
|
352 |
+
offs_y_m, mask_m = _load_writeback_idx_and_mask(
|
353 |
+
WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m)
|
354 |
+
# Later, mask out the acc for computing flexpoint scales.
|
355 |
+
MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
|
356 |
+
|
357 |
+
if USE_SCATTER_TMA and SPLIT_K > 1:
|
358 |
+
# Compute the split k offset in number of rows, and add it to offs_y_m.
|
359 |
+
# This allows us to write to the correct slice in the output tensor while using
|
360 |
+
# a 2D TMA scatter.
|
361 |
+
tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m))
|
362 |
+
split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m)
|
363 |
+
offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m)
|
364 |
+
else:
|
365 |
+
offs_y_m = start_m1 + offs_m
|
366 |
+
|
367 |
+
if USE_GATHER_TMA:
|
368 |
+
MASK_ACC: tl.constexpr = False
|
369 |
+
else:
|
370 |
+
# Later, mask out the acc for computing flexpoint scales.
|
371 |
+
MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE
|
372 |
+
|
373 |
+
# TMA is faster on Blackwell if a SWAP_XW transpose is not needed, or when we need registers to mask out the acc.
|
374 |
+
# Contrary to the SWAP_XW case, having a fused activation function tends to make TMA faster again.
|
375 |
+
# For the ideal optimization, this would depend on what the activation function is doing.
|
376 |
+
Y_USE_TMA: tl.constexpr = (MASK_ACC or cuda_capability_geq(10, 0)) and not (
|
377 |
+
DISABLE_Y_TMA or (SWAP_XW and ACTIVATION_FN is None))
|
378 |
+
|
379 |
+
YBase = Y + start_z1.to(index_type) * stride_y_z + start_m1.to(index_type) * stride_y_m
|
380 |
+
if USE_SCATTER_TMA:
|
381 |
+
if ExptData is None: # start_z1 may change; update the descriptor
|
382 |
+
y_desc = _update_tensor_desc(y_desc, YBase)
|
383 |
+
elif not HAS_FUSED_SCATTER and Y_USE_TMA:
|
384 |
+
y_desc = tl.make_tensor_descriptor(
|
385 |
+
YBase + pid_k1.to(index_type) * stride_y_k,
|
386 |
+
shape=[M if M is not None else eM1, yN],
|
387 |
+
strides=[stride_y_m, stride_y_n],
|
388 |
+
block_shape=[BLOCK_M, OUT_BLOCK_N],
|
389 |
+
)
|
390 |
+
|
391 |
+
# bias + scale
|
392 |
+
offs_y_n = off_n1 + tl.arange(0, BLOCK_N)
|
393 |
+
mask_n = offs_y_n < N
|
394 |
+
if B is not None:
|
395 |
+
BPtrs = B + expt_id1 * stride_b_e + offs_y_n
|
396 |
+
if pid_k1 == 0:
|
397 |
+
bias = tl.load(BPtrs, mask=mask_n, other=0)
|
398 |
+
else:
|
399 |
+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
|
400 |
+
else:
|
401 |
+
bias = tl.full([BLOCK_N], 0, dtype=tl.float32)
|
402 |
+
if Betas is not None:
|
403 |
+
betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0)
|
404 |
+
else:
|
405 |
+
betas = tl.full([BLOCK_M], 1, dtype=tl.float32)
|
406 |
+
if Gammas is not None:
|
407 |
+
gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0)
|
408 |
+
else:
|
409 |
+
gammas = tl.full([BLOCK_M], 1, dtype=tl.float32)
|
410 |
+
x_scale = load_scale(XScale)
|
411 |
+
if PER_BATCH_SCALE:
|
412 |
+
w_scale = load_scale(WScale + expt_id1)
|
413 |
+
else:
|
414 |
+
w_scale = load_scale(WScale)
|
415 |
+
|
416 |
+
accs = (acc,)
|
417 |
+
biases = (bias,)
|
418 |
+
|
419 |
+
if SUBTILE_FACTOR >= 2:
|
420 |
+
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
|
421 |
+
accs = (acc0, acc1)
|
422 |
+
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
|
423 |
+
biases = (bias0, bias1)
|
424 |
+
|
425 |
+
if SUBTILE_FACTOR >= 4:
|
426 |
+
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
|
427 |
+
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
|
428 |
+
accs = (acc00, acc01, acc10, acc11)
|
429 |
+
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
|
430 |
+
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
|
431 |
+
biases = (bias00, bias01, bias10, bias11)
|
432 |
+
|
433 |
+
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
|
434 |
+
tl.static_assert(len(accs) == SUBTILE_FACTOR)
|
435 |
+
|
436 |
+
for a_i in tl.static_range(len(accs)):
|
437 |
+
acc_tile = accs[a_i]
|
438 |
+
acc_tile *= x_scale * w_scale
|
439 |
+
|
440 |
+
if SWAP_XW:
|
441 |
+
acc_tile = acc_tile.T
|
442 |
+
|
443 |
+
acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None]
|
444 |
+
if out_alpha is not None:
|
445 |
+
acc_tile *= out_alpha
|
446 |
+
|
447 |
+
if ACTIVATION_FN is not None:
|
448 |
+
out = ACTIVATION_FN(acc_tile, *activation_fn_args)
|
449 |
+
tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})")
|
450 |
+
else:
|
451 |
+
tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided")
|
452 |
+
out = acc_tile
|
453 |
+
|
454 |
+
out *= gammas[:, None]
|
455 |
+
|
456 |
+
if MASK_ACC:
|
457 |
+
out = tl.where(mask_m[:, None], out, 0.0)
|
458 |
+
# Flexpoint
|
459 |
+
out_view = tl.reshape(
|
460 |
+
out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True)
|
461 |
+
local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0))
|
462 |
+
out = float_to_flex(
|
463 |
+
out, YExpectedScale,
|
464 |
+
None, # ActualScale: local absmax is tracked and updated after the loop
|
465 |
+
YChecksumScale,
|
466 |
+
None, # mask: out is manually masked to 0
|
467 |
+
Y, FLEXPOINT_SATURATE_INF)
|
468 |
+
if EPILOGUE_FN is not None:
|
469 |
+
out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=Y.dtype.element_ty, pid=len(accs)*tile_id1 + a_i)
|
470 |
+
|
471 |
+
out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N
|
472 |
+
if USE_SCATTER_TMA:
|
473 |
+
# Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
|
474 |
+
# there shouldn't be any other negative values.
|
475 |
+
offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True)
|
476 |
+
y_desc.scatter(out.to(Y.dtype.element_ty), offs_y_m, out_off_n)
|
477 |
+
elif not HAS_FUSED_SCATTER and Y_USE_TMA:
|
478 |
+
y_desc.store([off_m1, out_off_n], out.to(Y.dtype.element_ty))
|
479 |
+
else:
|
480 |
+
offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N)
|
481 |
+
mask_n = offs_y_n < yN
|
482 |
+
|
483 |
+
YPtrs = Y + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n
|
484 |
+
mask = mask_m[:, None] & mask_n[None, :]
|
485 |
+
tl.store(YPtrs, out, mask=mask)
|
486 |
+
|
487 |
+
|
488 |
+
# Update the flexpoint scales
|
489 |
+
if YActualScale is not None:
|
490 |
+
tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), Y), sem="relaxed")
|
491 |
+
|
492 |
+
|
493 |
+
_per_device_alloc_fns = {}
|
494 |
+
def get_per_device_per_stream_alloc_fn(device):
|
495 |
+
if device not in _per_device_alloc_fns:
|
496 |
+
_per_stream_tensors = {}
|
497 |
+
def alloc_fn(size: int, alignment: int, stream):
|
498 |
+
assert alignment == 128
|
499 |
+
if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size:
|
500 |
+
_per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8)
|
501 |
+
_per_stream_tensors[stream].__hibernate__ = {"type": "ignore"}
|
502 |
+
return _per_stream_tensors[stream]
|
503 |
+
|
504 |
+
_per_device_alloc_fns[device] = alloc_fn
|
505 |
+
return _per_device_alloc_fns[device]
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# isort: off
|
2 |
+
# fmt: off
|
3 |
+
from dataclasses import dataclass
|
4 |
+
import triton
|
5 |
+
from triton_kernels.target_info import get_cdna_version
|
6 |
+
import torch
|
7 |
+
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class OptFlags:
|
12 |
+
block_m: int
|
13 |
+
block_n: int
|
14 |
+
block_k: int
|
15 |
+
num_warps: int
|
16 |
+
num_stages: int
|
17 |
+
group_m: int
|
18 |
+
xcd_swizzle: int
|
19 |
+
w_cache_modifier: str
|
20 |
+
split_k: int
|
21 |
+
fused_scatter: bool
|
22 |
+
is_persistent: bool
|
23 |
+
idle_sms: int
|
24 |
+
epilogue_subtile: int | None
|
25 |
+
arch: str
|
26 |
+
target_kernel_kwargs: dict
|
27 |
+
|
28 |
+
def __post_init__(self):
|
29 |
+
if self.fused_scatter and self.split_k != 1:
|
30 |
+
raise ValueError("Not supported")
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def make_default_opt_flags_amd(
|
35 |
+
out_dtype,
|
36 |
+
lhs_dtype,
|
37 |
+
rhs_dtype,
|
38 |
+
precision_config,
|
39 |
+
m,
|
40 |
+
n,
|
41 |
+
k,
|
42 |
+
routing_data,
|
43 |
+
can_use_persistent_tma,
|
44 |
+
can_use_fused_scatter,
|
45 |
+
enforce_bitwise_invariance,
|
46 |
+
epilogue_effective_itemsize,
|
47 |
+
constraints,
|
48 |
+
):
|
49 |
+
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
|
50 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
51 |
+
# tokens per expert
|
52 |
+
if routing_data is None:
|
53 |
+
tokens_per_expt = m
|
54 |
+
elif routing_data.expected_tokens_per_expt is None:
|
55 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
56 |
+
else:
|
57 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
58 |
+
|
59 |
+
is_cdna4 = get_cdna_version() == 4
|
60 |
+
# block_m
|
61 |
+
if constraints.get("block_m", None):
|
62 |
+
block_m = constraints["block_m"]
|
63 |
+
elif enforce_bitwise_invariance:
|
64 |
+
block_m = 256 if is_cdna4 else 128
|
65 |
+
elif tokens_per_expt >= 512 and n >= 2048:
|
66 |
+
block_m = 256 if is_cdna4 else 128
|
67 |
+
elif is_cdna4 and m >= 512:
|
68 |
+
block_m = 128
|
69 |
+
else:
|
70 |
+
block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64))
|
71 |
+
|
72 |
+
if routing_data is not None:
|
73 |
+
grid_m = routing_data.n_blocks(m, block_m)
|
74 |
+
else:
|
75 |
+
grid_m = triton.cdiv(m, block_m)
|
76 |
+
# group_m:
|
77 |
+
group_m = 4
|
78 |
+
# number of xcds
|
79 |
+
num_xcds = 8
|
80 |
+
xcd_swizzle = num_xcds
|
81 |
+
# block_nk:
|
82 |
+
block_n, block_k = opt_flags_amd.compute_block_nk(
|
83 |
+
n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config
|
84 |
+
)
|
85 |
+
# Replace block_k if provided in constraints.
|
86 |
+
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
|
87 |
+
if constraints.get("block_k", None) is not None:
|
88 |
+
block_k = constraints["block_k"]
|
89 |
+
is_persistent = constraints.get("is_persistent", False)
|
90 |
+
# split_k:
|
91 |
+
if constraints.get("split_k", None) is not None:
|
92 |
+
split_k = constraints["split_k"]
|
93 |
+
elif is_persistent or enforce_bitwise_invariance:
|
94 |
+
split_k = 1
|
95 |
+
else:
|
96 |
+
grid_size = grid_m * ((n + block_n - 1) // block_n)
|
97 |
+
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
|
98 |
+
split_k = max(1, n_cu // grid_size)
|
99 |
+
# w_cache_modifier:
|
100 |
+
w_cache_modifier = ".cg" if block_m <= 32 else None
|
101 |
+
# num_warps, num_stages
|
102 |
+
num_warps = 2 if (m is not None and m <= 16) else 8
|
103 |
+
num_stages = 2
|
104 |
+
# AMD-specific
|
105 |
+
target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
|
106 |
+
ret = OptFlags(
|
107 |
+
block_m=block_m,
|
108 |
+
block_n=block_n,
|
109 |
+
block_k=block_k,
|
110 |
+
num_warps=num_warps,
|
111 |
+
num_stages=num_stages,
|
112 |
+
group_m=group_m,
|
113 |
+
xcd_swizzle=xcd_swizzle,
|
114 |
+
w_cache_modifier=w_cache_modifier,
|
115 |
+
split_k=split_k,
|
116 |
+
fused_scatter=constraints.get('fused_scatter', False),
|
117 |
+
is_persistent=is_persistent,
|
118 |
+
idle_sms=0,
|
119 |
+
epilogue_subtile=constraints.get('epilogue_subtile', None),
|
120 |
+
arch=None,
|
121 |
+
target_kernel_kwargs=target_kernel_kwargs,
|
122 |
+
)
|
123 |
+
# check constraints
|
124 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
125 |
+
return ret
|
126 |
+
|
127 |
+
def make_default_opt_flags_nvidia(
|
128 |
+
out_dtype,
|
129 |
+
lhs_dtype,
|
130 |
+
rhs_dtype,
|
131 |
+
precision_config,
|
132 |
+
m,
|
133 |
+
n,
|
134 |
+
k,
|
135 |
+
routing_data,
|
136 |
+
can_use_persistent_tma,
|
137 |
+
can_use_fused_scatter,
|
138 |
+
enforce_bitwise_invariance,
|
139 |
+
epilogue_effective_itemsize,
|
140 |
+
constraints,
|
141 |
+
):
|
142 |
+
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms"]
|
143 |
+
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
|
144 |
+
# tokens per expert
|
145 |
+
if routing_data is None:
|
146 |
+
tokens_per_expt = m
|
147 |
+
elif routing_data.expected_tokens_per_expt is None:
|
148 |
+
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
|
149 |
+
else:
|
150 |
+
tokens_per_expt = routing_data.expected_tokens_per_expt
|
151 |
+
# pid swizzling
|
152 |
+
group_m = 8
|
153 |
+
xcd_swizzle = 1
|
154 |
+
# block_m
|
155 |
+
if constraints.get("block_m", None):
|
156 |
+
block_m = constraints["block_m"]
|
157 |
+
elif enforce_bitwise_invariance:
|
158 |
+
block_m = 128
|
159 |
+
else:
|
160 |
+
min_block_m = 64 if torch.cuda.get_device_capability()[0] == 10 else 16
|
161 |
+
block_m = max(min_block_m, min(triton.next_power_of_2(tokens_per_expt), 128))
|
162 |
+
# block n
|
163 |
+
arch = None
|
164 |
+
block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
|
165 |
+
# is_persistent
|
166 |
+
grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
|
167 |
+
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
|
168 |
+
tiles_per_sm = grid_size / n_sms
|
169 |
+
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
|
170 |
+
if constraints.get("is_persistent", None) is not None:
|
171 |
+
is_persistent = constraints["is_persistent"]
|
172 |
+
else:
|
173 |
+
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
|
174 |
+
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
|
175 |
+
# TEMP CHANGE
|
176 |
+
if precision_config.act_scale is not None or precision_config.out_scale is not None:
|
177 |
+
is_persistent = False
|
178 |
+
# block k
|
179 |
+
if constraints.get("block_k", None) is not None:
|
180 |
+
block_k = constraints["block_k"]
|
181 |
+
else:
|
182 |
+
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
|
183 |
+
# split_k
|
184 |
+
if constraints.get("split_k", None) is not None:
|
185 |
+
split_k = constraints["split_k"]
|
186 |
+
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
|
187 |
+
split_k = 1
|
188 |
+
else:
|
189 |
+
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
|
190 |
+
split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
|
191 |
+
if split_k > 1:
|
192 |
+
# With split_k, results are written in f32. Use that for the following computations.
|
193 |
+
out_dtype = torch.float32
|
194 |
+
compute_num_stages_args = (
|
195 |
+
precision_config,
|
196 |
+
is_persistent,
|
197 |
+
block_m,
|
198 |
+
block_n,
|
199 |
+
block_k,
|
200 |
+
out_dtype,
|
201 |
+
lhs_dtype,
|
202 |
+
rhs_dtype,
|
203 |
+
)
|
204 |
+
|
205 |
+
if constraints.get("epilogue_subtile", None) is not None:
|
206 |
+
subtiles_to_check = [constraints["epilogue_subtile"]]
|
207 |
+
else:
|
208 |
+
subtiles_to_check = [1, 2, 4]
|
209 |
+
num_stages = -1
|
210 |
+
for ep in subtiles_to_check:
|
211 |
+
ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
|
212 |
+
if ns > num_stages:
|
213 |
+
epilogue_subtile, num_stages = ep, ns
|
214 |
+
assert num_stages >= 1
|
215 |
+
if constraints.get("num_stages", None):
|
216 |
+
num_stages = constraints["num_stages"]
|
217 |
+
|
218 |
+
# fused scatter scratchpad
|
219 |
+
if constraints.get("fused_scatter", None) is not None:
|
220 |
+
fused_scatter = constraints["fused_scatter"]
|
221 |
+
else:
|
222 |
+
fused_scatter = can_use_fused_scatter and split_k == 1
|
223 |
+
# Handshake with the HBM swizzling
|
224 |
+
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
|
225 |
+
ret = OptFlags(
|
226 |
+
block_m=block_m,
|
227 |
+
block_n=block_n,
|
228 |
+
block_k=block_k,
|
229 |
+
num_warps=num_warps,
|
230 |
+
num_stages=num_stages,
|
231 |
+
group_m=group_m,
|
232 |
+
xcd_swizzle=xcd_swizzle,
|
233 |
+
w_cache_modifier=None,
|
234 |
+
split_k=split_k,
|
235 |
+
fused_scatter=fused_scatter,
|
236 |
+
is_persistent=is_persistent,
|
237 |
+
epilogue_subtile=epilogue_subtile,
|
238 |
+
arch=arch,
|
239 |
+
target_kernel_kwargs=dict(),
|
240 |
+
idle_sms=constraints.get("idle_sms", 0),
|
241 |
+
)
|
242 |
+
# check constraints
|
243 |
+
assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}"
|
244 |
+
return ret
|
245 |
+
|
246 |
+
# --------------
|
247 |
+
# User Interface
|
248 |
+
# --------------
|
249 |
+
|
250 |
+
_opt_flags_constraints: dict = dict()
|
251 |
+
_opt_flags: OptFlags | None = None
|
252 |
+
|
253 |
+
def update_opt_flags_constraints(constraints: dict[str, int]):
|
254 |
+
global _opt_flags_constraints
|
255 |
+
_opt_flags_constraints.update(constraints)
|
256 |
+
|
257 |
+
def reset_opt_flags_constraints():
|
258 |
+
global _opt_flags_constraints
|
259 |
+
_opt_flags_constraints = dict()
|
260 |
+
|
261 |
+
def set_opt_flags(opt_flags: OptFlags):
|
262 |
+
global _opt_flags
|
263 |
+
assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override"
|
264 |
+
assert not _opt_flags, "opt_flags already set; please reset to None first"
|
265 |
+
_opt_flags = opt_flags
|
266 |
+
|
267 |
+
class InapplicableConstraint(Exception):
|
268 |
+
pass
|
269 |
+
|
270 |
+
def make_opt_flags(
|
271 |
+
out_dtype,
|
272 |
+
lhs_dtype,
|
273 |
+
rhs_dtype,
|
274 |
+
precision_config,
|
275 |
+
m,
|
276 |
+
n,
|
277 |
+
k,
|
278 |
+
routing_data,
|
279 |
+
can_use_persistent_tma,
|
280 |
+
can_use_fused_scatter,
|
281 |
+
epilogue_effective_itemsize,
|
282 |
+
):
|
283 |
+
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
|
284 |
+
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
|
285 |
+
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
|
286 |
+
if _opt_flags is not None:
|
287 |
+
assert not _opt_flags_constraints
|
288 |
+
return _opt_flags
|
289 |
+
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
|
290 |
+
routing_data, can_use_persistent_tma, can_use_fused_scatter,
|
291 |
+
enforce_bitwise_invariance, epilogue_effective_itemsize,
|
292 |
+
_opt_flags_constraints]
|
293 |
+
backend = triton.runtime.driver.active.get_current_target().backend
|
294 |
+
if backend == "hip":
|
295 |
+
return make_default_opt_flags_amd(*args)
|
296 |
+
if backend == "cuda":
|
297 |
+
return make_default_opt_flags_nvidia(*args)
|
298 |
+
assert False
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
from triton_kernels.target_info import get_cdna_version
|
4 |
+
from triton_kernels.tensor import bitwidth
|
5 |
+
|
6 |
+
|
7 |
+
def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config):
|
8 |
+
lhs_width = bitwidth(lhs_dtype) / 8
|
9 |
+
rhs_width = bitwidth(rhs_dtype) / 8
|
10 |
+
|
11 |
+
# block_n:
|
12 |
+
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
|
13 |
+
if n is not None:
|
14 |
+
if n <= 128 and (n & (n - 1)) == 0:
|
15 |
+
block_n = n
|
16 |
+
else:
|
17 |
+
block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu)))
|
18 |
+
elif block_m > 64:
|
19 |
+
block_n = 256
|
20 |
+
else:
|
21 |
+
block_n = 128
|
22 |
+
|
23 |
+
if get_cdna_version() == 4 and block_m == 128:
|
24 |
+
block_n = 512
|
25 |
+
|
26 |
+
# block_k needs to match the cacheline size (128B)
|
27 |
+
block_k = int(128 // min(lhs_width, rhs_width))
|
28 |
+
|
29 |
+
# TODO: block_k = 128 seems to work better for now.
|
30 |
+
# perhaps due to increased number of k loops to pipeline
|
31 |
+
if precision_config.weight_scale is not None and get_cdna_version() != 4:
|
32 |
+
block_k = 128
|
33 |
+
return block_n, block_k
|
build/torch-universal/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
from triton_kernels import target_info
|
4 |
+
from triton_kernels.tensor import get_layout, bitwidth, FP4
|
5 |
+
from triton_kernels.tensor_details.layout import HopperMXScaleLayout
|
6 |
+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
7 |
+
|
8 |
+
|
9 |
+
def compute_grid_size(routing_data, m, n, block_m, block_n):
|
10 |
+
if routing_data is not None:
|
11 |
+
grid_m = routing_data.n_blocks(m, block_m)
|
12 |
+
else:
|
13 |
+
grid_m = triton.cdiv(m, block_m)
|
14 |
+
grid_n = (n + block_n - 1) // block_n
|
15 |
+
return grid_m * grid_n
|
16 |
+
|
17 |
+
|
18 |
+
def compute_block_n(n: int, arch, precision_config):
|
19 |
+
# block_n:
|
20 |
+
layout = get_layout(precision_config.weight_scale)
|
21 |
+
if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
|
22 |
+
return 128
|
23 |
+
elif precision_config.max_num_imprecise_acc is None and n > 128:
|
24 |
+
return 256
|
25 |
+
else:
|
26 |
+
return max(16, min(128, triton.next_power_of_2(n)))
|
27 |
+
|
28 |
+
|
29 |
+
def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
|
30 |
+
lhs_width = bitwidth(lhs_dtype)
|
31 |
+
rhs_width = bitwidth(rhs_dtype)
|
32 |
+
# block_k needs to match the cacheline size (1024 bits)
|
33 |
+
block_k = int(1024 // min(lhs_width, rhs_width))
|
34 |
+
has_native_mxfp = target_info.cuda_capability_geq(10, 0)
|
35 |
+
if rhs_width == 4 and not has_native_mxfp:
|
36 |
+
block_k = 128
|
37 |
+
elif k is not None:
|
38 |
+
block_k = max(32, min(triton.next_power_of_2(k), block_k))
|
39 |
+
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
|
40 |
+
if has_native_mxfp and is_persistent and has_mx_weight_scale:
|
41 |
+
block_k = min(block_k, 128)
|
42 |
+
return block_k
|
43 |
+
|
44 |
+
|
45 |
+
def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
|
46 |
+
device_props = torch.cuda.get_device_properties(0)
|
47 |
+
n_sms = device_props.multi_processor_count
|
48 |
+
split_k = n_sms // grid_size
|
49 |
+
if k is not None:
|
50 |
+
# avoid split_k for small k
|
51 |
+
num_block_k = triton.cdiv(k, block_k)
|
52 |
+
split_k = min(split_k, num_block_k // 4)
|
53 |
+
split_k = max(split_k, 1)
|
54 |
+
return split_k
|
55 |
+
|
56 |
+
|
57 |
+
def compute_num_warps(block_m, block_n, precision_config):
|
58 |
+
layout = get_layout(precision_config.weight_scale)
|
59 |
+
if isinstance(layout, HopperMXScaleLayout):
|
60 |
+
return layout.num_warps
|
61 |
+
return max(block_m * block_n // 4096, 4)
|
62 |
+
|
63 |
+
|
64 |
+
def compute_num_stages(
|
65 |
+
precision_config,
|
66 |
+
is_persistent,
|
67 |
+
block_m,
|
68 |
+
block_n,
|
69 |
+
block_k,
|
70 |
+
out_dtype,
|
71 |
+
lhs_dtype,
|
72 |
+
rhs_dtype,
|
73 |
+
epilogue_subtile,
|
74 |
+
epilogue_effective_itemsize,
|
75 |
+
):
|
76 |
+
if precision_config.max_num_imprecise_acc is not None:
|
77 |
+
return 3
|
78 |
+
weight_size = bitwidth(rhs_dtype) / 8
|
79 |
+
stage_size = block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
|
80 |
+
device_props = torch.cuda.get_device_properties(0)
|
81 |
+
smem_capacity = device_props.shared_memory_per_block_optin
|
82 |
+
has_native_mxfp = target_info.cuda_capability_geq(10, 0)
|
83 |
+
if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None:
|
84 |
+
if rhs_dtype == FP4:
|
85 |
+
# 4-bit e2m1 weights are padded 2x
|
86 |
+
# https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
|
87 |
+
stage_size += block_k * block_n * weight_size
|
88 |
+
|
89 |
+
if is_persistent:
|
90 |
+
# Per-stage wait barrier
|
91 |
+
stage_size += 8
|
92 |
+
if target_info.cuda_capability_geq(10, 0):
|
93 |
+
acc_size = epilogue_effective_itemsize or out_dtype.itemsize
|
94 |
+
else:
|
95 |
+
acc_size = out_dtype.itemsize
|
96 |
+
if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
|
97 |
+
acc_block_n = block_n // epilogue_subtile
|
98 |
+
else:
|
99 |
+
acc_block_n = block_n
|
100 |
+
# pipelined TMA store local to global, or
|
101 |
+
# pipelined layout conversion before store of the accumulator
|
102 |
+
# note: layout conversion has some padding
|
103 |
+
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
|
104 |
+
if precision_config.weight_scale is not None:
|
105 |
+
# mx scales
|
106 |
+
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
|
107 |
+
elif has_native_mxfp:
|
108 |
+
# mx scales
|
109 |
+
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
|
110 |
+
num_stages = min(4, smem_capacity // int(stage_size))
|
111 |
+
return num_stages
|
build/torch-universal/triton_kernels/numerics.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
MAX_FINITE_FLOAT8E5 = 57344.0
|
5 |
+
MAX_FINITE_FLOAT8E4NV = 448.0
|
6 |
+
MAX_FINITE_FLOAT8E4B8 = 240.0
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass(frozen=True)
|
10 |
+
class BaseFlexData:
|
11 |
+
dtype: torch.dtype | None = None
|
12 |
+
|
13 |
+
def view(self, x: torch.Tensor):
|
14 |
+
if self.dtype is None:
|
15 |
+
return x
|
16 |
+
return x.view(self.dtype)
|
17 |
+
|
18 |
+
def reinterpret(self, x):
|
19 |
+
if self.dtype is None or x.dtype.itemsize > 1:
|
20 |
+
return x
|
21 |
+
return x.view(self.dtype)
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass(frozen=True)
|
25 |
+
class InFlexData(BaseFlexData):
|
26 |
+
scale: torch.Tensor | None = None
|
27 |
+
|
28 |
+
@property
|
29 |
+
def is_per_batch(self):
|
30 |
+
return False if self.scale is None else len(self.scale) > 1
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass(frozen=True)
|
34 |
+
class OutFlexData(BaseFlexData):
|
35 |
+
expected_scale: torch.Tensor | None = None
|
36 |
+
actual_scale: torch.Tensor | None = None
|
37 |
+
checksum_scale: torch.Tensor | None = None
|
38 |
+
|
39 |
+
def __iter__(self):
|
40 |
+
yield self.expected_scale
|
41 |
+
yield self.actual_scale
|
42 |
+
yield self.checksum_scale
|
build/torch-universal/triton_kernels/numerics_details/__init__.py
ADDED
File without changes
|
build/torch-universal/triton_kernels/numerics_details/flexpoint.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
2 |
+
from triton_kernels import target_info
|
3 |
+
import triton
|
4 |
+
import triton.language as tl
|
5 |
+
|
6 |
+
# -------------------------------
|
7 |
+
# Kernels stuff
|
8 |
+
# -------------------------------
|
9 |
+
|
10 |
+
TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5)
|
11 |
+
TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV)
|
12 |
+
TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8)
|
13 |
+
TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750)
|
14 |
+
TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0)
|
15 |
+
|
16 |
+
TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16
|
17 |
+
TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9
|
18 |
+
TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8
|
19 |
+
TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1
|
20 |
+
TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16
|
21 |
+
|
22 |
+
|
23 |
+
@triton.jit
|
24 |
+
def max_finite(dtype):
|
25 |
+
if dtype == tl.constexpr(tl.float8e5):
|
26 |
+
return TL_MAX_FINITE_FLOAT8E5
|
27 |
+
elif dtype == tl.constexpr(tl.float8e4nv):
|
28 |
+
return TL_MAX_FINITE_FLOAT8E4NV
|
29 |
+
elif dtype == tl.constexpr(tl.float8e4b8):
|
30 |
+
return TL_MAX_FINITE_FLOAT8E4B8
|
31 |
+
elif dtype == tl.constexpr(tl.float8e4b15):
|
32 |
+
return TL_MAX_FINITE_FLOAT8E4B15
|
33 |
+
elif dtype == tl.constexpr(tl.float16):
|
34 |
+
return TL_MAX_FINITE_FLOAT16
|
35 |
+
else:
|
36 |
+
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
37 |
+
|
38 |
+
|
39 |
+
@triton.jit
|
40 |
+
def rcp_max_finite(dtype):
|
41 |
+
if dtype == tl.constexpr(tl.float8e5):
|
42 |
+
return TL_RCP_MAX_FINITE_FLOAT8E5
|
43 |
+
elif dtype == tl.constexpr(tl.float8e4nv):
|
44 |
+
return TL_RCP_MAX_FINITE_FLOAT8E4NV
|
45 |
+
elif dtype == tl.constexpr(tl.float8e4b8):
|
46 |
+
return TL_RCP_MAX_FINITE_FLOAT8E4B8
|
47 |
+
elif dtype == tl.constexpr(tl.float8e4b15):
|
48 |
+
return TL_RCP_MAX_FINITE_FLOAT8E4B15
|
49 |
+
elif dtype == tl.constexpr(tl.float16):
|
50 |
+
return TL_RCP_MAX_FINITE_FLOAT16
|
51 |
+
else:
|
52 |
+
tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint")
|
53 |
+
|
54 |
+
|
55 |
+
@tl.constexpr_function
|
56 |
+
def cuda_capability_geq(major, minor):
|
57 |
+
return target_info.cuda_capability_geq(major, minor)
|
58 |
+
|
59 |
+
|
60 |
+
@triton.jit
|
61 |
+
def sm86_min_nan_xorsign_abs_f32(a, b):
|
62 |
+
"""Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
|
63 |
+
|
64 |
+
Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
|
65 |
+
NaN inputs are propagated to the output.
|
66 |
+
|
67 |
+
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
|
68 |
+
"""
|
69 |
+
tl.static_assert(cuda_capability_geq(8, 6), "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+")
|
70 |
+
tl.static_assert(a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs")
|
71 |
+
tl.static_assert(b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs")
|
72 |
+
|
73 |
+
return tl.inline_asm_elementwise(
|
74 |
+
"""{
|
75 |
+
min.NaN.xorsign.abs.f32 $0, $1, $2;
|
76 |
+
}""",
|
77 |
+
"=r,r,r",
|
78 |
+
[a, b],
|
79 |
+
dtype=tl.float32,
|
80 |
+
is_pure=True,
|
81 |
+
pack=1,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
@triton.jit
|
86 |
+
def sm86_max_nan_xorsign_abs_f32(a, b):
|
87 |
+
"""Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
|
88 |
+
|
89 |
+
Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
|
90 |
+
NaN inputs are propagated to the output.
|
91 |
+
|
92 |
+
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
|
93 |
+
"""
|
94 |
+
tl.static_assert(cuda_capability_geq(8, 6), "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+")
|
95 |
+
tl.static_assert(a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs")
|
96 |
+
tl.static_assert(b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs")
|
97 |
+
|
98 |
+
return tl.inline_asm_elementwise(
|
99 |
+
"""{
|
100 |
+
max.NaN.xorsign.abs.f32 $0, $1, $2;
|
101 |
+
}""",
|
102 |
+
"=r,r,r",
|
103 |
+
[a, b],
|
104 |
+
dtype=tl.float32,
|
105 |
+
is_pure=True,
|
106 |
+
pack=1,
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
@triton.jit
|
111 |
+
def load_scale(scale_ptr):
|
112 |
+
return 1.0 if scale_ptr is None else tl.load(scale_ptr)
|
113 |
+
|
114 |
+
|
115 |
+
@triton.jit
|
116 |
+
def flex_to_float(x, scale_ptr):
|
117 |
+
scale = load_scale(scale_ptr)
|
118 |
+
return x.to(tl.float32) * scale
|
119 |
+
|
120 |
+
|
121 |
+
@triton.jit
|
122 |
+
def clip(x, limit):
|
123 |
+
res = tl.minimum(x, limit)
|
124 |
+
res = tl.maximum(-limit, res)
|
125 |
+
return res
|
126 |
+
|
127 |
+
|
128 |
+
@triton.jit
|
129 |
+
def nan_propagating_absmax_reduce(x, axis=None):
|
130 |
+
if cuda_capability_geq(8, 6):
|
131 |
+
# abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
|
132 |
+
x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32)
|
133 |
+
# Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
|
134 |
+
x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
|
135 |
+
else:
|
136 |
+
# Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
|
137 |
+
masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF
|
138 |
+
x_absmax = tl.max(masked_abs_x, axis)
|
139 |
+
|
140 |
+
return x_absmax
|
141 |
+
|
142 |
+
|
143 |
+
@triton.jit
|
144 |
+
def compute_scale(x, Out):
|
145 |
+
x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True))
|
146 |
+
|
147 |
+
# atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
|
148 |
+
# We use integer minimum because NaNs are above +inf in integer representation.
|
149 |
+
x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True)
|
150 |
+
RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty)
|
151 |
+
return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30)
|
152 |
+
|
153 |
+
|
154 |
+
@triton.jit
|
155 |
+
def update_scale(x, scale_ptr, Out) -> None:
|
156 |
+
if scale_ptr is not None:
|
157 |
+
scale = compute_scale(x, Out)
|
158 |
+
tl.atomic_max(scale_ptr, scale, sem="relaxed")
|
159 |
+
|
160 |
+
|
161 |
+
@triton.jit
|
162 |
+
def float_to_flex(
|
163 |
+
x,
|
164 |
+
expected_scale_ptr_or_val,
|
165 |
+
actual_scale_ptr,
|
166 |
+
checksum_scale_ptr,
|
167 |
+
mask,
|
168 |
+
Out,
|
169 |
+
saturate_infs: tl.constexpr,
|
170 |
+
):
|
171 |
+
if expected_scale_ptr_or_val is not None:
|
172 |
+
if expected_scale_ptr_or_val.dtype.is_ptr():
|
173 |
+
invscale = 1.0 / tl.load(expected_scale_ptr_or_val)
|
174 |
+
else:
|
175 |
+
invscale = 1.0 / expected_scale_ptr_or_val
|
176 |
+
else:
|
177 |
+
invscale = 1.0
|
178 |
+
if checksum_scale_ptr is not None:
|
179 |
+
x_int32 = x.to(tl.int32, bitcast=True)
|
180 |
+
zero = tl.cast(0.0, tl.int32)
|
181 |
+
if mask is not None:
|
182 |
+
x_int32 = tl.where(mask, x_int32, zero)
|
183 |
+
checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0)
|
184 |
+
tl.atomic_add(checksum_scale_ptr, checksum_local)
|
185 |
+
if mask is not None:
|
186 |
+
if actual_scale_ptr is not None:
|
187 |
+
x = tl.where(mask, x, 0.0)
|
188 |
+
update_scale(x, actual_scale_ptr, Out)
|
189 |
+
x = x * invscale
|
190 |
+
# if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
|
191 |
+
if expected_scale_ptr_or_val is not None:
|
192 |
+
if saturate_infs:
|
193 |
+
CLIP_VALUE = max_finite(Out.dtype.element_ty)
|
194 |
+
x = clip(x, CLIP_VALUE)
|
195 |
+
return x
|
build/torch-universal/triton_kernels/numerics_details/mxfp.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# isort: off
|
2 |
+
# fmt: off
|
3 |
+
from enum import Enum
|
4 |
+
import triton
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
|
8 |
+
from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE
|
9 |
+
|
10 |
+
# -----------------------------------------------------------------------------
|
11 |
+
# Dequantization / Quantization Utilities
|
12 |
+
# -----------------------------------------------------------------------------
|
13 |
+
|
14 |
+
|
15 |
+
class DequantScaleRoundingMode(Enum):
|
16 |
+
ROUND_UP = 0
|
17 |
+
ROUND_DOWN = 1
|
18 |
+
|
19 |
+
|
20 |
+
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
|
21 |
+
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
|
22 |
+
"""
|
23 |
+
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
|
24 |
+
|
25 |
+
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
|
26 |
+
Note that this means the k_dim of the tensor will be half of the logical k_dim.
|
27 |
+
|
28 |
+
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
|
29 |
+
in their respective formats.
|
30 |
+
"""
|
31 |
+
ndim = src_tensor.ndim
|
32 |
+
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
33 |
+
axis = axis if axis >= 0 else axis + ndim
|
34 |
+
# downcast
|
35 |
+
src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1)
|
36 |
+
is_fp4 = out_quant_type == torch.uint8
|
37 |
+
is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2)
|
38 |
+
assert is_fp4 or is_fp8
|
39 |
+
divisor = 2 if is_fp4 else 1
|
40 |
+
L = src_tensor.shape[-1]
|
41 |
+
if is_fp4:
|
42 |
+
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
|
43 |
+
out_shape = src_tensor.shape[:-1] + (L // divisor, )
|
44 |
+
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
|
45 |
+
|
46 |
+
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
|
47 |
+
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
|
48 |
+
|
49 |
+
if src_tensor.numel() > 0:
|
50 |
+
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
|
51 |
+
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
|
52 |
+
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
|
53 |
+
|
54 |
+
BLOCK_OUT_DIM = 128
|
55 |
+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
|
56 |
+
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
|
57 |
+
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
|
58 |
+
|
59 |
+
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
|
60 |
+
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
|
61 |
+
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
|
62 |
+
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
|
63 |
+
|
64 |
+
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
|
65 |
+
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
|
66 |
+
return out_quant_tensor, out_scale
|
67 |
+
|
68 |
+
|
69 |
+
def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int):
|
70 |
+
"""
|
71 |
+
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
|
72 |
+
|
73 |
+
The function assumes that the tensors were quantized along the given axis.
|
74 |
+
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
|
75 |
+
launches the Triton upcast kernel, and then unpermutes back to the original order.
|
76 |
+
"""
|
77 |
+
ndim = tensor.ndim
|
78 |
+
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
79 |
+
axis = axis if axis >= 0 else axis + ndim
|
80 |
+
assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. "
|
81 |
+
f"Got {tensor.ndim=} and {scale.ndim=}")
|
82 |
+
# dtype checks
|
83 |
+
assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \
|
84 |
+
f"Invalid tensor dtype {tensor.dtype=}"
|
85 |
+
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
|
86 |
+
assert dtype in (torch.float16, torch.bfloat16), f"Invalid output dtype {dtype=}"
|
87 |
+
# upcast
|
88 |
+
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
|
89 |
+
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
|
90 |
+
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
|
91 |
+
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device)
|
92 |
+
reshaped_out = out.view(-1, out.shape[-1])
|
93 |
+
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
|
94 |
+
reshaped_scale = scale.view(-1, scale.shape[-1])
|
95 |
+
BLOCK_OUT_DIM = 128
|
96 |
+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
|
97 |
+
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
|
98 |
+
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
|
99 |
+
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
|
100 |
+
*reshaped_scale.stride(), reshaped_tensor,
|
101 |
+
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
|
102 |
+
BLOCK_QUANT_DIM, num_warps=8)
|
103 |
+
out = out.transpose(axis, scale.ndim - 1).contiguous()
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
# ------------
|
108 |
+
|
109 |
+
|
110 |
+
def right_shift_unsigned(x, shift):
|
111 |
+
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
|
112 |
+
return (x >> shift) & ((1 << (32 - shift)) - 1)
|
113 |
+
|
114 |
+
|
115 |
+
def get_max_quant_val(dtype: torch.dtype):
|
116 |
+
d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0}
|
117 |
+
assert dtype in d
|
118 |
+
return d[dtype]
|
119 |
+
|
120 |
+
|
121 |
+
def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
|
122 |
+
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
|
123 |
+
"""
|
124 |
+
Converts the src tensor to the output format specified by out_quant_type.
|
125 |
+
axis: The axis along which the tensors are contiguous and quantization is applied.
|
126 |
+
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
out_quant_tensor: Quantized tensor in mx format.
|
130 |
+
• For mxfp8, the output has the same shape as src_tensor.
|
131 |
+
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
|
132 |
+
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
|
133 |
+
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
|
134 |
+
where L is the original length along that axis.
|
135 |
+
"""
|
136 |
+
# This should probably be packed into its own tiny class
|
137 |
+
ndim = src_tensor.ndim
|
138 |
+
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
139 |
+
assert src_tensor.dtype in {torch.float32, torch.bfloat16,
|
140 |
+
torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}"
|
141 |
+
|
142 |
+
axis = axis if axis >= 0 else axis + ndim
|
143 |
+
is_fp4 = out_quant_type == torch.uint8
|
144 |
+
is_fp8 = "float8" in str(out_quant_type)
|
145 |
+
assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}"
|
146 |
+
|
147 |
+
device = src_tensor.device
|
148 |
+
|
149 |
+
# For mxfp4 conversion, we assume the contiguous axis length is even.
|
150 |
+
if is_fp4:
|
151 |
+
axis_shape = src_tensor.size(axis)
|
152 |
+
assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even."
|
153 |
+
|
154 |
+
# Permute the tensor so that the contiguous axis becomes the last dimension.
|
155 |
+
src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32)
|
156 |
+
axis_shape = src.shape[-1]
|
157 |
+
|
158 |
+
# Pad the axis to be divisible by 32, in case it is not.
|
159 |
+
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
|
160 |
+
pad_amount = next_multiple - axis_shape
|
161 |
+
padded_src = F.pad(src, (0, pad_amount))
|
162 |
+
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
|
163 |
+
padded_axis_shape = padded_src.size(-1) # now divisible by 32
|
164 |
+
|
165 |
+
# --- Compute per-group maximums for scale ---
|
166 |
+
# Set padded entries to -1 so they don’t affect the max.
|
167 |
+
abs_f = torch.abs(padded_src)
|
168 |
+
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
|
169 |
+
# Reshape the last dimension into groups of 32.
|
170 |
+
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
|
171 |
+
abs_groups = abs_f.view(*new_shape)
|
172 |
+
# Compute maximum along the group dimension (of size 32).
|
173 |
+
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
|
174 |
+
|
175 |
+
# Choose a max quantization value depending on type.
|
176 |
+
max_quant_val = get_max_quant_val(out_quant_type)
|
177 |
+
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
|
178 |
+
|
179 |
+
# Convert to int to round the FP32 scale, prior to quantization!
|
180 |
+
ds_int = dequant_scale.view(torch.int32)
|
181 |
+
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
|
182 |
+
ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000
|
183 |
+
else:
|
184 |
+
ds_int_rounded = ds_int & 0x7F800000
|
185 |
+
# Reinterpret back as float32.
|
186 |
+
dequant_scale_rounded = ds_int_rounded.view(torch.float32)
|
187 |
+
|
188 |
+
# Compute the quantization scale.
|
189 |
+
quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded)
|
190 |
+
|
191 |
+
# Quantize the tensor
|
192 |
+
orig_padded_shape = padded_src.shape
|
193 |
+
padded_src_groups = padded_src.view(*new_shape)
|
194 |
+
quant_tensor = padded_src_groups * quant_scale
|
195 |
+
# Reshape back to the original shape and trim padding
|
196 |
+
quant_tensor = quant_tensor.view(orig_padded_shape)
|
197 |
+
quant_tensor = quant_tensor[..., :axis_shape]
|
198 |
+
|
199 |
+
# Finally, convert the quantized tensor to the target format
|
200 |
+
if is_fp8:
|
201 |
+
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
|
202 |
+
quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val)
|
203 |
+
out_weight = quant_tensor.to(out_quant_type)
|
204 |
+
else:
|
205 |
+
assert is_fp4, f"Invalid output quantization type {out_quant_type}"
|
206 |
+
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
|
207 |
+
# First, reinterpret the quantized tensor bits.
|
208 |
+
q_int = quant_tensor.contiguous().view(torch.int32)
|
209 |
+
# Extract sign, exponent, and mantissa.
|
210 |
+
signs = q_int & 0x80000000
|
211 |
+
exponents = right_shift_unsigned(q_int, 23) & 0xFF
|
212 |
+
mantissas = q_int & 0x7FFFFF
|
213 |
+
|
214 |
+
E8_BIAS = 127
|
215 |
+
E2_BIAS = 1
|
216 |
+
# Adjust mantissas for subnormals.
|
217 |
+
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
|
218 |
+
(E8_BIAS - exponents - 1), mantissas)
|
219 |
+
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
|
220 |
+
e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1)
|
221 |
+
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))
|
222 |
+
e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape)
|
223 |
+
|
224 |
+
# Pack pairs of 4-bit values along the last dimension.
|
225 |
+
e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2)
|
226 |
+
evens = e2m1_value[..., 0]
|
227 |
+
odds = e2m1_value[..., 1]
|
228 |
+
out_weight = evens | (odds << 4) # shape: (..., axis_shape//2)
|
229 |
+
|
230 |
+
# --- Process and output the scale ---
|
231 |
+
dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1)
|
232 |
+
dq_scale = dq_scale.squeeze(-1)
|
233 |
+
out_weight = out_weight.transpose(axis, src_tensor.ndim - 1)
|
234 |
+
dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1)
|
235 |
+
return out_weight, dq_scale
|
236 |
+
|
237 |
+
|
238 |
+
def cvt_e2m1_to_fp32(input_tensor):
|
239 |
+
assert input_tensor.dtype == torch.uint8
|
240 |
+
|
241 |
+
input_tensor = input_tensor.to(torch.int32)
|
242 |
+
evens = input_tensor & 0xF
|
243 |
+
odds = (input_tensor >> 4) & 0xF
|
244 |
+
|
245 |
+
vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6]
|
246 |
+
outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device)
|
247 |
+
outputs = torch.cat([outputs, -outputs])
|
248 |
+
|
249 |
+
even_floats = outputs[evens]
|
250 |
+
odd_floats = outputs[odds]
|
251 |
+
output_tensor = torch.stack([even_floats, odd_floats], dim=-1)
|
252 |
+
output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1)
|
253 |
+
return output_tensor
|
254 |
+
|
255 |
+
|
256 |
+
def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int):
|
257 |
+
"""
|
258 |
+
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
|
259 |
+
axis: The axis along which dequantization is applied.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
out_weight: Tensor in the target format.
|
263 |
+
"""
|
264 |
+
|
265 |
+
ndim = tensor.ndim
|
266 |
+
assert -ndim <= axis < ndim, f"Invalid axis {axis=}"
|
267 |
+
is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2
|
268 |
+
assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}"
|
269 |
+
|
270 |
+
# Permute the tensor and scale so that the quantization axis becomes the last dimension
|
271 |
+
axis = axis if axis >= 0 else axis + ndim
|
272 |
+
scale = scale.transpose(axis, scale.ndim - 1)
|
273 |
+
tensor = tensor.transpose(axis, tensor.ndim - 1)
|
274 |
+
|
275 |
+
dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32
|
276 |
+
if tensor.dtype == torch.uint8:
|
277 |
+
fp32_tensor = cvt_e2m1_to_fp32(tensor)
|
278 |
+
else:
|
279 |
+
fp32_tensor = tensor.to(torch.float32)
|
280 |
+
|
281 |
+
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
|
282 |
+
axis_shape = fp32_tensor.size(-1)
|
283 |
+
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
|
284 |
+
pad_size = padded_axis_shape - axis_shape
|
285 |
+
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
|
286 |
+
|
287 |
+
new_axis_shape = padded_tensor.shape[-1]
|
288 |
+
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
|
289 |
+
padded_tensor = padded_tensor.view(*new_shape)
|
290 |
+
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
|
291 |
+
out_padded = padded_tensor * dq_scale_padded
|
292 |
+
|
293 |
+
# Flatten back and remove the padded tail
|
294 |
+
out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape)
|
295 |
+
out_tensor = out_padded[..., :axis_shape]
|
296 |
+
|
297 |
+
out_tensor = out_tensor.to(target_dtype).contiguous()
|
298 |
+
out_tensor = out_tensor.transpose(axis, tensor.ndim - 1)
|
299 |
+
|
300 |
+
return out_tensor
|
301 |
+
|
302 |
+
|
303 |
+
dequantize_mxfp8_fn = _dequantize_mxfp8_fn
|
build/torch-universal/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
|
4 |
+
# fmt: off
|
5 |
+
|
6 |
+
|
7 |
+
MXFP_BLOCK_SIZE = tl.constexpr(32)
|
8 |
+
|
9 |
+
|
10 |
+
@triton.jit
|
11 |
+
def _get_max_quant_val(dtype: tl.constexpr):
|
12 |
+
if dtype == tl.uint8:
|
13 |
+
return 6.0
|
14 |
+
elif dtype == tl.float8e5:
|
15 |
+
return 57344.0
|
16 |
+
elif dtype == tl.float8e4nv:
|
17 |
+
return 448.0
|
18 |
+
else:
|
19 |
+
tl.static_assert(False, f"Invalid {dtype=}")
|
20 |
+
|
21 |
+
@triton.jit
|
22 |
+
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
|
23 |
+
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
|
24 |
+
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
25 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
|
26 |
+
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
|
27 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
|
28 |
+
|
29 |
+
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
|
30 |
+
f32_tensor = src_tensor.to(tl.float32)
|
31 |
+
abs_tensor = tl.abs(f32_tensor)
|
32 |
+
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
|
33 |
+
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
34 |
+
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
|
35 |
+
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
|
36 |
+
if DEQUANT_SCALE_ROUNDING_MODE == 0:
|
37 |
+
# DequantScaleRoundingMode.ROUND_UP
|
38 |
+
# compute 2 ** ceil(log2(dequant_scale))
|
39 |
+
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
|
40 |
+
# A corner case: exponent is 0xFF that will overflow but that's already
|
41 |
+
# NaN so assume we don't care.
|
42 |
+
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
|
43 |
+
else:
|
44 |
+
# DequantScaleRoundingMode.ROUND_DOWN
|
45 |
+
# compute 2 ** floor(log2(dequant_scale))
|
46 |
+
assert DEQUANT_SCALE_ROUNDING_MODE == 1
|
47 |
+
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
|
48 |
+
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
|
49 |
+
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
|
50 |
+
|
51 |
+
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
52 |
+
quant_tensor = f32_tensor * quant_scale
|
53 |
+
|
54 |
+
# Reshape the tensors after scaling
|
55 |
+
quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
56 |
+
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
|
57 |
+
quant_tensor = tl.where(valid_src_mask, quant_tensor, 0)
|
58 |
+
dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE])
|
59 |
+
|
60 |
+
# First, we simply extract the exponent part of the scales and store the result
|
61 |
+
dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8)
|
62 |
+
# Now we must convert the tensors to the mx format.
|
63 |
+
if is_fp8:
|
64 |
+
out_tensor = quant_tensor.to(mx_tensor_dtype)
|
65 |
+
else:
|
66 |
+
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
|
67 |
+
signs = quant_tensor & 0x80000000
|
68 |
+
exponents = (quant_tensor >> 23) & 0xFF
|
69 |
+
mantissas = (quant_tensor & 0x7FFFFF)
|
70 |
+
|
71 |
+
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
|
72 |
+
E8_BIAS = 127
|
73 |
+
E2_BIAS = 1
|
74 |
+
# Move implicit bit 1 at the beginning to mantissa for denormals
|
75 |
+
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
|
76 |
+
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
|
77 |
+
|
78 |
+
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
|
79 |
+
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
|
80 |
+
|
81 |
+
# Combine sign, exponent, and mantissa, while saturating
|
82 |
+
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
|
83 |
+
e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7)
|
84 |
+
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
|
85 |
+
|
86 |
+
e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
|
87 |
+
evens, odds = tl.split(e2m1_value)
|
88 |
+
out_tensor = evens | (odds << 4)
|
89 |
+
|
90 |
+
return out_tensor, dequant_scale_exponent
|
91 |
+
|
92 |
+
@triton.jit
|
93 |
+
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
|
94 |
+
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
|
95 |
+
src_ptr, stride_src_outer, stride_src_quant,
|
96 |
+
outer_dim, quant_dim,
|
97 |
+
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
|
98 |
+
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
|
99 |
+
|
100 |
+
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
|
101 |
+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
|
102 |
+
|
103 |
+
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
104 |
+
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
105 |
+
tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5),
|
106 |
+
f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.")
|
107 |
+
|
108 |
+
src_dtype: tl.constexpr = src_ptr.dtype.element_ty
|
109 |
+
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8")
|
110 |
+
tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16")
|
111 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
112 |
+
|
113 |
+
outer_block = tl.program_id(0).to(tl.int64)
|
114 |
+
quant_block = tl.program_id(1).to(tl.int64)
|
115 |
+
|
116 |
+
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
117 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
118 |
+
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
119 |
+
|
120 |
+
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
121 |
+
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
122 |
+
start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
123 |
+
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
124 |
+
|
125 |
+
src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer
|
126 |
+
mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer
|
127 |
+
mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer
|
128 |
+
|
129 |
+
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
130 |
+
offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
131 |
+
offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
132 |
+
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
133 |
+
|
134 |
+
mask_src_quant = start_src_quant + offs_src_quant < quant_dim
|
135 |
+
mask_n = start_out + offs_outer < outer_dim
|
136 |
+
full_mask_src = mask_src_quant & mask_n
|
137 |
+
|
138 |
+
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
139 |
+
full_mask_mxt = mask_mxt_quant & mask_n
|
140 |
+
|
141 |
+
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
142 |
+
full_scale_mask = scale_mask_k & mask_n
|
143 |
+
|
144 |
+
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer
|
145 |
+
mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer
|
146 |
+
mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer
|
147 |
+
src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src)
|
148 |
+
|
149 |
+
out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype,
|
150 |
+
DEQUANT_SCALE_ROUNDING_MODE)
|
151 |
+
|
152 |
+
tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask)
|
153 |
+
tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt)
|
154 |
+
|
155 |
+
|
156 |
+
@triton.jit(repr=lambda _: "_dequantize_mxfp8")
|
157 |
+
def _dequantize_mxfp8_fn(input, mask, pid=None):
|
158 |
+
return _compute_quant_and_scale(input, mask, tl.float8e4nv)
|
build/torch-universal/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
|
4 |
+
|
5 |
+
|
6 |
+
# fmt: off
|
7 |
+
@triton.jit
|
8 |
+
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
|
9 |
+
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
|
10 |
+
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
|
11 |
+
|
12 |
+
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
|
13 |
+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
|
14 |
+
# uint8 signifies two fp4 e2m1 values packed into a single byte
|
15 |
+
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
|
16 |
+
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
|
17 |
+
tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16)
|
18 |
+
tl.static_assert(
|
19 |
+
mx_tensor_dtype == tl.uint8
|
20 |
+
or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype),
|
21 |
+
"mx_tensor_ptr must be uint8 or float8 or dst_dtype")
|
22 |
+
tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8")
|
23 |
+
|
24 |
+
# Determine if we are dealing with fp8 types.
|
25 |
+
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
|
26 |
+
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
|
27 |
+
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
|
28 |
+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
|
29 |
+
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
|
30 |
+
|
31 |
+
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
|
32 |
+
outer_block = tl.program_id(0).to(tl.int64)
|
33 |
+
quant_block = tl.program_id(1).to(tl.int64)
|
34 |
+
|
35 |
+
start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR
|
36 |
+
start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM
|
37 |
+
start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE
|
38 |
+
start_out = outer_block * BLOCK_SIZE_OUT_DIM
|
39 |
+
|
40 |
+
mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer
|
41 |
+
mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer
|
42 |
+
out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant
|
43 |
+
|
44 |
+
# Compute offsets and masks.
|
45 |
+
offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64)
|
46 |
+
offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64)
|
47 |
+
offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64)
|
48 |
+
offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64)
|
49 |
+
|
50 |
+
mask_outer = start_out + offs_outer < outer_dim
|
51 |
+
mask_out_quant = start_out_quant + offs_out_quant < quant_dim
|
52 |
+
full_mask_out = mask_out_quant & mask_outer
|
53 |
+
|
54 |
+
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
|
55 |
+
full_mask_src = mask_src_quant & mask_outer
|
56 |
+
|
57 |
+
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
|
58 |
+
full_scale_mask = mask_scale & mask_outer
|
59 |
+
|
60 |
+
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
|
61 |
+
scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer
|
62 |
+
out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer
|
63 |
+
|
64 |
+
# Load the packed tensor and scale.
|
65 |
+
tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src)
|
66 |
+
scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask)
|
67 |
+
|
68 |
+
# Upcast the scale to the destination type.
|
69 |
+
if dst_dtype == tl.bfloat16:
|
70 |
+
dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True)
|
71 |
+
else:
|
72 |
+
tl.static_assert(dst_dtype == tl.float16)
|
73 |
+
dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
|
74 |
+
dst_scale = dst_scale.to(tl.float16)
|
75 |
+
|
76 |
+
# Now upcast the tensor.
|
77 |
+
if is_fp8:
|
78 |
+
dst_tensor = tensor.to(dst_dtype)
|
79 |
+
if tensor.dtype == tl.float8e5:
|
80 |
+
from_e_bits: tl.constexpr = 5
|
81 |
+
from_m_bits: tl.constexpr = 2
|
82 |
+
to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5
|
83 |
+
to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
|
84 |
+
|
85 |
+
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
|
86 |
+
non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits
|
87 |
+
non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits
|
88 |
+
dst_tensor = tl.where(
|
89 |
+
(tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src,
|
90 |
+
(dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True),
|
91 |
+
dst_tensor,
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
assert is_fp4
|
95 |
+
dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15
|
96 |
+
dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800
|
97 |
+
dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10
|
98 |
+
# e2m1
|
99 |
+
em0 = tensor & 0x07
|
100 |
+
em1 = tensor & 0x70
|
101 |
+
x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12)
|
102 |
+
x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8)
|
103 |
+
# Three cases:
|
104 |
+
# 1) x is normal and non-zero: Correct bias
|
105 |
+
x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0)
|
106 |
+
x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1)
|
107 |
+
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
|
108 |
+
x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0)
|
109 |
+
x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1)
|
110 |
+
# 3) x is zero, do nothing
|
111 |
+
dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True)
|
112 |
+
|
113 |
+
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
|
114 |
+
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
|
115 |
+
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
|
116 |
+
scale = scale.reshape(dst_scale.shape)
|
117 |
+
|
118 |
+
out_tensor = dst_tensor * dst_scale
|
119 |
+
# Correct any NaNs encoded via the scale.
|
120 |
+
out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor)
|
121 |
+
out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM])
|
122 |
+
tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out)
|
build/torch-universal/triton_kernels/proton_opts.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# proton options
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
_launch_metadata_allow_sync = None
|
6 |
+
|
7 |
+
|
8 |
+
def launch_metadata_allow_sync():
|
9 |
+
global _launch_metadata_allow_sync
|
10 |
+
if _launch_metadata_allow_sync is None:
|
11 |
+
_launch_metadata_allow_sync = not (os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1")
|
12 |
+
return _launch_metadata_allow_sync
|
13 |
+
|
14 |
+
|
15 |
+
def set_launch_metadata_allow_sync(allow_sync: bool):
|
16 |
+
global _launch_metadata_allow_sync
|
17 |
+
_launch_metadata_allow_sync = allow_sync
|
build/torch-universal/triton_kernels/reduction_details/reduce_bitmatrix.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
|
6 |
+
@triton.jit
|
7 |
+
def vpopc(x):
|
8 |
+
"""
|
9 |
+
Vertical popcount
|
10 |
+
Input x : uint32[..., N]
|
11 |
+
Output y : uint32[..., 32]
|
12 |
+
semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
|
13 |
+
credits: @apgoucher
|
14 |
+
"""
|
15 |
+
|
16 |
+
tl.static_assert(x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers")
|
17 |
+
|
18 |
+
BLOCK_N: tl.constexpr = x.shape[-1] # summation axis
|
19 |
+
BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches
|
20 |
+
if BLOCK_N >= 8:
|
21 |
+
sa1: tl.constexpr = 8
|
22 |
+
else:
|
23 |
+
sa1: tl.constexpr = BLOCK_N
|
24 |
+
# create 8-way sums in 4-bit fields:
|
25 |
+
y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1])
|
26 |
+
y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111
|
27 |
+
y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4]
|
28 |
+
if BLOCK_N >= 128:
|
29 |
+
sa2: tl.constexpr = 16
|
30 |
+
else:
|
31 |
+
sa2: tl.constexpr = BLOCK_N // sa1
|
32 |
+
# create 128-way sums in 8-bit fields:
|
33 |
+
y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4])
|
34 |
+
y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0f0f0f0f
|
35 |
+
y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
|
36 |
+
sa3: tl.constexpr = BLOCK_N // (sa1 * sa2)
|
37 |
+
# create N-way sums in 32-bit fields:
|
38 |
+
y = tl.reshape(y, [BATCHES, 1, sa3, 8])
|
39 |
+
y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000ff
|
40 |
+
y = tl.sum(y, 2) # [BATCHES, 4, 8]
|
41 |
+
y = tl.reshape(y, x.shape[:-1] + [32])
|
42 |
+
return y
|
43 |
+
|
44 |
+
|
45 |
+
@triton.jit
|
46 |
+
def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr):
|
47 |
+
pid = tl.program_id(0)
|
48 |
+
offs = pid * BLOCK + tl.arange(0, BLOCK)
|
49 |
+
tl.store(Ret + offs, 0)
|
50 |
+
|
51 |
+
|
52 |
+
@triton.jit
|
53 |
+
def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix
|
54 |
+
Ret, Partials, stride_pm: tl.constexpr, stride_pn, shape_pn, # outputs
|
55 |
+
BLOCK_MM: tl.constexpr, BLOCK_M: tl.constexpr):
|
56 |
+
|
57 |
+
tl.static_assert(BLOCK_MM % BLOCK_M == 0)
|
58 |
+
TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M
|
59 |
+
if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr():
|
60 |
+
shape_bm = tl.load(shape_bm)
|
61 |
+
pid_m = tl.program_id(0)
|
62 |
+
pid_n = tl.program_id(1)
|
63 |
+
offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM)
|
64 |
+
offs_n = pid_n * 32 + tl.arange(0, 32)
|
65 |
+
n_rows = shape_bm
|
66 |
+
bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0)
|
67 |
+
bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M])
|
68 |
+
ret = vpopc(bits) # [TILE_SIZE, 32]
|
69 |
+
|
70 |
+
offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE)
|
71 |
+
|
72 |
+
tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed")
|
73 |
+
tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret)
|
74 |
+
|
75 |
+
|
76 |
+
def clear_sums(n_cols, device, MEMSET_BLOCK=512):
|
77 |
+
cdiv = triton.cdiv
|
78 |
+
blocks = cdiv(n_cols, MEMSET_BLOCK)
|
79 |
+
out_ret = torch.empty((blocks * MEMSET_BLOCK, ), device=device, dtype=torch.int32)
|
80 |
+
_sum_bitmatrix_memset[(blocks, )](out_ret, MEMSET_BLOCK)
|
81 |
+
return out_ret
|
82 |
+
|
83 |
+
|
84 |
+
def sum_bitmatrix_rows(x, out_ret, partials_block_size=None):
|
85 |
+
assert partials_block_size is not None
|
86 |
+
cdiv = triton.cdiv
|
87 |
+
PARTIALS_BLOCK_M = partials_block_size
|
88 |
+
n_rows, n_cols = x.shape
|
89 |
+
n_rows_max = x.shape_max[0]
|
90 |
+
assert out_ret.shape == (n_cols, )
|
91 |
+
|
92 |
+
TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M)
|
93 |
+
BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE
|
94 |
+
|
95 |
+
pids_x = cdiv(n_rows_max, BLOCK_MM)
|
96 |
+
pids_y = cdiv(n_cols, 32)
|
97 |
+
out_partials = torch.empty((pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32)
|
98 |
+
out_partials = torch.transpose(out_partials, 0, 1)
|
99 |
+
|
100 |
+
# output tensors
|
101 |
+
_sum_bitmatrix_rows[(pids_x, pids_y)](
|
102 |
+
x.storage.data, n_rows, x.stride(0), x.stride(1), # input
|
103 |
+
out_ret, # output [final reduction]
|
104 |
+
out_partials, out_partials.stride(0), out_partials.stride(1),
|
105 |
+
out_partials.shape[1], # output [partial reductions]
|
106 |
+
BLOCK_M=PARTIALS_BLOCK_M, BLOCK_MM=BLOCK_MM, # constants
|
107 |
+
num_warps=8)
|
108 |
+
|
109 |
+
out_partials = out_partials[:cdiv(n_rows_max, PARTIALS_BLOCK_M), :]
|
110 |
+
|
111 |
+
return out_ret, out_partials
|
build/torch-universal/triton_kernels/routing.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from .routing_details._routing_compute import _combined_routing_compute
|
5 |
+
from .routing_details._routing_compute import _combined_routing_memset
|
6 |
+
from .routing_details._routing_compute import _routing_clear_bitmatrix
|
7 |
+
from .routing_details._expt_data import _expt_data_memset
|
8 |
+
from .routing_details._expt_data import _expt_data_compute
|
9 |
+
from .target_info import is_hip
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class GatherIndx:
|
14 |
+
"""
|
15 |
+
Indices for an operation that performs:
|
16 |
+
Y = X[src_idx, :]
|
17 |
+
"""
|
18 |
+
# array such that `dst_idx[src_idx] = arange(0, N)`
|
19 |
+
src_indx: torch.Tensor
|
20 |
+
dst_indx: torch.Tensor
|
21 |
+
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ScatterIndx:
|
25 |
+
"""
|
26 |
+
Indices for an operation that performs:
|
27 |
+
Y[dst_idx, :] = X
|
28 |
+
"""
|
29 |
+
# array such that `dst_idx[src_idx] = arange(0, N)`
|
30 |
+
src_indx: torch.Tensor
|
31 |
+
dst_indx: torch.Tensor
|
32 |
+
|
33 |
+
|
34 |
+
@dataclass
|
35 |
+
class ExptData:
|
36 |
+
# hist[i] is the number of tokens routed to expert i
|
37 |
+
hist: torch.Tensor
|
38 |
+
# token_offs_raw[i] is the offset of the first token routed
|
39 |
+
# to expert i in an expert-sorted array
|
40 |
+
token_offs_raw: torch.Tensor
|
41 |
+
# token_offs_pad[block][i] is the offset of the first token routed
|
42 |
+
# to expert i in an expert-sorted array, assuming histogram
|
43 |
+
# rounded to the next multiple of `block`
|
44 |
+
token_offs_pad: dict[int, torch.Tensor]
|
45 |
+
# block_id_map[block] contain one value for each `pid`` launched by
|
46 |
+
# the matrix multiplication kernel launched with BLOCK_M=block:
|
47 |
+
# - the value is -1 if the `pid` has no work to do
|
48 |
+
# - otherwise, the value is two int16 (packed as an int32) that
|
49 |
+
# correspond respectively to (1) the expert assigned to
|
50 |
+
# the tokens processed by this pid; (2) the block assigned to the
|
51 |
+
# tokens processed by this pid (think `pid_m` in a regular matmul)
|
52 |
+
# see `test_routing.py` for a reference implementation and more details
|
53 |
+
block_pid_map: dict[int, torch.Tensor]
|
54 |
+
|
55 |
+
def __post_init__(self):
|
56 |
+
if self.hist is not None:
|
57 |
+
assert self.hist.dtype == torch.int32
|
58 |
+
if self.token_offs_raw is not None:
|
59 |
+
assert self.token_offs_raw.dtype == torch.int32
|
60 |
+
if self.token_offs_pad is not None:
|
61 |
+
for v in self.token_offs_pad.values():
|
62 |
+
assert v.dtype == torch.int32
|
63 |
+
if self.block_pid_map is not None:
|
64 |
+
for v in self.block_pid_map.values():
|
65 |
+
assert v.dtype == torch.int32
|
66 |
+
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class RoutingData:
|
70 |
+
gate_scal: torch.Tensor = field()
|
71 |
+
expt_hist: torch.Tensor = field()
|
72 |
+
n_expts_tot: int = field()
|
73 |
+
n_expts_act: int = field()
|
74 |
+
expt_data: ExptData = None
|
75 |
+
|
76 |
+
# Used to make perf annotation cleaner: when we use expert sharding, we can
|
77 |
+
# use this to tell the "expected" number of local tokens per expert, because
|
78 |
+
# the actual number can vary per each input.
|
79 |
+
expected_tokens_per_expt: int = field(default=None)
|
80 |
+
|
81 |
+
def n_blocks(self, n_rows, block_m):
|
82 |
+
if n_rows <= self.n_expts_tot:
|
83 |
+
return n_rows
|
84 |
+
else:
|
85 |
+
return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1
|
86 |
+
|
87 |
+
|
88 |
+
# --------------------------
|
89 |
+
# sort tokens by expert
|
90 |
+
# --------------------------
|
91 |
+
|
92 |
+
|
93 |
+
class SortTokens(torch.autograd.Function):
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
|
97 |
+
HIST_BLOCK_M = 32
|
98 |
+
INDX_OFFS_BLOCK_M = 512
|
99 |
+
MEMSET_BLOCK = 1024
|
100 |
+
cdiv = triton.cdiv
|
101 |
+
|
102 |
+
device = expt_scal.device
|
103 |
+
dtype = expt_scal.dtype
|
104 |
+
n_tokens_raw, _ = bitmatrix.shape
|
105 |
+
n_tokens_pad, n_expts_act = expt_scal.shape
|
106 |
+
n_gates_pad = n_tokens_pad * n_expts_act
|
107 |
+
|
108 |
+
hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M)
|
109 |
+
hist = hist[:n_expts_tot]
|
110 |
+
assert hist.dtype == torch.int32
|
111 |
+
# scratchpad
|
112 |
+
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
|
113 |
+
combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
|
114 |
+
# output
|
115 |
+
topk_indx = combined_indx[:n_gates_pad]
|
116 |
+
gate_indx = combined_indx[n_gates_pad:]
|
117 |
+
gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
|
118 |
+
|
119 |
+
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
|
120 |
+
hist, n_expts_tot, n_gates_pad)
|
121 |
+
|
122 |
+
blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
|
123 |
+
blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
|
124 |
+
|
125 |
+
_combined_routing_memset[(blocks1a + blocks1b, )](
|
126 |
+
combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, #
|
127 |
+
expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs
|
128 |
+
partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
|
129 |
+
token_offs_combined, token_offs_combined.stride(0), #
|
130 |
+
blocks1a, block_pid_map, #
|
131 |
+
block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
|
132 |
+
BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
|
133 |
+
)
|
134 |
+
|
135 |
+
indx_offs = partial_hist
|
136 |
+
|
137 |
+
_combined_routing_compute[(blocks2a + blocks2b, )](
|
138 |
+
topk_indx, gate_indx, gate_scal, # outputs
|
139 |
+
expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs
|
140 |
+
expt_offs, n_tokens_raw, # input shape
|
141 |
+
HIST_BLOCK_M, n_expts_act, # constants
|
142 |
+
hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
|
143 |
+
block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc.
|
144 |
+
)
|
145 |
+
|
146 |
+
ctx.n_tokens_raw = n_tokens_raw
|
147 |
+
ctx.n_tokens_pad = n_tokens_pad
|
148 |
+
ctx.n_expts_act = n_expts_act
|
149 |
+
ctx.save_for_backward(gate_indx)
|
150 |
+
return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map
|
151 |
+
|
152 |
+
@staticmethod
|
153 |
+
def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
|
154 |
+
(gate_indx, ) = ctx.saved_tensors
|
155 |
+
dgate_scal = dgate_scal[gate_indx]
|
156 |
+
dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
|
157 |
+
return dgate_scal, None, None, None
|
158 |
+
|
159 |
+
|
160 |
+
def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix):
|
161 |
+
return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix)
|
162 |
+
|
163 |
+
|
164 |
+
# --------------------------
|
165 |
+
# prune routing
|
166 |
+
# --------------------------
|
167 |
+
|
168 |
+
|
169 |
+
class PruneRouting(torch.autograd.Function):
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
|
173 |
+
from .compaction import compaction
|
174 |
+
n_tokens_pad = expt_scal.shape[0]
|
175 |
+
assert n_expts_tot % simulated_ep == 0
|
176 |
+
_routing_clear_bitmatrix[(n_tokens_pad, )](
|
177 |
+
bitmatrix.storage.data,
|
178 |
+
bitmatrix.storage.data.stride(0),
|
179 |
+
bitmatrix.storage.data.stride(1),
|
180 |
+
bitmatrix.storage.data.shape[1],
|
181 |
+
n_expts_tot // simulated_ep,
|
182 |
+
BLOCK_N=512,
|
183 |
+
)
|
184 |
+
# perform compaction to update expt_scal / expt_indx
|
185 |
+
expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix)
|
186 |
+
n_expts_tot = n_expts_tot // simulated_ep
|
187 |
+
bitmatrix.shape[-1] = n_expts_tot
|
188 |
+
return expt_scal, expt_indx, bitmatrix
|
189 |
+
|
190 |
+
|
191 |
+
def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep):
|
192 |
+
return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep)
|
193 |
+
|
194 |
+
|
195 |
+
# --------------------------
|
196 |
+
# expt_data
|
197 |
+
# --------------------------
|
198 |
+
|
199 |
+
|
200 |
+
def log2_power_of_two(x):
|
201 |
+
assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two"
|
202 |
+
return x.bit_length() - 1
|
203 |
+
|
204 |
+
|
205 |
+
block_m_log2_start = 4
|
206 |
+
|
207 |
+
|
208 |
+
def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
|
209 |
+
|
210 |
+
MEMSET_BLOCK = 512
|
211 |
+
HIST2_BLOCK_M = 512
|
212 |
+
device = expt_hist.device
|
213 |
+
n_expts_tot = n_expts_tot
|
214 |
+
cdiv = triton.cdiv
|
215 |
+
# block_ms are all powers-of-two between 16 and 128 (inclusive)
|
216 |
+
block_m_log2_end = 9 if is_hip() else 8
|
217 |
+
block_m_num = block_m_log2_end - block_m_log2_start
|
218 |
+
if n_gates <= n_expts_tot:
|
219 |
+
max_n_tiles = n_gates
|
220 |
+
else:
|
221 |
+
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start)
|
222 |
+
# allocate memory
|
223 |
+
pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
|
224 |
+
dtype = torch.int32
|
225 |
+
|
226 |
+
token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device)
|
227 |
+
|
228 |
+
token_offs_raw = token_offs_combined[0][:n_expts_tot + 1]
|
229 |
+
token_offs_pad = token_offs_combined[1:]
|
230 |
+
|
231 |
+
block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device)
|
232 |
+
memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
|
233 |
+
# compute outputs
|
234 |
+
token_offs_pad = token_offs_pad[:, :n_expts_tot + 1]
|
235 |
+
block_pid_map = block_pid_map[:, :max_n_tiles]
|
236 |
+
|
237 |
+
blocks1 = memset_grid + block_m_num + 1
|
238 |
+
blocks2 = n_expts_tot * block_m_num
|
239 |
+
return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num
|
240 |
+
|
241 |
+
|
242 |
+
def _unpack_into_dict(x):
|
243 |
+
|
244 |
+
block_m_log2_end = block_m_log2_start + x.shape[0]
|
245 |
+
x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
|
246 |
+
return x
|
247 |
+
|
248 |
+
|
249 |
+
def compute_expt_data(expt_hist, n_expts_tot, n_gates):
|
250 |
+
|
251 |
+
if expt_hist is None:
|
252 |
+
return ExptData(None, None, None, None)
|
253 |
+
|
254 |
+
# this just computes the kernel arguments:
|
255 |
+
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
|
256 |
+
expt_hist, n_expts_tot, n_gates)
|
257 |
+
|
258 |
+
_expt_data_memset[(blocks1, )](
|
259 |
+
expt_hist, n_expts_tot, #
|
260 |
+
token_offs_combined, token_offs_combined.stride(0), #
|
261 |
+
block_pid_map, #
|
262 |
+
block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters
|
263 |
+
num_warps=4)
|
264 |
+
_expt_data_compute[(blocks2, )](
|
265 |
+
expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
|
266 |
+
block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters
|
267 |
+
num_warps=4)
|
268 |
+
|
269 |
+
token_offs_pad = _unpack_into_dict(token_offs_pad)
|
270 |
+
block_pid_map = _unpack_into_dict(block_pid_map)
|
271 |
+
return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
|
272 |
+
|
273 |
+
|
274 |
+
# --------------------------
|
275 |
+
# routing
|
276 |
+
# --------------------------
|
277 |
+
|
278 |
+
|
279 |
+
def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act):
|
280 |
+
hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
|
281 |
+
expt_scal, expt_indx, n_expts_tot, bitmatrix)
|
282 |
+
token_offs_pad = _unpack_into_dict(token_offs_pad)
|
283 |
+
block_pid_map = _unpack_into_dict(block_pid_map)
|
284 |
+
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
|
285 |
+
|
286 |
+
# pack the matmul data structure
|
287 |
+
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
|
288 |
+
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
|
289 |
+
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
|
290 |
+
|
291 |
+
|
292 |
+
def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None):
|
293 |
+
from .topk import topk
|
294 |
+
if sm_first:
|
295 |
+
logits = torch.softmax(logits, dim=-1)
|
296 |
+
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, #
|
297 |
+
apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows)
|
298 |
+
n_expts_tot = logits.shape[-1] // simulated_ep
|
299 |
+
# mutate bitmatrix
|
300 |
+
if simulated_ep > 1:
|
301 |
+
expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep)
|
302 |
+
|
303 |
+
return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act)
|
304 |
+
|
305 |
+
|
306 |
+
# --------------------------
|
307 |
+
# torch reference
|
308 |
+
# --------------------------
|
309 |
+
|
310 |
+
|
311 |
+
def compute_expt_data_torch(hist, n_expts_tot, n_gates):
|
312 |
+
# offset for each experts
|
313 |
+
device = hist.device
|
314 |
+
token_offs_raw = torch.cumsum(hist, dim=0)
|
315 |
+
token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw))
|
316 |
+
token_offs_raw = token_offs_raw.int()
|
317 |
+
# maximum number of tiles for all values of `block_m` considered
|
318 |
+
block_ms = [16, 32, 64, 128]
|
319 |
+
if is_hip():
|
320 |
+
block_ms.append(256)
|
321 |
+
if n_gates <= n_expts_tot:
|
322 |
+
max_n_tiles = n_gates
|
323 |
+
else:
|
324 |
+
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
|
325 |
+
# ceil_div(x, y): -(-x // y)
|
326 |
+
max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms))
|
327 |
+
# fill up tile offset/infos for each block
|
328 |
+
token_offs_pad = dict()
|
329 |
+
block_pid_map = dict()
|
330 |
+
for block_m in block_ms:
|
331 |
+
n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed
|
332 |
+
token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0)
|
333 |
+
token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m]))
|
334 |
+
token_offs_pad[block_m] = token_offs_pad[block_m].int()
|
335 |
+
# compute data required to drive ragged batch matmul
|
336 |
+
block_pid_map[block_m] = -torch.ones(max_n_tiles, device=device)
|
337 |
+
for e in range(n_expts_tot):
|
338 |
+
offset = token_offs_pad[block_m][e]
|
339 |
+
for b in range(n_tiles[e]):
|
340 |
+
block_pid_map[block_m][offset + b] = (b << 16) + e
|
341 |
+
block_pid_map[block_m] = block_pid_map[block_m].int()
|
342 |
+
return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
|
343 |
+
|
344 |
+
|
345 |
+
def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None):
|
346 |
+
has_user_provided_indx = expt_indx is not None
|
347 |
+
n_gates_pad = logits.shape[0] * n_expts_act
|
348 |
+
|
349 |
+
if n_rows is not None:
|
350 |
+
logits = logits[:n_rows, :]
|
351 |
+
|
352 |
+
def topk(vals, k, expt_indx):
|
353 |
+
# topk of experts
|
354 |
+
if has_user_provided_indx:
|
355 |
+
tk_indx = expt_indx
|
356 |
+
else:
|
357 |
+
tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
|
358 |
+
tk_indx = tk_indx.long()
|
359 |
+
tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
|
360 |
+
tk_indx = tk_indx.int()
|
361 |
+
return tk_val, tk_indx
|
362 |
+
|
363 |
+
_, n_expts_tot = logits.shape
|
364 |
+
if sm_first:
|
365 |
+
logits = torch.softmax(logits, dim=-1)
|
366 |
+
expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx)
|
367 |
+
if not sm_first:
|
368 |
+
expt_scal = torch.softmax(expt_scal, dim=-1)
|
369 |
+
# sort each token's selections by expert
|
370 |
+
if not has_user_provided_indx:
|
371 |
+
expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
|
372 |
+
expt_scal = torch.gather(expt_scal, 1, sort_indices)
|
373 |
+
# flatten topk data
|
374 |
+
expt_scal = expt_scal.reshape(-1)
|
375 |
+
expt_indx = expt_indx.reshape(-1).to(torch.int32)
|
376 |
+
# sort by expert_id so experts are contiguous for the matmul
|
377 |
+
topk_indx = torch.argsort(expt_indx, stable=True)
|
378 |
+
gate_indx = torch.argsort(topk_indx, stable=True)
|
379 |
+
gate_scal = expt_scal[topk_indx]
|
380 |
+
hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() # histogram of tokens over experts
|
381 |
+
# pack the matmul data structure
|
382 |
+
gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
|
383 |
+
scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
|
384 |
+
# compute expt_data
|
385 |
+
expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad)
|
386 |
+
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
|
build/torch-universal/triton_kernels/routing_details/_expt_data.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
|
4 |
+
|
5 |
+
@triton.jit
|
6 |
+
def _cdiv_pow2(n, log2_k):
|
7 |
+
return (n + ((1 << log2_k) - 1)) >> log2_k
|
8 |
+
|
9 |
+
|
10 |
+
@triton.jit
|
11 |
+
def _expt_data_memset(Hist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
|
12 |
+
SIZES: tl.constexpr, BLOCK: tl.constexpr):
|
13 |
+
|
14 |
+
pid = tl.program_id(0)
|
15 |
+
|
16 |
+
if pid <= SIZES:
|
17 |
+
|
18 |
+
MDStarts += pid * tile_starts_stridem
|
19 |
+
x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
|
20 |
+
Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
|
21 |
+
tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
|
22 |
+
|
23 |
+
for i in range(0, n_expts_tot + 1, BLOCK):
|
24 |
+
|
25 |
+
offs_n = tl.arange(0, BLOCK) + i
|
26 |
+
mask_n0 = offs_n < n_expts_tot
|
27 |
+
hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
|
28 |
+
hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
|
29 |
+
|
30 |
+
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
|
31 |
+
x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
|
32 |
+
tl.store(Tile_ptrs, tile_starts - hist_tile)
|
33 |
+
Tile_ptrs += BLOCK
|
34 |
+
|
35 |
+
else:
|
36 |
+
|
37 |
+
pid -= (SIZES + 1)
|
38 |
+
TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
|
39 |
+
tl.store(TileInfoOut, 0xffffffff)
|
40 |
+
|
41 |
+
|
42 |
+
@triton.jit
|
43 |
+
def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
|
44 |
+
SIZES: tl.constexpr, BLOCK: tl.constexpr):
|
45 |
+
|
46 |
+
pid = tl.program_id(0)
|
47 |
+
|
48 |
+
expt_id = pid // SIZES
|
49 |
+
buff_id = pid % SIZES
|
50 |
+
|
51 |
+
MDTileStarts += buff_id * tile_starts_stridem
|
52 |
+
MDTileInfo += buff_id * tile_info_stridem
|
53 |
+
|
54 |
+
n_tokens = tl.load(Hist + expt_id)
|
55 |
+
tile_dim_log2 = first_tile_dim_log2 + buff_id
|
56 |
+
n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2)
|
57 |
+
|
58 |
+
tile_off = tl.load(MDTileStarts + expt_id)
|
59 |
+
MDTileInfo += tile_off
|
60 |
+
|
61 |
+
for block_off in range(0, n_blocks, BLOCK):
|
62 |
+
block_offs = block_off + tl.arange(0, BLOCK)
|
63 |
+
data = (block_offs << 16) + expt_id
|
64 |
+
tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks)
|
build/torch-universal/triton_kernels/routing_details/_routing_compute.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
|
4 |
+
from ._expt_data import _expt_data_compute, _expt_data_memset
|
5 |
+
|
6 |
+
|
7 |
+
@triton.jit
|
8 |
+
def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histogram
|
9 |
+
BLOCK_N: tl.constexpr):
|
10 |
+
loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N
|
11 |
+
x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty)
|
12 |
+
for i in range(loop_iterations):
|
13 |
+
offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N)
|
14 |
+
mask_n = offs_n < hist_size
|
15 |
+
hist2 = tl.load(ExpertHist + offs_n, mask=mask_n)
|
16 |
+
tok_starts = tl.cumsum(hist2, 0) - hist2 + x
|
17 |
+
x += tl.sum(hist2, 0)
|
18 |
+
tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n)
|
19 |
+
offs_n += BLOCK_N
|
20 |
+
|
21 |
+
|
22 |
+
@triton.jit
|
23 |
+
def _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id):
|
24 |
+
offs_m = tl.arange(0, BLOCK_M)
|
25 |
+
# iterate over input data
|
26 |
+
curr_sum = 0
|
27 |
+
for _ in range(0, shape_pm, BLOCK_M):
|
28 |
+
offs = offs_m * stride_pm + expt_id * stride_pn
|
29 |
+
curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
|
30 |
+
out = tl.cumsum(curr, 0) + curr_sum
|
31 |
+
curr_sum += tl.sum(curr, 0)
|
32 |
+
tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm)
|
33 |
+
offs_m += BLOCK_M
|
34 |
+
|
35 |
+
|
36 |
+
@triton.jit
|
37 |
+
def _keyed_add(x, y):
|
38 |
+
|
39 |
+
# we keep the key in the upper 16 bits of a uint32:
|
40 |
+
key_mask: tl.constexpr = 0xffff0000
|
41 |
+
|
42 |
+
kx = x & key_mask
|
43 |
+
ky = y & key_mask
|
44 |
+
z = tl.where(kx == ky, x + y - kx, y)
|
45 |
+
return z
|
46 |
+
|
47 |
+
|
48 |
+
@triton.jit
|
49 |
+
def _routing_compute_indx(pid_m, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
|
50 |
+
stride_pn, TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
|
51 |
+
|
52 |
+
if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr():
|
53 |
+
n_tokens = tl.load(n_tokens)
|
54 |
+
n_gates = n_tokens * N_EXPTS_ACT
|
55 |
+
|
56 |
+
tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768)
|
57 |
+
|
58 |
+
local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M)
|
59 |
+
offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
|
60 |
+
expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32)
|
61 |
+
|
62 |
+
# stable-sort by expert ID:
|
63 |
+
kv_pairs = ((expert << 16) | local_offs).to(tl.uint32)
|
64 |
+
kv_pairs = tl.sort(kv_pairs, 0)
|
65 |
+
expert = kv_pairs >> 16
|
66 |
+
offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff)
|
67 |
+
mask = expert != 0xffff
|
68 |
+
gate_scal = tl.load(ExptScal + offs, mask=mask)
|
69 |
+
|
70 |
+
# compute run lengths in expert-sorted order:
|
71 |
+
x = (kv_pairs & 0xffff0000 | 0x00000001)
|
72 |
+
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
|
73 |
+
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
|
74 |
+
|
75 |
+
gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
|
76 |
+
gates += tl.load(TokensStart + expert, mask=mask)
|
77 |
+
gates += exclusive_run_lengths
|
78 |
+
|
79 |
+
tl.store(ScatterIndx + offs, gates, mask=mask)
|
80 |
+
tl.store(GatherIndx + gates, offs, mask=mask)
|
81 |
+
tl.store(GateScal + gates, gate_scal, mask=mask)
|
82 |
+
|
83 |
+
|
84 |
+
@triton.jit
|
85 |
+
def _combined_routing_compute(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn,
|
86 |
+
TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr, Hist,
|
87 |
+
MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
|
88 |
+
SIZES: tl.constexpr, BLOCK: tl.constexpr, blocks2a):
|
89 |
+
|
90 |
+
pid = tl.program_id(0)
|
91 |
+
if pid < blocks2a:
|
92 |
+
_expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
|
93 |
+
SIZES, BLOCK)
|
94 |
+
else:
|
95 |
+
pid -= blocks2a
|
96 |
+
_routing_compute_indx(pid, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
|
97 |
+
stride_pn, TokensStart, n_tokens, BLOCK_M, N_EXPTS_ACT)
|
98 |
+
|
99 |
+
|
100 |
+
@triton.jit
|
101 |
+
def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr):
|
102 |
+
pid_m = tl.program_id(0)
|
103 |
+
cutoff_word = cutoff // 32
|
104 |
+
cutoff_bit = cutoff % 32
|
105 |
+
cutoff_mask = (1 << (cutoff_bit)) - 1
|
106 |
+
for start_n in range(0, shape_bn, BLOCK_N):
|
107 |
+
offs_n = start_n + tl.arange(0, BLOCK_N)
|
108 |
+
values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn)
|
109 |
+
values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values)
|
110 |
+
values = tl.where(offs_n > cutoff_word, 0, values)
|
111 |
+
tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn)
|
112 |
+
|
113 |
+
|
114 |
+
@triton.jit
|
115 |
+
def _combined_routing_memset(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
|
116 |
+
n_expts_tot, PartialHist, shape_pm, stride_pm, stride_pn, MDStarts, tile_starts_stridem,
|
117 |
+
blocks1a, MDTileInfo, first_tile_dim_log2, SIZES: tl.constexpr, BLOCK_A: tl.constexpr,
|
118 |
+
BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr):
|
119 |
+
"""
|
120 |
+
This kernel essentially combines 6 different pieces of functionality,
|
121 |
+
statically branching on the value of tl.program_id(0) to decide which
|
122 |
+
codepath to take.
|
123 |
+
|
124 |
+
pid == 0: create the token cumsum
|
125 |
+
1 <= pid <= SIZES: create a tile cumsum
|
126 |
+
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
|
127 |
+
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
|
128 |
+
pid == blocks1a + n_expts_tot: compute_expt_offs
|
129 |
+
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
|
130 |
+
|
131 |
+
As each of these is a relatively trivial workload, launching them from
|
132 |
+
this single trampoline is beneficial as they can execute on different
|
133 |
+
streaming multiprocesses in parallel.
|
134 |
+
"""
|
135 |
+
|
136 |
+
pid = tl.program_id(0)
|
137 |
+
|
138 |
+
if pid < blocks1a:
|
139 |
+
_expt_data_memset(ExpertHist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
|
140 |
+
SIZES, BLOCK_A)
|
141 |
+
elif pid == n_expts_tot + blocks1a:
|
142 |
+
_routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
|
143 |
+
elif pid < n_expts_tot + blocks1a:
|
144 |
+
_routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a)
|
145 |
+
else:
|
146 |
+
offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
|
147 |
+
mask = offs < size
|
148 |
+
tl.store(Indx + offs, sentinel, mask=mask)
|
build/torch-universal/triton_kernels/specialize.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import re
|
3 |
+
import textwrap
|
4 |
+
import types
|
5 |
+
import triton
|
6 |
+
|
7 |
+
|
8 |
+
def cacheable(f):
|
9 |
+
"""
|
10 |
+
A decorator that allow you to write something of the form:
|
11 |
+
|
12 |
+
@cacheable
|
13 |
+
def my_kernel(): return (expression dynamically defining a kernel)
|
14 |
+
|
15 |
+
such that it interacts gracefully with triton cache and preload.
|
16 |
+
"""
|
17 |
+
|
18 |
+
g = f()
|
19 |
+
g.fn.__name__ = f.__name__
|
20 |
+
g.fn.__module__ = f.__module__
|
21 |
+
g.fn.__qualname__ = f.__qualname__
|
22 |
+
g._fn_name = f"{f.__module__}.{f.__qualname__}"
|
23 |
+
return g
|
24 |
+
|
25 |
+
|
26 |
+
def define_kernel(src, module, attrs=None, **extra_globals):
|
27 |
+
"""
|
28 |
+
Dynamically create a Triton function or kernel from a src string,
|
29 |
+
linking any symbols in the kernel to objects specified by extra_globals.
|
30 |
+
"""
|
31 |
+
|
32 |
+
# create templace function
|
33 |
+
def _empty_fn():
|
34 |
+
pass
|
35 |
+
|
36 |
+
gdict = dict(**(_empty_fn.__globals__))
|
37 |
+
gdict.update(extra_globals)
|
38 |
+
f = types.FunctionType(_empty_fn.__code__, gdict)
|
39 |
+
f.__module__ = module.__name__
|
40 |
+
|
41 |
+
src = textwrap.dedent(src)
|
42 |
+
src = src[src.find("def "):]
|
43 |
+
|
44 |
+
stored_functions = []
|
45 |
+
function_name = src[4:].split("(")[0].strip()
|
46 |
+
|
47 |
+
exec_globals = gdict
|
48 |
+
exec_globals.update({"stored_functions": stored_functions})
|
49 |
+
exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals)
|
50 |
+
|
51 |
+
f.__signature__ = inspect.signature(stored_functions[0])
|
52 |
+
f.__name__ = function_name
|
53 |
+
f.__doc__ = stored_functions[0].__doc__
|
54 |
+
|
55 |
+
if attrs is None:
|
56 |
+
attrs = dict()
|
57 |
+
f = triton.JITFunction(f, **attrs)
|
58 |
+
f._unsafe_update_src(src)
|
59 |
+
return f
|
60 |
+
|
61 |
+
|
62 |
+
def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()):
|
63 |
+
assert isinstance(fn, triton.runtime.jit.JITFunction)
|
64 |
+
if name is None:
|
65 |
+
name = f"{fn.__name__}"
|
66 |
+
# Get original source code
|
67 |
+
src = inspect.getsource(fn.fn)
|
68 |
+
src = textwrap.dedent(src)
|
69 |
+
lines = src.split("\n")
|
70 |
+
# Skip decorator and def line
|
71 |
+
def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def"))
|
72 |
+
# separate header vs body LOC
|
73 |
+
header_end = def_idx
|
74 |
+
while not lines[header_end].rstrip().endswith(":"):
|
75 |
+
header_end += 1
|
76 |
+
body_lines = lines[header_end + 1:]
|
77 |
+
header_lines = lines[def_idx:header_end + 1]
|
78 |
+
# clean-up header
|
79 |
+
header_clean = [
|
80 |
+
l.split("#", 1)[0].strip() # keep code, discard comment
|
81 |
+
for l in header_lines
|
82 |
+
if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines
|
83 |
+
]
|
84 |
+
# decompose arguments
|
85 |
+
header_src = " ".join(header_clean) # turn it into a single line
|
86 |
+
m = re.search(r"\((.*)\)\s*:", header_src)
|
87 |
+
if not m:
|
88 |
+
raise ValueError("Could not parse function header")
|
89 |
+
args_str = m.group(1)
|
90 |
+
args = [arg.strip() for arg in args_str.split(",") if arg.strip()]
|
91 |
+
non_specialized_args = []
|
92 |
+
for arg in args:
|
93 |
+
arg_key = arg.split(":")[0].split("=")[0].strip()
|
94 |
+
new_args = tuples.get(arg_key, [arg])
|
95 |
+
if arg_key not in constants:
|
96 |
+
non_specialized_args += new_args
|
97 |
+
# add global symbols
|
98 |
+
spec_fns = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)}
|
99 |
+
globals = spec_fns | fn.get_capture_scope()
|
100 |
+
# build new source code and define kernel dynamically
|
101 |
+
new_signature = f"def {name}({', '.join(non_specialized_args)}):"
|
102 |
+
constexpr_lines = [
|
103 |
+
f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items()
|
104 |
+
]
|
105 |
+
tuple_lines = [
|
106 |
+
f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items()
|
107 |
+
]
|
108 |
+
new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines)
|
109 |
+
# find function parameters
|
110 |
+
sig = inspect.signature(triton.runtime.jit.JITFunction.__init__)
|
111 |
+
params = list(sig.parameters.values())[2:]
|
112 |
+
attrs = {param.name: getattr(fn, param.name, param.default) for param in params}
|
113 |
+
|
114 |
+
# make a new repr which appends the repr of the specialized functions.
|
115 |
+
base_repr = attrs["repr"]
|
116 |
+
|
117 |
+
def new_repr(specialization):
|
118 |
+
ret = base_repr(specialization)
|
119 |
+
for spec_fn in spec_fns.values():
|
120 |
+
spec_repr = spec_fn.repr(None)
|
121 |
+
if spec_repr:
|
122 |
+
spec_repr = spec_repr.strip("_")
|
123 |
+
if spec_repr:
|
124 |
+
ret += f"_{spec_repr}"
|
125 |
+
return ret
|
126 |
+
|
127 |
+
attrs["repr"] = new_repr
|
128 |
+
|
129 |
+
if do_not_specialize:
|
130 |
+
attrs["do_not_specialize"] = do_not_specialize
|
131 |
+
ret = define_kernel(new_src, module, attrs, **globals)
|
132 |
+
return ret
|
build/torch-universal/triton_kernels/swiglu.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from triton_kernels.numerics import InFlexData, OutFlexData
|
3 |
+
import torch
|
4 |
+
import triton
|
5 |
+
from .swiglu_details._swiglu import _swiglu, _swiglu_fn
|
6 |
+
from triton_kernels import target_info
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass(frozen=True)
|
10 |
+
class FlexCtx:
|
11 |
+
out_data: OutFlexData = OutFlexData()
|
12 |
+
inp_data: InFlexData = InFlexData()
|
13 |
+
saturate_inf: bool = False
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass(frozen=True)
|
17 |
+
class PrecisionConfig:
|
18 |
+
limit: float
|
19 |
+
flex_ctx: FlexCtx = FlexCtx()
|
20 |
+
|
21 |
+
|
22 |
+
swiglu_fn = _swiglu_fn
|
23 |
+
|
24 |
+
|
25 |
+
class SwiGLU(torch.autograd.Function):
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def forward(ctx, a, alpha, precision_config, routing_data):
|
29 |
+
N = a.shape[-1]
|
30 |
+
M = a.numel() // N
|
31 |
+
assert a.stride()[-1] == 1
|
32 |
+
assert a.shape[-1] % 2 == 0
|
33 |
+
out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device)
|
34 |
+
flex_ctx = precision_config.flex_ctx
|
35 |
+
# optimization hyperparameters
|
36 |
+
BLOCK_M, BLOCK_N = 32 // a.itemsize, 128
|
37 |
+
num_warps = 4
|
38 |
+
kwargs = {'maxnreg': 64} if not target_info.is_hip() else {}
|
39 |
+
# launch semi-persistent kernel
|
40 |
+
N_BLOCKS = triton.cdiv(N // 2, BLOCK_N)
|
41 |
+
num_sms = target_info.num_sms()
|
42 |
+
if routing_data is not None:
|
43 |
+
waves_per_sm = 32 if target_info.is_hip() else 128
|
44 |
+
num_pid = num_sms * (waves_per_sm // num_warps)
|
45 |
+
M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS))
|
46 |
+
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
|
47 |
+
else:
|
48 |
+
M_BLOCKS = triton.cdiv(M, BLOCK_M)
|
49 |
+
if M_BLOCKS * N_BLOCKS >= 8 * num_sms:
|
50 |
+
grid = (8 * num_sms, )
|
51 |
+
else:
|
52 |
+
grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), )
|
53 |
+
n_tokens = None
|
54 |
+
if routing_data is not None:
|
55 |
+
n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot]
|
56 |
+
_swiglu[grid](
|
57 |
+
flex_ctx.out_data.reinterpret(out),
|
58 |
+
flex_ctx.out_data.expected_scale,
|
59 |
+
flex_ctx.out_data.actual_scale,
|
60 |
+
flex_ctx.out_data.checksum_scale,
|
61 |
+
flex_ctx.inp_data.reinterpret(a),
|
62 |
+
flex_ctx.inp_data.scale,
|
63 |
+
alpha,
|
64 |
+
M,
|
65 |
+
N // 2,
|
66 |
+
a.shape[-1],
|
67 |
+
1,
|
68 |
+
out.shape[-1],
|
69 |
+
1,
|
70 |
+
precision_config.limit,
|
71 |
+
n_tokens,
|
72 |
+
BLOCK_M=BLOCK_M,
|
73 |
+
BLOCK_N=BLOCK_N,
|
74 |
+
EVEN_N=(N // 2) % BLOCK_N == 0,
|
75 |
+
M_BLOCKS=M_BLOCKS,
|
76 |
+
N_BLOCKS=N_BLOCKS,
|
77 |
+
flexpoint_saturate_inf=flex_ctx.saturate_inf,
|
78 |
+
num_warps=num_warps,
|
79 |
+
**kwargs,
|
80 |
+
)
|
81 |
+
out = out.view(a.shape[:-1] + out.shape[-1:])
|
82 |
+
return out
|
83 |
+
|
84 |
+
|
85 |
+
def swiglu(a, alpha, precision_config, routing_data=None):
|
86 |
+
return SwiGLU.apply(a, alpha, precision_config, routing_data)
|
87 |
+
|
88 |
+
|
89 |
+
def swiglu_torch(a, alpha, precision_config):
|
90 |
+
limit = precision_config.limit
|
91 |
+
a_gelu = a[..., ::2]
|
92 |
+
if limit is not None:
|
93 |
+
a_gelu = a_gelu.clamp(max=limit)
|
94 |
+
a_linear = a[..., 1::2]
|
95 |
+
if limit is not None:
|
96 |
+
a_linear = a_linear.clamp(min=-limit, max=limit)
|
97 |
+
|
98 |
+
out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu)
|
99 |
+
out = out_gelu * (a_linear + 1)
|
100 |
+
return out
|
build/torch-universal/triton_kernels/swiglu_details/_swiglu.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
|
5 |
+
|
6 |
+
@triton.jit
|
7 |
+
def clip(x, limit, clip_lower: tl.constexpr):
|
8 |
+
res = tl.minimum(x, limit)
|
9 |
+
if clip_lower:
|
10 |
+
res = tl.maximum(-limit, res)
|
11 |
+
return res
|
12 |
+
|
13 |
+
|
14 |
+
@triton.jit
|
15 |
+
def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr):
|
16 |
+
return tl.max(tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True), axis=1)
|
17 |
+
|
18 |
+
|
19 |
+
def swiglu_repr(specialization):
|
20 |
+
signature = specialization.signature
|
21 |
+
constants = specialization.constants
|
22 |
+
convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
|
23 |
+
dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]])
|
24 |
+
blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]])
|
25 |
+
return f"_swiglu_{dtypes}_{blocks}"
|
26 |
+
|
27 |
+
|
28 |
+
def swiglu_launch_metadata(grid, kernel, args):
|
29 |
+
M, N = args["M"], args["N"]
|
30 |
+
ret = dict()
|
31 |
+
ret["name"] = f"{kernel.name} [M = {M}, N = {N}]"
|
32 |
+
A, Out = args["A"], args["Out"]
|
33 |
+
ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size()
|
34 |
+
return ret
|
35 |
+
|
36 |
+
|
37 |
+
@triton.jit
|
38 |
+
def compute_swiglu(gelu, linear, scale, alpha, limit):
|
39 |
+
gelu = gelu.to(tl.float32) * scale
|
40 |
+
if limit is not None:
|
41 |
+
gelu = clip(gelu, limit, clip_lower=False)
|
42 |
+
linear = linear.to(tl.float32) * scale
|
43 |
+
if limit is not None:
|
44 |
+
linear = clip(linear, limit, clip_lower=True)
|
45 |
+
s = gelu / (1 + tl.exp(-alpha * gelu))
|
46 |
+
return tl.fma(s, linear, s) # (s * (linear + 1))
|
47 |
+
|
48 |
+
|
49 |
+
@triton.jit(repr=lambda _: "_swiglu")
|
50 |
+
def _swiglu_fn(input, alpha, limit):
|
51 |
+
gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2)))
|
52 |
+
return compute_swiglu(gelu, linear, 1.0, alpha, limit)
|
53 |
+
|
54 |
+
|
55 |
+
@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata)
|
56 |
+
def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an,
|
57 |
+
stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
58 |
+
EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr):
|
59 |
+
if NTokens is not None:
|
60 |
+
M = tl.load(NTokens)
|
61 |
+
M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M
|
62 |
+
|
63 |
+
local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32)
|
64 |
+
|
65 |
+
a_scale = load_scale(AScale)
|
66 |
+
out_expected_scale = load_scale(OutExpectedScale)
|
67 |
+
|
68 |
+
for pid in tl.range(tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2):
|
69 |
+
pid_m = (pid // N_BLOCKS)
|
70 |
+
pid_n = (pid % N_BLOCKS)
|
71 |
+
off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
72 |
+
off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
73 |
+
mask_m = off_m < M
|
74 |
+
mask_n = off_n < N
|
75 |
+
packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2
|
76 |
+
packed_mask_n = packed_off_n < N
|
77 |
+
packed_mask_n = tl.max_constancy(packed_mask_n, [16])
|
78 |
+
# load a
|
79 |
+
packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N)
|
80 |
+
packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an
|
81 |
+
if EVEN_N:
|
82 |
+
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
|
83 |
+
else:
|
84 |
+
if pid_n * BLOCK_N + BLOCK_N <= N:
|
85 |
+
a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.)
|
86 |
+
else:
|
87 |
+
packed_mask = mask_m[:, None] & packed_mask_n[None, :]
|
88 |
+
a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.)
|
89 |
+
a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2)))
|
90 |
+
out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit)
|
91 |
+
# update flexpoint stats and divide by scale
|
92 |
+
# we don't need masking because of the `other` when loading `A`
|
93 |
+
if OutActualScale is not None:
|
94 |
+
absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads())
|
95 |
+
local_max = tl.maximum(local_max, absmax)
|
96 |
+
out = float_to_flex(out, out_expected_scale,
|
97 |
+
None, # ActualScale: local absmax is tracked and updated after the loop
|
98 |
+
OutChecksumScale, None, Out, flexpoint_saturate_inf)
|
99 |
+
mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :]
|
100 |
+
tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask)
|
101 |
+
|
102 |
+
update_scale(local_max, OutActualScale, Out)
|
build/torch-universal/triton_kernels/target_info.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
|
4 |
+
cached_capabilities = {}
|
5 |
+
|
6 |
+
|
7 |
+
def is_cuda():
|
8 |
+
if "is_cuda" not in cached_capabilities:
|
9 |
+
target = triton.runtime.driver.active.get_current_target()
|
10 |
+
cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda"
|
11 |
+
return cached_capabilities["is_cuda"]
|
12 |
+
|
13 |
+
|
14 |
+
def is_hip():
|
15 |
+
if "is_hip" not in cached_capabilities:
|
16 |
+
cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip)
|
17 |
+
return cached_capabilities["is_hip"]
|
18 |
+
|
19 |
+
|
20 |
+
def is_hip_cdna3():
|
21 |
+
if "is_hip_cdna3" not in cached_capabilities:
|
22 |
+
target = triton.runtime.driver.active.get_current_target()
|
23 |
+
cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip'
|
24 |
+
and target.arch == 'gfx942')
|
25 |
+
return cached_capabilities["is_hip_cdna3"]
|
26 |
+
|
27 |
+
|
28 |
+
def is_hip_cdna4():
|
29 |
+
if "is_hip_cdna4" not in cached_capabilities:
|
30 |
+
target = triton.runtime.driver.active.get_current_target()
|
31 |
+
cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip'
|
32 |
+
and target.arch == 'gfx950')
|
33 |
+
return cached_capabilities["is_hip_cdna4"]
|
34 |
+
|
35 |
+
|
36 |
+
def cuda_capability_geq(major, minor=0):
|
37 |
+
"""
|
38 |
+
Determines whether we have compute capability >= (major, minor) and
|
39 |
+
returns this as a constexpr boolean. This can be used for guarding
|
40 |
+
inline asm implementations that require a certain compute capability.
|
41 |
+
"""
|
42 |
+
if is_hip():
|
43 |
+
return False
|
44 |
+
if "cuda" not in cached_capabilities:
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
cached_capabilities["cuda"] = torch.cuda.get_device_capability()
|
47 |
+
else:
|
48 |
+
cached_capabilities["cuda"] = (0, 0)
|
49 |
+
return cached_capabilities["cuda"] >= (major, minor)
|
50 |
+
|
51 |
+
|
52 |
+
def get_cdna_version():
|
53 |
+
"""
|
54 |
+
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
|
55 |
+
only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
|
56 |
+
hardware or unsupported architecture
|
57 |
+
"""
|
58 |
+
target = triton.runtime.driver.active.get_current_target()
|
59 |
+
if target.backend != 'hip':
|
60 |
+
return -1
|
61 |
+
if target.arch == 'gfx942':
|
62 |
+
return 3
|
63 |
+
if target.arch == 'gfx950':
|
64 |
+
return 4
|
65 |
+
return -1
|
66 |
+
|
67 |
+
|
68 |
+
def has_tma_gather():
|
69 |
+
return cuda_capability_geq(10, 0)
|
70 |
+
|
71 |
+
|
72 |
+
def has_native_mxfp():
|
73 |
+
return cuda_capability_geq(10, 0)
|
74 |
+
|
75 |
+
|
76 |
+
def num_sms():
|
77 |
+
return torch.cuda.get_device_properties(0).multi_processor_count
|
build/torch-universal/triton_kernels/tensor.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, fields
|
2 |
+
from typing import Type
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from triton.tools.tensor_descriptor import TensorDescriptor
|
6 |
+
|
7 |
+
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows
|
8 |
+
from .target_info import cuda_capability_geq
|
9 |
+
from .tensor_details.layout import Layout, StridedLayout
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class Storage:
|
14 |
+
data: torch.Tensor
|
15 |
+
layout: Layout = None
|
16 |
+
|
17 |
+
def __post_init__(self):
|
18 |
+
assert isinstance(self.data, torch.Tensor)
|
19 |
+
if self.layout is None:
|
20 |
+
self.layout = StridedLayout(self.data.shape)
|
21 |
+
|
22 |
+
@property
|
23 |
+
def device(self):
|
24 |
+
return self.data.device
|
25 |
+
|
26 |
+
def is_tma_compliant(self):
|
27 |
+
# TMAs didn't exist until Hopper
|
28 |
+
if not cuda_capability_geq(9, 0):
|
29 |
+
return False
|
30 |
+
# TMAs only exist for 2D, 3D, 5D inputs
|
31 |
+
if len(self.data.shape) not in [2, 3, 5]:
|
32 |
+
return False
|
33 |
+
# TMAs need at most one stride equal to 1
|
34 |
+
# and all other strides divisble by 16
|
35 |
+
strides = list(self.data.stride())
|
36 |
+
try:
|
37 |
+
major_dim = strides.index(1)
|
38 |
+
except ValueError:
|
39 |
+
major_dim = -1
|
40 |
+
ndim = self.data.ndim
|
41 |
+
bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
|
42 |
+
compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
|
43 |
+
return all(compliant)
|
44 |
+
|
45 |
+
def make_tma(self, block_shape, transpose=False):
|
46 |
+
strides = list(self.data.stride())
|
47 |
+
shape = list(self.data.shape)
|
48 |
+
# TODO
|
49 |
+
# there is an issue w/ column-major TMA; we transpose instead
|
50 |
+
transpose = self.data.stride()[-1] != 1
|
51 |
+
if transpose:
|
52 |
+
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
|
53 |
+
shape = shape[:-2] + [shape[-1], shape[-2]]
|
54 |
+
strides = strides[:-2] + [strides[-1], strides[-2]]
|
55 |
+
if self.data.dtype == torch.uint8 and self.layout.name is None:
|
56 |
+
# physical block size is half logical block size along packed dimension
|
57 |
+
indx = strides.index(1)
|
58 |
+
block_shape[indx] = block_shape[indx] // 2
|
59 |
+
# Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
|
60 |
+
# CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
|
61 |
+
pad = 128
|
62 |
+
shape[-1] = (shape[-1] + pad - 1) // pad * pad
|
63 |
+
block_shape = self.layout.swizzle_block_shape(block_shape)
|
64 |
+
return TensorDescriptor(self.data, shape, strides, block_shape)
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class IntegerType:
|
69 |
+
bitwidth: int
|
70 |
+
|
71 |
+
|
72 |
+
@dataclass
|
73 |
+
class FloatType:
|
74 |
+
bitwidth_exponent: int
|
75 |
+
bitwidth_mantissa: int
|
76 |
+
is_signed: bool
|
77 |
+
|
78 |
+
def __post_init__(self):
|
79 |
+
self.bitwidth = int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa
|
80 |
+
|
81 |
+
|
82 |
+
BIT = IntegerType(1)
|
83 |
+
FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True)
|
84 |
+
|
85 |
+
|
86 |
+
def bitwidth(type: IntegerType | FloatType | torch.dtype):
|
87 |
+
if isinstance(type, torch.dtype):
|
88 |
+
return type.itemsize * 8
|
89 |
+
return type.bitwidth
|
90 |
+
|
91 |
+
|
92 |
+
@dataclass
|
93 |
+
class Tensor:
|
94 |
+
storage: Storage | torch.Tensor
|
95 |
+
dtype: IntegerType | FloatType | torch.dtype = None
|
96 |
+
shape: list[int] | None = None
|
97 |
+
shape_max: list[int] | None = None
|
98 |
+
|
99 |
+
def __post_init__(self):
|
100 |
+
# set storage
|
101 |
+
if isinstance(self.storage, torch.Tensor):
|
102 |
+
self.storage = Storage(self.storage)
|
103 |
+
# initialize dtype
|
104 |
+
if self.dtype is None:
|
105 |
+
self.dtype = self.storage.data.dtype
|
106 |
+
if bitwidth(self.dtype) < 8 and self.shape is None:
|
107 |
+
raise ValueError("shape must be provided for sub-byte types")
|
108 |
+
# initialize shape
|
109 |
+
if self.shape is None:
|
110 |
+
self.shape = list(self.storage.data.shape)
|
111 |
+
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
|
112 |
+
is_int = lambda s: isinstance(s, int)
|
113 |
+
is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
|
114 |
+
assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
|
115 |
+
# initialize shape_max
|
116 |
+
if self.shape_max is None:
|
117 |
+
self.shape_max = [None] * len(self.shape)
|
118 |
+
for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
|
119 |
+
if smax is not None and not is_int(smax):
|
120 |
+
raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
|
121 |
+
if smax is None:
|
122 |
+
self.shape_max[i] = s
|
123 |
+
# validate shape_max: all elements must be `int`
|
124 |
+
assert all(map(is_int, self.shape_max))
|
125 |
+
|
126 |
+
# torch compatibility layer
|
127 |
+
@property
|
128 |
+
def ndim(self):
|
129 |
+
return len(self.shape)
|
130 |
+
|
131 |
+
@property
|
132 |
+
def device(self):
|
133 |
+
return self.storage.device
|
134 |
+
|
135 |
+
def stride(self, i=None):
|
136 |
+
return self.storage.data.stride() if i is None else self.storage.data.stride(i)
|
137 |
+
|
138 |
+
def data_ptr(self):
|
139 |
+
return self.storage.data.data_ptr()
|
140 |
+
|
141 |
+
def numel(self):
|
142 |
+
return self.storage.data.numel()
|
143 |
+
|
144 |
+
def element_size(self):
|
145 |
+
return bitwidth(self.dtype) // 8
|
146 |
+
|
147 |
+
@property
|
148 |
+
def data(self):
|
149 |
+
t = self.storage
|
150 |
+
return t.data if isinstance(t, Storage) else t
|
151 |
+
|
152 |
+
def dim(self):
|
153 |
+
return self.ndim
|
154 |
+
|
155 |
+
def size(self, i=None):
|
156 |
+
if i is None:
|
157 |
+
return self.shape
|
158 |
+
return self.shape[i]
|
159 |
+
|
160 |
+
|
161 |
+
@dataclass
|
162 |
+
class Bitmatrix(Tensor):
|
163 |
+
"""
|
164 |
+
Represents a boolean matrix in a packed format where each element occupies
|
165 |
+
a single bit of memory.
|
166 |
+
|
167 |
+
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
|
168 |
+
with the actual bitmatrix to avoid having to launch a separate memset
|
169 |
+
kernel when we call Bitmatrix::sum().
|
170 |
+
"""
|
171 |
+
|
172 |
+
scratchpad: torch.Tensor = None
|
173 |
+
|
174 |
+
def __init__(self, storage, shape, shape_max=None, scratchpad=None):
|
175 |
+
super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max)
|
176 |
+
self.scratchpad = scratchpad
|
177 |
+
|
178 |
+
def sum(self, partials_block_size):
|
179 |
+
_, n_cols = self.shape
|
180 |
+
dev = self.device
|
181 |
+
if self.scratchpad is None:
|
182 |
+
self.scratchpad = clear_sums(n_cols, dev)
|
183 |
+
out_ret = self.scratchpad[:n_cols]
|
184 |
+
self.scratchpad = None # throw error if we try to sum again
|
185 |
+
return sum_bitmatrix_rows(self, out_ret, partials_block_size)
|
186 |
+
|
187 |
+
|
188 |
+
def get_layout(tensor: torch.Tensor | Tensor | None):
|
189 |
+
if tensor is None:
|
190 |
+
return None
|
191 |
+
if isinstance(tensor, Tensor):
|
192 |
+
return tensor.storage.layout
|
193 |
+
return StridedLayout
|
194 |
+
|
195 |
+
|
196 |
+
def wrap_torch_tensor(torch_tensor, dtype=None):
|
197 |
+
if dtype is None:
|
198 |
+
dtype = torch_tensor.dtype
|
199 |
+
shape = list(torch_tensor.shape)
|
200 |
+
shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype)
|
201 |
+
return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
|
202 |
+
|
203 |
+
|
204 |
+
def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):
|
205 |
+
assert isinstance(tensor, Tensor)
|
206 |
+
old_storage = tensor.storage
|
207 |
+
old_data = old_storage.layout.unswizzle_data(old_storage.data)
|
208 |
+
new_layout = layout_cls(old_data.shape, **layout_kwargs)
|
209 |
+
new_data = new_layout.swizzle_data(old_data)
|
210 |
+
attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"}
|
211 |
+
return Tensor(Storage(new_data, new_layout), **attrs)
|
build/torch-universal/triton_kernels/tensor_details/layout.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .layout_details.base import Layout
|
2 |
+
from .layout_details.blackwell_scale import BlackwellMXScaleLayout
|
3 |
+
from .layout_details.hopper_scale import HopperMXScaleLayout
|
4 |
+
from .layout_details.hopper_value import HopperMXValueLayout
|
5 |
+
from .layout_details.strided import StridedLayout
|
6 |
+
from ..target_info import cuda_capability_geq
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"Layout",
|
10 |
+
"BlackwellMXScaleLayout",
|
11 |
+
"HopperMXScaleLayout",
|
12 |
+
"HopperMXValueLayout",
|
13 |
+
"StridedLayout",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
def make_default_matmul_mxfp4_w_layout(mx_axis: int):
|
18 |
+
if cuda_capability_geq(10):
|
19 |
+
return StridedLayout, dict()
|
20 |
+
elif cuda_capability_geq(9):
|
21 |
+
return HopperMXValueLayout, {"mx_axis": mx_axis}
|
22 |
+
else:
|
23 |
+
return StridedLayout, dict()
|
24 |
+
|
25 |
+
|
26 |
+
def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8):
|
27 |
+
if cuda_capability_geq(10):
|
28 |
+
return BlackwellMXScaleLayout, dict()
|
29 |
+
elif cuda_capability_geq(9):
|
30 |
+
return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps}
|
31 |
+
else:
|
32 |
+
return StridedLayout, dict()
|
build/torch-universal/triton_kernels/tensor_details/layout_details/base.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class Layout(ABC):
|
5 |
+
|
6 |
+
def __init__(self, shape) -> None:
|
7 |
+
self.initial_shape = shape
|
8 |
+
|
9 |
+
@abstractmethod
|
10 |
+
def swizzle_data(self, data):
|
11 |
+
pass
|
12 |
+
|
13 |
+
@abstractmethod
|
14 |
+
def unswizzle_data(self, data):
|
15 |
+
pass
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def swizzle_block_shape(self, block_shape):
|
19 |
+
pass
|
build/torch-universal/triton_kernels/tensor_details/layout_details/blackwell_scale.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
import torch
|
5 |
+
from .base import Layout
|
6 |
+
|
7 |
+
SWIZZLE_ALIGN_INNER = 8
|
8 |
+
SWIZZLE_SIZE_INNER = 4
|
9 |
+
SWIZZLE_SIZE_OUTER = 128
|
10 |
+
|
11 |
+
|
12 |
+
class BlackwellMXScaleLayout(Layout):
|
13 |
+
name: str = "BLACKWELL_SCALE"
|
14 |
+
|
15 |
+
def __init__(self, shape) -> None:
|
16 |
+
super().__init__(shape)
|
17 |
+
*self.leading_shape, self.K, self.N, = shape
|
18 |
+
self.B = math.prod(self.leading_shape)
|
19 |
+
self.ALIGN_K = 8
|
20 |
+
self.ALIGN_N = 128
|
21 |
+
self.SWIZZLE_K = 4
|
22 |
+
self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K
|
23 |
+
self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N
|
24 |
+
|
25 |
+
def swizzle_data(self, data):
|
26 |
+
data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_pad - self.K))
|
27 |
+
data = data.transpose(-1, -2).contiguous()
|
28 |
+
data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K,
|
29 |
+
self.SWIZZLE_K)
|
30 |
+
data = data.transpose(2, 4).contiguous()
|
31 |
+
data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256)
|
32 |
+
return data
|
33 |
+
|
34 |
+
def unswizzle_data(self, data):
|
35 |
+
data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32,
|
36 |
+
self.SWIZZLE_K)
|
37 |
+
data = data.transpose(2, 4)
|
38 |
+
data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad)
|
39 |
+
data = data.transpose(-1, -2)
|
40 |
+
return data[..., :self.K, :self.N]
|
41 |
+
|
42 |
+
def swizzle_block_shape(self, block_shape):
|
43 |
+
MX_PACK_DIVISOR = 32
|
44 |
+
MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR
|
45 |
+
return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256]
|
46 |
+
|
47 |
+
|
48 |
+
@triton.jit
|
49 |
+
def unswizzle_mx_scale_bw(x, SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER,
|
50 |
+
SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER,
|
51 |
+
ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER):
|
52 |
+
shape_0: tl.constexpr = x.shape[0]
|
53 |
+
shape_1: tl.constexpr = x.shape[1]
|
54 |
+
tl.static_assert(shape_1 % SIZE_OUTER == 0)
|
55 |
+
tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER)
|
56 |
+
x = x.reshape(shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER)
|
57 |
+
x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER)
|
58 |
+
return x
|
build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_scale.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
from .base import Layout
|
5 |
+
|
6 |
+
|
7 |
+
class HopperMXScaleLayout(Layout):
|
8 |
+
name: str = "HOPPER_SCALE"
|
9 |
+
|
10 |
+
def __init__(self, shape, mx_axis, num_warps=8) -> None:
|
11 |
+
assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2"
|
12 |
+
super().__init__(shape)
|
13 |
+
self.mx_axis = mx_axis
|
14 |
+
self.num_warps = num_warps
|
15 |
+
*self.leading_shape, _, _ = shape
|
16 |
+
|
17 |
+
def _maybe_mT(self, data):
|
18 |
+
if self.mx_axis == len(self.leading_shape):
|
19 |
+
return data.contiguous().mT
|
20 |
+
return data
|
21 |
+
|
22 |
+
def swizzle_data(self, data):
|
23 |
+
data = self._maybe_mT(data).contiguous()
|
24 |
+
*batch, M, K = data.shape
|
25 |
+
SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8
|
26 |
+
SWIZZLE_ALIGN_K = 2
|
27 |
+
pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M
|
28 |
+
pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K
|
29 |
+
data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m))
|
30 |
+
*batch, M, K = data.shape
|
31 |
+
assert data.is_contiguous()
|
32 |
+
assert M % (
|
33 |
+
2 * self.num_warps * 2 *
|
34 |
+
8) == 0 and K % 2 == 0, f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)"
|
35 |
+
b = len(batch)
|
36 |
+
data = data.reshape(*batch, M // (2 * self.num_warps * 2 * 8), 2, self.num_warps, 2, 8, K // 2, 2)
|
37 |
+
perm = [0, 2, 5, 1, 4, 6, 3]
|
38 |
+
perm = list(range(b)) + [b + p for p in perm]
|
39 |
+
data = data.permute(*perm)
|
40 |
+
data = data.flatten(-5, -1)
|
41 |
+
data = data.flatten(-3, -2)
|
42 |
+
assert data.shape[-2] == M // 32
|
43 |
+
assert data.shape[-1] == K * 32
|
44 |
+
data = self._maybe_mT(data)
|
45 |
+
return data
|
46 |
+
|
47 |
+
def unswizzle_data(self, data):
|
48 |
+
data = self._maybe_mT(data)
|
49 |
+
*batch, M, K = data.shape
|
50 |
+
b = len(batch)
|
51 |
+
data = data.reshape(*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2)
|
52 |
+
perm = [0, 3, 1, 6, 4, 2, 5]
|
53 |
+
perm = list(range(b)) + [b + p for p in perm]
|
54 |
+
data = data.permute(*perm)
|
55 |
+
data = data.reshape(*batch, M * 32, K // 32)
|
56 |
+
data = self._maybe_mT(data)
|
57 |
+
return data
|
58 |
+
|
59 |
+
def swizzle_block_shape(self, block_shape):
|
60 |
+
return block_shape
|
61 |
+
|
62 |
+
|
63 |
+
@triton.jit
|
64 |
+
def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr):
|
65 |
+
"""
|
66 |
+
Triton inverse of swizzle_mxfp4_scale_hopper
|
67 |
+
"""
|
68 |
+
tl.static_assert(len(x.shape) == 2, "NYI")
|
69 |
+
# implementation assumes mxfp data is packed along the last dimension
|
70 |
+
x = x.trans() if mx_axis == 0 else x
|
71 |
+
M: tl.constexpr = x.shape[0]
|
72 |
+
K: tl.constexpr = x.shape[1]
|
73 |
+
tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}")
|
74 |
+
tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}")
|
75 |
+
x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2)
|
76 |
+
x = x.trans(0, 3, 1, 6, 4, 2, 5)
|
77 |
+
x = x.reshape(M * 32, K // 32)
|
78 |
+
# implementation assumed mxfp data is packed along the last dimension
|
79 |
+
x = x.trans() if mx_axis == 0 else x
|
80 |
+
return x
|
build/torch-universal/triton_kernels/tensor_details/layout_details/hopper_value.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
import triton.language as tl
|
4 |
+
from .base import Layout
|
5 |
+
|
6 |
+
|
7 |
+
def right_shift_unsigned(x, shift):
|
8 |
+
return (x >> shift) & ((1 << (32 - shift)) - 1)
|
9 |
+
|
10 |
+
|
11 |
+
# -----------------------------------------------------------------------
|
12 |
+
# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as:
|
13 |
+
# 1000000111000000 (first fp4)
|
14 |
+
# 1000000111000000 (second fp4)
|
15 |
+
# 1000000111000000 (third fp4)
|
16 |
+
# 0110110000000000 (fourth fp4)
|
17 |
+
# This is done so that dequantization can be done in 14 SASS instructions
|
18 |
+
# -----------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
def _compress_fp4(x):
|
22 |
+
x = x.to(torch.int32)
|
23 |
+
return ((x & 0x8) << 12) | ((x & 0x7) << 6)
|
24 |
+
|
25 |
+
|
26 |
+
def _compress_fourth(x):
|
27 |
+
x = x.to(torch.int32)
|
28 |
+
return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13)
|
29 |
+
|
30 |
+
|
31 |
+
def _pack_bits(x: torch.Tensor, mx_axis: int):
|
32 |
+
x = x.contiguous()
|
33 |
+
assert x.shape[-1] % 4 == 0, "Input tensor must have a last dimension divisible by 4"
|
34 |
+
x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4))
|
35 |
+
first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16)
|
36 |
+
second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16)
|
37 |
+
third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16)
|
38 |
+
fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16)
|
39 |
+
x = first | right_shift_unsigned(second, 3) | right_shift_unsigned(third, 6) | fourth
|
40 |
+
assert x.is_contiguous()
|
41 |
+
x = x.view(torch.uint8)
|
42 |
+
return x
|
43 |
+
|
44 |
+
|
45 |
+
# -----------------------------------------------------------------------
|
46 |
+
# inverse operation of _pack_bits
|
47 |
+
# -----------------------------------------------------------------------
|
48 |
+
|
49 |
+
|
50 |
+
def _bf16_to_fp4e2m1(x):
|
51 |
+
# 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8)
|
52 |
+
assert x.dtype == torch.int16
|
53 |
+
s = (right_shift_unsigned(x, 15) & 0x1) << 3
|
54 |
+
em = right_shift_unsigned(x, 6) & 0x7
|
55 |
+
return (s | em).to(torch.uint8)
|
56 |
+
|
57 |
+
|
58 |
+
def _bf16x2_to_fp4e2m1x2(x):
|
59 |
+
# 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8)
|
60 |
+
assert x.dtype == torch.int32
|
61 |
+
lo = (x & 0xFFFF).to(torch.int16)
|
62 |
+
hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16)
|
63 |
+
ret_lo = _bf16_to_fp4e2m1(lo)
|
64 |
+
ret_hi = _bf16_to_fp4e2m1(hi)
|
65 |
+
return ret_lo | (ret_hi << 4)
|
66 |
+
|
67 |
+
|
68 |
+
def _unpack_bits(x, mx_axis: int):
|
69 |
+
x = x.view(torch.int32)
|
70 |
+
m = 0b10000001110000001000000111000000
|
71 |
+
a = (x << 1) & 0b10000000000000001000000000000000
|
72 |
+
b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000
|
73 |
+
c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000
|
74 |
+
unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c]
|
75 |
+
x = torch.stack(unpacked, dim=-1)
|
76 |
+
x = x.flatten(-2, -1)
|
77 |
+
x = _bf16x2_to_fp4e2m1x2(x)
|
78 |
+
return x
|
79 |
+
|
80 |
+
|
81 |
+
# -----------------------------------------------------------------------
|
82 |
+
|
83 |
+
|
84 |
+
class HopperMXValueLayout(Layout):
|
85 |
+
name: str = "HOPPER_VALUE"
|
86 |
+
|
87 |
+
def __init__(self, shape, mx_axis, mma_version=3):
|
88 |
+
super().__init__(shape)
|
89 |
+
assert mx_axis in range(len(shape))
|
90 |
+
self.mx_axis = mx_axis
|
91 |
+
self.mma_version = mma_version
|
92 |
+
*self.leading_shape, self.K, self.N, = shape
|
93 |
+
|
94 |
+
def _maybe_mT(self, data):
|
95 |
+
if self.mx_axis == len(self.leading_shape):
|
96 |
+
return data.mT
|
97 |
+
return data
|
98 |
+
|
99 |
+
def swizzle_data(self, data):
|
100 |
+
"""
|
101 |
+
Given a uint8 tensor of shape (*, M, K), returns a tensor of shape
|
102 |
+
(*, M // 4, K * 4) such that:
|
103 |
+
|
104 |
+
1) Groups contiguously all the elements owned by the same thread of 4
|
105 |
+
mma tiles along the K axis. The following animation shows a similar
|
106 |
+
grouping for 2 tiles along M and 2 tiles along K rather than 4 along K
|
107 |
+
as done here:
|
108 |
+
https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif
|
109 |
+
|
110 |
+
2) Moves the elements belonging to thread 4-7 to be contiguous with those
|
111 |
+
from thread 0-3. This is done to get a full cache line when loading them
|
112 |
+
from HBM.
|
113 |
+
|
114 |
+
mx_axis selects the lhs or rhs of the matmul.
|
115 |
+
|
116 |
+
WARNING: Assumes that the matmul will be done in bf16 or fp16!
|
117 |
+
Implementing it for fp8 is as easy as making the tile size (8, 8)
|
118 |
+
"""
|
119 |
+
batch = data.ndim - 2
|
120 |
+
assert batch >= 0
|
121 |
+
assert self.mma_version in (2, 3)
|
122 |
+
data = self._maybe_mT(data)
|
123 |
+
init_shape = data.shape
|
124 |
+
|
125 |
+
# We are loading 8 bf16 elements per thread to use ld.global.v4
|
126 |
+
# Every u8 represents 2 mxfp4 elements
|
127 |
+
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
|
128 |
+
|
129 |
+
# Pack the 4 // u8_kwidth subtiles of an mma into a u4x8
|
130 |
+
contig = (1, u8_kwidth)
|
131 |
+
scott_trick = (2, 1)
|
132 |
+
threads = (4, 4)
|
133 |
+
warp_tile = (2, 2)
|
134 |
+
k_tile = (1, 4 // u8_kwidth)
|
135 |
+
|
136 |
+
sizes = list(data.shape[:-2])
|
137 |
+
pads = []
|
138 |
+
# [rest, K, tile, threads] per dimension
|
139 |
+
for i, (a, b, c, s, d) in enumerate(zip(k_tile, warp_tile, threads, scott_trick, contig)):
|
140 |
+
pack = a * b * c * s * d
|
141 |
+
size = data.shape[batch + i]
|
142 |
+
pad = (pack - size % pack) % pack
|
143 |
+
pads += [(0, pad)]
|
144 |
+
sizes.append((size + pad) // pack)
|
145 |
+
sizes += [a, b, c, s, d]
|
146 |
+
|
147 |
+
pads = tuple(x for t in pads[::-1] for x in t)
|
148 |
+
data = torch.nn.functional.pad(data, pads)
|
149 |
+
init_shape = data.shape
|
150 |
+
# 0: rest[0]
|
151 |
+
# 1: k_tile[0]
|
152 |
+
# 2: warp_tile[0]
|
153 |
+
# 3: threads[0]
|
154 |
+
# 4: scott_trick[0]
|
155 |
+
# 5: contig[0]
|
156 |
+
# 6: rest[1]
|
157 |
+
# 7: k_tile[1]
|
158 |
+
# 8: warp_tile[1]
|
159 |
+
# 9: threads[1]
|
160 |
+
# 10: scott_trick[1]
|
161 |
+
# 11: contig[1]
|
162 |
+
data = data.view(*sizes)
|
163 |
+
# Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]]
|
164 |
+
perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11]
|
165 |
+
perm = list(range(batch)) + [batch + p for p in perm]
|
166 |
+
data = data.permute(*perm).contiguous()
|
167 |
+
# These are views
|
168 |
+
data = data.flatten(-10, -1)
|
169 |
+
data = data.flatten(-3, -2)
|
170 |
+
assert data.is_contiguous()
|
171 |
+
assert data.shape[-2] == init_shape[-2] // 4
|
172 |
+
assert data.shape[-1] == init_shape[-1] * 4
|
173 |
+
# twiddle the bits
|
174 |
+
data = _pack_bits(data, self.mx_axis)
|
175 |
+
data = self._maybe_mT(data)
|
176 |
+
return data
|
177 |
+
|
178 |
+
def unswizzle_data(self, data):
|
179 |
+
data = self._maybe_mT(data)
|
180 |
+
data = _unpack_bits(data, self.mx_axis)
|
181 |
+
*batch, M, K = data.shape
|
182 |
+
# We have two times the elements if we already upcasted to bfloat16
|
183 |
+
mult = 2 if data.dtype == torch.bfloat16 else 1
|
184 |
+
assert M % 4 == 0, "M must be divisible by 4"
|
185 |
+
assert K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}"
|
186 |
+
# We are loading 8 bf16 elements per thread to use ld.global.v4
|
187 |
+
# Every u8 represents 2 mxfp4 elements
|
188 |
+
u8_kwidth = 8 // 2 if self.mma_version == 2 else 1
|
189 |
+
data = data.reshape(*batch, M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult)
|
190 |
+
b = len(batch)
|
191 |
+
perm = [0, 6, 1, 3, 2, 5, 4, 7]
|
192 |
+
perm = list(range(b)) + [b + p for p in perm]
|
193 |
+
data = data.permute(*perm)
|
194 |
+
data = data.reshape(*batch, M * 4, K // 4)
|
195 |
+
data = self._maybe_mT(data)
|
196 |
+
return data[..., :self.K, :self.N]
|
197 |
+
|
198 |
+
def swizzle_block_shape(self, block_shape):
|
199 |
+
return block_shape
|
200 |
+
|
201 |
+
|
202 |
+
@triton.jit
|
203 |
+
def _unshuffle_triton(x, mma_version: tl.constexpr):
|
204 |
+
"""
|
205 |
+
Triton inverse of swizzle_mxfp4_value_hopper
|
206 |
+
"""
|
207 |
+
tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3")
|
208 |
+
# if mx_axis == 0:
|
209 |
+
# x = x.trans()
|
210 |
+
|
211 |
+
# We have two times the elements if we already upcasted to bfloat16
|
212 |
+
mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1
|
213 |
+
M: tl.constexpr = x.shape[0]
|
214 |
+
K: tl.constexpr = x.shape[1]
|
215 |
+
tl.static_assert(M % 4 == 0, "M must be divisible by 4")
|
216 |
+
tl.static_assert(K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}")
|
217 |
+
|
218 |
+
# We are loading 8 bf16 elements per thread to use ld.global.v4
|
219 |
+
# Every u8 represents 2 mxfp4 elements
|
220 |
+
u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1
|
221 |
+
x = x.reshape(M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult)
|
222 |
+
x = x.trans(0, 6, 1, 3, 2, 5, 4, 7)
|
223 |
+
x = x.reshape(M * 4, K // 4)
|
224 |
+
# if mx_axis == 0:
|
225 |
+
# x = x.trans()
|
226 |
+
return x
|
227 |
+
|
228 |
+
|
229 |
+
@triton.jit
|
230 |
+
def _unpack_fp4_to_bf16_triton(x):
|
231 |
+
# For now we implement just H100 support (mul.bf16x2)
|
232 |
+
# A100 support is possible via fma
|
233 |
+
r0, r1 = tl.inline_asm_elementwise(
|
234 |
+
r"""
|
235 |
+
{
|
236 |
+
.reg .b32 b, c, d<7>, scale;
|
237 |
+
.reg .b32 bias;
|
238 |
+
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
|
239 |
+
// We add the missing bias to the scale directly
|
240 |
+
and.b32 $0, $4, 0b10000001110000001000000111000000;
|
241 |
+
mul.bf16x2 $0, $0, bias;
|
242 |
+
shl.b32 b, $4, 3;
|
243 |
+
and.b32 $1, b, 0b10000001110000001000000111000000;
|
244 |
+
mul.bf16x2 $1, $1, bias;
|
245 |
+
shl.b32 c, $4, 6;
|
246 |
+
and.b32 $2, c, 0b10000001110000001000000111000000;
|
247 |
+
mul.bf16x2 $2, $2, bias;
|
248 |
+
// Unpack last two elements
|
249 |
+
shl.b32 d0, $4, 1;
|
250 |
+
and.b32 d1, d0, 0b10000000000000001000000000000000;
|
251 |
+
shr.b32 d2, $4, 3;
|
252 |
+
and.b32 d3, d2, 0b00000001100000000000000110000000;
|
253 |
+
or.b32 d4, d1, d3;
|
254 |
+
shr.b32 d5, $4, 7;
|
255 |
+
and.b32 d6, d5, 0b00000000010000000000000001000000;
|
256 |
+
or.b32 $3, d4, d6;
|
257 |
+
mul.bf16x2 $3, $3, bias;
|
258 |
+
}
|
259 |
+
""",
|
260 |
+
constraints="=r,=r,=r,=r,r",
|
261 |
+
args=[x],
|
262 |
+
dtype=(tl.bfloat16, tl.bfloat16),
|
263 |
+
is_pure=True,
|
264 |
+
pack=4,
|
265 |
+
)
|
266 |
+
# Concat each pack of 4
|
267 |
+
x = tl.join(r0, r1)
|
268 |
+
x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2])
|
269 |
+
x = x.trans(0, 1, 3, 2)
|
270 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])
|
271 |
+
return x
|
272 |
+
|
273 |
+
|
274 |
+
@triton.jit
|
275 |
+
def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
|
276 |
+
"""
|
277 |
+
Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements):
|
278 |
+
(x << 0) & 0b1000000111000000
|
279 |
+
(x << 3) & 0b1000000111000000
|
280 |
+
(x << 6) & 0b1000000111000000
|
281 |
+
((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000)
|
282 |
+
"""
|
283 |
+
# upcast values to bfloat16
|
284 |
+
tl.static_assert(len(x.shape) == 2)
|
285 |
+
tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1")
|
286 |
+
tl.static_assert(x.shape[1] % 4 == 0)
|
287 |
+
tl.static_assert(x.dtype == tl.uint8)
|
288 |
+
if mx_axis == 0:
|
289 |
+
x = x.trans()
|
290 |
+
x = _unpack_fp4_to_bf16_triton(x)
|
291 |
+
x = _unshuffle_triton(x, mma_version=3)
|
292 |
+
if mx_axis == 0:
|
293 |
+
x = x.trans()
|
294 |
+
|
295 |
+
# upcast scale to bfloat16
|
296 |
+
# Add bias missing from the bf16 upcasting sequence
|
297 |
+
# triton / LLVM generates terrible code for this sequence
|
298 |
+
# scale = scale.to(tl.uint16)
|
299 |
+
# scale = scale << 7
|
300 |
+
# scale = scale.to(tl.bfloat16, bitcast=True)
|
301 |
+
scale = tl.inline_asm_elementwise(
|
302 |
+
r"""
|
303 |
+
{
|
304 |
+
prmt.b32 $0, $2, 0, 0x5140;
|
305 |
+
shl.b32 $0, $0, 7;
|
306 |
+
prmt.b32 $1, $2, 0, 0x7362;
|
307 |
+
shl.b32 $1, $1, 7;
|
308 |
+
}
|
309 |
+
""",
|
310 |
+
constraints="=r,=r,r",
|
311 |
+
args=[scale],
|
312 |
+
dtype=tl.bfloat16,
|
313 |
+
is_pure=True,
|
314 |
+
pack=4,
|
315 |
+
)
|
316 |
+
# Broadcast scale
|
317 |
+
scale = scale.expand_dims(mx_axis + 1)
|
318 |
+
scale = scale.broadcast_to(scale.shape[:mx_axis + 1] + [32] + scale.shape[mx_axis + 2:])
|
319 |
+
scale = scale.reshape(x.shape)
|
320 |
+
|
321 |
+
# Combine scale and x
|
322 |
+
x = x * scale
|
323 |
+
return x
|
build/torch-universal/triton_kernels/tensor_details/layout_details/strided.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import Layout
|
2 |
+
|
3 |
+
|
4 |
+
class StridedLayout(Layout):
|
5 |
+
name: str = None
|
6 |
+
|
7 |
+
def __init__(self, shape) -> None:
|
8 |
+
super().__init__(shape)
|
9 |
+
|
10 |
+
def swizzle_data(self, data):
|
11 |
+
return data
|
12 |
+
|
13 |
+
def unswizzle_data(self, data):
|
14 |
+
return data
|
15 |
+
|
16 |
+
def swizzle_block_shape(self, block_shape):
|
17 |
+
return block_shape
|
build/torch-universal/triton_kernels/testing.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
|
8 |
+
|
9 |
+
|
10 |
+
def assert_equal(ref, tri):
|
11 |
+
if isinstance(ref, torch.Tensor):
|
12 |
+
assert torch.all(ref == tri)
|
13 |
+
else:
|
14 |
+
assert ref == tri
|
15 |
+
|
16 |
+
|
17 |
+
def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True):
|
18 |
+
if tri.dtype.itemsize == 1:
|
19 |
+
ref_as_type = ref.to(tri.dtype)
|
20 |
+
if ref.dtype == tri.dtype:
|
21 |
+
assert torch.all(ref_as_type == tri)
|
22 |
+
return
|
23 |
+
ref = ref_as_type
|
24 |
+
|
25 |
+
if maxtol is None:
|
26 |
+
maxtol = 2e-2
|
27 |
+
if rmstol is None:
|
28 |
+
rmstol = 4e-3
|
29 |
+
"""
|
30 |
+
Compare reference values against obtained values.
|
31 |
+
"""
|
32 |
+
|
33 |
+
# cast to float32:
|
34 |
+
ref = ref.to(torch.float32).detach()
|
35 |
+
tri = tri.to(torch.float32).detach()
|
36 |
+
assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}"
|
37 |
+
|
38 |
+
# deal with infinite elements:
|
39 |
+
inf_mask_ref = torch.isinf(ref)
|
40 |
+
inf_mask_tri = torch.isinf(tri)
|
41 |
+
assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements"
|
42 |
+
refn = torch.where(inf_mask_ref, 0, ref)
|
43 |
+
trin = torch.where(inf_mask_tri, 0, tri)
|
44 |
+
|
45 |
+
# normalise so that RMS calculation doesn't overflow:
|
46 |
+
eps = 1.0e-30
|
47 |
+
multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps)
|
48 |
+
refn *= multiplier
|
49 |
+
trin *= multiplier
|
50 |
+
|
51 |
+
ref_rms = torch.sqrt(torch.square(refn).mean()) + eps
|
52 |
+
|
53 |
+
rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn))
|
54 |
+
max_err = torch.max(rel_err).item()
|
55 |
+
rms_err = torch.sqrt(torch.square(rel_err).mean()).item()
|
56 |
+
|
57 |
+
if verbose:
|
58 |
+
print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol))
|
59 |
+
print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol))
|
60 |
+
|
61 |
+
if max_err > maxtol:
|
62 |
+
bad_idxs = torch.nonzero(rel_err > maxtol)
|
63 |
+
num_nonzero = bad_idxs.size(0)
|
64 |
+
bad_idxs = bad_idxs[:1000]
|
65 |
+
print("%d / %d mismatched elements (shape = %s) at coords %s" %
|
66 |
+
(num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist()))
|
67 |
+
|
68 |
+
bad_idxs = bad_idxs.unbind(-1)
|
69 |
+
print("ref values: ", ref[tuple(bad_idxs)].cpu())
|
70 |
+
print("tri values: ", tri[tuple(bad_idxs)].cpu())
|
71 |
+
|
72 |
+
assert max_err <= maxtol
|
73 |
+
assert rms_err <= rmstol
|
74 |
+
|
75 |
+
|
76 |
+
class ComputeSanitizerTool(enum.Enum):
|
77 |
+
MEMCHECK = "memcheck"
|
78 |
+
RACECHECK = "racecheck"
|
79 |
+
SYNCCHECK = "synccheck"
|
80 |
+
INITCHECK = "initcheck"
|
81 |
+
|
82 |
+
|
83 |
+
def compute_sanitizer(**target_kwargs):
|
84 |
+
"""
|
85 |
+
Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled,
|
86 |
+
to expose potential memory access errors.
|
87 |
+
This decorator requires the `request` fixture to be present.
|
88 |
+
If `run_sanitizer` argument is present and set to False, the sanitizer is not run.
|
89 |
+
Running tests under compute sanitizer requires launching subprocess and is slow,
|
90 |
+
so use sparingly
|
91 |
+
"""
|
92 |
+
|
93 |
+
def decorator(test_fn):
|
94 |
+
|
95 |
+
@functools.wraps(test_fn)
|
96 |
+
def wrapper(*args, **kwargs):
|
97 |
+
if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1":
|
98 |
+
test_fn(*args, **kwargs)
|
99 |
+
return
|
100 |
+
|
101 |
+
import psutil
|
102 |
+
|
103 |
+
if target_kwargs.pop("clear_torch_cache", False):
|
104 |
+
# If we don't pop clear_torch_cache, it won't pass
|
105 |
+
# target_kwargs.items() <= kwargs.items() condition below.
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK])
|
108 |
+
assert isinstance(tools_to_check, list), f"{tools_to_check=}"
|
109 |
+
assert all(tool in ComputeSanitizerTool for tool in tools_to_check), (
|
110 |
+
f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}")
|
111 |
+
|
112 |
+
ppid_name = psutil.Process(os.getppid()).exe()
|
113 |
+
run_compute_sanitizer = target_kwargs.items() <= kwargs.items()
|
114 |
+
if "run_sanitizer" in kwargs:
|
115 |
+
run_compute_sanitizer &= kwargs["run_sanitizer"]
|
116 |
+
if run_compute_sanitizer and "compute-sanitizer" not in ppid_name:
|
117 |
+
for tool in tools_to_check:
|
118 |
+
path = os.path.realpath(test_fn.__globals__["__file__"])
|
119 |
+
# get path of current file
|
120 |
+
env = {
|
121 |
+
"PATH": os.environ["PATH"],
|
122 |
+
"PYTORCH_NO_CUDA_MEMORY_CACHING": "1",
|
123 |
+
"TORCH_SHOW_CPP_STACKTRACES": "1",
|
124 |
+
"CUDA_LAUNCH_BLOCKING": "1",
|
125 |
+
}
|
126 |
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
127 |
+
env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"]
|
128 |
+
assert "request_fixture" in kwargs, (
|
129 |
+
"memcheck'ed test must have a (possibly unused) `request` fixture")
|
130 |
+
test_id = kwargs["request_fixture"].node.callspec.id
|
131 |
+
cmd = f"{path}::{test_fn.__name__}[{test_id}]"
|
132 |
+
cmd = [
|
133 |
+
"compute-sanitizer",
|
134 |
+
"--target-processes=application-only",
|
135 |
+
"--destroy-on-device-error=context",
|
136 |
+
f"--tool={tool.value}",
|
137 |
+
sys.executable,
|
138 |
+
"-m",
|
139 |
+
"pytest",
|
140 |
+
"-vsx",
|
141 |
+
cmd,
|
142 |
+
]
|
143 |
+
for opt in ["--update_checksum", "--ignore_checksum_error"]:
|
144 |
+
if opt in sys.argv:
|
145 |
+
cmd.append(opt)
|
146 |
+
out = subprocess.run(
|
147 |
+
cmd,
|
148 |
+
stdout=subprocess.PIPE,
|
149 |
+
stderr=subprocess.STDOUT,
|
150 |
+
env=env,
|
151 |
+
)
|
152 |
+
sanitizer_ok = "ERROR SUMMARY: 0 errors" in str(
|
153 |
+
out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout)
|
154 |
+
test_output = out.stdout
|
155 |
+
if type(test_output) is bytes:
|
156 |
+
test_output = test_output.decode()
|
157 |
+
|
158 |
+
fail = False
|
159 |
+
if not sanitizer_ok:
|
160 |
+
print("compute-sanitizer returned an error")
|
161 |
+
fail = True
|
162 |
+
elif out.returncode != 0:
|
163 |
+
print(
|
164 |
+
"The test failed due to some other reason: consider running without compute-sanitizer to verify."
|
165 |
+
)
|
166 |
+
print(f"{out.returncode=}")
|
167 |
+
fail = True
|
168 |
+
|
169 |
+
if fail:
|
170 |
+
print("*****************************************************")
|
171 |
+
print("******************** TEST OUTPUT ********************")
|
172 |
+
print("*****************************************************")
|
173 |
+
print(test_output)
|
174 |
+
print("*****************************************************")
|
175 |
+
print("****************** TEST OUTPUT END ******************")
|
176 |
+
print("*****************************************************")
|
177 |
+
assert None
|
178 |
+
else:
|
179 |
+
test_fn(*args, **kwargs)
|
180 |
+
|
181 |
+
return wrapper
|
182 |
+
|
183 |
+
return decorator
|
184 |
+
|
185 |
+
|
186 |
+
def compute_actual_scale(x, dtype):
|
187 |
+
max_finite = {
|
188 |
+
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
|
189 |
+
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
|
190 |
+
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
|
191 |
+
}[dtype]
|
192 |
+
return x.abs().max() / max_finite
|
build/torch-universal/triton_kernels/topk.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import triton
|
3 |
+
from triton_kernels.topk_details._topk_forward import _topk_forward
|
4 |
+
from triton_kernels.topk_details._topk_backward import _topk_backward
|
5 |
+
from triton_kernels.tensor import Tensor, Bitmatrix
|
6 |
+
|
7 |
+
|
8 |
+
def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
|
9 |
+
if not isinstance(x, Tensor):
|
10 |
+
x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]]
|
11 |
+
x_shape_max = [x.shape[0], x.shape[1]]
|
12 |
+
x = Tensor(x, shape=x_shape, shape_max=x_shape_max)
|
13 |
+
cdiv = lambda a, b: (a + b - 1) // b
|
14 |
+
BLOCK_M = 32
|
15 |
+
BLOCK_N = 32
|
16 |
+
BLOCK_S = 128
|
17 |
+
assert len(x.shape) == 2
|
18 |
+
assert x.shape_max[-1] < 32768
|
19 |
+
assert dim == 1
|
20 |
+
assert return_bitmatrix
|
21 |
+
n_rows, n_cols = x.shape
|
22 |
+
n_rows_max, _ = x.shape_max
|
23 |
+
dev = x.device
|
24 |
+
# scratchpad tensors
|
25 |
+
# NOTE: these are not returned
|
26 |
+
y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev)
|
27 |
+
if y_indx is not None:
|
28 |
+
use_provided_indx = True
|
29 |
+
else:
|
30 |
+
y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev)
|
31 |
+
use_provided_indx = False
|
32 |
+
# create bitmatrix in transposed memory layout:
|
33 |
+
n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N
|
34 |
+
n_cols_words = n_cols_pad // 32
|
35 |
+
bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev)
|
36 |
+
bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max]
|
37 |
+
s_blocks = cdiv(n_cols, BLOCK_S)
|
38 |
+
s_cols = s_blocks * BLOCK_S
|
39 |
+
scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev)
|
40 |
+
pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks)
|
41 |
+
_topk_forward[(pids, )](
|
42 |
+
x, x.stride(0), # inputs
|
43 |
+
y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk]
|
44 |
+
bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix]
|
45 |
+
n_rows, n_cols, # shapes
|
46 |
+
scratchpad, BLOCK_S, s_blocks, # thing to memset to zero
|
47 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
|
48 |
+
APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
|
49 |
+
)
|
50 |
+
bitmatrix_shape = [n_rows, n_cols_words * 32]
|
51 |
+
bitmatrix_shape_max = [n_rows_max, None]
|
52 |
+
bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad)
|
53 |
+
return y_vals, y_indx, bitmatrix
|
54 |
+
|
55 |
+
|
56 |
+
def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax):
|
57 |
+
assert dy_vals.shape[-1] == k
|
58 |
+
n_expts_pad = triton.next_power_of_2(x.shape[-1])
|
59 |
+
dx = torch.empty_like(x)
|
60 |
+
_topk_backward[(dy_vals.shape[0], )](
|
61 |
+
y_indx, y_indx.stride(0), dy_vals, dy_vals.stride(0), x, x.stride(0), # inputs
|
62 |
+
dx, # outputs
|
63 |
+
dx.stride(0), x.shape[0], n_rows, x.shape[-1], APPLY_SOFTMAX=apply_softmax, N_EXPTS_ACT=k,
|
64 |
+
N_EXPTS_PAD=n_expts_pad)
|
65 |
+
return dx
|
66 |
+
|
67 |
+
|
68 |
+
class TopK(torch.autograd.Function):
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows):
|
72 |
+
y_vals, y_indx, bitmatrix = topk_forward(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
|
73 |
+
ctx.save_for_backward(x, y_indx)
|
74 |
+
ctx.apply_softmax = apply_softmax
|
75 |
+
ctx.k = k
|
76 |
+
ctx.n_rows = n_rows
|
77 |
+
return y_vals, y_indx, bitmatrix
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def backward(ctx, dy_vals, _0, _1):
|
81 |
+
x, y_indx = ctx.saved_tensors
|
82 |
+
dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax)
|
83 |
+
return dx, None, None, None, None, None, None
|
84 |
+
|
85 |
+
|
86 |
+
def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None):
|
87 |
+
ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows)
|
88 |
+
return ret
|
89 |
+
|
90 |
+
|
91 |
+
# x = torch.randn((32, 32), dtype=torch.float16, device="cuda")
|
92 |
+
# print(topk(x, 4))
|
build/torch-universal/triton_kernels/topk_details/__init__.py
ADDED
File without changes
|
build/torch-universal/triton_kernels/topk_details/_topk_backward.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
|
4 |
+
|
5 |
+
@triton.jit
|
6 |
+
def _topk_backward(
|
7 |
+
Yi,
|
8 |
+
stride_ym, # topk indices
|
9 |
+
DY,
|
10 |
+
stride_dym, # output gradient values
|
11 |
+
X,
|
12 |
+
stride_xm, # input values
|
13 |
+
DX,
|
14 |
+
stride_dxm, # input gradient values
|
15 |
+
n_rows,
|
16 |
+
NRows,
|
17 |
+
n_expts_tot,
|
18 |
+
APPLY_SOFTMAX: tl.constexpr,
|
19 |
+
N_EXPTS_ACT: tl.constexpr,
|
20 |
+
N_EXPTS_PAD: tl.constexpr,
|
21 |
+
):
|
22 |
+
pid_m = tl.program_id(0)
|
23 |
+
if NRows is not None:
|
24 |
+
n_rows = tl.load(NRows)
|
25 |
+
if pid_m >= n_rows:
|
26 |
+
return
|
27 |
+
Yi += pid_m * stride_ym
|
28 |
+
DY += pid_m * stride_dym
|
29 |
+
X += pid_m * stride_xm
|
30 |
+
DX += pid_m * stride_dxm
|
31 |
+
# --
|
32 |
+
offs_xn = tl.arange(0, N_EXPTS_PAD)
|
33 |
+
offs_yn = tl.arange(0, N_EXPTS_ACT)
|
34 |
+
mask_xn = offs_xn < n_expts_tot
|
35 |
+
# recompute softmax
|
36 |
+
y_indx = tl.load(Yi + offs_yn)
|
37 |
+
x = tl.load(X + y_indx)
|
38 |
+
x = x.to(tl.float32)
|
39 |
+
y = tl.softmax(x)
|
40 |
+
# compute input-gradient
|
41 |
+
dy = tl.load(DY + offs_yn)
|
42 |
+
dy = dy.to(tl.float32)
|
43 |
+
s = tl.sum(y * dy, 0)
|
44 |
+
# write-back input gradient
|
45 |
+
tl.store(DX + offs_xn, 0, mask=mask_xn)
|
46 |
+
tl.debug_barrier()
|
47 |
+
if APPLY_SOFTMAX:
|
48 |
+
dx = y * (dy - s)
|
49 |
+
else:
|
50 |
+
dx = dy
|
51 |
+
tl.store(DX + y_indx, dx)
|
build/torch-universal/triton_kernels/topk_details/_topk_forward.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import triton
|
2 |
+
import triton.language as tl
|
3 |
+
|
4 |
+
|
5 |
+
@triton.jit
|
6 |
+
def get_topmask_and_fullmask(x):
|
7 |
+
tl.static_assert(x.dtype.is_int_unsigned(), "floating-point value must be passed as bits")
|
8 |
+
tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth)
|
9 |
+
fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1
|
10 |
+
tm_arr = tl.full(x.shape, tm, dtype=x.dtype)
|
11 |
+
fm_arr = tl.full(x.shape, fm, dtype=x.dtype)
|
12 |
+
return tm_arr, fm_arr
|
13 |
+
|
14 |
+
|
15 |
+
@triton.jit
|
16 |
+
def fpval_to_key(x):
|
17 |
+
tm, fm = get_topmask_and_fullmask(x)
|
18 |
+
return x ^ tl.where((x & tm) != 0, fm, tm)
|
19 |
+
|
20 |
+
|
21 |
+
@triton.jit
|
22 |
+
def key_to_fpval(x):
|
23 |
+
tm, fm = get_topmask_and_fullmask(x)
|
24 |
+
return x ^ tl.where((x & tm) == 0, fm, tm)
|
25 |
+
|
26 |
+
|
27 |
+
# stable top-k tie-breaks to value with smaller index
|
28 |
+
@triton.jit
|
29 |
+
def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr):
|
30 |
+
return N_EXPTS_PAD - indx
|
31 |
+
|
32 |
+
|
33 |
+
@triton.jit
|
34 |
+
def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr):
|
35 |
+
return N_EXPTS_PAD - indx
|
36 |
+
|
37 |
+
|
38 |
+
@triton.jit
|
39 |
+
def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
|
40 |
+
BLOCK_N: tl.constexpr):
|
41 |
+
x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth
|
42 |
+
x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}")
|
43 |
+
if x_nbits < 16:
|
44 |
+
# this ensures that we leave at least 16 bits for expert index
|
45 |
+
# even if the input dtype is smaller than 16 bits:
|
46 |
+
y_nbits: tl.constexpr = 32
|
47 |
+
else:
|
48 |
+
y_nbits: tl.constexpr = x_nbits * 2
|
49 |
+
x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}")
|
50 |
+
x_dtype: tl.constexpr = X.dtype.element_ty
|
51 |
+
|
52 |
+
# subtract 1 from loop iterations because we peel the first (masked) iteration:
|
53 |
+
loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1
|
54 |
+
offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N)
|
55 |
+
mask_n = offs_x_n[None, :] < n_expts_tot
|
56 |
+
|
57 |
+
# first iteration:
|
58 |
+
X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :]
|
59 |
+
x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf"))
|
60 |
+
x = fpval_to_key(x.to(x_utype, bitcast=True))
|
61 |
+
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
|
62 |
+
acc = tl.topk(x, N_EXPTS_ACT, dim=1)
|
63 |
+
|
64 |
+
# subsequent iterations:
|
65 |
+
for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations):
|
66 |
+
acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge
|
67 |
+
X_ptrs -= BLOCK_N
|
68 |
+
offs_x_n -= BLOCK_N
|
69 |
+
x = tl.load(X_ptrs, mask=mask_m, other=float("-inf"))
|
70 |
+
x = fpval_to_key(x.to(x_utype, bitcast=True))
|
71 |
+
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
|
72 |
+
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
|
73 |
+
|
74 |
+
# rotate expert index into upper 16 bits:
|
75 |
+
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
|
76 |
+
acc = (acc << (y_nbits - 16)) | (acc >> 16)
|
77 |
+
# sort in ascending order of expert (descending order of key)
|
78 |
+
acc = tl.sort(acc, dim=1, descending=True)
|
79 |
+
# iiii0000vvvvvvvv --> 0000iiii:
|
80 |
+
y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
|
81 |
+
y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
|
82 |
+
# iiii0000vvvvvvvv --> vvvvvvvv:
|
83 |
+
y_values_raw = acc.to(x_utype)
|
84 |
+
y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
|
85 |
+
|
86 |
+
return y_values, y_indices
|
87 |
+
|
88 |
+
|
89 |
+
@triton.jit
|
90 |
+
def _topk_forward(X, stride_xm, # inputs
|
91 |
+
Yv, Yi, stride_ym, # topk values/indices
|
92 |
+
USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, # bitmatrix
|
93 |
+
n_rows, n_expts_tot, # shape
|
94 |
+
S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset
|
95 |
+
APPLY_SOFTMAX: tl.constexpr, # constant
|
96 |
+
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr):
|
97 |
+
|
98 |
+
pid = tl.program_id(0)
|
99 |
+
if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr():
|
100 |
+
n_rows = tl.load(n_rows)
|
101 |
+
|
102 |
+
if pid < s_blocks:
|
103 |
+
tl.store(S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32))
|
104 |
+
|
105 |
+
if pid * BLOCK_M >= n_rows:
|
106 |
+
# early exit:
|
107 |
+
return
|
108 |
+
|
109 |
+
tl.static_assert(BLOCK_N % 32 == 0)
|
110 |
+
tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0)
|
111 |
+
x_dtype: tl.constexpr = X.dtype.element_ty
|
112 |
+
|
113 |
+
# load logits
|
114 |
+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
|
115 |
+
offs_y_n = tl.arange(0, N_EXPTS_ACT)
|
116 |
+
mask_m = offs_m[:, None] < n_rows
|
117 |
+
if USE_PROVIDED_INDX:
|
118 |
+
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
|
119 |
+
y_indices = tl.load(Yi_ptrs, mask=mask_m)
|
120 |
+
Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices
|
121 |
+
y_values = tl.load(Xv_ptrs, mask=mask_m)
|
122 |
+
else:
|
123 |
+
y_values, y_indices = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, #
|
124 |
+
N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N)
|
125 |
+
|
126 |
+
# normalize selected values
|
127 |
+
if APPLY_SOFTMAX:
|
128 |
+
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
|
129 |
+
|
130 |
+
# write back
|
131 |
+
Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]
|
132 |
+
tl.store(Yv_ptrs, y_values, mask=mask_m)
|
133 |
+
if not USE_PROVIDED_INDX:
|
134 |
+
Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :]
|
135 |
+
tl.store(Yi_ptrs, y_indices, mask=mask_m)
|
136 |
+
|
137 |
+
# pack into bitmatrix
|
138 |
+
y_div = y_indices // 32
|
139 |
+
y_rem = y_indices % 32
|
140 |
+
loop_iterations = N_EXPTS_PAD // BLOCK_N
|
141 |
+
for i in range(loop_iterations):
|
142 |
+
offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32)
|
143 |
+
y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0)
|
144 |
+
r = tl.reduce_or(y2, axis=1)
|
145 |
+
BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn
|
146 |
+
tl.store(BitsPtrs, r, mask=mask_m)
|
flake.lock
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1747046372,
|
6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-compat_2": {
|
19 |
+
"locked": {
|
20 |
+
"lastModified": 1733328505,
|
21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
22 |
+
"owner": "edolstra",
|
23 |
+
"repo": "flake-compat",
|
24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
25 |
+
"type": "github"
|
26 |
+
},
|
27 |
+
"original": {
|
28 |
+
"owner": "edolstra",
|
29 |
+
"repo": "flake-compat",
|
30 |
+
"type": "github"
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"flake-utils": {
|
34 |
+
"inputs": {
|
35 |
+
"systems": "systems"
|
36 |
+
},
|
37 |
+
"locked": {
|
38 |
+
"lastModified": 1731533236,
|
39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
40 |
+
"owner": "numtide",
|
41 |
+
"repo": "flake-utils",
|
42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
43 |
+
"type": "github"
|
44 |
+
},
|
45 |
+
"original": {
|
46 |
+
"owner": "numtide",
|
47 |
+
"repo": "flake-utils",
|
48 |
+
"type": "github"
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"flake-utils_2": {
|
52 |
+
"inputs": {
|
53 |
+
"systems": "systems_2"
|
54 |
+
},
|
55 |
+
"locked": {
|
56 |
+
"lastModified": 1731533236,
|
57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
+
"owner": "numtide",
|
59 |
+
"repo": "flake-utils",
|
60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
+
"type": "github"
|
62 |
+
},
|
63 |
+
"original": {
|
64 |
+
"owner": "numtide",
|
65 |
+
"repo": "flake-utils",
|
66 |
+
"type": "github"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"hf-nix": {
|
70 |
+
"inputs": {
|
71 |
+
"flake-compat": "flake-compat_2",
|
72 |
+
"flake-utils": "flake-utils_2",
|
73 |
+
"nixpkgs": "nixpkgs"
|
74 |
+
},
|
75 |
+
"locked": {
|
76 |
+
"lastModified": 1751968576,
|
77 |
+
"narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
|
78 |
+
"owner": "huggingface",
|
79 |
+
"repo": "hf-nix",
|
80 |
+
"rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
|
81 |
+
"type": "github"
|
82 |
+
},
|
83 |
+
"original": {
|
84 |
+
"owner": "huggingface",
|
85 |
+
"repo": "hf-nix",
|
86 |
+
"type": "github"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"kernel-builder": {
|
90 |
+
"inputs": {
|
91 |
+
"flake-compat": "flake-compat",
|
92 |
+
"flake-utils": "flake-utils",
|
93 |
+
"hf-nix": "hf-nix",
|
94 |
+
"nixpkgs": [
|
95 |
+
"kernel-builder",
|
96 |
+
"hf-nix",
|
97 |
+
"nixpkgs"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1754384793,
|
102 |
+
"narHash": "sha256-pmsetd7e4HtJxduQlHOuQNwdnZYjUtcEDguZ0VGHhoE=",
|
103 |
+
"owner": "huggingface",
|
104 |
+
"repo": "kernel-builder",
|
105 |
+
"rev": "ea0b13b8a53e53c900181242840640e27f169484",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "huggingface",
|
110 |
+
"repo": "kernel-builder",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"nixpkgs": {
|
115 |
+
"locked": {
|
116 |
+
"lastModified": 1747820358,
|
117 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
118 |
+
"owner": "danieldk",
|
119 |
+
"repo": "nixpkgs",
|
120 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
121 |
+
"type": "github"
|
122 |
+
},
|
123 |
+
"original": {
|
124 |
+
"owner": "danieldk",
|
125 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
+
"repo": "nixpkgs",
|
127 |
+
"type": "github"
|
128 |
+
}
|
129 |
+
},
|
130 |
+
"root": {
|
131 |
+
"inputs": {
|
132 |
+
"kernel-builder": "kernel-builder"
|
133 |
+
}
|
134 |
+
},
|
135 |
+
"systems": {
|
136 |
+
"locked": {
|
137 |
+
"lastModified": 1681028828,
|
138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
139 |
+
"owner": "nix-systems",
|
140 |
+
"repo": "default",
|
141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
142 |
+
"type": "github"
|
143 |
+
},
|
144 |
+
"original": {
|
145 |
+
"owner": "nix-systems",
|
146 |
+
"repo": "default",
|
147 |
+
"type": "github"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"systems_2": {
|
151 |
+
"locked": {
|
152 |
+
"lastModified": 1681028828,
|
153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
154 |
+
"owner": "nix-systems",
|
155 |
+
"repo": "default",
|
156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
157 |
+
"type": "github"
|
158 |
+
},
|
159 |
+
"original": {
|
160 |
+
"owner": "nix-systems",
|
161 |
+
"repo": "default",
|
162 |
+
"type": "github"
|
163 |
+
}
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"root": "root",
|
167 |
+
"version": 7
|
168 |
+
}
|