diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..76392899799c6acf8ecfaef11ac605ef838a83b0 --- /dev/null +++ b/README.md @@ -0,0 +1,7 @@ +# triton-kernels + +triton-kernels is a set of kernels that enable fast moe on different architecture. These kernels are compatible with different precision (e.g bf16, mxfp4) + + +Original code here https://github.com/triton-lang/triton/tree/main/python/triton_kernels +The current version is the following commit 7d0efaa7231661299284a603512fce4fa255e62c \ No newline at end of file diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..611319c81d55eef6daa12d02a387dd7198fe7acf --- /dev/null +++ b/build.toml @@ -0,0 +1,3 @@ +[general] +name = "triton_kernels" +universal = true diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..2a464e2a9d532642d0e6df522226dce4c79b2640 --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for triton-kernels kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/__pycache__/__init__.cpython-310.pyc b/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14159be36e86c4e5d90ccbac0a437eb0c843400e Binary files /dev/null and b/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/tests/__pycache__/conftest.cpython-310-pytest-8.3.4.pyc b/tests/__pycache__/conftest.cpython-310-pytest-8.3.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdca73603d13bf271fc61135dc57ec979e1ec88b Binary files /dev/null and b/tests/__pycache__/conftest.cpython-310-pytest-8.3.4.pyc differ diff --git a/tests/__pycache__/test_mxfp.cpython-310-pytest-8.3.4.pyc b/tests/__pycache__/test_mxfp.cpython-310-pytest-8.3.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5bc6bc649d3a005d0dc9e1aa82bb503a8776761 Binary files /dev/null and b/tests/__pycache__/test_mxfp.cpython-310-pytest-8.3.4.pyc differ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8982884245200f2d7d5aa837b0846518bd53b3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,20 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption("--device", action="store", default="cuda") + + +@pytest.fixture +def device(request): + return request.config.getoption("--device") + + +@pytest.fixture +def fresh_knobs(monkeypatch): + from triton._internal_testing import _fresh_knobs_impl + fresh_function, reset_function = _fresh_knobs_impl(monkeypatch) + try: + yield fresh_function() + finally: + reset_function() diff --git a/tests/test_compaction.py b/tests/test_compaction.py new file mode 100644 index 0000000000000000000000000000000000000000..4e6c31e4278f1be76e27fda8444f6663f2f84ce5 --- /dev/null +++ b/tests/test_compaction.py @@ -0,0 +1,28 @@ +import pytest +import torch +from triton_kernels.compaction import compaction, compaction_torch + + +@pytest.mark.parametrize("n_tokens, n_cols, k, p", [ + (8192, 64, 4, 0.5), + (8192, 64, 4, 1.0), + (131, 128, 16, 0.6), + (496, 128, 16, 0.), +]) +def test_compaction(n_tokens, n_cols, k, p, device): + yi = torch.rand((n_tokens, n_cols), device=device).argsort(dim=-1) + yi = yi[:, :k].to(torch.int32) + yv = torch.randn((n_tokens, k), dtype=torch.bfloat16, device=device) + # "drop" indices from yi with probability `p` + mask = torch.zeros((n_tokens, n_cols), dtype=torch.int32, device=device) + keep = (torch.rand(yi.shape, device=device) < p) + if keep.any(): + rows = torch.arange(yi.size(0), device=device).unsqueeze(1).expand_as(yi) + mask[rows[keep], yi[keep]] = 1 + chunks = mask.view(*mask.shape[:-1], -1, 32) + weights = (1 << torch.arange(32, dtype=torch.int32, device=device)) + bitmask = (chunks.int() * weights).sum(dim=-1) + yv_ref, yi_ref = compaction_torch(yv, yi, bitmask) + yv_tri, yi_tri = compaction(yv, yi, bitmask) + assert torch.all(yi_ref == yi_tri) + assert torch.all(yv_ref == yv_tri) diff --git a/tests/test_matmul.py b/tests/test_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..7b10f317b363d844adf004bf49925f5b935785aa --- /dev/null +++ b/tests/test_matmul.py @@ -0,0 +1,569 @@ +# isort: off +# fmt: off +from dataclasses import dataclass, fields, replace +import pytest +import torch +from typing import Union +import triton +# routing utilities +from triton_kernels.routing import routing +# matmul utilities +import triton_kernels.matmul_ogs_details.opt_flags as opt_flags +from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, FusedActivation, FnSpecs, FnName, Epilogue +from triton_kernels.matmul_ogs import matmul_ogs_set_idle_sms, matmul_ogs, matmul_ogs_torch +from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig +from triton_kernels.tensor import convert_layout, wrap_torch_tensor, FP4 +from triton_kernels.tensor_details import layout +# numerics utilities +from triton_kernels.numerics import InFlexData, OutFlexData +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp, dequantize_mxfp8_fn, downcast_to_mxfp_torch, upcast_from_mxfp_torch, MXFP_BLOCK_SIZE +# testing utilities +from triton_kernels.testing import assert_close, compute_actual_scale +# target-specific utilities +from triton_kernels.target_info import is_hip, is_hip_cdna3, is_cuda, is_hip_cdna4 + +# --------------- +# initialize data +# --------------- + + +def alloc_rand(shape, device, dtype, requires_grad=True): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + +def alloc_rand_like(x): + return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad) + + +def mask_indx(idx, n_expts_act): + idx.src_indx[idx.dst_indx[-n_expts_act:]] = -1 + idx.dst_indx[-n_expts_act:] = -1 + return idx + + +def init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, device="cuda"): + logits = torch.randn((m, n_expts_tot), dtype=torch.float16, device=device, requires_grad=True) + routing_data, gather_idx, scatter_idx = routing(logits, n_expts_act, simulated_ep=n_expt_shards) + routing_data.gate_scal = None + gather_idx = gather_idx if do_gather else None + scatter_idx = scatter_idx if do_scatter else None + return m, routing_data, gather_idx, scatter_idx + + +def init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, act_dtype, weight_dtype, + has_y_gammas, requires_grad=True, device="cuda"): + torch.manual_seed(0) + assert mode in {'batched', "plain", 'ragged'} + in_m = m * (n_expts_act if gindx is None else 1) + shape_x = (n_expts_tot, in_m, k) if mode == 'batched' else (in_m, k) + shape_batch = tuple() if mode == "plain" else (n_expts_tot // n_expt_shards, ) + x = alloc_rand(shape_x, device=device, dtype=act_dtype, requires_grad=requires_grad) + w = alloc_rand(shape_batch + (k, n), device=device, dtype=weight_dtype, requires_grad=requires_grad) + bias = alloc_rand(shape_batch + (n, ), device=device, dtype=torch.float32, requires_grad=requires_grad) + gs0 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) + gs1 = 2**torch.randint(-5, 0, (m * n_expts_act, ), device=device, dtype=torch.float32, requires_grad=requires_grad) + gs0 = gs0.detach().requires_grad_(requires_grad) + gs1 = gs1.detach().requires_grad_(requires_grad) + if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): + gs0 = None + gs1 = None + if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: + w = w.transpose(-1, -2).contiguous().transpose(-1, -2) + return x, w, bias, gs0, gs1 + + +# --------------- +# numerics stuff +# --------------- + + +def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, device="cuda"): + weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp + # flexpoint + make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) + + ([val0] + if n_expts_tot % 2 else []), dtype=torch.float32, device=device) + make_scalar = lambda val: torch.tensor([val], dtype=torch.float32, device=device) + in_flex_data = lambda scale, use_flex: InFlexData(dtype=out_dtype, scale=make_scalar(scale) + ) if use_flex else InFlexData() + in_flex_edata = lambda scale0, scale1, use_flex: InFlexData(dtype=weight_dtype, scale=make_tensor(scale0, scale1) + ) if use_flex else InFlexData() + out_flex_data = lambda scale, use_flex: OutFlexData(dtype=out_dtype, expected_scale=make_scalar( + scale), actual_scale=make_scalar(0), checksum_scale=make_scalar(0)) if use_flex else OutFlexData() + flex_ctx = FlexCtx( + lhs_data=in_flex_data(1.25, act_use_flexpoint), + rhs_data=in_flex_edata(1.50, 1.25, weight_use_flexpoint), + out_data=out_flex_data(4.00, act_use_flexpoint), + ) + return PrecisionConfig(flex_ctx=flex_ctx, acc_scale=2.0 if act_use_flexpoint or weight_use_flexpoint else 1.0, + out_dtype=out_dtype) + + +def apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_config): + flex_ctx = precision_config.flex_ctx + + def apply(x, scale): + if scale is None: + x = x.clone() + elif scale.numel() == 1: + x = x.float() * scale + else: + assert x.ndim == 3 + assert scale.numel() == x.shape[0] + x = x.float() * scale[:, None, None] + return x.detach().requires_grad_() + + return ( + apply(x_tri, flex_ctx.lhs_data.scale), + apply(w_tri, flex_ctx.rhs_data.scale), + apply(bias_tri, None), + None if gs0_tri is None else apply(gs0_tri, None), + None if gs1_tri is None else apply(gs1_tri, None), + ) + + +def dtype_str_to_torch(dtype_str: str) -> torch.dtype: + return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) + + +# Scope to ensure that the opt_flags_constraints are reset after the test +@pytest.fixture +def opt_flags_scope(request): + yield + opt_flags.reset_opt_flags_constraints() + + +# --------------- +# unit tests +# --------------- + + +@dataclass +class Case: + m: int + n: int + k: int + mode: str + act_dtype_str: str + weight_dtype_str: str + n_expts_tot: int = 1 + n_expts_act: int = 1 + n_expt_shards: int = 1 + split_k: int = 1 + hbm_swizzling: bool = False + epilogue_subtile: Union[int, None] = None + + +@pytest.mark.parametrize( + ", ".join(f.name for f in fields(Case)), + [ + tuple(getattr(case, f.name) for f in fields(Case)) for case in [ + # Non-mx types: + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=2), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, n_expt_shards=4), + Case(16, 256, 256, "ragged", "float16", "float16", 4, 1, n_expt_shards=2), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), + Case(16, 256, 256, "ragged", "float16", "float16", 128, 4, split_k=3), + Case(300, 400, 400, "batched", "float8_e5m2", "float8_e5m2", 5, 1), + Case(16, 256, 256, "batched", "float16", "float16", 5, 1), + Case(16, 256, 256, "ragged", "float16", "float16", 3, 1), + Case(256, 256, 256, "ragged", "float16", "float16", 4, 1), + Case(256, 256, 256, "ragged", "float16", "float16", 4, 1, split_k=3), + Case(300, 400, 400, "batched", "float16", "float16", 5, 1), + Case(300, 400, 400, "ragged", "float16", "float16"), + Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"), + Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2), + Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, split_k=2), + Case(1000, 400, 400, "ragged", "float16", "float16", 3, 1), + Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2), + Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9), + # mx types: + Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1), + Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), + Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1), + Case(16, 256, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True), + Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), + Case(1000, 700, 700, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), + Case(1000, 700, 700, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9), + Case(1000, 512, 256, "ragged", "bfloat16", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), + Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4), + Case(300, 400, 400, "ragged", "bfloat16", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), + Case(300, 400, 400, "batched", "bfloat16", "mxfloat8_e5m2", 32, 4), + Case(1000, 700, 2, "batched", "bfloat16", "mxfloat4_e2m1", 8, 2), + Case(1, 2880, 2880, "ragged", "bfloat16", "mxfloat4_e2m1", 128, 4), + Case(16, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), + Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), + Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), + Case(1000, 704, 832, "batched", "float8_e5m2", "mxfloat4_e2m1", 3, 1), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2), + Case(1000, 704, 800, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), + Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4), + Case(300, 400, 400, "ragged", "float8_e5m2", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), + Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4), + Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), + Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4), + Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), + Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), + Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False), + Case(16, 256, 256, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True), + Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 3, 1, hbm_swizzling=True), + Case(1000, 704, 800, "batched", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 2, 1), + Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9), + Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, split_k=9, hbm_swizzling=True), + Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2), + Case(1000, 704, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 2, hbm_swizzling=True), + Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), + Case(300, 400, 400, "ragged", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4, hbm_swizzling=True), + Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4), + Case(300, 400, 800, "ragged", "mxfloat8_e4m3fn", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True), + Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4), + Case(300, 400, 400, "batched", "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True), + # AMD + Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"), + Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, n_expt_shards=2), + Case(600, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 4, 2, split_k=2), + Case(300, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn"), + Case(1000, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 3, 1), + Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2), + Case(600, 400, 400, "ragged", "float8_e4m3fn", "float8_e4m3fn", 4, 2, n_expt_shards=2), + ] + ], +) +@pytest.mark.parametrize("block_m", [16, 128]) +@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ + (False, False, False), + (True, False, False), + (False, True, False), + (True, True, False), + (True, True, True), +]) +@pytest.mark.parametrize("has_y_gammas", [False, True]) +@pytest.mark.parametrize("is_persistent", [False, True]) +def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot, + n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile, + device, opt_flags_scope, fresh_knobs): + # TODO: remove when Triton FP8 supports proper RTNE + if is_cuda(): + if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("Float8 not tested on A100") + if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10: + pytest.skip("float16 x mx not supported with cuda capability >= 10") + if weight_dtype_str.startswith("mx"): + if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: + pytest.skip("float8 x mx not supported with cuda capability < 10") + if act_dtype_str == "mxfloat8_e4m3fn": + if is_persistent: + pytest.skip("mx x mx not supported with persistent kernel") + if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("Not enough memory on A100") + + elif is_hip(): + if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4(): + pytest.skip("float8 x mx only supported on CDNA4") + if "float8" in act_dtype_str and "mxfloat8" in weight_dtype_str: + pytest.skip("NYI: float8 x mxfloat8 not tested on AMD GPU") + if act_dtype_str.startswith("mx") and weight_dtype_str.startswith("mx"): + pytest.skip("NYI: mx x mx not tested on AMD GPU") + if is_persistent: + pytest.skip("NYI: Persistent kernel not supported on AMD GPU") + if split_k > 1: + pytest.skip("splitK hasn't been fully tested on AMD GPU.") + + if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3(): + pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") + + if fused_scatter and split_k > 1: + pytest.skip("fused scatter scratchpad not supported with split_k") + if hbm_swizzling: + if is_hip(): + pytest.skip("NYI. HBM swizzling just implemented for CUDA.") + if torch.cuda.get_device_capability()[0] < 9: + pytest.skip("NYI. Ampere swizzling.") + if torch.cuda.get_device_capability()[0] < 10: + if "mxfloat4" not in weight_dtype_str: + pytest.skip("NYI. Hopper swizzling just implemented for mxfp4.") + if k % 64 != 0 or n % 64 != 0: + # Automatic padding not implemented for Hopper swizzle + pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).") + + # launch metadata for batched / mx types may not work yet. + test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str) + + torch.manual_seed(0) + + block_k = None + if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: + # Override block_k for testing correctness. The default is temporarily 128 for + # performance reasons which doesn't work with persistent matmul. + # TODO: revisit when Triton is better for H100 + MXFP4 + block_k = 256 + + constraints = { + "block_m": block_m, + "block_k": block_k, + "split_k": split_k, + "fused_scatter": fused_scatter, + "is_persistent": is_persistent, + "epilogue_subtile": epilogue_subtile, + } + opt_flags.update_opt_flags_constraints(constraints) + + weight_mxfp = weight_dtype_str.startswith("mx") + if weight_mxfp: + weight_dtype_str = weight_dtype_str[2:] + act_mxfp8 = act_dtype_str.startswith("mx") + act_is_float8 = act_dtype_str.startswith("float8") + if act_mxfp8: + act_dtype_str = act_dtype_str[2:] + dequantize_mxfp8_spec = FnSpecs( + FnName.DEQUANTIZE_MXFP8.name, dequantize_mxfp8_fn, (), () + ) + + test_bwd = False + weight_dtype = dtype_str_to_torch(weight_dtype_str) + act_dtype = dtype_str_to_torch(act_dtype_str) + precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp, n_expts_tot // n_expt_shards, device=device) + # precision_opt.x_pad_trans_requires_flexpoint = False + if mode == "ragged": + m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, + device=device) + else: + rdata = gindx = sindx = None + x_tri, w_tri, bias_tri, gs0_tri, gs1_tri = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, + n_expt_shards, mode, torch.bfloat16 if act_mxfp8 else act_dtype, # + torch.bfloat16 if weight_mxfp else weight_dtype, + has_y_gammas, requires_grad=test_bwd, device=device) + x_ref, w_ref, bias_ref, gs0_ref, gs1_ref = apply_precision(x_tri, w_tri, bias_tri, gs0_tri, gs1_tri, precision_opt) + + if w_tri.shape[0] == 1: + # Test the case when weight has dim 2, i.e., shape (K, N). + w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd) + w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd) + + if weight_mxfp: + mx_axis = w_tri.ndim - 2 + # compute layouts + w_layout, w_layout_opts = layout.StridedLayout, dict() + w_scale_layout, w_scale_layout_opts = layout.StridedLayout, dict() + if hbm_swizzling and "float4" in weight_dtype_str: + w_layout, w_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=mx_axis) + w_scale_layout, w_scale_layout_opts = layout.make_default_matmul_mxfp4_w_scale_layout( + mx_axis=mx_axis, num_warps=8) + # downcast to mxfp + w_tri, w_scale_tri = downcast_to_mxfp(w_tri, weight_dtype, axis=mx_axis) + w_ref = upcast_from_mxfp(w_tri, w_scale_tri, torch.bfloat16, axis=mx_axis) + w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype + w_tri = wrap_torch_tensor(w_tri, w_tri_dtype) + w_scale_tri = wrap_torch_tensor(w_scale_tri) + # convert layouts + w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) + w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) + precision_opt.weight_scale = w_scale_tri + epilogue = None + if act_mxfp8: + x_tri, x_mx_scales_tri = downcast_to_mxfp(x_tri, act_dtype, axis=-1) + x_ref = upcast_from_mxfp(x_tri, x_mx_scales_tri, torch.bfloat16, axis=-1) + is_input_batched = x_tri.ndim == 3 + y_shape = x_tri.shape if is_input_batched else (1,) + x_tri.shape + n_rows = y_shape[1] if gindx is None or mode == "batched" else gindx.dst_indx.shape[0] + y_shape = (y_shape[0], n_rows, w_tri.shape[-1]) + if sindx is None or mode == "batched": + if not is_input_batched: + y_shape = (y_shape[1], y_shape[2]) + else: + y_shape = (n_rows // rdata.n_expts_act, y_shape[-1]) + y_scale_shape = y_shape[:-1] + (triton.cdiv(y_shape[-1], MXFP_BLOCK_SIZE),) + y_scale = torch.empty(y_scale_shape, dtype=torch.uint8, device=x_tri.device) + precision_opt = replace(precision_opt, act_scale=x_mx_scales_tri, out_scale=y_scale) + epilogue = Epilogue(dequantize_mxfp8_spec, tuple(), tuple(), effective_itemsize=6.0) + else: + y_scale = None + + if test_launch_metadata: + + def _clobber(t, used_mask): + # Fill the unread part of the tensor with garbage, to be sure that + # we don't actually read from the part. + if len(used_mask) == 1: + return + elif t.element_size() == 1: + t.view(torch.int8)[~used_mask] = 127 + else: + t[~used_mask] = torch.inf + + if rdata is not None: + n_tokens = rdata.expt_hist.sum().item() + used_expts = (rdata.expt_hist > 0) + _clobber(w_tri, used_expts) + n_w_bytes = used_expts.sum().item() * n * k * w_tri.element_size() + else: + n_tokens = m + n_w_bytes = w_tri.numel() * w_tri.element_size() + + if gindx is not None: + used_x_rows = (gindx.dst_indx.view(-1, n_expts_act) != -1).any(dim=1) + _clobber(x_tri, used_x_rows) + n_x_bytes = used_x_rows.sum().item() * k * x_tri.element_size() + elif rdata is not None: + n_x_bytes = n_tokens * k * x_tri.element_size() + else: + n_x_bytes = x_tri.numel() * x_tri.element_size() + + nbytes = None + + def _hook(launch_metadata): + nonlocal nbytes + metadata = launch_metadata.get() + if "matmul_ogs" in metadata["name"]: + nbytes = metadata["bytes"] + + triton.knobs.runtime.launch_enter_hook = _hook + + if mode == "batched": + rdata, gindx, sindx = None, None, None + flex = precision_opt.flex_ctx + + # triton + try: + tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref, epilogue=epilogue) + except (opt_flags.InapplicableConstraint, NotImplementedError): + pytest.skip("inapplicable opt_flags constraint") + # If split_k > 1, then the intermediate tensor is fp32. + sep_gather = mode == "ragged" and do_gather and n_expts_act > 1 and split_k == 1 + sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1 + y_scale = flex.out_data.expected_scale if act_is_float8 else 1 + + if test_launch_metadata: + if gindx is not None: + n_y_bytes = (gindx.src_indx != -1).sum().item() * n * tri_y.element_size() + elif rdata is not None: + n_y_bytes = n_tokens * n * tri_y.element_size() + else: + n_y_bytes = tri_y.numel() * tri_y.element_size() + assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes + triton.knobs.runtime.launch_enter_hook = None + + def round_x(x, idx): + return x.to(act_dtype).to(torch.float32) if sep_gather else x + + round_y = lambda y: (y / y_scale).to(act_dtype).to(torch.float32) * y_scale if sep_scatter else y + ref_y = matmul_ogs_torch(x_ref, w_ref, bias_ref, # + rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref) + scale = lambda val, scal: val if scal is None else val / scal + if n_expt_shards > 1: + if do_scatter: + indx = sindx.dst_indx[sindx.dst_indx != -1] + ref_y = ref_y[indx // n_expts_act, :] + if act_is_float8: + tri_y = tri_y.view(torch.int8) + tri_y = tri_y[indx // n_expts_act, :] + if act_is_float8: + tri_y = tri_y.view(act_dtype) + else: + n_rows = rdata.expt_hist.sum() + assert n_rows > 0 + ref_y = ref_y[:n_rows] + tri_y = tri_y[:n_rows] + if act_mxfp8: + tri_y = upcast_from_mxfp(tri_y, precision_opt.out_scale, dtype=torch.bfloat16, axis=-1).to(ref_y.dtype) + ref_y_quant, ref_y_scale = downcast_to_mxfp_torch(ref_y, act_dtype, axis=-1) + ref_y = upcast_from_mxfp_torch(ref_y_quant, ref_y_scale, target_dtype=ref_y.dtype, axis=-1) + maxtol = 4e-1 + rmstol = 4e-2 + else: + maxtol = None + rmstol = None + assert_close(scale(ref_y, flex.out_data.expected_scale), tri_y, maxtol=maxtol, rmstol=rmstol) + + if act_is_float8: + tri_y_scale = flex.out_data.actual_scale.clone() + ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype) + assert (ref_y_scale - + tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}" + + +def test_set_idle_sms(): + if not is_cuda(): + pytest.skip("Only supported on CUDA") + from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags + num_idle_sms = 24 + matmul_ogs_set_idle_sms(num_idle_sms) + flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \ + 1024, 1024, 1024, None, True, False, 1) + assert flags.idle_sms == num_idle_sms + + +@pytest.mark.parametrize("m, n, k, mode", [ + (1200, 704, 608, "ragged"), + (800, 800, 400, "batched"), +]) +@pytest.mark.parametrize("split_k", [1, 2]) +@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [ + (False, False, False), + (True, False, False), + (False, True, False), + (True, True, False), + (True, True, True), +]) +@pytest.mark.parametrize("is_persistent, epilogue_subtile", [ + (False, None), + (True, 1), + (True, 4), +]) +@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [ + (1.1, 1.4), + (1.0, 1.2), + (0.7, 1.0), +]) +def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile, + swiglu_alpha, swiglu_limit, device, opt_flags_scope): + if fused_scatter and split_k > 1: + pytest.skip("fused scatter scratchpad not supported with split_k") + torch.manual_seed(0) + constraints = { + "is_persistent": is_persistent, + "epilogue_subtile": epilogue_subtile, + "fused_scatter": fused_scatter, + "split_k": split_k, + } + n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1 + opt_flags.update_opt_flags_constraints(constraints) + + weight_dtype, act_dtype = torch.float16, torch.float16 + if mode == "ragged": + m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter, + device=device) + else: + rdata = gindx = sindx = None + + precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot // n_expt_shards, device=device) + x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode, + act_dtype, weight_dtype, False, requires_grad=False, device=device) + + if mode == "batched": + rdata, gindx, sindx = None, None, None + + try: + a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha, + precision_config=SwiGLUPrecisionConfig(swiglu_limit)) + b = matmul_ogs( + x, w, bias, rdata, gindx, sindx, precision_opt, + fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), + (swiglu_alpha, swiglu_limit), 2)) + except opt_flags.InapplicableConstraint: + pytest.skip("inapplicable constraint") + assert_close(a, b) diff --git a/tests/test_mxfp.py b/tests/test_mxfp.py new file mode 100644 index 0000000000000000000000000000000000000000..7389e5426112b6b3a39bd0fd0a3bebbf900848f9 --- /dev/null +++ b/tests/test_mxfp.py @@ -0,0 +1,113 @@ +import pytest +import torch + +from triton_kernels.numerics_details.mxfp import ( + DequantScaleRoundingMode, + downcast_to_mxfp, + downcast_to_mxfp_torch, + get_max_quant_val, + upcast_from_mxfp, + upcast_from_mxfp_torch, +) +from triton_kernels.testing import assert_close, assert_equal + + +def dtype_str_to_torch(dtype_str: str) -> torch.dtype: + return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str) + + +@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16"]) +def test_mxfp4_rounding_cases(dst_dtype): + dst_dtype = dtype_str_to_torch(dst_dtype) + x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3]).cuda().bfloat16().view(1, -1, 1) + quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1) + dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1) + assert dequant.flatten().tolist() == [6, 0, 0, 0.5, 1.0, 1.0, 1.0, 1.5], f"{dequant=}" + + quant_torch, scale_torch = downcast_to_mxfp_torch(x, torch.uint8, axis=1) + assert_equal(quant_torch, quant) + assert_equal(scale_torch, scale) + + dequant_torch = upcast_from_mxfp_torch(quant_torch, scale_torch, dst_dtype, axis=1) + assert_equal(dequant_torch, dequant) + + +@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"]) +@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16"]) +def test_mxfp_quant_dequant(src_dtype, dst_dtype): + if "float8" in src_dtype and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("Float8 not tested on A100") + limit_range = src_dtype == "float8_e5m2" and dst_dtype == "float16" + + # This test checks that quantization and dequantization kernels produce the exact values for some inputs + # that can be represented exactly in the quantized format. + src_dtype = dtype_str_to_torch(src_dtype) + dst_dtype = dtype_str_to_torch(dst_dtype) + max_val = get_max_quant_val(src_dtype) + if limit_range: + # FP16 can't represent the full range of MXFP8, so we limit the max value here + max_val = 128 + + # These are all the valid mxfp4 positive values. + pos_vals = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, max_val], device="cuda", dtype=dst_dtype) + neg_vals = -pos_vals + k_dim = torch.cat([pos_vals, neg_vals]) + k_dim = k_dim.reshape([k_dim.shape[0], 1]) + + # We pick power of 2 scales since both the scales and their inverse only require exponent bits to be exactly + # represented. This means we can store the scales exactly in the e8m0 format. + powers = torch.arange(-8, 8, device="cuda", dtype=dst_dtype) + scales = 2**powers + scales = scales.reshape([1, powers.shape[0]]) + weight = k_dim * scales + weight = weight.repeat((9, 32)) # Repeat the dimensions to test multi block launches. + weight = weight.reshape([1, weight.shape[0], weight.shape[1]]) + weight = weight.mT.contiguous().mT + quant, scale = downcast_to_mxfp(weight, src_dtype, axis=1) + dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1) + assert_equal(weight, dequant) + + +# fmt: off +@pytest.mark.parametrize( + "shape, axis, quant_dtype, rounding_mode", + [ + ((3, 4096, 1024), 1, "float4_e2m1", DequantScaleRoundingMode.ROUND_UP), + ((10, 254, 60), 0, "float4_e2m1", DequantScaleRoundingMode.ROUND_DOWN), + ((1, 320, 160), 2, "float8_e5m2", DequantScaleRoundingMode.ROUND_UP), + ((2, 16, 512), -1, "float8_e4m3fn", DequantScaleRoundingMode.ROUND_DOWN), + ], +) +# fmt: on +@pytest.mark.parametrize("dequant_dtype", ["float16", "bfloat16"]) +def test_mxfp_casting( + shape: tuple[int, ...], + axis: int, + quant_dtype: str, + dequant_dtype: str, + rounding_mode: DequantScaleRoundingMode, +): + if "float8" in quant_dtype and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("Float8 not tested on A100") + quant_torch_type = dtype_str_to_torch(quant_dtype) + dequant_torch_type = dtype_str_to_torch(dequant_dtype) + # Generate random input tensor that is contiguous once axis is the last dimension + x = torch.randn(shape, device="cuda", dtype=dequant_torch_type) + + # Quantize and check equivalence + quant, scale = downcast_to_mxfp(x, quant_torch_type, axis, DEQUANT_SCALE_ROUNDING_MODE=rounding_mode) + quant_torch, scale_torch = downcast_to_mxfp_torch(x, quant_torch_type, axis, + DEQUANT_SCALE_ROUNDING_MODE=rounding_mode) + + assert_equal(quant_torch, quant) + assert_equal(scale_torch, scale) + assert_equal(1, quant.stride(axis)) + assert_equal(1, quant_torch.stride(axis)) + + # Dequantize and check equivalence + dequant = upcast_from_mxfp(quant, scale, dequant_torch_type, axis) + dequant_torch = upcast_from_mxfp_torch(quant_torch, scale_torch, dequant_torch_type, axis) + assert_equal(dequant, dequant_torch) + + # Dequantized result should be close to the original, though tolerance is large due to the precision loss. + assert_close(x, dequant, maxtol=0.5, rmstol=0.15) diff --git a/tests/test_routing.py b/tests/test_routing.py new file mode 100644 index 0000000000000000000000000000000000000000..60bb35d26c4d0a7fcac139d2ab7f2203cd55317f --- /dev/null +++ b/tests/test_routing.py @@ -0,0 +1,97 @@ +import pytest +import torch +from triton_kernels.routing import routing, routing_torch +from triton_kernels.testing import assert_close +from triton_kernels.testing import assert_equal + + +def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): + logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device, requires_grad=True) + return logits + + +n_tokens = [(x, None) for x in [371, 255, 256, 4096, 1023, 1024]] +n_tokens += [(1152, 911)] + + +@pytest.mark.parametrize("n_tokens_pad, n_tokens_raw", n_tokens) +@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 32), (1500, 8)]) +@pytest.mark.parametrize("use_expt_indx", [False, True]) +@pytest.mark.parametrize("sm_first", [True, False]) +def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_expt_indx, device): + torch.manual_seed(2) + if n_tokens_raw is None: + n_tokens_raw = n_tokens_pad + n_routing_rows = None + else: + n_routing_rows = torch.tensor([n_tokens_raw], dtype=torch.int32, device=device) + n_gates_raw = n_tokens_raw * n_expts_act + tri_logits = init_data(n_tokens_pad, n_expts_tot, device=device, dtype=torch.float32).detach() + tri_logits[n_tokens_raw:, :] = float("inf") # should not be used + tri_logits = tri_logits.requires_grad_(True) + ref_logits = tri_logits.clone().detach().requires_grad_(True) + + if use_expt_indx: + rand_idx = lambda: torch.randperm(n_expts_tot, device="cuda", dtype=torch.int64) + tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens_pad)]) + tri_expt_indx, _ = torch.sort(tri_expt_indx, dim=1) + tri_expt_indx[n_tokens_raw:] = -99999 # should not be used + ref_expt_indx = tri_expt_indx[:n_tokens_raw] + else: + tri_expt_indx = ref_expt_indx = None + ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, sm_first, ref_expt_indx, + n_rows=n_routing_rows) + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, sm_first, tri_expt_indx, + n_rows=n_routing_rows) + + def _assert_indx_equal(ref, tri): + assert_equal(ref, tri[:len(ref)]) + assert torch.all(tri[len(ref):] == -1) + + assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal[:n_gates_raw], 2e-2, 4e-3) + assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist) + + ref_expt_data = ref_routing_data.expt_data + tri_expt_data = tri_routing_data.expt_data + assert_equal(ref_expt_data.hist, tri_expt_data.hist) + assert_equal(ref_expt_data.token_offs_raw, tri_expt_data.token_offs_raw) + assert len(ref_expt_data.token_offs_pad) == len(tri_expt_data.token_offs_pad) + assert len(ref_expt_data.block_pid_map) == len(tri_expt_data.block_pid_map) + for block_m in ref_expt_data.token_offs_pad.keys(): + assert_equal(ref_expt_data.token_offs_pad[block_m], tri_expt_data.token_offs_pad[block_m]) + assert_equal(ref_expt_data.block_pid_map[block_m], tri_expt_data.block_pid_map[block_m]) + + assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot + assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act + + _assert_indx_equal(ref_gather.src_indx, tri_gather.src_indx) + _assert_indx_equal(ref_gather.dst_indx, tri_gather.dst_indx) + _assert_indx_equal(ref_scatter.src_indx, tri_scatter.src_indx) + _assert_indx_equal(ref_scatter.dst_indx, tri_scatter.dst_indx) + + scales_grad = torch.randn_like(tri_routing_data.gate_scal) + ref_routing_data.gate_scal.backward(scales_grad[:n_gates_raw]) + tri_routing_data.gate_scal.backward(scales_grad) + + assert_close(ref_logits.grad[:n_tokens_raw], tri_logits.grad[:n_tokens_raw]) + + +def bench_routing(): + import triton.profiler as proton + n_tokens = 8192 + n_expts_tot, n_expts_act = 128, 4 + tri_logits = init_data(n_tokens, n_expts_tot) + proton.start("routing") + proton.activate() + for i in range(100): + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) + proton.finalize() + try: + import os + os.system("proton-viewer -m time/ms routing.hatchet") + except Exception: + pass + + +if __name__ == "__main__": + bench_routing() diff --git a/tests/test_specialize.py b/tests/test_specialize.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc10b7db1f880e57af8f753b63271f5c0866c04 --- /dev/null +++ b/tests/test_specialize.py @@ -0,0 +1,84 @@ +import torch +import importlib +from triton_kernels.specialize import cacheable, specialize +import triton +import triton.language as tl + + +@triton.jit +def template_kernel(o): + cst = 1.0 + tl.store(o, cst) + + +def retrieve_fn(module, name): + module = importlib.import_module(module) + fn = getattr(module, name) + return fn + + +_specialized_kernel = None + + +def get_specialized_kernel(): + global _specialized_kernel + if _specialized_kernel is not None: + return _specialized_kernel + import types + spec_constants = {} + spec_tuples = {} + module = types.ModuleType("specialized_kernel") + module.specialized = specialize(template_kernel, module, spec_constants, spec_tuples) + _specialized_kernel = module.specialized + return _specialized_kernel + + +@cacheable +def cacheable_kernel(): + return get_specialized_kernel() + + +def test_cacheable(device, fresh_knobs): + specialized_kernel = get_specialized_kernel() + + specialization_data = None + fn_name = None + module_name = None + + def cache_hook(*args, **kwargs): + nonlocal specialization_data + nonlocal fn_name + nonlocal module_name + specialization_data = kwargs["compile"]["specialization_data"] + fn_name = kwargs["fn"].name + module_name = kwargs["fn"].module + + triton.knobs.runtime.jit_cache_hook = cache_hook + o = torch.empty((1, ), dtype=torch.float32, device=device) + k = specialized_kernel[(1, )](o, ) + hash = k.hash + assert o.item() == 1.0 + assert module_name == "tests.test_specialize" + assert fn_name == "cacheable_kernel" + + compile_count = 0 + + def count_hook(*args, **kwargs): + nonlocal compile_count + compile_count += 1 + + triton.knobs.runtime.jit_cache_hook = count_hook + # clear the cache + specialized_kernel.device_caches.clear() + + # retrieve the kernel from name and preload it. + fn = retrieve_fn(module_name, fn_name) + assert fn == specialized_kernel + preload = fn.preload(specialization_data) + assert compile_count == 1 + assert preload.hash == hash + + # verify that we hit the cache. + compile_count = 0 + specialized_kernel[(1, )](o, ) + assert compile_count == 0 diff --git a/tests/test_swiglu.py b/tests/test_swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ab353774548954f09c9e8e7aef373755fa362a --- /dev/null +++ b/tests/test_swiglu.py @@ -0,0 +1,42 @@ +from triton_kernels.routing import routing_torch +from triton_kernels.swiglu import swiglu, swiglu_torch, PrecisionConfig +from triton_kernels.testing import assert_close +import torch +import pytest + +from .test_routing import init_data as init_routing_data + +# --------------- +# initialize data +# --------------- + + +def alloc_rand(shape, device, dtype, requires_grad=True): + if dtype.itemsize == 1: + tmp = 2**-(torch.randint(4, 8, shape, device=device, dtype=torch.float16)) + return tmp.to(dtype).requires_grad_(requires_grad) + return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad) + + +# --------------- +# unit tests +# --------------- + + +@pytest.mark.parametrize("M, N", [(1311, 4352)]) +@pytest.mark.parametrize("limit", [1e-2, 10]) +def test_op(M, N, limit, device, alpha=0.5): + torch.manual_seed(2) + # initialize expert data + n_expts_tot = 6 + n_expts_act = 2 + logits = init_routing_data(M, n_expts_tot).detach() + routing_data, _, _ = routing_torch(logits, n_expts_act) + n_tokens = routing_data.expt_hist.sum() + + # initialize data + x = alloc_rand([n_tokens, N], device=device, dtype=torch.bfloat16) + precision_config = PrecisionConfig(limit=limit) + tri_y = swiglu(x, alpha, precision_config, routing_data) + ref_y = swiglu_torch(x, alpha, precision_config) + assert_close(tri_y, ref_y) diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9d0f9d2eaa6853759cd6e91e6189fb139dd995 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1 @@ +# TODO: add tests for non-layout parts of tensor class diff --git a/tests/test_tensor_details/test_layout_blackwell.py b/tests/test_tensor_details/test_layout_blackwell.py new file mode 100644 index 0000000000000000000000000000000000000000..084b7005c01dd7c55c475419a28bc293b9a41f53 --- /dev/null +++ b/tests/test_tensor_details/test_layout_blackwell.py @@ -0,0 +1,24 @@ +import pytest +import torch +from triton_kernels.tensor_details.layout import BlackwellMXScaleLayout + +# ------------------------------------------------------------ +# Torch tests +# ------------------------------------------------------------ + + +@pytest.mark.parametrize( + "shape", + [ + (3, 4096, 1024), + (10, 254, 60), + (1, 320, 160), + (2, 16, 512), + (3, 2, 36), + ], +) +def test_mxfp4_scale_roundtrip(shape): + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") + layout = BlackwellMXScaleLayout(x.shape) + res = layout.unswizzle_data(layout.swizzle_data(x)) + assert (res == x).all() diff --git a/tests/test_tensor_details/test_layout_hopper.py b/tests/test_tensor_details/test_layout_hopper.py new file mode 100644 index 0000000000000000000000000000000000000000..e9edac5fbf11fc69634ebe4c0b0e21e30d812689 --- /dev/null +++ b/tests/test_tensor_details/test_layout_hopper.py @@ -0,0 +1,99 @@ +import pytest +from triton._internal_testing import is_cuda +from triton_kernels.tensor import wrap_torch_tensor, convert_layout, FP4 +from triton_kernels.tensor_details.layout import HopperMXScaleLayout, HopperMXValueLayout +from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp +from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton +from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper +from triton_kernels.target_info import cuda_capability_geq +import triton.language as tl +import triton +import torch + +# ------------------------------------------------------------ +# Torch tests +# ------------------------------------------------------------ + + +@pytest.mark.parametrize("shape", [(16, 32), (16, 64), (32, 32), (32, 64), (64, 128), (128, 128)]) +@pytest.mark.parametrize("trans", [False, True]) +@pytest.mark.parametrize("mx_axis", [0, 1]) +@pytest.mark.parametrize("mma_version", [2, 3]) +def test_mxfp4_value_roundtrip(shape, trans, mx_axis, mma_version): + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") + if trans: + x = x.mT + if x.shape[1 - mx_axis] < 32: + pytest.skip("Not enough elements along non-mx axis") + layout = HopperMXValueLayout(x.shape, mx_axis, mma_version) + res = layout.unswizzle_data(layout.swizzle_data(x)) + assert (res == x).all() + + +@pytest.mark.parametrize("mx_axis", [0, 1]) +@pytest.mark.parametrize("num_warps", [4, 8]) +@pytest.mark.parametrize("shape", [(256, 64), (256, 128), (256, 256)]) +def test_mxfp4_scale_roundtrip(shape, mx_axis, num_warps): + x = torch.randint(0, 256, shape, dtype=torch.uint8, device="cuda") + layout = HopperMXScaleLayout(x.shape, mx_axis=mx_axis, num_warps=num_warps) + res = layout.unswizzle_data(layout.swizzle_data(x)) + assert (res[:shape[0], :shape[1]] == x).all() + + +# ------------------------------------------------------------ +# Triton tests +# ------------------------------------------------------------ + +# ------------------ upcast mxfp4 to bf16 -------------------- + + +@triton.jit +def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m, x_scale_stride_n, y_stride_m, + y_stride_n, X_BLOCK_M: tl.constexpr, X_BLOCK_N: tl.constexpr, Y_BLOCK_M: tl.constexpr, + Y_BLOCK_N: tl.constexpr, SCALE_BLOCK_M: tl.constexpr, SCALE_BLOCK_N: tl.constexpr, + mx_axis: tl.constexpr): + offs_m_val = tl.arange(0, X_BLOCK_M) + offs_n_val = tl.arange(0, X_BLOCK_N) + offs_m_scale = tl.arange(0, SCALE_BLOCK_M) + offs_n_scale = tl.arange(0, SCALE_BLOCK_N) + # load values + offs_x = offs_m_val[:, None] * x_stride_m + offs_n_val[None, :] * x_stride_n + x = tl.load(X + offs_x) + # load scales + offs_x_scale = offs_m_scale[:, None] * x_scale_stride_m + offs_n_scale[None, :] * x_scale_stride_n + x_scale = tl.load(XScale + offs_x_scale) + x_scale = unswizzle_mxfp4_scale_hopper(x_scale, mx_axis=mx_axis, num_warps=tl.extra.cuda.num_warps()) + y = mxfp4_to_bf16_triton(x, x_scale, mx_axis=mx_axis) + # write back output + offs_m_val = tl.arange(0, Y_BLOCK_M) + offs_n_val = tl.arange(0, Y_BLOCK_N) + offs_y = offs_m_val[:, None] * y_stride_m + offs_n_val[None, :] * y_stride_n + tl.store(Y + offs_y, y) + + +@pytest.mark.skipif(not is_cuda(), reason="Only supported on cuda") +@pytest.mark.skipif(not cuda_capability_geq(9), reason="Only supported for capability >= 9") +def test_upcast_mxfp4_to_bf16(): + mx_axis = 0 + num_warps = 4 + torch.manual_seed(0) + torch.cuda.manual_seed(0) + shape = (256, 128) + x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") + x_fp4_val, x_fp4_scale = downcast_to_mxfp(x, torch.uint8, axis=mx_axis) + x_bf16 = upcast_from_mxfp(x_fp4_val, x_fp4_scale, x.dtype, axis=mx_axis) + x_fp4_val = wrap_torch_tensor(x_fp4_val, dtype=FP4) + x_fp4_scale = wrap_torch_tensor(x_fp4_scale) + x_fp4_val = convert_layout(x_fp4_val, HopperMXValueLayout, mx_axis=mx_axis) + x_fp4_scale = convert_layout(x_fp4_scale, HopperMXScaleLayout, mx_axis=mx_axis, num_warps=num_warps) + y = torch.empty_like(x_bf16) + _upcast_mxfp4_to_bf16[(1, )]( + y, x_fp4_val.storage.data, x_fp4_scale.storage.data, # + x_fp4_val.storage.data.stride(0), x_fp4_val.storage.data.stride(1), # + x_fp4_scale.storage.data.stride(0), x_fp4_scale.storage.data.stride(1), # + y.stride(0), y.stride(1), # + x_fp4_val.storage.data.shape[0], x_fp4_val.storage.data.shape[1], # + shape[0], shape[1], # + x_fp4_scale.storage.data.shape[0], x_fp4_scale.storage.data.shape[1], # + mx_axis=mx_axis, num_warps=num_warps) + assert (y == x_bf16).all() diff --git a/torch-ext/triton_kernels/__init__.py b/torch-ext/triton_kernels/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch-ext/triton_kernels/__pycache__/__init__.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..987bd65eccffc34607e5fa9bc3bfb5f823b314d9 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/__init__.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/compaction.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/compaction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342aab728f7a8bd9993bb581784d2a1661bb09a4 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/compaction.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/datastruct.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/datastruct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb2c77f595ef2d49f10fc08cb25ffb4686d69dbf Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/datastruct.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/matmul_ogs.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/matmul_ogs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb5d06725f4b0d82cdcdd9d4b12e9ea217799c57 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/matmul_ogs.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/numerics.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/numerics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..842d2308a081bfa5c2d06da8de03b9726760fa32 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/numerics.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/routing.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/routing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..149061b6e5ddf14e64d2922b27c4769e47592dc1 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/routing.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/specialize.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/specialize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ded18d04a103d526fe73739cb354726104b9e5c Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/specialize.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/swiglu.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/swiglu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd920969ba9d939334072bc1a8f9238d5c796643 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/swiglu.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/target_info.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/target_info.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19355322b0b35e4268f0c6e16a3da91e128509de Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/target_info.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/__pycache__/topk.cpython-310.pyc b/torch-ext/triton_kernels/__pycache__/topk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f58b257e11e5064e3eae655dc70fe20cc8620920 Binary files /dev/null and b/torch-ext/triton_kernels/__pycache__/topk.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/compaction.py b/torch-ext/triton_kernels/compaction.py new file mode 100644 index 0000000000000000000000000000000000000000..417d63fec07103ca4a4dbd7876f94ee1b79ae6d2 --- /dev/null +++ b/torch-ext/triton_kernels/compaction.py @@ -0,0 +1,69 @@ +import torch +from .compaction_details._masked_compaction import _masked_compaction +from .tensor import Bitmatrix + + +def compaction(yv, yi, bitmask, sentinel=-1): + """ + Return compacted copies of *yv* and *yi* based on a per-row bitmask. + + Only the elements whose index appears among the active bits of *bitmask* + are kept; the rest are replaced by *sentinel*. Kept elements preserve + their original left-to-right order. + + Parameters + ---------- + yv : torch.Tensor, shape (B, K) + Values tensor. + yi : torch.Tensor, shape (B, K), dtype torch.long + Integer indices (0 ≤ index < 32) associated with *yv*. + bitmask : torch.Tensor, shape (B,) **or** (B, 32) + Per-row mask of active indices. See the in-place version for details. + sentinel : int, default -1 + Value written into dropped positions of the returned tensors. + + Returns + ------- + (yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K) + New tensors with the same dtype/device as the inputs. + + """ + + n_rows, n_cols = yi.shape + ret_yv = torch.empty_like(yv) + ret_yi = torch.empty_like(yi) + if isinstance(bitmask, Bitmatrix): + bitmask = bitmask.storage.data + + _masked_compaction[(n_rows, )]( + yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs + ret_yv, ret_yi, # outputs + sentinel, # sentinel + K=n_cols # constants + ) + return ret_yv, ret_yi + + +def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1): + """ + reference implementation of `masked_compact` + """ + B, K = yi.shape + device = yi.device + # Expand bitmask to a boolean matrix of active bits (B, 32) + w = (1 << torch.arange(32, device=device, dtype=bitmask.dtype)) + bits = (bitmask.unsqueeze(-1) & w) != 0 + mask = bits.flatten(start_dim=-2) # or bits.reshape(B, -1) + # For every yi element decide whether it should be kept + keep = mask.gather(1, yi.long()) + # Build a stable permutation that brings all "keep" items forward + # False→0, True→1 ==> invert so kept==0, dropped==1, then argsort + order = (~keep).to(torch.int).argsort(dim=1, stable=True) + # Re‑order tensors according to above permutation + yi_sorted = yi.gather(1, order) + yv_sorted = yv.gather(1, order) + # fill relevant positions with sentinel + keep_sorted = keep.gather(1, order) + yi_sorted[~keep_sorted] = sentinel + yv_sorted[~keep_sorted] = sentinel + return yv_sorted, yi_sorted diff --git a/torch-ext/triton_kernels/compaction_details/__pycache__/_masked_compaction.cpython-310.pyc b/torch-ext/triton_kernels/compaction_details/__pycache__/_masked_compaction.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..770f186d7696f770738d06a2573f3bd0001527a7 Binary files /dev/null and b/torch-ext/triton_kernels/compaction_details/__pycache__/_masked_compaction.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/compaction_details/_masked_compaction.py b/torch-ext/triton_kernels/compaction_details/_masked_compaction.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6728ceb70ef5e4c73318aff7bfcf1d575d11e5 --- /dev/null +++ b/torch-ext/triton_kernels/compaction_details/_masked_compaction.py @@ -0,0 +1,20 @@ +import triton +import triton.language as tl + + +@triton.jit +def _masked_compaction(Yv, Yi, BitMask, stride_bm, stride_bn, RetYv, RetYi, sentinel, K: tl.constexpr): + pid_m = tl.program_id(0) + yv = tl.load(Yv + pid_m * K + tl.arange(0, K)) + yi = tl.load(Yi + pid_m * K + tl.arange(0, K)) + div = yi // 32 + rem = yi % 32 + active_bits = (tl.load(BitMask + pid_m * stride_bm + div * stride_bn) >> rem) & 1 + exc_cumsum = tl.cumsum(active_bits, 0) - active_bits + active_flags = active_bits.to(tl.int1) + rev_arange = tl.where(active_flags, 0, K - 1 - tl.arange(0, K)) + write_indx = exc_cumsum + rev_arange + yv = tl.where(active_flags, yv, sentinel) + yi = tl.where(active_flags, yi, sentinel) + tl.store(RetYv + pid_m * K + write_indx, yv) + tl.store(RetYi + pid_m * K + write_indx, yi) diff --git a/torch-ext/triton_kernels/matmul_ogs.py b/torch-ext/triton_kernels/matmul_ogs.py new file mode 100644 index 0000000000000000000000000000000000000000..57ea4f427199d1c2878c827069d52d6f19b91c43 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs.py @@ -0,0 +1,662 @@ +# isort: off +# fmt: off +from dataclasses import dataclass +import itertools +import sys +import torch +import triton +from enum import Enum, auto +# utilities +from triton_kernels import target_info +from triton_kernels.numerics import InFlexData, OutFlexData +from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx +from triton_kernels.target_info import is_cuda +# details +from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx +from .matmul_ogs_details._matmul_ogs import _matmul_ogs +from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn +from .matmul_ogs_details._finalize_matmul import _finalize_matmul +from .matmul_ogs_details.opt_flags import make_opt_flags, update_opt_flags_constraints +from .numerics_details.mxfp import MXFP_BLOCK_SIZE +from .specialize import specialize +from .tensor import Storage, Tensor, FP4, bitwidth, wrap_torch_tensor + + +@dataclass(frozen=True) +class FnSpecs: + name: str + fn: "triton.runtime.jit.JITFunction" + fn_arg_names: tuple[str] + fn_arg_do_not_specialize: tuple[str] = tuple() + + @staticmethod + def default(): + return FnSpecs("dflt", None, tuple()) + + +@dataclass(frozen=True) +class FusedActivation: + specs: FnSpecs = FnSpecs.default() + fn_args: tuple[object] = tuple() + reduction_n: int = 1 + + +@dataclass(frozen=True) +class Epilogue: + specs: FnSpecs = FnSpecs.default() + fn_arg_values_matmul: tuple[object] = tuple() + fn_arg_values_finalize: tuple[object] = tuple() + effective_itemsize: float = None + +class FnName(Enum): + DEQUANTIZE_MXFP8 = auto() + + +EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated + +_kernels = dict() + + +def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()): + global _kernels + key = (fused_activation.name, epilogue.name) + if key in _kernels: + return _kernels[key] + spec_constants = { + "ACTIVATION_FN": fused_activation.fn, + "EPILOGUE_FN": epilogue.fn, + } + spec_tuples = { + "activation_fn_args": fused_activation.fn_arg_names, + "epilogue_fn_args": epilogue.fn_arg_names, + } + do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize + import types + + module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}") + sys.modules[module.__name__] = module + module._finalize_matmul = specialize(_finalize_matmul, module, spec_constants, spec_tuples, + do_not_specialize=do_not_specialize) + module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples, + do_not_specialize=do_not_specialize) + module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples, + do_not_specialize=do_not_specialize) + _kernels[key] = module + return module + + +# ----------------------------------------------------------------------------- +# Matrix Multiplication + Outer Gather/Scatter +# ----------------------------------------------------------------------------- + + +def can_overflow_int32(tensor: torch.Tensor): + max_int32 = (1 << 31) - 1 + offset = 0 + for i in range(tensor.ndim): + offset += (tensor.shape[i] - 1) * tensor.stride(i) + return offset > max_int32 + + +def should_upcast_indices(*args): + return any(tensor is not None and can_overflow_int32(tensor) for tensor in args) + + +# --------------------- +# Numerics +# --------------------- + +# fmt: off + +@dataclass(frozen=True) +class FlexCtx: + lhs_data: InFlexData = InFlexData() + rhs_data: InFlexData = InFlexData() + out_data: OutFlexData = OutFlexData() + +@dataclass +class PrecisionConfig: + max_num_imprecise_acc: int = None + allow_tf32: bool = True + flex_ctx: FlexCtx = FlexCtx() + acc_scale: int = 1.0 + flexpoint_saturate_inf: bool = False + report_quantization_err_fn: callable = None + act_scale: Tensor | None = None + weight_scale: Tensor| None = None + out_scale: Tensor | None = None + out_dtype: torch.dtype = None + enforce_bitwise_invariance: bool = False + +# --------------------- +# Preprocessing +# --------------------- + +@dataclass(frozen=True) +class PreprocessingFeatures: + swap_xw: bool + + +def init_preprocessing_features(w, precision_config, opt_flags): + swap_xw = False # Whether or not to swap X and W operands to the tl.dot + if target_info.cuda_capability_geq(10, 0): + swap_xw = precision_config.weight_scale is not None and opt_flags.block_m <= 64 and opt_flags.is_persistent + return PreprocessingFeatures(swap_xw) + +def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features): + has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1 + if has_fused_scatter_scratchpad: + M = scatter_indx.src_indx.shape[0] + writeback_idxs = torch.zeros((M,), dtype=torch.int32, device=x.device) + writeback_size = writeback_idxs.shape[0] + finalize_scatter_idxs = torch.zeros((M // routing_data.n_expts_act + M + 1,), dtype=torch.int32, device=x.device) + BLOCK_M=256 + _compute_writeback_idx[(triton.cdiv(M, BLOCK_M),)]( + writeback_idxs, + finalize_scatter_idxs, + scatter_indx.dst_indx, + scatter_indx.src_indx, + M // routing_data.n_expts_act, + M, + BLOCK_M=BLOCK_M, + N_EXPTS_ACT=routing_data.n_expts_act, + ) + elif scatter_indx is not None and routing_data.n_expts_act == 1: + writeback_idxs = scatter_indx.dst_indx + writeback_size = scatter_indx.dst_indx.shape[0] + finalize_scatter_idxs = None + else: + writeback_idxs, writeback_size, finalize_scatter_idxs = None, None, None + # preprocess routing information and ptr lookup table + M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0] + return x, w, writeback_idxs, writeback_size, finalize_scatter_idxs + + +# --------------------- +# Postprocessing +# --------------------- + + +@dataclass(frozen=True) +class PostprocessingFeatures: + finalize: bool + +def init_postprocessing_features(routing_data, scatter_indx, opt_flags): + finalize = (scatter_indx is not None and routing_data.n_expts_act > 1) or opt_flags.split_k > 1 + return PostprocessingFeatures(finalize) + +def apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_offs, num_indx, precision_config, routing_data, + postprocess_features, memory, fused_activation, epilogue): + out = memory["output"] + flex_ctx = precision_config.flex_ctx + if postprocess_features.finalize: + has_fused_scatter_scratchpad = opt_flags.fused_scatter and routing_data.n_expts_act > 1 + if has_fused_scatter_scratchpad: + inp = memory["output"] + else: + inp = memory["scratchpad"]["matmul"] + if scatter_indx is not None: + assert inp.shape[1] == 1, "batched finalize scatter not supported" + n_final_rows = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act + scatter_src_indx = scatter_indx.src_indx + EXPT_PER_TOK = routing_data.n_expts_act + num_rows = None + else: + n_final_rows = inp.shape[1] * inp.shape[2] + scatter_src_indx = None + EXPT_PER_TOK = 1 + num_rows = num_indx or (None if expt_offs is None else expt_offs[-1]) + + if inp.dtype == torch.float32: + inp_flex = OutFlexData() + else: + inp_flex = precision_config.flex_ctx.out_data + + out_scatter = memory["output"] + out_scatter_flex = precision_config.flex_ctx.out_data + + N = inp.shape[3] + M = n_final_rows + warps_per_sm = 32 if target_info.is_hip() else 128 + + def compute_grid(BLOCK_N, num_warps): + num_pid = target_info.num_sms() * (warps_per_sm // num_warps) + if M < num_pid or target_info.is_hip(): + grid_n = triton.cdiv(N, BLOCK_N) + grid_m = min(M, max(1, triton.cdiv(num_pid, grid_n))) + else: + grid_m = min(M, num_pid) + grid_n = 1 + return (grid_m, grid_n) + + if inp.dtype.itemsize == 1: + candidates = [(1024, 1)] + else: + if target_info.is_hip(): + candidates = [(4096 // inp.dtype.itemsize, 2)] + else: + if inp.dtype.itemsize == 2: + candidates = [ + (4096 // inp.dtype.itemsize, 4), + (1024 // inp.dtype.itemsize, 1), + ] + else: + candidates = [ + (2048 // inp.dtype.itemsize, 4), + (1024 // inp.dtype.itemsize, 1), + ] + if precision_config.enforce_bitwise_invariance: + candidates = [candidates[0]] + + # sort by smallest grid_n so we share compute across a row + grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0] + STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5) + + out_scale = precision_config.out_scale + out_has_mx = out_scale is not None + out_scale_strides = (None, None) if out_scale is None else out_scale.stride()[-2:] + mx_a_scale = memory["scratchpad"].get("mx_out_scale", None) + if mx_a_scale is not None: + mx_a_scale_stride_k, mx_a_scale_stride_m = [mx_a_scale.stride(i) for i in (0, 2)] + else: + mx_a_scale_stride_k, mx_a_scale_stride_m = None, None + + kernels = get_kernels(epilogue.specs, fused_activation.specs) + kernels._finalize_matmul[grid]( + flex_ctx.out_data.reinterpret(out_scatter), + *((None, out_scale, None) if out_has_mx else out_scatter_flex), + *out_scale_strides, + flex_ctx.out_data.reinterpret(inp), inp.stride(0), inp.stride(2), + inp_flex.expected_scale if mx_a_scale is None else mx_a_scale, + mx_a_scale_stride_k, mx_a_scale_stride_m, + scatter_src_indx, finalize_scatter_idxs, + inp.shape[0], M, N, num_rows, + *fused_activation.fn_args, fused_activation.reduction_n, + *epilogue.fn_arg_values_finalize, + EXPT_PER_TOK=EXPT_PER_TOK, + BLOCK_N=BLOCK_N, + STAGES=STAGES, + num_warps=num_warps, + flexpoint_saturate_inf=precision_config.flexpoint_saturate_inf, + HAS_FUSED_SCRATCHPAD=has_fused_scatter_scratchpad, + ) + out = out_scatter + # trim unnecessary part of output + if has_fused_scatter_scratchpad: + # Discard scratchpad part. + # This still gives a contiguous tensor, because shape[0] > 1 only when + # batch mode is enabled, in which case this is a no-op (there's no scratchpad). + out = out[:, :, :n_final_rows, :] + return out + + +# --------------------- +# Allocation +# --------------------- + +@dataclass +class MatmulAllocation: + device: str + output: tuple[tuple[int], torch.dtype] + scratchpads: dict[str, tuple] + +def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags, + preprocessing_features, postprocessing_features): + # ---- output ------ + N = w.shape[-1] + # by default - M is number of rows in the activations + M = x.shape[-2] + # if the activations are gathered, then M is number of gather indices + if gather_indx is not None: + M = gather_indx.src_indx.shape[0] + # final output + if routing_data.n_expts_act == 1 or scatter_indx is None: + y_rows = M + elif opt_flags.fused_scatter: + # we need the scratchpad and the output to be contiguous in memory + Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows + y_rows = M + Mc + else: + Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows + y_rows = Mc + batch_dim = x.shape[0] if x.ndim == 3 else 1 + y_shape = (batch_dim, y_rows, N // fused_activation.reduction_n) + out_dtype = precision_config.out_dtype or x.dtype + output = (y_shape, out_dtype) + # ---- scratchpad -----# + scratchpad = dict() + # if we need either standalone scatter or split-k, the matmul output will need post-processing + if postprocessing_features.finalize: + if opt_flags.split_k > 1 or not opt_flags.fused_scatter: + dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype + scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), dtype) + if precision_config.out_scale is not None and not (scratchpad.get("matmul", None) is not None and scratchpad["matmul"][1].itemsize > 1): + scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8) + return MatmulAllocation(x.device, output, scratchpad) + +def apply_allocation(allocation: MatmulAllocation, output): + ret = dict() + if output is None: + output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1]) + else: + assert output.shape == allocation.output[0] + ret["output"] = output[None, :, :] + ret["scratchpad"] = { + k: torch.empty(v[0], device=allocation.device, dtype=v[1]) + for k, v in allocation.scratchpads.items() + } + return ret + +# ----------------------------------------------------------------------------- +# Canonicalize +# ----------------------------------------------------------------------------- +# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used +# we can canonicalize storages to make the implementation more uniform + +def _canonicalize_storage(storage, out_ndim, flex_data): + assert out_ndim >= storage.data.ndim + new_storage_shape = [1] * (out_ndim - storage.data.ndim) + list(storage.data.shape) + new_storage_data = storage.data.view(new_storage_shape) + if flex_data is not None: + new_storage_data = flex_data.reinterpret(new_storage_data) + return Storage(new_storage_data, storage.layout) + + +# ----------------------------------------------------------------------------- +# Triton Implementation +# ----------------------------------------------------------------------------- + +def matmul_ogs_set_idle_sms(num_idle_sms): + """ + persistent kernels will leave `num_idle_sms` idle + """ + update_opt_flags_constraints({"idle_sms": num_idle_sms}) + +def matmul_ogs(x, w, bias, + routing_data: RoutingData | None = None, + gather_indx: GatherIndx | None = None, + scatter_indx: ScatterIndx | None = None, + precision_config: PrecisionConfig | None = None, + betas: torch.Tensor | None = None, + gammas: torch.Tensor | None = None, + out_alpha: float | None = None, + y: torch.Tensor | None = None, + fused_activation: FusedActivation | None = None, + epilogue: Epilogue | None = None, + ): + """ + Y[:, :] = 0. + for e in num_experts: + Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :]) + """ + is_input_batched = x.ndim == 3 + if is_input_batched: + assert gather_indx is None, "gather not supported in batched mode" + assert scatter_indx is None, "scatter not supported in batched mode" + assert routing_data is None, "routing not supported in batched mode" + assert w.ndim == 3 and w.shape[0] == x.shape[0] + # canonicalize inputs + if precision_config is None: + precision_config = PrecisionConfig() + if fused_activation is None: + fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1) + if epilogue is None: + epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False) + if routing_data is None: + routing_data = RoutingData(None, None, max(1, w.shape[0]), 1) + # unpack scales + w_scale = precision_config.weight_scale + w_has_mx = w_scale is not None + is_hopper_fp8 = is_cuda() and not target_info.cuda_capability_geq(10, 0) and bitwidth(w.dtype) == 8 + if w_has_mx: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp" + if is_hopper_fp8: assert w.stride(-2) == 1, "`w` must be column-major when it has data-type FP8 on capability < 10" + if not isinstance(w, Tensor): + # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real + dtype = FP4 if w.dtype == torch.uint8 else w.dtype + w = wrap_torch_tensor(w, dtype=dtype) + if w_scale is not None and not isinstance(w_scale, Tensor): + w_scale = Tensor(w_scale) + if w_scale is not None: + w_scale.storage.data = w_scale.data.view(torch.uint8) + w_scale.dtype = torch.uint8 + x_scale = precision_config.act_scale + x_has_mx = x_scale is not None + if x_has_mx: assert x.stride(-1) == 1, "'x' must be row-major when it has data-type mxfp" + if x_scale is not None and not isinstance(x_scale, Tensor): + x_scale = Tensor(x_scale) + if not isinstance(x, Tensor): + x = Tensor(x, dtype=x.dtype) + # determine shapes + M = x.shape[-2] if gather_indx is None else gather_indx.src_indx.shape[0] + batch_size = w.shape[0] if routing_data.expt_hist is None and w.ndim == 3 else 1 + K, N = w.shape[-2:] + assert K == x.shape[-1] + if x.ndim == 3 and w.ndim == 3: + assert x.shape[0] == w.shape[0] + # compute optimization flags + out_dtype = precision_config.out_dtype or x.dtype + can_use_tma = x.storage.is_tma_compliant() and \ + w.storage.is_tma_compliant() and \ + (w_scale is None or w_scale.storage.is_tma_compliant()) + # hopper w/ mxfp4 doesn't support TMA + can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) + can_use_fused_scatter = scatter_indx is not None and fused_activation.specs.fn is None + opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, + M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize, + ) + if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp(): + raise NotImplementedError("Must use non-persistent kernel for simulated MXFP") + if w_scale is not None and w_scale.storage.layout.name is not None and not opt_flags.is_persistent and target_info.has_native_mxfp(): + raise NotImplementedError("Must use persistent kernel and be TMA-compliant for native MXFP") + # determine necessary pre/post processing + preprocessing_features = init_preprocessing_features(w, precision_config, opt_flags) + postprocessing_features = init_postprocessing_features(routing_data, scatter_indx, opt_flags) + # allocate output/scratchpad memory + allocation = init_allocation(x, w, precision_config, fused_activation, + routing_data, gather_indx, scatter_indx, + opt_flags, preprocessing_features, postprocessing_features + ) + memory = apply_allocation(allocation, y) + # TMA descriptors require a global memory allocation + if opt_flags.is_persistent: + triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device)) + # Intermediate tensors and postprocess kernels for each situation + out0, out0_flex = memory["output"], precision_config.flex_ctx.out_data + fused_postprocess_activation = FusedActivation(FnSpecs.default(), tuple(), 1) + out_scale = None if precision_config.out_scale is None else precision_config.out_scale.data.view(torch.uint8) + if postprocessing_features.finalize: + if opt_flags.fused_scatter: + out0 = memory["output"] + else: + out0 = memory["scratchpad"]["matmul"] + if "mx_out_scale" in memory["scratchpad"]: + assert out_scale is not None + out_scale = memory["scratchpad"]["mx_out_scale"] + out0_flex = OutFlexData() if out0.dtype == torch.float32 else precision_config.flex_ctx.out_data + + fused_activation, fused_postprocess_activation = fused_postprocess_activation, fused_activation + out_has_mx = out_scale is not None and out0.element_size() == 1 + if out_has_mx: + if isinstance(out_scale, Tensor): + out_scale = Tensor(out_scale) + else: + out_scale = None + # pre-processing + x, w, writeback_idxs, writeback_size, finalize_scatter_idxs = apply_preprocessing_features( + x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features + ) + # matrix multiplication + flex = precision_config.flex_ctx + bias_stride = None if bias is None else bias.stride(0) + num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0] + # moe metadata + expt_data = routing_data.expt_data + block_m = opt_flags.block_m + expt_hist = None if expt_data is None else expt_data.hist + expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1] + expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw + expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m] + # spmd grid + grid_m = triton.cdiv(M, opt_flags.block_m) + if expt_block_pid_map is not None: + grid_m = routing_data.n_blocks(M, opt_flags.block_m) + grid_n = triton.cdiv(N, opt_flags.block_n) + max_grid = batch_size * grid_m * grid_n * opt_flags.split_k + grid = min(target_info.num_sms() - opt_flags.idle_sms, max_grid) if opt_flags.is_persistent else max_grid + # canonicalize storage + has_gather = gather_indx is not None + x_storage = _canonicalize_storage(x.storage, 2 if has_gather else 3, flex.lhs_data) + w_storage = _canonicalize_storage(w.storage, 3, flex.rhs_data) + # create tma descriptor for x + x_has_tma = ((not has_gather) or (has_gather and target_info.has_tma_gather())) and opt_flags.is_persistent + x_block_tma = ([1] if has_gather else [1, opt_flags.block_m]) + [opt_flags.block_k] + x_tensor_or_tma = x_storage.make_tma(x_block_tma) if x_has_tma else x_storage.data + # create tma descriptor for w + w_has_tma = opt_flags.is_persistent + w_tensor_or_tma = w_storage.make_tma([1, opt_flags.block_k, opt_flags.block_n]) if w_has_tma else w_storage.data + # create tma descriptor for w_scale + w_scale_tensor_or_tma = w_scale + w_scale_has_tma = opt_flags.is_persistent and w_scale is not None + w_scale_tensor_or_tma = w_scale.storage.make_tma([opt_flags.block_n, opt_flags.block_k]) if w_scale_has_tma else w_scale + # canonicalize strides + x_strides = [0]*(3 - x_storage.data.ndim) + list(x_storage.data.stride()) + x_scale_strides = x_scale.stride() if x_has_mx else (None, None, None) + x_scale_strides = (0, ) * (3 - len(x_scale_strides)) + x_scale_strides + w_scale_strides = w_scale.stride() if w_has_mx and not w_scale_has_tma else (None, None, None) + w_scale_strides = (0, ) * (3 - len(w_scale_strides)) + w_scale_strides + out_scale_strides = out_scale.stride() if out_has_mx else (None, None, None, None) + out_scale_strides = (0, ) * (3 - len(out_scale_strides)) + out_scale_strides + # launch kernel + kernels = get_kernels(epilogue.specs, fused_activation.specs) + (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(grid,)]( + flex.out_data.reinterpret(memory["output"]), + flex.out_data.reinterpret(out0), *out0.stride(), + *((None, out_scale, None) if out_has_mx else out0_flex), + *out_scale_strides[-3:], + x_tensor_or_tma, x_storage.data, *x_strides, + flex.lhs_data.scale, + None if x_scale is None else x_scale.data.view(torch.uint8), *x_scale_strides, + w_tensor_or_tma, *w_storage.data.stride(), w_storage.data.stride()[-1] != 1, + flex.rhs_data.scale, + w_scale_tensor_or_tma, *w_scale_strides, + bias, bias_stride, + x.shape[-2], + x.shape[-2] if routing_data.expt_hist is None else None, + N, K, + betas, gammas, + None if gather_indx is None else gather_indx.src_indx, + None if scatter_indx is None else scatter_indx.src_indx, + num_indx, + writeback_idxs, writeback_size, + expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map, + batch_size, grid_m, grid_n, + out_alpha, + *fused_activation.fn_args, fused_activation.reduction_n, + *epilogue.fn_arg_values_matmul, + routing_data.n_expts_tot, routing_data.n_expts_act, + precision_config.max_num_imprecise_acc, + precision_config.allow_tf32, + precision_config.flexpoint_saturate_inf, + flex.rhs_data.is_per_batch, + opt_flags.block_m, + opt_flags.block_n, + opt_flags.block_k, + opt_flags.group_m, + XCD_SWIZZLE=opt_flags.xcd_swizzle, + SWIZZLE_MX_VALUE=w.storage.layout.name, + SWIZZLE_MX_SCALE=None if w_scale is None else w_scale.storage.layout.name, + EPILOGUE_SUBTILE=opt_flags.epilogue_subtile, + SPLIT_K=opt_flags.split_k, + EVEN_K=K % opt_flags.block_k == 0, + W_CACHE_MODIFIER=opt_flags.w_cache_modifier, + TOKENS_PER_EXPT_FOR_ANNOTATION=routing_data.expected_tokens_per_expt, + num_warps=opt_flags.num_warps, + num_stages=opt_flags.num_stages, + arch=opt_flags.arch, + UPCAST_INDICES=should_upcast_indices(x, w, out0), + DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0, + SWAP_XW=preprocessing_features.swap_xw, + IS_EPILOGUE_DEQUANT_MXFP8=epilogue.specs.name == FnName.DEQUANTIZE_MXFP8.name, + NUM_SMS = grid if opt_flags.is_persistent else 0, + **opt_flags.target_kernel_kwargs) + # post-processing + out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw, + num_indx, precision_config, routing_data, + postprocessing_features, memory, fused_postprocess_activation, epilogue) + # remove split-k + out = out.squeeze(0) + if not is_input_batched: + out = out.view(out.shape[-2], out.shape[-1]) + return out + + +# ----------------------------------------------------------------------------- +# Reference Implementation +# ----------------------------------------------------------------------------- + +def matmul_ogs_torch(x, w, bias, + routing_data: RoutingData = None, + gather_indx: GatherIndx = None, + scatter_indx: ScatterIndx = None, + precision_config: PrecisionConfig = None, + betas = None, + gammas = None, + round_x = None, round_y = None, + ): + is_input_batched = x.ndim == 3 + assert x.dtype.itemsize > 1 + assert w.dtype.itemsize > 1 + if is_input_batched: + assert gather_indx is None, "gather not supported in batched mode" + assert scatter_indx is None, "scatter not supported in batched mode" + assert routing_data is None, "routing not supported in batched mode" + assert w.ndim == 3 and w.shape[0] == x.shape[0] + if round_x is None: + round_x = lambda x: x + if round_y is None: + round_y = lambda x: x + if bias.ndim == 1: + bias = bias.view(1, *bias.shape) + if w.ndim == 2: + w = w.view(1, *w.shape) + if x.ndim == 2: + x = x.view(1, *x.shape) + if routing_data is None: + routing_data = RoutingData(None, None, w.shape[0], 1) + n_expts_act = routing_data.n_expts_act + # memory offsets + if routing_data.n_expts_tot > 1 and not is_input_batched: + sizes = routing_data.expt_hist + off = torch.zeros(sizes.shape[0] + 1, dtype=torch.int32) + off[1:] = torch.cumsum(sizes, 0) + offs = list(itertools.pairwise(off)) + else: + offs = [[0, x.shape[1]] for _ in range(w.shape[0])] + # compute + n_rows = x.shape[1] if gather_indx is None else gather_indx.dst_indx.shape[0] + y = torch.zeros((x.shape[0], n_rows, w.shape[-1]), device=x.device, dtype=x.dtype) + for i, (lo, hi) in enumerate(offs): + if gather_indx is None: + idx = torch.arange(lo, hi, device=x.device) + else: + idx = gather_indx.src_indx[lo:hi] // n_expts_act + batch = i if is_input_batched else 0 + out = torch.matmul(round_x(x[batch, idx, :], torch.arange(lo, hi, device="cuda")).float(), + w[i].float()) + if bias is not None: + out += bias[i, :] if betas is None else bias[i, :] * betas[lo:hi, None] + if gammas is not None: + out *= gammas[lo:hi, None] + y[batch, lo:hi, :] = round_y(out) + if not is_input_batched: + y = y.view(y.shape[1], y.shape[2]) + if scatter_indx is None: + return y + # accumulate output from all experts + n_rows = y.shape[0] // n_expts_act + out = torch.zeros((n_rows, y.shape[-1]), dtype=torch.float32, device=x.device) + for i, (lo, hi) in enumerate(offs): + dst_idx = scatter_indx.dst_indx[lo:hi] // n_expts_act + msk = dst_idx != -1 + out[dst_idx[msk], :] += y[lo:hi, :][msk, :].float() + return out diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_common.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c49b72446cb47a980823ce44dc3c4b00313b1ca Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_common.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_finalize_matmul.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_finalize_matmul.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7419a001037b91921e1b685aa7446cac1363d44e Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_finalize_matmul.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_matmul_ogs.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_matmul_ogs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60dc1c0dc6b77646e6b934b8deecf74c7626ae7b Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_matmul_ogs.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_p_matmul_ogs.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_p_matmul_ogs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27ab080b51ca1cf858ed92bbf6b866f1b9f11e2a Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_p_matmul_ogs.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_weight_transpose.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_weight_transpose.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a144e0cf0d83c26da580b3c28527eeb56e8fd53 Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/_weight_transpose.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/fast_contiguous.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/fast_contiguous.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633fbd2d05e2fc4d426bd00290ebaea259f241a2 Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/fast_contiguous.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14b1e9d6a7074750b7176a8ff0cc0e613c273e98 Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags_amd.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags_amd.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c796a676656642cd05d625e9ac75d583aa027160 Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags_amd.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags_nvidia.cpython-310.pyc b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags_nvidia.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f585c90c65a81674ef1dae9d97b6bb48edf16d60 Binary files /dev/null and b/torch-ext/triton_kernels/matmul_ogs_details/__pycache__/opt_flags_nvidia.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_common.py b/torch-ext/triton_kernels/matmul_ogs_details/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..25755d3105ffe2b119f9fc65c96be858a987fd3c --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/_common.py @@ -0,0 +1,165 @@ +import torch + +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +# ----------------------------------------------------------------------------- +# Utilities +# ----------------------------------------------------------------------------- + + +@tl.constexpr_function +def get_scaled_dot_format_string(dtype: tl.dtype): + mapping = { + tl.float16: "fp16", + tl.bfloat16: "bf16", + tl.uint8: "e2m1", + tl.float8e4nv: "e4m3", + tl.float8e5: "e5m2", + } + return mapping[dtype] + + +@triton.jit +def xcd_swizzle(pid, domain_size, XCD_SWIZZLE: tl.constexpr): + """ + Swizzle the program id based on integer XCD_SWIZZLE. + This is useful for reording how blocks are ordered. A scheduler may, for example, + assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2. + This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment + becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to + the same hardware unit. + """ + # Number of pids per group in the new arrangement + pids_per_group = domain_size // XCD_SWIZZLE + extra_pid_groups = domain_size % XCD_SWIZZLE + + # Compute current current and local pid within the group + group = pid % XCD_SWIZZLE + local_pid = pid // XCD_SWIZZLE + + # Calculate new pid based on the new grouping + new_pid = group * pids_per_group + min(group, extra_pid_groups) + local_pid + return new_pid + + +@triton.jit +def swizzle2d(pid, grid_m, grid_n, GROUP_M: tl.constexpr): + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + return pid_m, pid_n + + +def make_matmul_repr(base_name, order): + + def matmul_repr(specialization): + signature = specialization.signature + constants = specialization.constants + reorder = lambda L: [L[i] for i in order] + layout = lambda stride: "N" if stride in constants else "T" + + def convert_dtype(dtype): + if "tensordesc" in dtype: + ret = convert_dtype(dtype.split("<")[1].split("[")[0]) + return ret + elif "u8" in dtype: + return "mxfp4" + elif dtype[0] == "*": + return dtype[1:] + else: + return dtype + + dtypes = "x".join([convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])]) + layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])]) + blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]]) + # mode = [] + # if "GatherIndx" not in constants: + # mode += ['g'] + # if "ScatterSrcIndx" not in constants: + # mode += ['s'] + # suffix = "" if not mode else "_o" + (''.join(mode)) + # if base_name.startswith("_p"): + # suffix += "_ptma" + return f"{base_name}_{layouts}_{dtypes}_{blocks}" + + return matmul_repr + + +def matmul_launch_metadata(grid, kernel, args): + from ..proton_opts import launch_metadata_allow_sync + + ret = dict() + M, N, K = args["M"], args["N"], args["K"] + Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]] + tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") + hist = args["ExptHist"] + if hist is not None: + # If annotation is given, use that to generate name for profiling. + if tokens_per_expt is not None: + n_rows = f"{tokens_per_expt}*" + elif launch_metadata_allow_sync(): + n_rows = int(hist.float().mean()) + else: + n_rows = "unknown" + + if launch_metadata_allow_sync(): + n_tokens = float(hist.sum()) + n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum() + elif tokens_per_expt is not None: + n_tokens = tokens_per_expt * args["N_EXPTS_TOT"] + # This may not be totally correct (e.g., we might not be using all experts) + # but it's better than nothing. + n_w_bytes = W.numel() * W.element_size() + else: + n_tokens = None + n_w_bytes = 0 + + # If annotation is given, use that to generate name for profiling. + tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION") + n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows + else: + n_tokens = None + n_w_bytes = W.numel() * W.element_size() + repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}" + nbits = X.dtype.itemsize * 8 + batch_repr = "" + if "batch_size" in args and args["batch_size"] > 1: + batch_repr = repr("B", args["batch_size"]) + ", " + ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}" + ep_subtile = args["EPILOGUE_SUBTILE"] + if ep_subtile is not None and ep_subtile > 1: + ret["name"] += f" ep/{ep_subtile}" + + if hist is not None and n_tokens is None: + return ret # Don't fill metadata because we can't compute them properly. + + fM = M if M is not None else n_tokens + fK = K if K is not None else n_tokens + ret[f"flops{nbits}"] = 2.0 * fM * N * fK + + gindx = args.get("GatherIndx", None) + # sindx = args.get("WriteBackIndx", None) + n_x_bytes = X.numel() * X.element_size() + n_y_bytes = Y.numel() * Y.element_size() + if hist is not None: + assert n_tokens is not None + n_expts_act = args["N_EXPTS_ACT"] + + if (gindx is not None) and launch_metadata_allow_sync(): + # recreate inverse GatherIndx. + dst = torch.full_like(gindx, -1) + idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32) + mask = (gindx != -1) + dst[gindx[mask]] = idx[mask] + n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum() + else: + n_read_rows = n_tokens + n_x_bytes = n_read_rows * X.shape[-1] * X.element_size() + n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size() + ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes) + + return ret diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py b/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc1c51728cfc3bee46510ee1e708ebd24fa876 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/_finalize_matmul.py @@ -0,0 +1,377 @@ +import triton +import triton.language as tl +from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, update_scale +from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE +from triton_kernels.target_info import cuda_capability_geq as _cuda_capability_geq +from triton_kernels.target_info import is_hip as _is_hip + + +# fmt: off +@tl.constexpr_function +def is_hip(): + return _is_hip() + + +@tl.constexpr_function +def cuda_capability_geq(x, y): + return _cuda_capability_geq(x, y) + + +@tl.constexpr_function +def log2(n): + return len(bin(n)) - 3 + + +@tl.constexpr_function +def _permute_to_end_order(n: int, axis: int): + """ + Returns the order of the axes of a tensor to permute `axis` to the end. + """ + order = tuple(range(n)) + return order[:axis] + order[(axis + 1):] + (axis, ) + + +@triton.jit +def permute_to_end(x, axis: tl.constexpr): + """ + Permutes `x` so that `axis` is the last axis. + """ + N: tl.constexpr = len(x.shape) + return tl.permute(x, _permute_to_end_order(N, axis).value) + + +@triton.jit +def split_n(x, N: tl.constexpr): + """ + Given `x`, a tensor of shape AxB...x2x2...x2, split it N times. + Return a tuple of the results. + """ + xs = (x, ) + for i in tl.static_range(N): + next = tl.split(xs[0]) + for j in tl.static_range(2**i - 1): + next = next + tl.split(xs[j + 1]) + xs = next + return xs + + +@triton.jit +def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr = None, NUM_THREADS: tl.constexpr = None): + N: tl.constexpr = tl.extra.cuda.num_threads() if NUM_THREADS is None else NUM_THREADS + BS: tl.constexpr = x.numel if BLOCK_SIZE is None else BLOCK_SIZE + tl.static_assert(BS % N == 0, "BLOCK_SIZE must be divisible by NUM_THREADS") + return tl.max(tl.reshape(tl.abs(x), [N, BS // N], can_reorder=True), axis=1) + + +def _finalize_matmul_launch_metadata(grid, kernel, args): + ret = dict() + Out, A, ScatterSrcIndx, FinalizeScatterIdxs, K, M, N, EXPT_PER_TOK, NumRows = [ + args[name] + for name in ["Out", "A", "ScatterSrcIndx", "FinalizeScatterIdxs", "K", "M", "N", "EXPT_PER_TOK", "NumRows"] + ] + ret["name"] = f"{kernel.name} [M={M}x{EXPT_PER_TOK} {N=} {K=}]" + + if FinalizeScatterIdxs is not None: + M = FinalizeScatterIdxs[-1].item() + + if ScatterSrcIndx is not None: + is_active = (ScatterSrcIndx != -1).view((-1, EXPT_PER_TOK)) + n_active = is_active.sum(dim=1) + need_accum = n_active >= (1 if K > 1 else 2) + is_active &= need_accum[:, None] + active_input_rows = is_active.sum() + active_output_rows = need_accum.sum() + if EXPT_PER_TOK > 1: + # Masked rows are set to zero. + active_output_rows += (n_active == 0).sum() + else: + if NumRows is not None: + if isinstance(NumRows, int): + active_input_rows = NumRows + else: + active_input_rows = NumRows.item() + else: + active_input_rows = M + active_output_rows = M + + ret["bytes"] = (active_input_rows * K * A.shape[-1] * A.element_size() + + active_output_rows * Out.shape[-1] * Out.element_size()) + if FinalizeScatterIdxs is not None: + ret["bytes"] += FinalizeScatterIdxs.numel() * FinalizeScatterIdxs.element_size() + elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1: + ret["bytes"] += ScatterSrcIndx.numel() * ScatterSrcIndx.element_size() + nbits = Out.dtype.itemsize * 8 + ret[f"flops{nbits}"] = active_input_rows * K * A.shape[-1] + return ret + + +@tl.constexpr_function +def _accumulate_f16_into_f32_and_track_absmax_ptx(n_inputs: int, src_type: str, absmax_reg_name: str | None): + """ + Generate PTX code to take fp16 inputs and sum them into an f32 accumulator using mixed-precision + adds. If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside + that register. + + Generates code something like: + + add.f32.f16 $0, $2, $1; + add.f32.f16 $0, $3, $0; + add.f32.f16 $0, $4, $0; + add.f32.f16 $0, $5, $0; + + .reg .f32 b; + abs.f32 b, $0; + max.f32 my_abs_max, my_abs_max, b; + """ + # Add the first f16 value to the input $1, store into the output $0. + ptx = f"\nadd.f32.{src_type} $0, $2, $1;" + # Accumulate the rest of the inputs into the output $0. + for i in range(1, n_inputs): + ptx += f"\nadd.f32.{src_type} $0, ${2 + i}, $0;" + if absmax_reg_name is not None: + # Update `absmax_reg_name` with the absolute maximum value seen so far. + ptx += f""" + .reg .f32 b; + abs.f32 b, $0; + max.f32 {absmax_reg_name}, {absmax_reg_name}, b; + """ + # Return the PTX snippet, brace-enclosed so we don't pollute the global namespace. + return f"{{{ptx}}}" + + +@triton.jit +def _mixed_precision_accumulate_and_track_absmax(acc, x, axis: tl.constexpr, absmax_reg_name: tl.constexpr = None): + """Given an fp8/bf16/fp16 tensor, accumulate into `acc` along `axis`. + Values are first converted to bf16/fp16, packed into 32-bit registers, and then accumulated using + mixed-precision adds. + + If `absmax_reg_name` is provided, the absolute maximum value seen so far is tracked inside that + register. + """ + REDUCTION_SIZE: tl.constexpr = x.shape[axis] + tl.static_assert(2**log2(REDUCTION_SIZE) == REDUCTION_SIZE, + f"Reduction size must be a power of 2, was {REDUCTION_SIZE}") + # move `axis` to the last axis and reshape for iterative splitting. + x = permute_to_end(x, axis) + x = tl.reshape(x, x.shape[:-1] + (2, ) * log2(REDUCTION_SIZE)) + # Split into a tuple of AxB tensors. + xs = split_n(x, log2(REDUCTION_SIZE)) + if (tl.constexpr(x.dtype == tl.float8e4nv) or tl.constexpr(x.dtype == tl.float8e5)): + # Convert fp8 to fp16. + fp16_xs = () + for i in tl.static_range(len(xs)): + fp16_xs += (xs[i].to(tl.float16), ) + xs = fp16_xs + src_type: tl.constexpr = "f16" + elif x.dtype == tl.float16: + src_type: tl.constexpr = "f16" + elif x.dtype == tl.bfloat16: + src_type: tl.constexpr = "bf16" + else: + tl.static_assert(False, f"Unsupported dtype: {x.dtype}") + return tl.inline_asm_elementwise( + _accumulate_f16_into_f32_and_track_absmax_ptx(REDUCTION_SIZE, src_type, absmax_reg_name), + "=r,r" + (",h" * len(xs)), + (acc, ) + xs, + dtype=tl.float32, + is_pure=True, + pack=1, + ) + + +def _finalize_matmul_repr(specialization): + signature = specialization.signature + suffix = "" if "ScatterSrcIndx" in specialization.constants else "_scatter" + return f"_finalize_matmul{suffix}_{signature['A'][1:]}" + + +@triton.jit(repr=_finalize_matmul_repr, launch_metadata=_finalize_matmul_launch_metadata) +def _finalize_matmul( + Out, + OutExpectedScale, + OutActualScale, + OutChecksumScale, + stride_out_mx_m, stride_out_mx_n, + A, + stride_a_k, + stride_a_m, + AScale, + stride_a_mx_k, + stride_a_mx_m, + ScatterSrcIndx, + FinalizeScatterIdxs, + K: tl.constexpr, + M, + N, + NumRows, + # fused activation function + ACTIVATION_FN: tl.constexpr, + activation_fn_args, + ACTIVATION_REDUCTION_N: tl.constexpr, + # epilogue transform + EPILOGUE_FN: tl.constexpr, + epilogue_fn_args, + EXPT_PER_TOK: tl.constexpr, + flexpoint_saturate_inf: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGES: tl.constexpr, + HAS_FUSED_SCRATCHPAD: tl.constexpr, +): + IN_MXFP8: tl.constexpr = stride_a_mx_k is not None + OUT_MXFP8: tl.constexpr = stride_out_mx_m is not None + if HAS_FUSED_SCRATCHPAD: + # Bump A to the scratchpad region. + A += tl.cast(M, tl.int64) * stride_a_m + + USE_FUSED_MIXED_PREC_ACC: tl.constexpr = (cuda_capability_geq(10, 0) + and tl.constexpr(A.dtype.element_ty != tl.float32)) + USE_FUSED_ABSMAX: tl.constexpr = (USE_FUSED_MIXED_PREC_ACC and OutActualScale is not None) and ACTIVATION_FN is None + + THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads() + local_max = tl.full([THREADS_PER_BLOCK], 0.0, tl.float32) + if USE_FUSED_ABSMAX: + local_max = tl.inline_asm_elementwise( + """ + .reg .f32 my_abs_max; + mov.b32 my_abs_max, 0; + mov.b32 $0, 0; + """, "=r,r", [local_max], dtype=tl.float32, is_pure=False, pack=1) + + out_scale = load_scale(OutExpectedScale) + a_scale = load_scale(AScale) + + if FinalizeScatterIdxs is not None: + MBound = tl.load(FinalizeScatterIdxs + M + M * EXPT_PER_TOK) + if tl.program_id(0) >= MBound: + return + else: + MBound = M + + if NumRows is not None: + NumRows = NumRows # remove constexpr + if NumRows.dtype.is_ptr(): + NumRows = tl.load(NumRows) + + if FinalizeScatterIdxs is not None or (ScatterSrcIndx is not None and EXPT_PER_TOK > 1): + n_active_experts = 0 + else: + n_active_experts: tl.constexpr = EXPT_PER_TOK + + OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + outN = N // ACTIVATION_REDUCTION_N + + for pid_m in tl.range(tl.program_id(0), MBound, tl.num_programs(0)): + src_offs = pid_m * EXPT_PER_TOK + tl.arange(0, EXPT_PER_TOK) + if FinalizeScatterIdxs is not None: + row = tl.load(FinalizeScatterIdxs + pid_m) + src_idxs = tl.load(FinalizeScatterIdxs + M + src_offs) + n_active_experts = tl.sum((src_idxs != -1).to(tl.int32)) + elif ScatterSrcIndx is not None and EXPT_PER_TOK > 1: + row = pid_m + src_idxs = tl.load(ScatterSrcIndx + src_offs) + n_active_experts = tl.sum((src_idxs != -1).to(tl.int32)) + else: + row = pid_m + src_idxs = src_offs + if NumRows is not None: + src_idxs = tl.where(src_idxs < NumRows, src_idxs, -1) + + if n_active_experts == 0: + for off_n in tl.range(tl.program_id(1) * OUT_BLOCK_N, outN, tl.num_programs(1) * OUT_BLOCK_N): + offs_n = off_n + tl.arange(0, OUT_BLOCK_N) + n_mask = offs_n < outN + tl.store(Out + row * outN + offs_n, tl.zeros([OUT_BLOCK_N], dtype=Out.dtype.element_ty), mask=n_mask) + else: + for off_n in tl.range(tl.program_id(1) * BLOCK_N, N, tl.num_programs(1) * BLOCK_N, num_stages=STAGES): + offs_n = off_n + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + if IN_MXFP8: + MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE + N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE) + offs_n_scale = off_n // BLOCK_N * MX_SCALE_BLOCK_N + tl.arange(0, MX_SCALE_BLOCK_N)[None, :] + n_mask_scale = offs_n_scale < N_MX_BLOCK + + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + if is_hip(): + if EXPT_PER_TOK > 1: + src_idxs_tup = split_n(tl.reshape(src_idxs, (2, ) * log2(EXPT_PER_TOK)), log2(EXPT_PER_TOK)) + else: + # Convert 1D tensor to 1D tuple. + src_idxs_tup = tl.split(tl.reshape(tl.join(src_idxs, src_idxs), (2, )))[:1] + for i in tl.static_range(0, EXPT_PER_TOK, 1): + src_idx = src_idxs_tup[i] + if src_idx != -1: + As = A + src_idx.to(tl.int64) * stride_a_m + offs_n + for ki in tl.static_range(K): + acc += tl.load(As, mask=n_mask, other=0.0) + As += stride_a_k + else: + As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :] + if IN_MXFP8: + AScales = AScale + src_idxs.to(tl.int64)[:, None] * stride_a_mx_m + offs_n_scale[None, :] + for ki in tl.static_range(K): + a = tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0) + As += stride_a_k + if IN_MXFP8: + a_mx_scale = tl.load(AScales, mask=(src_idxs != -1)[:, None] & n_mask_scale[None, :]) + AScales += stride_a_mx_k + a_mx_scale = (a_mx_scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + a_mx_scale = a_mx_scale.reshape([EXPT_PER_TOK, MX_SCALE_BLOCK_N, 1]) + a = a.to(tl.float32).reshape([EXPT_PER_TOK, MX_SCALE_BLOCK_N, MXFP_BLOCK_SIZE]) + a = (a_mx_scale * a).reshape([EXPT_PER_TOK, BLOCK_N]) + acc += tl.sum(a, dtype=tl.float32, axis=0) + elif USE_FUSED_MIXED_PREC_ACC: + acc = _mixed_precision_accumulate_and_track_absmax( + acc, a, axis=0, + absmax_reg_name="my_abs_max" if USE_FUSED_ABSMAX and ki == K - 1 else None) + else: + acc += tl.sum(a, dtype=tl.float32, axis=0) + if not IN_MXFP8: + acc = acc * a_scale + if ACTIVATION_FN is not None: + out = ACTIVATION_FN(tl.reshape(acc, (1, BLOCK_N)), *activation_fn_args) + out = tl.reshape(out, (OUT_BLOCK_N, )) + else: + tl.static_assert(ACTIVATION_REDUCTION_N == 1, + "Activation reduction must be 1 if no activation fn is provided") + out = acc + if not USE_FUSED_ABSMAX and OutActualScale is not None: + local_max = tl.maximum(local_max, thread_local_absmax(out)) + if OUT_MXFP8: + OUT_MX_SCALE_BLOCK_N: tl.constexpr = OUT_BLOCK_N // MXFP_BLOCK_SIZE + OUT_N_MX_BLOCK: tl.constexpr = (outN + MXFP_BLOCK_SIZE - 1) // MXFP_BLOCK_SIZE + offs_n_scale = off_n // BLOCK_N * OUT_MX_SCALE_BLOCK_N + tl.arange(0, OUT_MX_SCALE_BLOCK_N)[None, :] + n_mask_scale = offs_n_scale < OUT_N_MX_BLOCK + acc, acc_scale = EPILOGUE_FN(acc[None, :], n_mask[None, :], *epilogue_fn_args, + pid=row * tl.num_programs(1) + tl.program_id(1)) + tl.static_assert(OUT_BLOCK_N % OUT_MX_SCALE_BLOCK_N == 0, "") + tl.store(OutActualScale + row * stride_out_mx_m + offs_n_scale * stride_out_mx_n, acc_scale, mask=n_mask_scale) + tl.store(Out + row * outN + offs_n[None, :], acc, mask=n_mask[None, :]) + else: + out = float_to_flex(out, out_scale if OutExpectedScale is not None else None, None, OutChecksumScale, + None, Out, flexpoint_saturate_inf) + if EPILOGUE_FN is not None: + out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=Out.dtype.element_ty, + pid=row * tl.num_programs(1) + tl.program_id(1)) + offs_n = off_n // ACTIVATION_REDUCTION_N + tl.arange(0, OUT_BLOCK_N) + n_mask = offs_n < outN + tl.store(Out + row * outN + offs_n, out, mask=n_mask) + + persisent_m = tl.num_programs(0) < MBound + if not persisent_m and n_active_experts == 0: + # Skip updating the scale if there were no active experts and this is a non-persistent launch. + # The loop ran only once, and inside it we only stored zeros. + return + + if USE_FUSED_ABSMAX: + local_max = tl.inline_asm_elementwise( + "mov.b32 $0, my_abs_max;", + "=r,r", + [local_max], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + local_max *= a_scale + if not OUT_MXFP8: + update_scale(local_max, OutActualScale, Out) diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_matmul_ogs.py b/torch-ext/triton_kernels/matmul_ogs_details/_matmul_ogs.py new file mode 100644 index 0000000000000000000000000000000000000000..9fde7d2e92345fe01fee1d94af634e83c38e6bdf --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/_matmul_ogs.py @@ -0,0 +1,464 @@ +# isort: off +# fmt: off +import triton +import triton.language as tl +from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw +from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper +from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton +from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale +from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE +from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string + + +@triton.jit +def _zero_masked_rows( + pid_m, pid_n, + Y, stride_y_m, stride_y_n, + N, + ScatterSrcIndx, num_idxs, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + offs_m = BLOCK_M * pid_m.to(tl.int64) + tl.arange(0, BLOCK_M) + offs_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) + src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) + YPtrs = Y + offs_m[:, None] * stride_y_m + offs_n[None, :] * stride_y_n + mask_n = offs_n < N + mask = (src_idx == -1)[:, None] & mask_n[None, :] + tl.store(YPtrs, tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32), mask=mask) + + +_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2]) +@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], + repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) +def _matmul_ogs( + Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n, + YExpectedScale, YActualScale, YChecksumScale, + stride_y_mx_z, stride_y_mx_m, stride_y_mx_n, + X, XPtr, stride_x_z, stride_x_m, stride_x_k, + XScale, + XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k, + W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr, + WScale, + WMxScale, stride_w_mx_e, stride_w_mx_k, stride_w_mx_n, + B, stride_b_e, # Bias + NRows, M, N, K, # shapes + # expt data + Betas, Gammas, + GatherIndx, + ScatterSrcIndx, num_idxs, + WriteBackIndx, writeback_size, + ExptHist, ExptOffs, ExptOffsSum, ExptData, + # true grid size + batch_size, grid_m, grid_n, + # Out scale + out_alpha, + # fused activation function + ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr, + # epilogue transform + EPILOGUE_FN: tl.constexpr, epilogue_fn_args, + # MoE config + N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + # precision config + MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, + FLEXPOINT_SATURATE_INF: tl.constexpr, + PER_BATCH_SCALE: tl.constexpr, + # optimization config + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, + # One of ["HOPPER", "BLACKWELL", None] + SWIZZLE_MX_VALUE: tl.constexpr, + # One of ["HOPPER", "BLACKWELL", None] + SWIZZLE_MX_SCALE: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + W_CACHE_MODIFIER: tl.constexpr, + NUM_SMS: tl.constexpr, + TOKENS_PER_EXPT_FOR_ANNOTATION=None, + UPCAST_INDICES: tl.constexpr = False, + DISABLE_Y_TMA: tl.constexpr = True, + SWAP_XW: tl.constexpr = False, + IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False): + + Y = Out # Y is passed for the purposes of annotation; replace it with Out + is_w_microscaled: tl.constexpr = WMxScale is not None + MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE + if is_w_microscaled: + w_type: tl.constexpr = W.dtype.element_ty + is_mxfp4: tl.constexpr = w_type == tl.uint8 + tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), + "mx_weight_ptr must be uint8 or fp8") + tl.static_assert(WMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + tl.static_assert(SWIZZLE_MX_VALUE == "HOPPER_VALUE" or SWIZZLE_MX_VALUE is None, "Only Hopper swizzling is supported for values") + else: + tl.static_assert(SWIZZLE_MX_VALUE is None) + tl.static_assert(SWIZZLE_MX_SCALE is None) + is_x_microscaled: tl.constexpr = XMxScale is not None + if is_x_microscaled: + x_type: tl.constexpr = X.dtype.element_ty + tl.static_assert(is_w_microscaled) + tl.static_assert(x_type == tl.float8e4nv, "mx_act_ptr must be float8e4nv") + tl.static_assert(XMxScale.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + is_out_microscaled: tl.constexpr = stride_y_mx_z is not None + + OUT_BLOCK_N: tl.constexpr = BLOCK_N // ACTIVATION_REDUCTION_N + yN = N // ACTIVATION_REDUCTION_N + + pid = tl.program_id(0) + if ExptOffsSum is not None and XCD_SWIZZLE > 1: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - tl.load(ExptOffsSum) + else: + padding_m: tl.constexpr = 0 + + HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None + index_type: tl.constexpr = tl.int64 if UPCAST_INDICES else tl.int32 + + total_actual_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K + if padding_m > 0 and pid >= total_actual_tiles: + tl.device_assert(batch_size == 0) + pid_mn = pid - total_actual_tiles + if pid_mn < padding_m * grid_n: + pid_m, pid_n = swizzle2d(pid_mn, padding_m, grid_n, GROUP_M) + + # set masked out rows to 0 + if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N) + return + + # swizzle program ids + pid_emnk = pid + if XCD_SWIZZLE != 1: + pid_emnk = xcd_swizzle(pid_emnk, total_actual_tiles, XCD_SWIZZLE) + pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K) + pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K) + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M) + # For split-k, advance to the output k slice + if SPLIT_K > 1: + Y += pid_k.to( index_type) * stride_y_k + if is_out_microscaled: + YActualScale += pid_k.to(index_type) * stride_x_mx_k + # set masked out rows to 0 + if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + _zero_masked_rows(pid_m, pid_n, Y, stride_y_m, stride_y_n, yN, ScatterSrcIndx, num_idxs, BLOCK_M, OUT_BLOCK_N) + # unpack expert data + if ExptData is None: + tl.static_assert(M is not None) + expt_id, start_z, start_m, block_id = pid_e, pid_e, 0, pid_m + else: + tl.static_assert(M is None) + expt_data = tl.load(ExptData + pid_m) + if expt_data == -1: + return + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + M = tl.load(ExptHist + expt_id) + start_m = tl.load(ExptOffs + expt_id) + start_z = 0 + expt_id, block_id = expt_id.to(index_type), block_id.to(index_type) + start_m, start_z = start_m.to(index_type), start_z.to(index_type) + pid_n, pid_k = pid_n.to(index_type), pid_k.to(index_type) + # A pointers + offs_x_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M) + offs_x_m = tl.max_contiguous(tl.multiple_of(offs_x_m % M, BLOCK_M), BLOCK_M) + X += start_z * stride_x_z + if GatherIndx is None: + X += start_m * stride_x_m + else: + GatherIndx += start_m + # no needs to bounds-check here because `offs_x_m` wraps around M dim + offs_x_m = tl.load(GatherIndx + offs_x_m) // N_EXPTS_ACT + offs_k = BLOCK_K * pid_k + tl.arange(0, BLOCK_K) + XPtrs = X + offs_x_m.to(index_type)[:, None] * stride_x_m + offs_k.to(index_type)[None, :] * stride_x_k + + # TODO: refactor if/else when triton front end improves + if is_w_microscaled: + if SWIZZLE_MX_VALUE == "HOPPER_VALUE": + tl.static_assert(is_mxfp4, "Only mxfp4 is supported for HOPPER swizzling") + tl.static_assert(not is_x_microscaled) + # We have pack 2 fp4 values in a byte but we divide the dimension by 2 + # when swizzling + W_K_DIVISOR: tl.constexpr = 1 + W_K_MULTIPLIER: tl.constexpr = 2 + W_N_DIVISOR: tl.constexpr = 4 + else: + # We have pack 2 fp4 values in a byte + W_K_DIVISOR: tl.constexpr = 2 if is_mxfp4 else 1 + W_K_MULTIPLIER: tl.constexpr = 1 + W_N_DIVISOR: tl.constexpr = 1 + + PACKED_BLOCK_K_W: tl.constexpr = (BLOCK_K // W_K_DIVISOR) * W_K_MULTIPLIER + PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N // W_N_DIVISOR + MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR + + WMxScale += expt_id * stride_w_mx_e + + if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": + tl.static_assert(BLOCK_N % 128 == 0) + tl.static_assert(MX_SCALE_BLOCK_K % 4 == 0) + PACKED_MX_BLOCK: tl.constexpr = (MX_SCALE_BLOCK_K // 4) * 32 * 4 * 4 + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 128 + stride_scale_k: tl.constexpr = 1 + elif SWIZZLE_MX_SCALE == "HOPPER_SCALE": + n_warps: tl.constexpr = tl.extra.cuda.num_warps() + tl.static_assert(BLOCK_N % (2 * n_warps * 2 * 8) == 0) + tl.static_assert(MX_SCALE_BLOCK_K % 2 == 0) + PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K * 32 + SCALE_BLOCK_N: tl.constexpr = BLOCK_N // 32 + stride_scale_k = stride_w_mx_k + else: + PACKED_MX_BLOCK: tl.constexpr = MX_SCALE_BLOCK_K + SCALE_BLOCK_N: tl.constexpr = BLOCK_N + stride_scale_k = stride_w_mx_k + offs_n_scale = (pid_n * SCALE_BLOCK_N + tl.arange(0, SCALE_BLOCK_N)) % N + offs_n_scale = tl.max_contiguous(tl.multiple_of(offs_n_scale, SCALE_BLOCK_N), SCALE_BLOCK_N) + # K dimension must be the last dimension for the scales + offs_k_scale = PACKED_MX_BLOCK * pid_k + tl.arange(0, PACKED_MX_BLOCK) + WMxScalePtrs = WMxScale + offs_k_scale.to(index_type)[None, :] * stride_scale_k + offs_n_scale.to(index_type)[:, None] * stride_w_mx_n + else: + WMxScalePtrs = None + offs_k_scale = None + W_K_DIVISOR: tl.constexpr = 1 + W_K_MULTIPLIER: tl.constexpr = 1 + W_N_DIVISOR: tl.constexpr = 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K + PACKED_BLOCK_N_W: tl.constexpr = BLOCK_N + + # B pointers + offs_w_n = pid_n * PACKED_BLOCK_N_W + tl.arange(0, PACKED_BLOCK_N_W) + offs_w_n = tl.max_contiguous(tl.multiple_of(offs_w_n % (N // W_N_DIVISOR), PACKED_BLOCK_N_W), PACKED_BLOCK_N_W) + + if is_x_microscaled: + XMxScale += start_z.to(index_type) * stride_x_mx_z + if GatherIndx is None: + XMxScale += start_m * stride_x_mx_m + offs_x_k_scale = MX_SCALE_BLOCK_K * pid_k + tl.arange(0, MX_SCALE_BLOCK_K) + XMxScalePtrs = XMxScale + offs_x_m.to(index_type)[:, None] * stride_x_mx_m + offs_x_k_scale.to(index_type)[None, :] * stride_x_mx_k + else: + XMxScalePtrs = None + + offs_w_k = PACKED_BLOCK_K_W * pid_k + tl.arange(0, PACKED_BLOCK_K_W) + W += expt_id * stride_w_e + WPtrs = W + (offs_w_k.to(index_type)[:, None] * stride_w_k + offs_w_n.to(index_type)[None, :] * stride_w_n) + # compute output + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(K, BLOCK_K * pid_k, -(BLOCK_K * SPLIT_K)): + if EVEN_K: + mask_k = tl.full([BLOCK_K], True, dtype=tl.int1) + mask_k_w = tl.full([PACKED_BLOCK_K_W], True, dtype=tl.int1) + if is_w_microscaled and SWIZZLE_MX_SCALE is None: + mask_k_scale = tl.full([PACKED_MX_BLOCK], True, dtype=tl.int1) + if is_x_microscaled: + mask_x_k_scale = tl.full([MX_SCALE_BLOCK_K], True, dtype=tl.int1) + else: + mask_k = offs_k < k + mask_k_w = offs_w_k < ((k // W_K_DIVISOR) * W_K_MULTIPLIER) + if is_w_microscaled and SWIZZLE_MX_SCALE is None: + mask_k_scale = offs_k_scale * MX_PACK_DIVISOR < k + if is_x_microscaled: + mask_x_k_scale = offs_x_k_scale * MX_PACK_DIVISOR < k + + x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + w = tl.load(WPtrs, mask=mask_k_w[:, None], other=0.0, cache_modifier=W_CACHE_MODIFIER) + if is_w_microscaled: + x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) + w_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) + + if is_x_microscaled: + x_scales = tl.load(XMxScalePtrs, mask=mask_x_k_scale[None, :]) + elif x_format == "fp16" or x_format == "bf16": + x_scales: tl.constexpr = None + else: + # Scale of 1 in E8M0 format + x_scales = tl.full((BLOCK_M, MX_SCALE_BLOCK_K), 127, dtype=tl.uint8) + + if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": + w_scales = unswizzle_mx_scale_bw(tl.load(WMxScalePtrs)) + elif SWIZZLE_MX_SCALE == "HOPPER_SCALE": + # Handshake with the swizzling code + num_warps: tl.constexpr = tl.extra.cuda.num_warps() + w_scales = unswizzle_mxfp4_scale_hopper(tl.load(WMxScalePtrs), mx_axis=1, num_warps=num_warps) + else: + w_scales = tl.load(WMxScalePtrs, mask=mask_k_scale[None, :]) + + if SWIZZLE_MX_VALUE == "HOPPER_VALUE": + # Handshake with the swizzling code + tl.static_assert(x_format == "bf16") + tl.static_assert(w_format == "e2m1") + w = mxfp4_to_bf16_triton(w.trans(), w_scales, 1) + tl.static_assert(w.dtype == tl.bfloat16) + acc = acc.trans() + x = x.trans() + # w = w.trans() + acc = tl.dot(w, x, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + acc = acc.trans() + else: + acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, w_format, acc=acc, fast_math=True) + if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": + WMxScalePtrs += (MX_SCALE_BLOCK_K // 4 * SPLIT_K) * stride_w_mx_k + else: + WMxScalePtrs += (PACKED_MX_BLOCK * SPLIT_K) * stride_w_mx_k + if is_x_microscaled: + XMxScalePtrs += (MX_SCALE_BLOCK_K * SPLIT_K) * stride_x_mx_k + else: + acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + XPtrs += (BLOCK_K * SPLIT_K) * stride_x_k + WPtrs += (PACKED_BLOCK_K_W * SPLIT_K) * stride_w_k + # bias + scale + offs_m = BLOCK_M * block_id + tl.arange(0, BLOCK_M) + offs_y_n = BLOCK_N * pid_n + tl.arange(0, BLOCK_N) + mask_m = offs_m < M + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id * stride_b_e + offs_y_n + if pid_k == 0: + bias = tl.load(BPtrs, mask=mask_n, other=0) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + if Betas is not None: + betas = tl.load(Betas + start_m + offs_m, mask=mask_m, other=0.0) + else: + betas = tl.full([BLOCK_M], 1, dtype=tl.float32) + if Gammas is not None: + gammas = tl.load(Gammas + start_m + offs_m, mask=mask_m, other=0.0) + else: + gammas = tl.full([BLOCK_M], 1, dtype=tl.float32) + # flexpoint + x_scale = load_scale(XScale) + if PER_BATCH_SCALE: + w_scale = load_scale(WScale + expt_id) + else: + w_scale = load_scale(WScale) + acc *= x_scale * w_scale + acc = acc + bias[None, :] * betas[:, None] + if out_alpha is not None: + acc *= out_alpha + if ACTIVATION_FN is not None: + out = ACTIVATION_FN(acc, *activation_fn_args) + tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})") + offs_y_n = OUT_BLOCK_N * pid_n + tl.arange(0, OUT_BLOCK_N) + mask_n = offs_y_n < yN + else: + tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided") + out = acc + out *= gammas[:, None] + # write-back + Y += start_z.to(index_type) * stride_y_z + if WriteBackIndx is not None: + WriteBackIndx += start_m + dst_idx = tl.load(WriteBackIndx + offs_m, mask=start_m + offs_m < writeback_size, other=-1) + mask_m = mask_m & (dst_idx != -1) + offs_y_m = dst_idx + else: + Y += start_m * stride_y_m + offs_y_m = offs_m + + YPtrs = Y + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n.to(index_type)[None, :] * stride_y_n + mask = mask_m[:, None] & mask_n[None, :] + if is_out_microscaled: + MX_SCALE_BLOCK_N: tl.constexpr = BLOCK_N // MXFP_BLOCK_SIZE + N_MX_BLOCK: tl.constexpr = tl.cdiv(N, MXFP_BLOCK_SIZE) + tl.static_assert(EPILOGUE_FN is not None) + out, out_scale = EPILOGUE_FN(out, mask, *epilogue_fn_args) + tl.static_assert(BLOCK_N % MX_SCALE_BLOCK_N == 0, "") + offs_y_n_scale = MX_SCALE_BLOCK_N * pid_n + tl.arange(0, MX_SCALE_BLOCK_N) + mask_n_scale = offs_y_n_scale < N_MX_BLOCK + YActualScale += start_z.to(index_type) * stride_y_mx_z + if WriteBackIndx is None: + YActualScale += start_m * stride_y_mx_m + YActualScalePtrs = YActualScale + offs_y_m.to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n + else: + YActualScalePtrs = YActualScale + (offs_y_m - NRows).to(index_type)[:, None] * stride_y_mx_m + offs_y_n_scale.to(index_type)[None, :] * stride_y_mx_n + tl.store(YActualScalePtrs, out_scale, mask=mask_m[:, None] & mask_n_scale[None, :]) + else: + out = float_to_flex(out, YExpectedScale, YActualScale, YChecksumScale, mask, Y, FLEXPOINT_SATURATE_INF) + if EPILOGUE_FN is not None and not IS_EPILOGUE_DEQUANT_MXFP8: + out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=YPtrs.dtype.element_ty) + tl.store(YPtrs, out, mask=mask) + + +# Imagine N_EXPTS_ACT = 4, n_final_rows = 5, and n_scratchpad_rows = 8. +# Also imagine scatter_indx.src_indx is: +# (number of active experts per final row) +# -1 -1 0 -1 1 +# -1 2 -1 -1 1 +# 1 3 -1 -1 2 +# -1 4 5 6 3 +# -1 -1 -1 -1 0 (this row is unused) +# +# Then, row 0 and 1 can be written directly to the final tensor. +# In this case, WriteBackIndx looks like: +# [0] = 0 : intermediate row 0 is written directly to final row 0 +# [1] = 5+1=6 : scratchpad starts at offset 5 +# [2] = 1 : intermediate row 2 is written directly to final row 1 +# [3] = 5+3=8 +# [4] = 5+4=9 +# [5] = 5+5=10 +# [6] = 5+6=11 +# [7] = -1 : unused (there are only seven intermediate rows) +@triton.jit +def _compute_writeback_idx( + WriteBackIndx, + FinalizeScatterIdxs, + ScatterDstIndx, ScatterSrcIndx, + n_final_rows, n_scratchpad_rows, + BLOCK_M: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, +): + tl.static_assert(N_EXPTS_ACT > 1) + + pid_m = tl.program_id(0) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + mask_m = offs_m < n_scratchpad_rows + dst_idxs = tl.load(ScatterDstIndx + offs_m, mask=mask_m, other=-1) + # Load corresponding rows in ScatterSrcIndx. + mask = dst_idxs != -1 + src_offs = (dst_idxs // N_EXPTS_ACT) * N_EXPTS_ACT + src_offs = src_offs[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :] + src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask[:, None], other=-1) + # Compute the number of actually active experts. + is_src_active = (src_idxs != -1).to(tl.int32) + has_one_active = tl.sum(is_src_active, axis=1) == 1 + # Compute the writeback index. + wb_idx = tl.where(has_one_active, dst_idxs // N_EXPTS_ACT, n_final_rows + offs_m) + wb_idx = tl.where(mask, wb_idx, -1) + tl.store(WriteBackIndx + offs_m, wb_idx, mask=mask_m) + + if pid_m >= ((n_final_rows + BLOCK_M - 1) // BLOCK_M): + return + + mask_m = offs_m < n_final_rows + src_offs = offs_m[:, None] * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT)[None, :] + src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask_m[:, None], other=-1) + is_src_active = (src_idxs != -1).to(tl.int32) + num_src_active = tl.sum(is_src_active, axis=1) + + need_finalize_scatter = mask_m & (num_src_active != 1) + finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32)) + if finalize_scatter_count == 0: + return + pp_off = tl.atomic_add(FinalizeScatterIdxs + n_final_rows + n_scratchpad_rows, finalize_scatter_count) + + # need_finalize_scatter = [1, 0, 0, 1, 1, 0, 1, 0, 1] + # arange = [0, 1, 2, 3, 4, 5, 6, 7, 8] + arange = tl.arange(0, BLOCK_M) + # idxs = [0, _, _, 3, 4, _, 6, _, 8] + last = BLOCK_M - 1 + idxs = tl.where(need_finalize_scatter, arange, last) + # idxs = [0, 3, 4, 6, 8, _, _, _, _] + idxs = tl.sort(idxs) + # r = offs_m + # d = [r[0], r[3], r[4], r[6], r[8], r[-1], r[-1], r[-1], r[-1]] + d = tl.gather(offs_m, idxs, axis=0) + s = tl.gather(src_idxs, idxs.expand_dims(1).broadcast_to(src_idxs.shape), axis=0) + # store destination indices + Ptr = FinalizeScatterIdxs + pp_off + tl.store(Ptr + arange, d, mask=arange < finalize_scatter_count) + # store src indices + Ptr = FinalizeScatterIdxs + n_final_rows + pp_off * N_EXPTS_ACT + tl.store(Ptr + N_EXPTS_ACT * arange[:, None] + tl.arange(0, N_EXPTS_ACT)[None, :], s, mask=(arange < finalize_scatter_count)[:, None]) diff --git a/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py b/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py new file mode 100644 index 0000000000000000000000000000000000000000..fed5402c8ad83149a9b4ba339966d7bb50bf0268 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py @@ -0,0 +1,505 @@ +# isort: off +# fmt: off +import torch +import triton +import triton.language as tl +from triton_kernels import target_info +from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw +from triton_kernels.numerics_details.flexpoint import ( + float_to_flex, + load_scale, + nan_propagating_absmax_reduce, + compute_scale, +) +from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE +from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string + + +@tl.constexpr_function +def cuda_capability_geq(major, minor): + return target_info.cuda_capability_geq(major, minor) + +@tl.constexpr_function +def get_dtype(tensor_or_desc: tl.tensor | tl.tensor_descriptor) -> tl.dtype: + if isinstance(tensor_or_desc, tl.tensor): + return tensor_or_desc.dtype.element_ty + elif isinstance(tensor_or_desc, tl.tensor_descriptor): + return tensor_or_desc.dtype + else: + raise ValueError(f"Invalid type: {type(tensor_or_desc)}") + + +@triton.jit +def _tma_load_2d(desc, offs, transpose: tl.constexpr = False): + if len(desc.shape) == 2 and len(offs) == 3: + tl.device_assert(offs[0] == 0, "2D TMA load requires Z offset to be 0") + offs = offs[1:] + if transpose: + offs = offs[:-2] + [offs[-1], offs[-2]] + res = desc.load(offs) + res = tl.reshape(res, desc.block_shape[-2:]) + if transpose: + res = tl.trans(res) + return res + + +# Helper function to recreate a TMA desc with the same fields, but with a new pointer and optional new shape. +@triton.jit +def _update_tensor_desc(desc, ptr, shape=None): + return tl.make_tensor_descriptor( + ptr, + shape=shape or desc.shape, + # last dim must be constexpr 1; reflecting the old descriptor drops the constexpr + strides=desc.strides[:-1] + [tl.constexpr(1)], + block_shape=desc.block_shape, + ) + + +@triton.jit +def _load_tile_attrs( + tile_id, num_tiles, grid_m, grid_n, padding_m, + M, ExptData, ExptHist, ExptOffs, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, SPLIT_K: tl.constexpr, + GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr): + # unpack and swizzle program ids + pid_emnk = tile_id + if XCD_SWIZZLE != 1: + pid_emnk = xcd_swizzle(pid_emnk, num_tiles // SPLIT_K, XCD_SWIZZLE) + pid_e = pid_emnk // ((grid_m - padding_m) * grid_n * SPLIT_K) + pid_mnk = pid_emnk % ((grid_m - padding_m) * grid_n * SPLIT_K) + if SPLIT_K > 1: + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + else: + pid_k: tl.constexpr = 0 + pid_mn = pid_mnk + pid_m, pid_n = swizzle2d(pid_mn, (grid_m - padding_m), grid_n, GROUP_M) + + # unpack expert data + if ExptData is None: + tl.static_assert(M is not None) + expt_id, start_z, start_m, block_id, eM = pid_e, pid_e, 0, pid_m, -1 + else: + tl.static_assert(M is None) + expt_data = tl.load(ExptData + pid_m) + expt_id = expt_data & 0x0000FFFF + block_id = expt_data >> 16 + eM = tl.load(ExptHist + expt_id) + start_m = tl.load(ExptOffs + expt_id) + start_z = 0 + + off_m = BLOCK_M * block_id + off_n = BLOCK_N * pid_n + + return expt_id, start_z, start_m, eM, off_m, off_n, pid_k + + +@triton.jit +def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask): + mask = mask & (offs < writeback_size) + offs = tl.load(WriteBackIndx + offs, mask=mask, other=-1) + mask = offs != -1 + return (offs, mask) + + +_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2]) +@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"], + repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata) +def _p_matmul_ogs( + Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n, + YExpectedScale, YActualScale, YChecksumScale, + stride_y_mx_z, stride_y_mx_m, stride_y_mx_n, + X, XPtr, stride_x_z, stride_x_m, stride_x_k, + XScale, + XMxScale, stride_x_mx_z, stride_x_mx_m, stride_x_mx_k, + W, stride_w_e, stride_w_k, stride_w_n, W_TRANSPOSE: tl.constexpr, + WScale, + MxScale, stride_mx_e, stride_mx_k, stride_mx_n, + B, stride_b_e, # Bias + NRows, M, N, K, # shapes + # expt data + Betas, Gammas, + GatherIndx, + ScatterSrcIndx, num_idxs, + WriteBackIndx, writeback_size, + ExptHist, ExptOffs, ExptOffsSum, ExptData, + # true grid size + batch_size, grid_m, grid_n, + # Out scale + out_alpha, + # fused activation function + ACTIVATION_FN: tl.constexpr, activation_fn_args, ACTIVATION_REDUCTION_N: tl.constexpr, + # epilogue transform + EPILOGUE_FN: tl.constexpr, epilogue_fn_args, + # MoE config + N_EXPTS_TOT: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + # precision config + MAX_NUM_IMPRECISE_ACC: tl.constexpr, ALLOW_TF32: tl.constexpr, + FLEXPOINT_SATURATE_INF: tl.constexpr, + PER_BATCH_SCALE: tl.constexpr, + # optimization config + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, XCD_SWIZZLE: tl.constexpr, + # NYI: Must be None + SWIZZLE_MX_VALUE: tl.constexpr, + # One of ["BLACKWELL", None] + SWIZZLE_MX_SCALE: tl.constexpr, + EPILOGUE_SUBTILE: tl.constexpr, + EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + W_CACHE_MODIFIER: tl.constexpr, + NUM_SMS: tl.constexpr, + TOKENS_PER_EXPT_FOR_ANNOTATION=None, + UPCAST_INDICES:tl.constexpr=False, + DISABLE_Y_TMA: tl.constexpr=False, + SWAP_XW: tl.constexpr = False, + IS_EPILOGUE_DEQUANT_MXFP8: tl.constexpr = False): + tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling") + Y = Out # Y is passed for the purposes of annotation; replace it with Out + + is_microscaled_format: tl.constexpr = MxScale is not None + MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE + if is_microscaled_format: + w_type: tl.constexpr = get_dtype(W) + tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5), + "mx_weight_ptr must be uint8") + tl.static_assert(get_dtype(MxScale) == tl.uint8, "mx_scale_ptr must be uint8") + tl.static_assert(BLOCK_K % MX_PACK_DIVISOR == 0, "BLOCK_K must be a multiple of MX_PACK_DIVISOR") + tl.static_assert(SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None, "Only Blackwell swizzling is supported for scales") + + # We have pack 2 fp4 values in a byte + W_PACK_DIVISOR: tl.constexpr = 2 if w_type == tl.uint8 else 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K // W_PACK_DIVISOR + MX_SCALE_BLOCK_K: tl.constexpr = BLOCK_K // MX_PACK_DIVISOR + else: + W_PACK_DIVISOR: tl.constexpr = 1 + MX_SCALE_BLOCK_K: tl.constexpr = 1 + PACKED_BLOCK_K_W: tl.constexpr = BLOCK_K + tl.static_assert(SWIZZLE_MX_SCALE is None) + + if ExptOffsSum is not None: + # Determine how much padding there is on the expert data. This allows us to + # know the true grid size and avoid processing padding tiles. + padding_m = grid_m - tl.load(ExptOffsSum) + else: + padding_m: tl.constexpr = 0 + + HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None + index_type: tl.constexpr = tl.int64 + + if EPILOGUE_SUBTILE is None: + SUBTILE_FACTOR: tl.constexpr = 1 + else: + SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE + EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR + OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N + yN = N // ACTIVATION_REDUCTION_N + + # set masked out rows to 0 + if HAS_FUSED_SCATTER and N_EXPTS_ACT == 1: + # Iterate with reversed pids so that later pids will get more tiles if the number of + # tiles isn't evenly divisible by the number of SMs. + # The main loop after this iterates in the forward direction such that earlier + # pids get more tiles if the number of tiles isn't evenly divisible. + # This helps balance the work across the SMs. + for pid_mnk in range(NUM_SMS - tl.program_id(0) - 1, batch_size * grid_m * grid_n * SPLIT_K, NUM_SMS): + pid_k = pid_mnk % SPLIT_K + pid_mn = pid_mnk // SPLIT_K + pid_m, pid_n = swizzle2d(pid_mn, grid_m, grid_n, GROUP_M) + + z = tl.zeros([BLOCK_M, BLOCK_N // ACTIVATION_REDUCTION_N], dtype=tl.float32) + offs_m = z.shape[0] * pid_m + tl.arange(0, z.shape[0]) + offs_n = z.shape[1] * pid_n + tl.arange(0, z.shape[1]) + src_idx = tl.load(ScatterSrcIndx + offs_m, mask=offs_m < num_idxs, other=0) + YPtrs = Y + offs_m.to(index_type)[:, None] * stride_y_m + offs_n[None, :] * stride_y_n + mask_n = offs_n < yN + mask = (src_idx == -1)[:, None] & mask_n[None, :] + tl.store(YPtrs + pid_k * stride_y_k, z, mask=mask) + + USE_FLEXPOINT_SCALE: tl.constexpr = YActualScale is not None or YChecksumScale is not None + + USE_GATHER_TMA: tl.constexpr = GatherIndx is not None and cuda_capability_geq(10, 0) + X_USE_LOAD_TMA: tl.constexpr = GatherIndx is None and isinstance(X, tl.tensor_descriptor) + USE_SCATTER_TMA: tl.constexpr = (cuda_capability_geq(10, 0) and HAS_FUSED_SCATTER) and not DISABLE_Y_TMA + INT_MAX: tl.constexpr = 2147483647 + + if USE_SCATTER_TMA: + y_desc = tl.make_tensor_descriptor( + Y, + # No masking on the M dimension because we manually mask by setting indices to INT_MAX + shape=[INT_MAX - 1, yN], + strides=[stride_y_m, stride_y_n], + block_shape=[1, OUT_BLOCK_N], + ) + + k_tiles = tl.cdiv(K, BLOCK_K * SPLIT_K) + num_tiles = batch_size * (grid_m - padding_m) * grid_n * SPLIT_K + + # If true, do not share loop-carried variables between the prologue and the + # epilogue to enable better pipelining with mmav5 + INDEPENDENT_EPILOGUE: tl.constexpr = cuda_capability_geq(10, 0) + + # start negative; will be incremented at the top of the loop + if INDEPENDENT_EPILOGUE: + tile_id1 = tl.program_id(0) - NUM_SMS + + # Keep track of local max for updating flexpoint scales. + THREADS_PER_BLOCK: tl.constexpr = tl.extra.cuda.num_threads() + local_absmax = tl.full([THREADS_PER_BLOCK], 0.0, tl.uint32) + + DISALLOW_ACC_MULTI_BUFFER: tl.constexpr = is_microscaled_format and BLOCK_M * BLOCK_N >= 128 * 256 + # Enable warp specialization when all loads are TMA loads. + WARP_SPECIALIZE: tl.constexpr = (USE_GATHER_TMA or X_USE_LOAD_TMA) + + for tile_id in tl.range(tl.program_id(0), num_tiles, NUM_SMS, flatten=True, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER, warp_specialize=WARP_SPECIALIZE): + expt_id, start_z, start_m, eM, off_m, off_n, pid_k = _load_tile_attrs( + tile_id, num_tiles, grid_m, grid_n, padding_m, + M, ExptData, ExptHist, ExptOffs, + BLOCK_M, BLOCK_N, SPLIT_K, + GROUP_M, XCD_SWIZZLE) + + # Base pointers and offsets. + if not USE_GATHER_TMA and not X_USE_LOAD_TMA: + XBase = X + start_z.to(index_type) * stride_x_z + offs_x_k = tl.arange(0, BLOCK_K)[None, :] * stride_x_k + if SPLIT_K > 1: + offs_x_k += pid_k.to(index_type) * BLOCK_K * stride_x_k + + if not X_USE_LOAD_TMA: + offs_m = off_m + tl.arange(0, BLOCK_M) + mask_m = offs_m < (M if M is not None else eM) + if USE_GATHER_TMA: + # Mask the gather indices and load -1 instead. TMA will handle OOB accesses. + if ExptData is None: + offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, mask=mask_m) + # Bump rows to account for the Z offset. + offs_x_m += start_z * (stride_x_z // stride_x_m) + offs_x_m = tl.where(mask_m, offs_x_m, -1) + else: + offs_x_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m, + mask=mask_m, other=-N_EXPTS_ACT) // N_EXPTS_ACT + else: + if M is not None: + offs_m = tl.max_contiguous(tl.multiple_of(offs_m % M, BLOCK_M), BLOCK_M) + else: + offs_m = tl.max_contiguous(tl.multiple_of(offs_m % eM, BLOCK_M), BLOCK_M) + # no needs to bounds-check here because `offs_m` wraps around M dim + offs_m = tl.load(GatherIndx + start_m.to(index_type) + offs_m) // N_EXPTS_ACT + offs_x_m = offs_m.to(index_type)[:, None] * stride_x_m + + acc = tl.zeros((BLOCK_N, BLOCK_M) if SWAP_XW else (BLOCK_M, BLOCK_N), dtype=tl.float32) + for ki in tl.range(k_tiles, disallow_acc_multi_buffer=DISALLOW_ACC_MULTI_BUFFER): + off_k = pid_k * BLOCK_K + ki * BLOCK_K * SPLIT_K + off_k_w = pid_k * PACKED_BLOCK_K_W + ki * PACKED_BLOCK_K_W * SPLIT_K + off_k_mx = pid_k * MX_SCALE_BLOCK_K + ki * MX_SCALE_BLOCK_K * SPLIT_K + + if USE_GATHER_TMA: + x = X.gather(offs_x_m, off_k) + elif X_USE_LOAD_TMA: + x = _tma_load_2d(X, [start_z, start_m + off_m, off_k]) + else: + XPtrs = XBase + offs_x_m + offs_x_k + XBase += BLOCK_K * SPLIT_K * stride_x_k + mask_k = tl.arange(0, BLOCK_K) < K - off_k + if EVEN_K: + if SPLIT_K > 1: + x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + else: + x = tl.load(XPtrs) + else: + x = tl.load(XPtrs, mask=mask_k[None, :], other=0.0) + + w = _tma_load_2d(W, [expt_id, off_k_w, off_n], transpose=W_TRANSPOSE) + + if is_microscaled_format: + x_format: tl.constexpr = get_scaled_dot_format_string(x.dtype) + mx_format: tl.constexpr = get_scaled_dot_format_string(w.dtype) + if x_format == "fp16" or x_format == "bf16": + x_scales: tl.constexpr = None + else: + x_scales = tl.full((BLOCK_M, BLOCK_K // MX_PACK_DIVISOR), 127, dtype=tl.uint8) + if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE": + flattened_expt_n_idx = expt_id * ((N + 127) // 128) + (off_n // 128) + w_scales = MxScale.load([0, flattened_expt_n_idx, pid_k * MX_SCALE_BLOCK_K // 4 + ki * (MX_SCALE_BLOCK_K // 4 * SPLIT_K), 0, 0]) + w_scales = w_scales.reshape((w_scales.shape[1], w_scales.shape[2] * w_scales.shape[-2] * w_scales.shape[-1])) + w_scales = unswizzle_mx_scale_bw(w_scales) + else: + w_scales = _tma_load_2d(MxScale, [expt_id, off_k_mx, off_n]).T + if SWAP_XW: + acc = tl.dot_scaled(w.T, w_scales, mx_format, x.T, x_scales, x_format, acc=acc, fast_math=True) + else: + acc = tl.dot_scaled(x, x_scales, x_format, w, w_scales, mx_format, acc=acc, fast_math=True) + else: + if SWAP_XW: + acc = tl.dot(w.T, x.T, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + else: + acc = tl.dot(x, w, acc, max_num_imprecise_acc=MAX_NUM_IMPRECISE_ACC, allow_tf32=ALLOW_TF32) + + if INDEPENDENT_EPILOGUE: + tile_id1 += NUM_SMS + expt_id1, start_z1, start_m1, eM1, off_m1, off_n1, pid_k1 = _load_tile_attrs( + tile_id1, num_tiles, grid_m, grid_n, padding_m, + M, ExptData, ExptHist, ExptOffs, + BLOCK_M, BLOCK_N, SPLIT_K, + GROUP_M, XCD_SWIZZLE) + else: + tile_id1, expt_id1, start_z1, start_m1, eM1 = tile_id, expt_id, start_z, start_m, eM + off_m1, off_n1, pid_k1 = off_m, off_n, pid_k + + # Determine output row offsets and mask + offs_m = off_m1 + tl.arange(0, BLOCK_M) + mask_m = offs_m < M if M is not None else offs_m < eM1 + if HAS_FUSED_SCATTER: + offs_y_m, mask_m = _load_writeback_idx_and_mask( + WriteBackIndx, writeback_size, start_m1 + offs_m, mask_m) + # Later, mask out the acc for computing flexpoint scales. + MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE + + if USE_SCATTER_TMA and SPLIT_K > 1: + # Compute the split k offset in number of rows, and add it to offs_y_m. + # This allows us to write to the correct slice in the output tensor while using + # a 2D TMA scatter. + tl.device_assert(stride_y_k // stride_y_m == tl.cdiv(stride_y_k, stride_y_m)) + split_k_row_offs = pid_k1 * (stride_y_k // stride_y_m) + offs_y_m = tl.where(mask_m, offs_y_m + split_k_row_offs, offs_y_m) + else: + offs_y_m = start_m1 + offs_m + + if USE_GATHER_TMA: + MASK_ACC: tl.constexpr = False + else: + # Later, mask out the acc for computing flexpoint scales. + MASK_ACC: tl.constexpr = USE_FLEXPOINT_SCALE + + # TMA is faster on Blackwell if a SWAP_XW transpose is not needed, or when we need registers to mask out the acc. + # Contrary to the SWAP_XW case, having a fused activation function tends to make TMA faster again. + # For the ideal optimization, this would depend on what the activation function is doing. + Y_USE_TMA: tl.constexpr = (MASK_ACC or cuda_capability_geq(10, 0)) and not ( + DISABLE_Y_TMA or (SWAP_XW and ACTIVATION_FN is None)) + + YBase = Y + start_z1.to(index_type) * stride_y_z + start_m1.to(index_type) * stride_y_m + if USE_SCATTER_TMA: + if ExptData is None: # start_z1 may change; update the descriptor + y_desc = _update_tensor_desc(y_desc, YBase) + elif not HAS_FUSED_SCATTER and Y_USE_TMA: + y_desc = tl.make_tensor_descriptor( + YBase + pid_k1.to(index_type) * stride_y_k, + shape=[M if M is not None else eM1, yN], + strides=[stride_y_m, stride_y_n], + block_shape=[BLOCK_M, OUT_BLOCK_N], + ) + + # bias + scale + offs_y_n = off_n1 + tl.arange(0, BLOCK_N) + mask_n = offs_y_n < N + if B is not None: + BPtrs = B + expt_id1 * stride_b_e + offs_y_n + if pid_k1 == 0: + bias = tl.load(BPtrs, mask=mask_n, other=0) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + else: + bias = tl.full([BLOCK_N], 0, dtype=tl.float32) + if Betas is not None: + betas = tl.load(Betas + start_m1 + offs_m, mask=mask_m, other=0.0) + else: + betas = tl.full([BLOCK_M], 1, dtype=tl.float32) + if Gammas is not None: + gammas = tl.load(Gammas + start_m1 + offs_m, mask=mask_m, other=0.0) + else: + gammas = tl.full([BLOCK_M], 1, dtype=tl.float32) + x_scale = load_scale(XScale) + if PER_BATCH_SCALE: + w_scale = load_scale(WScale + expt_id1) + else: + w_scale = load_scale(WScale) + + accs = (acc,) + biases = (bias,) + + if SUBTILE_FACTOR >= 2: + acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split() + accs = (acc0, acc1) + bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split() + biases = (bias0, bias1) + + if SUBTILE_FACTOR >= 4: + acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split() + acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split() + accs = (acc00, acc01, acc10, acc11) + bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split() + bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split() + biases = (bias00, bias01, bias10, bias11) + + tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR) + tl.static_assert(len(accs) == SUBTILE_FACTOR) + + for a_i in tl.static_range(len(accs)): + acc_tile = accs[a_i] + acc_tile *= x_scale * w_scale + + if SWAP_XW: + acc_tile = acc_tile.T + + acc_tile = acc_tile + biases[a_i][None, :] * betas[:, None] + if out_alpha is not None: + acc_tile *= out_alpha + + if ACTIVATION_FN is not None: + out = ACTIVATION_FN(acc_tile, *activation_fn_args) + tl.static_assert(out.shape[1] == OUT_BLOCK_N, f"Activation fn out.shape[1] ({out.shape[1]}) doesn't match computed OUT_BLOCK_N ({OUT_BLOCK_N})") + else: + tl.static_assert(ACTIVATION_REDUCTION_N == 1, "Activation reduction must be 1 if no activation fn is provided") + out = acc_tile + + out *= gammas[:, None] + + if MASK_ACC: + out = tl.where(mask_m[:, None], out, 0.0) + # Flexpoint + out_view = tl.reshape( + out, [out.numel // THREADS_PER_BLOCK, THREADS_PER_BLOCK], can_reorder=True) + local_absmax = tl.maximum(local_absmax, nan_propagating_absmax_reduce(out_view, axis=0)) + out = float_to_flex( + out, YExpectedScale, + None, # ActualScale: local absmax is tracked and updated after the loop + YChecksumScale, + None, # mask: out is manually masked to 0 + Y, FLEXPOINT_SATURATE_INF) + if EPILOGUE_FN is not None: + out = EPILOGUE_FN(out, *epilogue_fn_args, target_dtype=Y.dtype.element_ty, pid=len(accs)*tile_id1 + a_i) + + out_off_n = off_n1 // ACTIVATION_REDUCTION_N + a_i * OUT_BLOCK_N + if USE_SCATTER_TMA: + # Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that + # there shouldn't be any other negative values. + offs_y_m = (offs_y_m.to(tl.uint32, bitcast=True) & 0x7FFFFFFF).to(tl.int32, bitcast=True) + y_desc.scatter(out.to(Y.dtype.element_ty), offs_y_m, out_off_n) + elif not HAS_FUSED_SCATTER and Y_USE_TMA: + y_desc.store([off_m1, out_off_n], out.to(Y.dtype.element_ty)) + else: + offs_y_n = out_off_n + tl.arange(0, OUT_BLOCK_N) + mask_n = offs_y_n < yN + + YPtrs = Y + pid_k1.to(index_type) * stride_y_k + start_z1.to(index_type) * stride_y_z + offs_y_m.to(index_type)[:, None] * stride_y_m + offs_y_n[None, :] * stride_y_n + mask = mask_m[:, None] & mask_n[None, :] + tl.store(YPtrs, out, mask=mask) + + + # Update the flexpoint scales + if YActualScale is not None: + tl.atomic_max(YActualScale, compute_scale(local_absmax.to(tl.float32, bitcast=True), Y), sem="relaxed") + + +_per_device_alloc_fns = {} +def get_per_device_per_stream_alloc_fn(device): + if device not in _per_device_alloc_fns: + _per_stream_tensors = {} + def alloc_fn(size: int, alignment: int, stream): + assert alignment == 128 + if stream not in _per_stream_tensors or _per_stream_tensors[stream].numel() < size: + _per_stream_tensors[stream] = torch.empty(size, device=device, dtype=torch.int8) + _per_stream_tensors[stream].__hibernate__ = {"type": "ignore"} + return _per_stream_tensors[stream] + + _per_device_alloc_fns[device] = alloc_fn + return _per_device_alloc_fns[device] diff --git a/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py new file mode 100644 index 0000000000000000000000000000000000000000..895deb23c54e4ac9d965e8fa517140899377ae32 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags.py @@ -0,0 +1,298 @@ +# isort: off +# fmt: off +from dataclasses import dataclass +import triton +from triton_kernels.target_info import get_cdna_version +import torch +from .opt_flags_details import opt_flags_amd, opt_flags_nvidia + + +@dataclass +class OptFlags: + block_m: int + block_n: int + block_k: int + num_warps: int + num_stages: int + group_m: int + xcd_swizzle: int + w_cache_modifier: str + split_k: int + fused_scatter: bool + is_persistent: bool + idle_sms: int + epilogue_subtile: int | None + arch: str + target_kernel_kwargs: dict + + def __post_init__(self): + if self.fused_scatter and self.split_k != 1: + raise ValueError("Not supported") + + + +def make_default_opt_flags_amd( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + enforce_bitwise_invariance, + epilogue_effective_itemsize, + constraints, +): + constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"] + assert not any([c not in constraints_supported for c in constraints]), constraints.keys() + # tokens per expert + if routing_data is None: + tokens_per_expt = m + elif routing_data.expected_tokens_per_expt is None: + tokens_per_expt = max(1, m // routing_data.n_expts_tot) + else: + tokens_per_expt = routing_data.expected_tokens_per_expt + + is_cdna4 = get_cdna_version() == 4 + # block_m + if constraints.get("block_m", None): + block_m = constraints["block_m"] + elif enforce_bitwise_invariance: + block_m = 256 if is_cdna4 else 128 + elif tokens_per_expt >= 512 and n >= 2048: + block_m = 256 if is_cdna4 else 128 + elif is_cdna4 and m >= 512: + block_m = 128 + else: + block_m = max(32, min(triton.next_power_of_2(tokens_per_expt), 64)) + + if routing_data is not None: + grid_m = routing_data.n_blocks(m, block_m) + else: + grid_m = triton.cdiv(m, block_m) + # group_m: + group_m = 4 + # number of xcds + num_xcds = 8 + xcd_swizzle = num_xcds + # block_nk: + block_n, block_k = opt_flags_amd.compute_block_nk( + n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config + ) + # Replace block_k if provided in constraints. + # TODO: Does opt_flags_amd.compute_block_nk need to be refactored? + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] + is_persistent = constraints.get("is_persistent", False) + # split_k: + if constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] + elif is_persistent or enforce_bitwise_invariance: + split_k = 1 + else: + grid_size = grid_m * ((n + block_n - 1) // block_n) + n_cu = torch.cuda.get_device_properties(0).multi_processor_count + split_k = max(1, n_cu // grid_size) + # w_cache_modifier: + w_cache_modifier = ".cg" if block_m <= 32 else None + # num_warps, num_stages + num_warps = 2 if (m is not None and m <= 16) else 8 + num_stages = 2 + # AMD-specific + target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1} + ret = OptFlags( + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_warps=num_warps, + num_stages=num_stages, + group_m=group_m, + xcd_swizzle=xcd_swizzle, + w_cache_modifier=w_cache_modifier, + split_k=split_k, + fused_scatter=constraints.get('fused_scatter', False), + is_persistent=is_persistent, + idle_sms=0, + epilogue_subtile=constraints.get('epilogue_subtile', None), + arch=None, + target_kernel_kwargs=target_kernel_kwargs, + ) + # check constraints + assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + return ret + +def make_default_opt_flags_nvidia( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + enforce_bitwise_invariance, + epilogue_effective_itemsize, + constraints, +): + constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms"] + assert not any([c not in constraints_supported for c in constraints]), constraints.keys() + # tokens per expert + if routing_data is None: + tokens_per_expt = m + elif routing_data.expected_tokens_per_expt is None: + tokens_per_expt = max(1, m // routing_data.n_expts_tot) + else: + tokens_per_expt = routing_data.expected_tokens_per_expt + # pid swizzling + group_m = 8 + xcd_swizzle = 1 + # block_m + if constraints.get("block_m", None): + block_m = constraints["block_m"] + elif enforce_bitwise_invariance: + block_m = 128 + else: + min_block_m = 64 if torch.cuda.get_device_capability()[0] == 10 else 16 + block_m = max(min_block_m, min(triton.next_power_of_2(tokens_per_expt), 128)) + # block n + arch = None + block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config) + # is_persistent + grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n) + n_sms = torch.cuda.get_device_properties(0).multi_processor_count + tiles_per_sm = grid_size / n_sms + supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9) + if constraints.get("is_persistent", None) is not None: + is_persistent = constraints["is_persistent"] + else: + has_simple_epilogue = precision_config.max_num_imprecise_acc is None + is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4 + # TEMP CHANGE + if precision_config.act_scale is not None or precision_config.out_scale is not None: + is_persistent = False + # block k + if constraints.get("block_k", None) is not None: + block_k = constraints["block_k"] + else: + block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config) + # split_k + if constraints.get("split_k", None) is not None: + split_k = constraints["split_k"] + elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None: + split_k = 1 + else: + estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n) + split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size) + if split_k > 1: + # With split_k, results are written in f32. Use that for the following computations. + out_dtype = torch.float32 + compute_num_stages_args = ( + precision_config, + is_persistent, + block_m, + block_n, + block_k, + out_dtype, + lhs_dtype, + rhs_dtype, + ) + + if constraints.get("epilogue_subtile", None) is not None: + subtiles_to_check = [constraints["epilogue_subtile"]] + else: + subtiles_to_check = [1, 2, 4] + num_stages = -1 + for ep in subtiles_to_check: + ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize) + if ns > num_stages: + epilogue_subtile, num_stages = ep, ns + assert num_stages >= 1 + if constraints.get("num_stages", None): + num_stages = constraints["num_stages"] + + # fused scatter scratchpad + if constraints.get("fused_scatter", None) is not None: + fused_scatter = constraints["fused_scatter"] + else: + fused_scatter = can_use_fused_scatter and split_k == 1 + # Handshake with the HBM swizzling + num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config) + ret = OptFlags( + block_m=block_m, + block_n=block_n, + block_k=block_k, + num_warps=num_warps, + num_stages=num_stages, + group_m=group_m, + xcd_swizzle=xcd_swizzle, + w_cache_modifier=None, + split_k=split_k, + fused_scatter=fused_scatter, + is_persistent=is_persistent, + epilogue_subtile=epilogue_subtile, + arch=arch, + target_kernel_kwargs=dict(), + idle_sms=constraints.get("idle_sms", 0), + ) + # check constraints + assert all(getattr(ret, ck) == cv for ck, cv in constraints.items() if cv is not None), f"{ret} != {constraints}" + return ret + +# -------------- +# User Interface +# -------------- + +_opt_flags_constraints: dict = dict() +_opt_flags: OptFlags | None = None + +def update_opt_flags_constraints(constraints: dict[str, int]): + global _opt_flags_constraints + _opt_flags_constraints.update(constraints) + +def reset_opt_flags_constraints(): + global _opt_flags_constraints + _opt_flags_constraints = dict() + +def set_opt_flags(opt_flags: OptFlags): + global _opt_flags + assert not _opt_flags_constraints, "setting constraints is incompatible with manual flags override" + assert not _opt_flags, "opt_flags already set; please reset to None first" + _opt_flags = opt_flags + +class InapplicableConstraint(Exception): + pass + +def make_opt_flags( + out_dtype, + lhs_dtype, + rhs_dtype, + precision_config, + m, + n, + k, + routing_data, + can_use_persistent_tma, + can_use_fused_scatter, + epilogue_effective_itemsize, +): + if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma: + raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint") + enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance + if _opt_flags is not None: + assert not _opt_flags_constraints + return _opt_flags + args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k, + routing_data, can_use_persistent_tma, can_use_fused_scatter, + enforce_bitwise_invariance, epilogue_effective_itemsize, + _opt_flags_constraints] + backend = triton.runtime.driver.active.get_current_target().backend + if backend == "hip": + return make_default_opt_flags_amd(*args) + if backend == "cuda": + return make_default_opt_flags_nvidia(*args) + assert False diff --git a/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py new file mode 100644 index 0000000000000000000000000000000000000000..ffe06c333f60191823dd650b84a6ab6225daba46 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py @@ -0,0 +1,33 @@ +import torch +import triton +from triton_kernels.target_info import get_cdna_version +from triton_kernels.tensor import bitwidth + + +def compute_block_nk(n, block_m, grid_m, num_xcds, lhs_dtype, rhs_dtype, precision_config): + lhs_width = bitwidth(lhs_dtype) / 8 + rhs_width = bitwidth(rhs_dtype) / 8 + + # block_n: + n_cu = torch.cuda.get_device_properties(0).multi_processor_count + if n is not None: + if n <= 128 and (n & (n - 1)) == 0: + block_n = n + else: + block_n = max(32, min(256, triton.next_power_of_2(grid_m * n * num_xcds // n_cu))) + elif block_m > 64: + block_n = 256 + else: + block_n = 128 + + if get_cdna_version() == 4 and block_m == 128: + block_n = 512 + + # block_k needs to match the cacheline size (128B) + block_k = int(128 // min(lhs_width, rhs_width)) + + # TODO: block_k = 128 seems to work better for now. + # perhaps due to increased number of k loops to pipeline + if precision_config.weight_scale is not None and get_cdna_version() != 4: + block_k = 128 + return block_n, block_k diff --git a/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd0ffd4750c796077d0b7b5bc2e21c4b8be2f93 --- /dev/null +++ b/torch-ext/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py @@ -0,0 +1,111 @@ +import torch +import triton +from triton_kernels import target_info +from triton_kernels.tensor import get_layout, bitwidth, FP4 +from triton_kernels.tensor_details.layout import HopperMXScaleLayout +from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE + + +def compute_grid_size(routing_data, m, n, block_m, block_n): + if routing_data is not None: + grid_m = routing_data.n_blocks(m, block_m) + else: + grid_m = triton.cdiv(m, block_m) + grid_n = (n + block_n - 1) // block_n + return grid_m * grid_n + + +def compute_block_n(n: int, arch, precision_config): + # block_n: + layout = get_layout(precision_config.weight_scale) + if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4: + return 128 + elif precision_config.max_num_imprecise_acc is None and n > 128: + return 256 + else: + return max(16, min(128, triton.next_power_of_2(n))) + + +def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config): + lhs_width = bitwidth(lhs_dtype) + rhs_width = bitwidth(rhs_dtype) + # block_k needs to match the cacheline size (1024 bits) + block_k = int(1024 // min(lhs_width, rhs_width)) + has_native_mxfp = target_info.cuda_capability_geq(10, 0) + if rhs_width == 4 and not has_native_mxfp: + block_k = 128 + elif k is not None: + block_k = max(32, min(triton.next_power_of_2(k), block_k)) + has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None + if has_native_mxfp and is_persistent and has_mx_weight_scale: + block_k = min(block_k, 128) + return block_k + + +def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int: + device_props = torch.cuda.get_device_properties(0) + n_sms = device_props.multi_processor_count + split_k = n_sms // grid_size + if k is not None: + # avoid split_k for small k + num_block_k = triton.cdiv(k, block_k) + split_k = min(split_k, num_block_k // 4) + split_k = max(split_k, 1) + return split_k + + +def compute_num_warps(block_m, block_n, precision_config): + layout = get_layout(precision_config.weight_scale) + if isinstance(layout, HopperMXScaleLayout): + return layout.num_warps + return max(block_m * block_n // 4096, 4) + + +def compute_num_stages( + precision_config, + is_persistent, + block_m, + block_n, + block_k, + out_dtype, + lhs_dtype, + rhs_dtype, + epilogue_subtile, + epilogue_effective_itemsize, +): + if precision_config.max_num_imprecise_acc is not None: + return 3 + weight_size = bitwidth(rhs_dtype) / 8 + stage_size = block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size + device_props = torch.cuda.get_device_properties(0) + smem_capacity = device_props.shared_memory_per_block_optin + has_native_mxfp = target_info.cuda_capability_geq(10, 0) + if has_native_mxfp and getattr(precision_config, "weight_scale", None) is not None: + if rhs_dtype == FP4: + # 4-bit e2m1 weights are padded 2x + # https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory + stage_size += block_k * block_n * weight_size + + if is_persistent: + # Per-stage wait barrier + stage_size += 8 + if target_info.cuda_capability_geq(10, 0): + acc_size = epilogue_effective_itemsize or out_dtype.itemsize + else: + acc_size = out_dtype.itemsize + if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None: + acc_block_n = block_n // epilogue_subtile + else: + acc_block_n = block_n + # pipelined TMA store local to global, or + # pipelined layout conversion before store of the accumulator + # note: layout conversion has some padding + smem_capacity -= int((block_m + 4) * acc_block_n * acc_size) + if precision_config.weight_scale is not None: + # mx scales + stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE)) + elif has_native_mxfp: + # mx scales + stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE)) + num_stages = min(4, smem_capacity // int(stage_size)) + return num_stages diff --git a/torch-ext/triton_kernels/numerics.py b/torch-ext/triton_kernels/numerics.py new file mode 100644 index 0000000000000000000000000000000000000000..024d3fcf0b819646a485596070b14c7a0a2e17ed --- /dev/null +++ b/torch-ext/triton_kernels/numerics.py @@ -0,0 +1,42 @@ +import torch +from dataclasses import dataclass + +MAX_FINITE_FLOAT8E5 = 57344.0 +MAX_FINITE_FLOAT8E4NV = 448.0 +MAX_FINITE_FLOAT8E4B8 = 240.0 + + +@dataclass(frozen=True) +class BaseFlexData: + dtype: torch.dtype | None = None + + def view(self, x: torch.Tensor): + if self.dtype is None: + return x + return x.view(self.dtype) + + def reinterpret(self, x): + if self.dtype is None or x.dtype.itemsize > 1: + return x + return x.view(self.dtype) + + +@dataclass(frozen=True) +class InFlexData(BaseFlexData): + scale: torch.Tensor | None = None + + @property + def is_per_batch(self): + return False if self.scale is None else len(self.scale) > 1 + + +@dataclass(frozen=True) +class OutFlexData(BaseFlexData): + expected_scale: torch.Tensor | None = None + actual_scale: torch.Tensor | None = None + checksum_scale: torch.Tensor | None = None + + def __iter__(self): + yield self.expected_scale + yield self.actual_scale + yield self.checksum_scale diff --git a/torch-ext/triton_kernels/numerics_details/__init__.py b/torch-ext/triton_kernels/numerics_details/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch-ext/triton_kernels/numerics_details/__pycache__/__init__.cpython-310.pyc b/torch-ext/triton_kernels/numerics_details/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..352addf8b03337594ad0d63a5c931080f6e09602 Binary files /dev/null and b/torch-ext/triton_kernels/numerics_details/__pycache__/__init__.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/numerics_details/__pycache__/flexpoint.cpython-310.pyc b/torch-ext/triton_kernels/numerics_details/__pycache__/flexpoint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be0cf88f2d37415d4bf6ddfe658b0f1e60031413 Binary files /dev/null and b/torch-ext/triton_kernels/numerics_details/__pycache__/flexpoint.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/numerics_details/__pycache__/mxfp.cpython-310.pyc b/torch-ext/triton_kernels/numerics_details/__pycache__/mxfp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad0c7a5489cf3459bc1ddcfd6d0a857ef25a12a6 Binary files /dev/null and b/torch-ext/triton_kernels/numerics_details/__pycache__/mxfp.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/numerics_details/flexpoint.py b/torch-ext/triton_kernels/numerics_details/flexpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9075b65c4ae156c329dfe7f16b6f978fea49c1 --- /dev/null +++ b/torch-ext/triton_kernels/numerics_details/flexpoint.py @@ -0,0 +1,195 @@ +from ..numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 +from triton_kernels import target_info +import triton +import triton.language as tl + +# ------------------------------- +# Kernels stuff +# ------------------------------- + +TL_MAX_FINITE_FLOAT8E5 = tl.constexpr(MAX_FINITE_FLOAT8E5) +TL_MAX_FINITE_FLOAT8E4NV = tl.constexpr(MAX_FINITE_FLOAT8E4NV) +TL_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(MAX_FINITE_FLOAT8E4B8) +TL_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(1.750) +TL_MAX_FINITE_FLOAT16 = tl.constexpr(65472.0) + +TL_RCP_MAX_FINITE_FLOAT8E5 = tl.constexpr(0x37924925) # 0x1.24924Ap-16 +TL_RCP_MAX_FINITE_FLOAT8E4NV = tl.constexpr(0x3B124925) # 0x1.24924Ap-9 +TL_RCP_MAX_FINITE_FLOAT8E4B8 = tl.constexpr(0x3B888889) # 0x1.111112p-8 +TL_RCP_MAX_FINITE_FLOAT8E4B15 = tl.constexpr(0x3F124925) # 0x1.24924Ap-1 +TL_RCP_MAX_FINITE_FLOAT16 = tl.constexpr(0x37802008) # 0x1.004010p-16 + + +@triton.jit +def max_finite(dtype): + if dtype == tl.constexpr(tl.float8e5): + return TL_MAX_FINITE_FLOAT8E5 + elif dtype == tl.constexpr(tl.float8e4nv): + return TL_MAX_FINITE_FLOAT8E4NV + elif dtype == tl.constexpr(tl.float8e4b8): + return TL_MAX_FINITE_FLOAT8E4B8 + elif dtype == tl.constexpr(tl.float8e4b15): + return TL_MAX_FINITE_FLOAT8E4B15 + elif dtype == tl.constexpr(tl.float16): + return TL_MAX_FINITE_FLOAT16 + else: + tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint") + + +@triton.jit +def rcp_max_finite(dtype): + if dtype == tl.constexpr(tl.float8e5): + return TL_RCP_MAX_FINITE_FLOAT8E5 + elif dtype == tl.constexpr(tl.float8e4nv): + return TL_RCP_MAX_FINITE_FLOAT8E4NV + elif dtype == tl.constexpr(tl.float8e4b8): + return TL_RCP_MAX_FINITE_FLOAT8E4B8 + elif dtype == tl.constexpr(tl.float8e4b15): + return TL_RCP_MAX_FINITE_FLOAT8E4B15 + elif dtype == tl.constexpr(tl.float16): + return TL_RCP_MAX_FINITE_FLOAT16 + else: + tl.static_assert(tl.constexpr(False), f"{dtype} not supported in flexpoint") + + +@tl.constexpr_function +def cuda_capability_geq(major, minor): + return target_info.cuda_capability_geq(major, minor) + + +@triton.jit +def sm86_min_nan_xorsign_abs_f32(a, b): + """Wrapper for min.NaN.xorsign.abs.f32 PTX instruction. + + Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs. + NaN inputs are propagated to the output. + + Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do). + """ + tl.static_assert(cuda_capability_geq(8, 6), "min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+") + tl.static_assert(a.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs") + tl.static_assert(b.dtype == tl.float32, "min.NaN.xorsign.abs.f32 requires float32 inputs") + + return tl.inline_asm_elementwise( + """{ + min.NaN.xorsign.abs.f32 $0, $1, $2; + }""", + "=r,r,r", + [a, b], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + + +@triton.jit +def sm86_max_nan_xorsign_abs_f32(a, b): + """Wrapper for max.NaN.xorsign.abs.f32 PTX instruction. + + Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs. + NaN inputs are propagated to the output. + + Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do). + """ + tl.static_assert(cuda_capability_geq(8, 6), "max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+") + tl.static_assert(a.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs") + tl.static_assert(b.dtype == tl.float32, "max.NaN.xorsign.abs.f32 requires float32 inputs") + + return tl.inline_asm_elementwise( + """{ + max.NaN.xorsign.abs.f32 $0, $1, $2; + }""", + "=r,r,r", + [a, b], + dtype=tl.float32, + is_pure=True, + pack=1, + ) + + +@triton.jit +def load_scale(scale_ptr): + return 1.0 if scale_ptr is None else tl.load(scale_ptr) + + +@triton.jit +def flex_to_float(x, scale_ptr): + scale = load_scale(scale_ptr) + return x.to(tl.float32) * scale + + +@triton.jit +def clip(x, limit): + res = tl.minimum(x, limit) + res = tl.maximum(-limit, res) + return res + + +@triton.jit +def nan_propagating_absmax_reduce(x, axis=None): + if cuda_capability_geq(8, 6): + # abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported. + x_absmax = tl.reduce(x, axis, sm86_max_nan_xorsign_abs_f32) + # Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it. + x_absmax = x_absmax.to(tl.uint32, bitcast=True) & 0x7FFFFFFF + else: + # Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float) + masked_abs_x = x.to(tl.uint32, bitcast=True) & 0x7FFFFFFF + x_absmax = tl.max(masked_abs_x, axis) + + return x_absmax + + +@triton.jit +def compute_scale(x, Out): + x_absmax = nan_propagating_absmax_reduce(tl.ravel(x, can_reorder=True)) + + # atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000). + # We use integer minimum because NaNs are above +inf in integer representation. + x_absmax = tl.minimum(x_absmax, 0x7F800000).to(tl.float32, bitcast=True) + RCP_MAX_VALUE = rcp_max_finite(Out.dtype.element_ty) + return tl.fma(x_absmax, RCP_MAX_VALUE.to(tl.float32, bitcast=True), 1.0e-30) + + +@triton.jit +def update_scale(x, scale_ptr, Out) -> None: + if scale_ptr is not None: + scale = compute_scale(x, Out) + tl.atomic_max(scale_ptr, scale, sem="relaxed") + + +@triton.jit +def float_to_flex( + x, + expected_scale_ptr_or_val, + actual_scale_ptr, + checksum_scale_ptr, + mask, + Out, + saturate_infs: tl.constexpr, +): + if expected_scale_ptr_or_val is not None: + if expected_scale_ptr_or_val.dtype.is_ptr(): + invscale = 1.0 / tl.load(expected_scale_ptr_or_val) + else: + invscale = 1.0 / expected_scale_ptr_or_val + else: + invscale = 1.0 + if checksum_scale_ptr is not None: + x_int32 = x.to(tl.int32, bitcast=True) + zero = tl.cast(0.0, tl.int32) + if mask is not None: + x_int32 = tl.where(mask, x_int32, zero) + checksum_local = tl.xor_sum(tl.ravel(x_int32, can_reorder=True), 0) + tl.atomic_add(checksum_scale_ptr, checksum_local) + if mask is not None: + if actual_scale_ptr is not None: + x = tl.where(mask, x, 0.0) + update_scale(x, actual_scale_ptr, Out) + x = x * invscale + # if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case. + if expected_scale_ptr_or_val is not None: + if saturate_infs: + CLIP_VALUE = max_finite(Out.dtype.element_ty) + x = clip(x, CLIP_VALUE) + return x diff --git a/torch-ext/triton_kernels/numerics_details/mxfp.py b/torch-ext/triton_kernels/numerics_details/mxfp.py new file mode 100644 index 0000000000000000000000000000000000000000..644b9c4491207870c8ff56ba4ef22496c81f0f9e --- /dev/null +++ b/torch-ext/triton_kernels/numerics_details/mxfp.py @@ -0,0 +1,303 @@ +# isort: off +# fmt: off +from enum import Enum +import triton +import torch +import torch.nn.functional as F +from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp +from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, _dequantize_mxfp8_fn, MXFP_BLOCK_SIZE + +# ----------------------------------------------------------------------------- +# Dequantization / Quantization Utilities +# ----------------------------------------------------------------------------- + + +class DequantScaleRoundingMode(Enum): + ROUND_UP = 0 + ROUND_DOWN = 1 + + +def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP): + """ + Convert the src weights to mx format. The src weight is quantized along the axis dimension. + + If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte. + Note that this means the k_dim of the tensor will be half of the logical k_dim. + + If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored + in their respective formats. + """ + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + # downcast + src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1) + is_fp4 = out_quant_type == torch.uint8 + is_fp8 = out_quant_type in (torch.float8_e4m3fn, torch.float8_e5m2) + assert is_fp4 or is_fp8 + divisor = 2 if is_fp4 else 1 + L = src_tensor.shape[-1] + if is_fp4: + assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}" + out_shape = src_tensor.shape[:-1] + (L // divisor, ) + out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), ) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + if src_tensor.numel() > 0: + kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1]) + kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1]) + kernel_scale = out_scale.view(-1, out_scale.shape[-1]) + + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value + grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM) + grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM) + + _downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale, + *kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(), + *kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM, + DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8) + + out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1) + out_scale = out_scale.transpose(axis, src_tensor.ndim - 1) + return out_quant_tensor, out_scale + + +def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int): + """ + Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16. + + The function assumes that the tensors were quantized along the given axis. + It permutes the tensor so that the quantized axis is last, reshapes to 2D, + launches the Triton upcast kernel, and then unpermutes back to the original order. + """ + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + assert tensor.ndim == scale.ndim, (f"Weight and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}") + # dtype checks + assert tensor.dtype in {torch.uint8, torch.float8_e5m2, torch.float8_e4m3fn}, \ + f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in (torch.float16, torch.bfloat16), f"Invalid output dtype {dtype=}" + # upcast + logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1) + tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous() + scale = scale.transpose(axis, scale.ndim - 1).contiguous() + out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device) + reshaped_out = out.view(-1, out.shape[-1]) + reshaped_tensor = tensor.view(-1, tensor.shape[-1]) + reshaped_scale = scale.view(-1, scale.shape[-1]) + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value + blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM) + blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM) + _upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale, + *reshaped_scale.stride(), reshaped_tensor, + *reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, num_warps=8) + out = out.transpose(axis, scale.ndim - 1).contiguous() + return out + + +# ------------ + + +def right_shift_unsigned(x, shift): + # CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift + return (x >> shift) & ((1 << (32 - shift)) - 1) + + +def get_max_quant_val(dtype: torch.dtype): + d = {torch.uint8: 6.0, torch.float8_e5m2: 57344.0, torch.float8_e4m3fn: 448.0} + assert dtype in d + return d[dtype] + + +def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int, + DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP): + """ + Converts the src tensor to the output format specified by out_quant_type. + axis: The axis along which the tensors are contiguous and quantization is applied. + DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN. + + Returns: + out_quant_tensor: Quantized tensor in mx format. + • For mxfp8, the output has the same shape as src_tensor. + • For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8. + scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis. + Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32), + where L is the original length along that axis. + """ + # This should probably be packed into its own tiny class + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + assert src_tensor.dtype in {torch.float32, torch.bfloat16, + torch.float16}, f"Invalid input tensor dtype {src_tensor.dtype}" + + axis = axis if axis >= 0 else axis + ndim + is_fp4 = out_quant_type == torch.uint8 + is_fp8 = "float8" in str(out_quant_type) + assert is_fp4 or is_fp8, f"Invalid input tensor dtype {out_quant_type}" + + device = src_tensor.device + + # For mxfp4 conversion, we assume the contiguous axis length is even. + if is_fp4: + axis_shape = src_tensor.size(axis) + assert axis_shape % 2 == 0, "For mxfp4 conversion the contiguous axis length must be even." + + # Permute the tensor so that the contiguous axis becomes the last dimension. + src = src_tensor.transpose(axis, src_tensor.ndim - 1).to(torch.float32) + axis_shape = src.shape[-1] + + # Pad the axis to be divisible by 32, in case it is not. + next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE + pad_amount = next_multiple - axis_shape + padded_src = F.pad(src, (0, pad_amount)) + valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount)) + padded_axis_shape = padded_src.size(-1) # now divisible by 32 + + # --- Compute per-group maximums for scale --- + # Set padded entries to -1 so they don’t affect the max. + abs_f = torch.abs(padded_src) + abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype)) + # Reshape the last dimension into groups of 32. + new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE) + abs_groups = abs_f.view(*new_shape) + # Compute maximum along the group dimension (of size 32). + max_val, _ = abs_groups.max(dim=-1, keepdim=True) + + # Choose a max quantization value depending on type. + max_quant_val = get_max_quant_val(out_quant_type) + dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1) + + # Convert to int to round the FP32 scale, prior to quantization! + ds_int = dequant_scale.view(torch.int32) + if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP: + ds_int_rounded = (ds_int + 0x007FFFFF) & 0x7F800000 + else: + ds_int_rounded = ds_int & 0x7F800000 + # Reinterpret back as float32. + dequant_scale_rounded = ds_int_rounded.view(torch.float32) + + # Compute the quantization scale. + quant_scale = torch.where(dequant_scale_rounded == 0, torch.tensor(0.0, device=device), 1.0 / dequant_scale_rounded) + + # Quantize the tensor + orig_padded_shape = padded_src.shape + padded_src_groups = padded_src.view(*new_shape) + quant_tensor = padded_src_groups * quant_scale + # Reshape back to the original shape and trim padding + quant_tensor = quant_tensor.view(orig_padded_shape) + quant_tensor = quant_tensor[..., :axis_shape] + + # Finally, convert the quantized tensor to the target format + if is_fp8: + # Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior + quant_tensor = torch.clamp(quant_tensor, -max_quant_val, max_quant_val) + out_weight = quant_tensor.to(out_quant_type) + else: + assert is_fp4, f"Invalid output quantization type {out_quant_type}" + # For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8. + # First, reinterpret the quantized tensor bits. + q_int = quant_tensor.contiguous().view(torch.int32) + # Extract sign, exponent, and mantissa. + signs = q_int & 0x80000000 + exponents = right_shift_unsigned(q_int, 23) & 0xFF + mantissas = q_int & 0x7FFFFF + + E8_BIAS = 127 + E2_BIAS = 1 + # Adjust mantissas for subnormals. + mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >> + (E8_BIAS - exponents - 1), mantissas) + exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS) + e2m1_tmp = right_shift_unsigned(((exponents << 2) | right_shift_unsigned(mantissas, 21)) + 1, 1) + e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device)) + e2m1_value = (right_shift_unsigned(signs, 28) | e2m1_tmp).to(torch.uint8) # shape: (..., even_axis_shape) + + # Pack pairs of 4-bit values along the last dimension. + e2m1_value = e2m1_value.view(*e2m1_value.shape[:-1], axis_shape // 2, 2) + evens = e2m1_value[..., 0] + odds = e2m1_value[..., 1] + out_weight = evens | (odds << 4) # shape: (..., axis_shape//2) + + # --- Process and output the scale --- + dq_scale = (ds_int_rounded.view(*dequant_scale.shape) >> 23).to(torch.uint8) # shape: (..., axis_shape//32, 1) + dq_scale = dq_scale.squeeze(-1) + out_weight = out_weight.transpose(axis, src_tensor.ndim - 1) + dq_scale = dq_scale.transpose(axis, src_tensor.ndim - 1) + return out_weight, dq_scale + + +def cvt_e2m1_to_fp32(input_tensor): + assert input_tensor.dtype == torch.uint8 + + input_tensor = input_tensor.to(torch.int32) + evens = input_tensor & 0xF + odds = (input_tensor >> 4) & 0xF + + vals = [0.0, 0.5, 1, 1.5, 2, 3, 4, 6] + outputs = torch.tensor(vals, dtype=torch.float32, device=input_tensor.device) + outputs = torch.cat([outputs, -outputs]) + + even_floats = outputs[evens] + odd_floats = outputs[odds] + output_tensor = torch.stack([even_floats, odd_floats], dim=-1) + output_tensor = output_tensor.view(*input_tensor.shape[:-1], -1) + return output_tensor + + +def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: torch.dtype, axis: int): + """ + Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype. + axis: The axis along which dequantization is applied. + + Returns: + out_weight: Tensor in the target format. + """ + + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + is_fp8 = tensor.dtype == torch.float8_e4m3fn or tensor.dtype == torch.float8_e5m2 + assert is_fp8 or tensor.dtype == torch.uint8, f"Invalid input quantization type {tensor.dtype}" + + # Permute the tensor and scale so that the quantization axis becomes the last dimension + axis = axis if axis >= 0 else axis + ndim + scale = scale.transpose(axis, scale.ndim - 1) + tensor = tensor.transpose(axis, tensor.ndim - 1) + + dq_scale = (scale.to(torch.int32) << 23).view(torch.float32) # Shift to the exponent and bitcast to fp32 + if tensor.dtype == torch.uint8: + fp32_tensor = cvt_e2m1_to_fp32(tensor) + else: + fp32_tensor = tensor.to(torch.float32) + + logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1) + axis_shape = fp32_tensor.size(-1) + padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE + pad_size = padded_axis_shape - axis_shape + padded_tensor = F.pad(fp32_tensor, (0, pad_size)) + + new_axis_shape = padded_tensor.shape[-1] + new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE) + padded_tensor = padded_tensor.view(*new_shape) + dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1] + out_padded = padded_tensor * dq_scale_padded + + # Flatten back and remove the padded tail + out_padded = out_padded.view(*fp32_tensor.shape[:-1], new_axis_shape) + out_tensor = out_padded[..., :axis_shape] + + out_tensor = out_tensor.to(target_dtype).contiguous() + out_tensor = out_tensor.transpose(axis, tensor.ndim - 1) + + return out_tensor + + +dequantize_mxfp8_fn = _dequantize_mxfp8_fn diff --git a/torch-ext/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py b/torch-ext/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py new file mode 100644 index 0000000000000000000000000000000000000000..8c0b831a9a1339a016fdebfcf3f4a96e526b884b --- /dev/null +++ b/torch-ext/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py @@ -0,0 +1,158 @@ +import triton +import triton.language as tl + +# fmt: off + + +MXFP_BLOCK_SIZE = tl.constexpr(32) + + +@triton.jit +def _get_max_quant_val(dtype: tl.constexpr): + if dtype == tl.uint8: + return 6.0 + elif dtype == tl.float8e5: + return 57344.0 + elif dtype == tl.float8e4nv: + return 448.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + +@triton.jit +def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0): + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation + abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if DEQUANT_SCALE_ROUNDING_MODE == 0: + # DequantScaleRoundingMode.ROUND_UP + # compute 2 ** ceil(log2(dequant_scale)) + # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros + # A corner case: exponent is 0xFF that will overflow but that's already + # NaN so assume we don't care. + dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + else: + # DequantScaleRoundingMode.ROUND_DOWN + # compute 2 ** floor(log2(dequant_scale)) + assert DEQUANT_SCALE_ROUNDING_MODE == 1 + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + quant_tensor = f32_tensor * quant_scale + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + dequant_scale_exponent = dequant_scale_exponent.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE]) + + # First, we simply extract the exponent part of the scales and store the result + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + # Now we must convert the tensors to the mx format. + if is_fp8: + out_tensor = quant_tensor.to(mx_tensor_dtype) + else: + quant_tensor = quant_tensor.to(tl.uint32, bitcast=True) + signs = quant_tensor & 0x80000000 + exponents = (quant_tensor >> 23) & 0xFF + mantissas = (quant_tensor & 0x7FFFFF) + + # 0.25 <= x < 0.75 maps to 0.5, a denormal number + E8_BIAS = 127 + E2_BIAS = 1 + # Move implicit bit 1 at the beginning to mantissa for denormals + adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False) + mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas) + + # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + e2m1_tmp = tl.minimum((((exponents << 2) | (mantissas >> 21)) + 1) >> 1, 0x7) + e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8) + + e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2]) + evens, odds = tl.split(e2m1_value) + out_tensor = evens | (odds << 4) + + return out_tensor, dequant_scale_exponent + +@triton.jit +def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr, + mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant, + src_ptr, stride_src_outer, stride_src_quant, + outer_dim, quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr, + DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr): + + tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.") + tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32") + + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert(mx_tensor_dtype == tl.uint8 or (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5), + f"Invalid {mx_tensor_dtype=}. Must be uint8 or float8.") + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, f"{mx_scale_ptr.dtype.element_ty=} must be uint8") + tl.static_assert((src_dtype == tl.bfloat16) or (src_dtype == tl.float16), f"{src_dtype=} must be bfloat16 or float16") + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer + mx_scale_ptr += start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer + mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer + + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + + mask_src_quant = start_src_quant + offs_src_quant < quant_dim + mask_n = start_out + offs_outer < outer_dim + full_mask_src = mask_src_quant & mask_n + + mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_mxt = mask_mxt_quant & mask_n + + scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE) + full_scale_mask = scale_mask_k & mask_n + + src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer + mx_scale_offsets = offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer + mx_tensor_offsets = offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer + src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src) + + out_tensor, scale_tensor = _compute_quant_and_scale(src_tensor, full_mask_src, mx_tensor_dtype, + DEQUANT_SCALE_ROUNDING_MODE) + + tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt) + + +@triton.jit(repr=lambda _: "_dequantize_mxfp8") +def _dequantize_mxfp8_fn(input, mask, pid=None): + return _compute_quant_and_scale(input, mask, tl.float8e4nv) diff --git a/torch-ext/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py b/torch-ext/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py new file mode 100644 index 0000000000000000000000000000000000000000..21a3ba8128e4cb2e49bb88288de7601362a06124 --- /dev/null +++ b/torch-ext/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py @@ -0,0 +1,122 @@ +import triton +import triton.language as tl +from ._downcast_to_mxfp import MXFP_BLOCK_SIZE + + +# fmt: off +@triton.jit +def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer, + stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr, + outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr): + + tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx") + tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32") + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16) + tl.static_assert( + mx_tensor_dtype == tl.uint8 + or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype), + "mx_tensor_ptr must be uint8 or float8 or dst_dtype") + tl.static_assert(mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8") + + # Determine if we are dealing with fp8 types. + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5 + K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR + + # Compute starting indices for the quantized (packed) dimension and the outer dimension. + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + mx_tensor_ptr += start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer + mx_scale_ptr += start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer + out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant + + # Compute offsets and masks. + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + + mask_outer = start_out + offs_outer < outer_dim + mask_out_quant = start_out_quant + offs_out_quant < quant_dim + full_mask_out = mask_out_quant & mask_outer + + mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR) + full_mask_src = mask_src_quant & mask_outer + + mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE) + full_scale_mask = mask_scale & mask_outer + + tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer + scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer + out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer + + # Load the packed tensor and scale. + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + # Upcast the scale to the destination type. + if dst_dtype == tl.bfloat16: + dst_scale = (scale.to(tl.uint16) << 7).to(dst_dtype, bitcast=True) + else: + tl.static_assert(dst_dtype == tl.float16) + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + dst_scale = dst_scale.to(tl.float16) + + # Now upcast the tensor. + if is_fp8: + dst_tensor = tensor.to(dst_dtype) + if tensor.dtype == tl.float8e5: + from_e_bits: tl.constexpr = 5 + from_m_bits: tl.constexpr = 2 + to_e_bits: tl.constexpr = 8 if dst_dtype == tl.bfloat16 else 5 + to_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10 + + # Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them! + non_finite_mask_src: tl.constexpr = ((1 << from_e_bits) - 1) << from_m_bits + non_finite_mask_dst: tl.constexpr = ((1 << to_e_bits) - 1) << to_m_bits + dst_tensor = tl.where( + (tensor.to(tl.uint8, bitcast=True) & non_finite_mask_src) == non_finite_mask_src, + (dst_tensor.to(tl.uint16, bitcast=True) | non_finite_mask_dst).to(dst_dtype, bitcast=True), + dst_tensor, + ) + else: + assert is_fp4 + dst_bias: tl.constexpr = 127 if dst_dtype == tl.bfloat16 else 15 + dst_0p5: tl.constexpr = 16128 if dst_dtype == tl.bfloat16 else 0x3800 + dst_m_bits: tl.constexpr = 7 if dst_dtype == tl.bfloat16 else 10 + # e2m1 + em0 = tensor & 0x07 + em1 = tensor & 0x70 + x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((tensor & 0x08).to(tl.uint16) << 12) + x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((tensor & 0x80).to(tl.uint16) << 8) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True) + + # Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale + # Correct any NaNs encoded via the scale. + out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) diff --git a/torch-ext/triton_kernels/proton_opts.py b/torch-ext/triton_kernels/proton_opts.py new file mode 100644 index 0000000000000000000000000000000000000000..fc00c64315faa7c440addafdfe24b32760b45f8f --- /dev/null +++ b/torch-ext/triton_kernels/proton_opts.py @@ -0,0 +1,17 @@ +# proton options + +import os + +_launch_metadata_allow_sync = None + + +def launch_metadata_allow_sync(): + global _launch_metadata_allow_sync + if _launch_metadata_allow_sync is None: + _launch_metadata_allow_sync = not (os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1") + return _launch_metadata_allow_sync + + +def set_launch_metadata_allow_sync(allow_sync: bool): + global _launch_metadata_allow_sync + _launch_metadata_allow_sync = allow_sync diff --git a/torch-ext/triton_kernels/reduction_details/__pycache__/reduce_bitmatrix.cpython-310.pyc b/torch-ext/triton_kernels/reduction_details/__pycache__/reduce_bitmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..295b7f0618895091a62090b012311ad35b3817dd Binary files /dev/null and b/torch-ext/triton_kernels/reduction_details/__pycache__/reduce_bitmatrix.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/reduction_details/reduce_bitmatrix.py b/torch-ext/triton_kernels/reduction_details/reduce_bitmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e5aba907c3e8155cc3f896d0593950e657a5ae --- /dev/null +++ b/torch-ext/triton_kernels/reduction_details/reduce_bitmatrix.py @@ -0,0 +1,111 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def vpopc(x): + """ + Vertical popcount + Input x : uint32[..., N] + Output y : uint32[..., 32] + semantics : y[..., i] = sum_j((x[..., j] >> i) & 1) + credits: @apgoucher + """ + + tl.static_assert(x.dtype == tl.uint32, "x should consist of 32-bit unsigned integers") + + BLOCK_N: tl.constexpr = x.shape[-1] # summation axis + BATCHES: tl.constexpr = x.numel // BLOCK_N # number of batches + if BLOCK_N >= 8: + sa1: tl.constexpr = 8 + else: + sa1: tl.constexpr = BLOCK_N + # create 8-way sums in 4-bit fields: + y = tl.reshape(x, [BATCHES, BLOCK_N // sa1, sa1, 1]) + y = (y >> tl.arange(0, 4)[None, None, None, :]) & 0x11111111 + y = tl.sum(y, 2) # [BATCHES, BLOCK_N // sa1, 4] + if BLOCK_N >= 128: + sa2: tl.constexpr = 16 + else: + sa2: tl.constexpr = BLOCK_N // sa1 + # create 128-way sums in 8-bit fields: + y = tl.reshape(y, [BATCHES, BLOCK_N // (sa1 * sa2), sa2, 1, 4]) + y = (y >> (4 * tl.arange(0, 2))[None, None, None, :, None]) & 0x0f0f0f0f + y = tl.sum(y, 2) # [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4] + sa3: tl.constexpr = BLOCK_N // (sa1 * sa2) + # create N-way sums in 32-bit fields: + y = tl.reshape(y, [BATCHES, 1, sa3, 8]) + y = (y >> (8 * tl.arange(0, 4))[None, :, None, None]) & 0x000000ff + y = tl.sum(y, 2) # [BATCHES, 4, 8] + y = tl.reshape(y, x.shape[:-1] + [32]) + return y + + +@triton.jit +def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + tl.store(Ret + offs, 0) + + +@triton.jit +def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix + Ret, Partials, stride_pm: tl.constexpr, stride_pn, shape_pn, # outputs + BLOCK_MM: tl.constexpr, BLOCK_M: tl.constexpr): + + tl.static_assert(BLOCK_MM % BLOCK_M == 0) + TILE_SIZE: tl.constexpr = BLOCK_MM // BLOCK_M + if isinstance(shape_bm, tl.tensor) and shape_bm.dtype.is_ptr(): + shape_bm = tl.load(shape_bm) + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM) + offs_n = pid_n * 32 + tl.arange(0, 32) + n_rows = shape_bm + bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0) + bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M]) + ret = vpopc(bits) # [TILE_SIZE, 32] + + offs_t = pid_m * TILE_SIZE + tl.arange(0, TILE_SIZE) + + tl.atomic_add(Ret + offs_n, tl.sum(ret, 0), sem="relaxed") + tl.store(Partials + offs_t[:, None] * stride_pm + offs_n[None, :] * stride_pn, ret) + + +def clear_sums(n_cols, device, MEMSET_BLOCK=512): + cdiv = triton.cdiv + blocks = cdiv(n_cols, MEMSET_BLOCK) + out_ret = torch.empty((blocks * MEMSET_BLOCK, ), device=device, dtype=torch.int32) + _sum_bitmatrix_memset[(blocks, )](out_ret, MEMSET_BLOCK) + return out_ret + + +def sum_bitmatrix_rows(x, out_ret, partials_block_size=None): + assert partials_block_size is not None + cdiv = triton.cdiv + PARTIALS_BLOCK_M = partials_block_size + n_rows, n_cols = x.shape + n_rows_max = x.shape_max[0] + assert out_ret.shape == (n_cols, ) + + TILE_SIZE = max(1, 128 // PARTIALS_BLOCK_M) + BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE + + pids_x = cdiv(n_rows_max, BLOCK_MM) + pids_y = cdiv(n_cols, 32) + out_partials = torch.empty((pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32) + out_partials = torch.transpose(out_partials, 0, 1) + + # output tensors + _sum_bitmatrix_rows[(pids_x, pids_y)]( + x.storage.data, n_rows, x.stride(0), x.stride(1), # input + out_ret, # output [final reduction] + out_partials, out_partials.stride(0), out_partials.stride(1), + out_partials.shape[1], # output [partial reductions] + BLOCK_M=PARTIALS_BLOCK_M, BLOCK_MM=BLOCK_MM, # constants + num_warps=8) + + out_partials = out_partials[:cdiv(n_rows_max, PARTIALS_BLOCK_M), :] + + return out_ret, out_partials diff --git a/torch-ext/triton_kernels/routing.py b/torch-ext/triton_kernels/routing.py new file mode 100644 index 0000000000000000000000000000000000000000..d76cb65596b83c4886fc2f0cae3d442e7559ba6e --- /dev/null +++ b/torch-ext/triton_kernels/routing.py @@ -0,0 +1,386 @@ +import torch +import triton +from dataclasses import dataclass, field +from .routing_details._routing_compute import _combined_routing_compute +from .routing_details._routing_compute import _combined_routing_memset +from .routing_details._routing_compute import _routing_clear_bitmatrix +from .routing_details._expt_data import _expt_data_memset +from .routing_details._expt_data import _expt_data_compute +from .target_info import is_hip + + +@dataclass +class GatherIndx: + """ + Indices for an operation that performs: + Y = X[src_idx, :] + """ + # array such that `dst_idx[src_idx] = arange(0, N)` + src_indx: torch.Tensor + dst_indx: torch.Tensor + + +@dataclass +class ScatterIndx: + """ + Indices for an operation that performs: + Y[dst_idx, :] = X + """ + # array such that `dst_idx[src_idx] = arange(0, N)` + src_indx: torch.Tensor + dst_indx: torch.Tensor + + +@dataclass +class ExptData: + # hist[i] is the number of tokens routed to expert i + hist: torch.Tensor + # token_offs_raw[i] is the offset of the first token routed + # to expert i in an expert-sorted array + token_offs_raw: torch.Tensor + # token_offs_pad[block][i] is the offset of the first token routed + # to expert i in an expert-sorted array, assuming histogram + # rounded to the next multiple of `block` + token_offs_pad: dict[int, torch.Tensor] + # block_id_map[block] contain one value for each `pid`` launched by + # the matrix multiplication kernel launched with BLOCK_M=block: + # - the value is -1 if the `pid` has no work to do + # - otherwise, the value is two int16 (packed as an int32) that + # correspond respectively to (1) the expert assigned to + # the tokens processed by this pid; (2) the block assigned to the + # tokens processed by this pid (think `pid_m` in a regular matmul) + # see `test_routing.py` for a reference implementation and more details + block_pid_map: dict[int, torch.Tensor] + + def __post_init__(self): + if self.hist is not None: + assert self.hist.dtype == torch.int32 + if self.token_offs_raw is not None: + assert self.token_offs_raw.dtype == torch.int32 + if self.token_offs_pad is not None: + for v in self.token_offs_pad.values(): + assert v.dtype == torch.int32 + if self.block_pid_map is not None: + for v in self.block_pid_map.values(): + assert v.dtype == torch.int32 + + +@dataclass +class RoutingData: + gate_scal: torch.Tensor = field() + expt_hist: torch.Tensor = field() + n_expts_tot: int = field() + n_expts_act: int = field() + expt_data: ExptData = None + + # Used to make perf annotation cleaner: when we use expert sharding, we can + # use this to tell the "expected" number of local tokens per expert, because + # the actual number can vary per each input. + expected_tokens_per_expt: int = field(default=None) + + def n_blocks(self, n_rows, block_m): + if n_rows <= self.n_expts_tot: + return n_rows + else: + return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 + + +# -------------------------- +# sort tokens by expert +# -------------------------- + + +class SortTokens(torch.autograd.Function): + + @staticmethod + def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix): + HIST_BLOCK_M = 32 + INDX_OFFS_BLOCK_M = 512 + MEMSET_BLOCK = 1024 + cdiv = triton.cdiv + + device = expt_scal.device + dtype = expt_scal.dtype + n_tokens_raw, _ = bitmatrix.shape + n_tokens_pad, n_expts_act = expt_scal.shape + n_gates_pad = n_tokens_pad * n_expts_act + + hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M) + hist = hist[:n_expts_tot] + assert hist.dtype == torch.int32 + # scratchpad + expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device) + combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device) + # output + topk_indx = combined_indx[:n_gates_pad] + gate_indx = combined_indx[n_gates_pad:] + gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device) + + token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal( + hist, n_expts_tot, n_gates_pad) + + blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1 + blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M) + + _combined_routing_memset[(blocks1a + blocks1b, )]( + combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, # + expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs + partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs + token_offs_combined, token_offs_combined.stride(0), # + blocks1a, block_pid_map, # + block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters + BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters + ) + + indx_offs = partial_hist + + _combined_routing_compute[(blocks2a + blocks2b, )]( + topk_indx, gate_indx, gate_scal, # outputs + expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs + expt_offs, n_tokens_raw, # input shape + HIST_BLOCK_M, n_expts_act, # constants + hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs + block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc. + ) + + ctx.n_tokens_raw = n_tokens_raw + ctx.n_tokens_pad = n_tokens_pad + ctx.n_expts_act = n_expts_act + ctx.save_for_backward(gate_indx) + return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map + + @staticmethod + def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5): + (gate_indx, ) = ctx.saved_tensors + dgate_scal = dgate_scal[gate_indx] + dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act) + return dgate_scal, None, None, None + + +def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix): + return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix) + + +# -------------------------- +# prune routing +# -------------------------- + + +class PruneRouting(torch.autograd.Function): + + @staticmethod + def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): + from .compaction import compaction + n_tokens_pad = expt_scal.shape[0] + assert n_expts_tot % simulated_ep == 0 + _routing_clear_bitmatrix[(n_tokens_pad, )]( + bitmatrix.storage.data, + bitmatrix.storage.data.stride(0), + bitmatrix.storage.data.stride(1), + bitmatrix.storage.data.shape[1], + n_expts_tot // simulated_ep, + BLOCK_N=512, + ) + # perform compaction to update expt_scal / expt_indx + expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) + n_expts_tot = n_expts_tot // simulated_ep + bitmatrix.shape[-1] = n_expts_tot + return expt_scal, expt_indx, bitmatrix + + +def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): + return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep) + + +# -------------------------- +# expt_data +# -------------------------- + + +def log2_power_of_two(x): + assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two" + return x.bit_length() - 1 + + +block_m_log2_start = 4 + + +def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates): + + MEMSET_BLOCK = 512 + HIST2_BLOCK_M = 512 + device = expt_hist.device + n_expts_tot = n_expts_tot + cdiv = triton.cdiv + # block_ms are all powers-of-two between 16 and 128 (inclusive) + block_m_log2_end = 9 if is_hip() else 8 + block_m_num = block_m_log2_end - block_m_log2_start + if n_gates <= n_expts_tot: + max_n_tiles = n_gates + else: + max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start) + # allocate memory + pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK + dtype = torch.int32 + + token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device) + + token_offs_raw = token_offs_combined[0][:n_expts_tot + 1] + token_offs_pad = token_offs_combined[1:] + + block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device) + memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division + # compute outputs + token_offs_pad = token_offs_pad[:, :n_expts_tot + 1] + block_pid_map = block_pid_map[:, :max_n_tiles] + + blocks1 = memset_grid + block_m_num + 1 + blocks2 = n_expts_tot * block_m_num + return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num + + +def _unpack_into_dict(x): + + block_m_log2_end = block_m_log2_start + x.shape[0] + x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))} + return x + + +def compute_expt_data(expt_hist, n_expts_tot, n_gates): + + if expt_hist is None: + return ExptData(None, None, None, None) + + # this just computes the kernel arguments: + token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal( + expt_hist, n_expts_tot, n_gates) + + _expt_data_memset[(blocks1, )]( + expt_hist, n_expts_tot, # + token_offs_combined, token_offs_combined.stride(0), # + block_pid_map, # + block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters + num_warps=4) + _expt_data_compute[(blocks2, )]( + expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs + block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters + num_warps=4) + + token_offs_pad = _unpack_into_dict(token_offs_pad) + block_pid_map = _unpack_into_dict(block_pid_map) + return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map) + + +# -------------------------- +# routing +# -------------------------- + + +def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act): + hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens( + expt_scal, expt_indx, n_expts_tot, bitmatrix) + token_offs_pad = _unpack_into_dict(token_offs_pad) + block_pid_map = _unpack_into_dict(block_pid_map) + expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + + # pack the matmul data structure + gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx) + scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx) + return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx + + +def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None): + from .topk import topk + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, # + apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) + n_expts_tot = logits.shape[-1] // simulated_ep + # mutate bitmatrix + if simulated_ep > 1: + expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep) + + return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act) + + +# -------------------------- +# torch reference +# -------------------------- + + +def compute_expt_data_torch(hist, n_expts_tot, n_gates): + # offset for each experts + device = hist.device + token_offs_raw = torch.cumsum(hist, dim=0) + token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw)) + token_offs_raw = token_offs_raw.int() + # maximum number of tiles for all values of `block_m` considered + block_ms = [16, 32, 64, 128] + if is_hip(): + block_ms.append(256) + if n_gates <= n_expts_tot: + max_n_tiles = n_gates + else: + # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1 + # ceil_div(x, y): -(-x // y) + max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms)) + # fill up tile offset/infos for each block + token_offs_pad = dict() + block_pid_map = dict() + for block_m in block_ms: + n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed + token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0) + token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m])) + token_offs_pad[block_m] = token_offs_pad[block_m].int() + # compute data required to drive ragged batch matmul + block_pid_map[block_m] = -torch.ones(max_n_tiles, device=device) + for e in range(n_expts_tot): + offset = token_offs_pad[block_m][e] + for b in range(n_tiles[e]): + block_pid_map[block_m][offset + b] = (b << 16) + e + block_pid_map[block_m] = block_pid_map[block_m].int() + return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + + +def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None): + has_user_provided_indx = expt_indx is not None + n_gates_pad = logits.shape[0] * n_expts_act + + if n_rows is not None: + logits = logits[:n_rows, :] + + def topk(vals, k, expt_indx): + # topk of experts + if has_user_provided_indx: + tk_indx = expt_indx + else: + tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] + tk_indx = tk_indx.long() + tk_val = torch.take_along_dim(vals, tk_indx, dim=1) + tk_indx = tk_indx.int() + return tk_val, tk_indx + + _, n_expts_tot = logits.shape + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx) + if not sm_first: + expt_scal = torch.softmax(expt_scal, dim=-1) + # sort each token's selections by expert + if not has_user_provided_indx: + expt_indx, sort_indices = torch.sort(expt_indx, dim=1) + expt_scal = torch.gather(expt_scal, 1, sort_indices) + # flatten topk data + expt_scal = expt_scal.reshape(-1) + expt_indx = expt_indx.reshape(-1).to(torch.int32) + # sort by expert_id so experts are contiguous for the matmul + topk_indx = torch.argsort(expt_indx, stable=True) + gate_indx = torch.argsort(topk_indx, stable=True) + gate_scal = expt_scal[topk_indx] + hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() # histogram of tokens over experts + # pack the matmul data structure + gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) + scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) + # compute expt_data + expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad) + return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx diff --git a/torch-ext/triton_kernels/routing_details/__pycache__/_expt_data.cpython-310.pyc b/torch-ext/triton_kernels/routing_details/__pycache__/_expt_data.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3974a786c4e8b40c429b51ab35f4d725c5099904 Binary files /dev/null and b/torch-ext/triton_kernels/routing_details/__pycache__/_expt_data.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/routing_details/__pycache__/_routing_compute.cpython-310.pyc b/torch-ext/triton_kernels/routing_details/__pycache__/_routing_compute.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5b65f155c6bb1d4794fc62122a2f42ae18e0801 Binary files /dev/null and b/torch-ext/triton_kernels/routing_details/__pycache__/_routing_compute.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/routing_details/_expt_data.py b/torch-ext/triton_kernels/routing_details/_expt_data.py new file mode 100644 index 0000000000000000000000000000000000000000..125afe4d528282e240728633cd381958c3ab01e7 --- /dev/null +++ b/torch-ext/triton_kernels/routing_details/_expt_data.py @@ -0,0 +1,64 @@ +import triton +import triton.language as tl + + +@triton.jit +def _cdiv_pow2(n, log2_k): + return (n + ((1 << log2_k) - 1)) >> log2_k + + +@triton.jit +def _expt_data_memset(Hist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2, + SIZES: tl.constexpr, BLOCK: tl.constexpr): + + pid = tl.program_id(0) + + if pid <= SIZES: + + MDStarts += pid * tile_starts_stridem + x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty) + Tile_ptrs = MDStarts + tl.arange(0, BLOCK) + tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1) + + for i in range(0, n_expts_tot + 1, BLOCK): + + offs_n = tl.arange(0, BLOCK) + i + mask_n0 = offs_n < n_expts_tot + hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0) + hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2) + + tile_starts = tl.cumsum(hist_tile, 0) + x_tile + x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty) + tl.store(Tile_ptrs, tile_starts - hist_tile) + Tile_ptrs += BLOCK + + else: + + pid -= (SIZES + 1) + TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK) + tl.store(TileInfoOut, 0xffffffff) + + +@triton.jit +def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2, + SIZES: tl.constexpr, BLOCK: tl.constexpr): + + pid = tl.program_id(0) + + expt_id = pid // SIZES + buff_id = pid % SIZES + + MDTileStarts += buff_id * tile_starts_stridem + MDTileInfo += buff_id * tile_info_stridem + + n_tokens = tl.load(Hist + expt_id) + tile_dim_log2 = first_tile_dim_log2 + buff_id + n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2) + + tile_off = tl.load(MDTileStarts + expt_id) + MDTileInfo += tile_off + + for block_off in range(0, n_blocks, BLOCK): + block_offs = block_off + tl.arange(0, BLOCK) + data = (block_offs << 16) + expt_id + tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks) diff --git a/torch-ext/triton_kernels/routing_details/_routing_compute.py b/torch-ext/triton_kernels/routing_details/_routing_compute.py new file mode 100644 index 0000000000000000000000000000000000000000..a72900030f4a241f3d2c6fd8740020de8ae0145d --- /dev/null +++ b/torch-ext/triton_kernels/routing_details/_routing_compute.py @@ -0,0 +1,148 @@ +import triton +import triton.language as tl + +from ._expt_data import _expt_data_compute, _expt_data_memset + + +@triton.jit +def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histogram + BLOCK_N: tl.constexpr): + loop_iterations = (hist_size + BLOCK_N - 1) // BLOCK_N + x = tl.zeros([BLOCK_N], ExpertHist.dtype.element_ty) + for i in range(loop_iterations): + offs_n = i * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < hist_size + hist2 = tl.load(ExpertHist + offs_n, mask=mask_n) + tok_starts = tl.cumsum(hist2, 0) - hist2 + x + x += tl.sum(hist2, 0) + tl.store(FinalExpertOffs + offs_n, tok_starts, mask=mask_n) + offs_n += BLOCK_N + + +@triton.jit +def _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id): + offs_m = tl.arange(0, BLOCK_M) + # iterate over input data + curr_sum = 0 + for _ in range(0, shape_pm, BLOCK_M): + offs = offs_m * stride_pm + expt_id * stride_pn + curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm) + out = tl.cumsum(curr, 0) + curr_sum + curr_sum += tl.sum(curr, 0) + tl.store(PartialHist + offs, out - curr, mask=offs_m < shape_pm) + offs_m += BLOCK_M + + +@triton.jit +def _keyed_add(x, y): + + # we keep the key in the upper 16 bits of a uint32: + key_mask: tl.constexpr = 0xffff0000 + + kx = x & key_mask + ky = y & key_mask + z = tl.where(kx == ky, x + y - kx, y) + return z + + +@triton.jit +def _routing_compute_indx(pid_m, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, + stride_pn, TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr): + + if isinstance(n_tokens, tl.tensor) and n_tokens.dtype.is_ptr(): + n_tokens = tl.load(n_tokens) + n_gates = n_tokens * N_EXPTS_ACT + + tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768) + + local_offs = tl.arange(0, N_EXPTS_ACT * BLOCK_M) + offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs + expert = tl.load(ExptIndx + offs, mask=(offs < n_gates), other=-1).to(tl.uint32) + + # stable-sort by expert ID: + kv_pairs = ((expert << 16) | local_offs).to(tl.uint32) + kv_pairs = tl.sort(kv_pairs, 0) + expert = kv_pairs >> 16 + offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff) + mask = expert != 0xffff + gate_scal = tl.load(ExptScal + offs, mask=mask) + + # compute run lengths in expert-sorted order: + x = (kv_pairs & 0xffff0000 | 0x00000001) + expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add) + exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff + + gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask) + gates += tl.load(TokensStart + expert, mask=mask) + gates += exclusive_run_lengths + + tl.store(ScatterIndx + offs, gates, mask=mask) + tl.store(GatherIndx + gates, offs, mask=mask) + tl.store(GateScal + gates, gate_scal, mask=mask) + + +@triton.jit +def _combined_routing_compute(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn, + TokensStart, n_tokens, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr, Hist, + MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2, + SIZES: tl.constexpr, BLOCK: tl.constexpr, blocks2a): + + pid = tl.program_id(0) + if pid < blocks2a: + _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2, + SIZES, BLOCK) + else: + pid -= blocks2a + _routing_compute_indx(pid, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, + stride_pn, TokensStart, n_tokens, BLOCK_M, N_EXPTS_ACT) + + +@triton.jit +def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr): + pid_m = tl.program_id(0) + cutoff_word = cutoff // 32 + cutoff_bit = cutoff % 32 + cutoff_mask = (1 << (cutoff_bit)) - 1 + for start_n in range(0, shape_bn, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + values = tl.load(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, mask=offs_n < shape_bn) + values = tl.where(offs_n == cutoff_word, values & cutoff_mask, values) + values = tl.where(offs_n > cutoff_word, 0, values) + tl.store(Bitmatrix + pid_m * stride_bm + offs_n * stride_bn, values, mask=offs_n < shape_bn) + + +@triton.jit +def _combined_routing_memset(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size, + n_expts_tot, PartialHist, shape_pm, stride_pm, stride_pn, MDStarts, tile_starts_stridem, + blocks1a, MDTileInfo, first_tile_dim_log2, SIZES: tl.constexpr, BLOCK_A: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr): + """ + This kernel essentially combines 6 different pieces of functionality, + statically branching on the value of tl.program_id(0) to decide which + codepath to take. + + pid == 0: create the token cumsum + 1 <= pid <= SIZES: create a tile cumsum + SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff + blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs + pid == blocks1a + n_expts_tot: compute_expt_offs + pid > blocks1a + n_expts_tot: initialise Indx to sentinel + + As each of these is a relatively trivial workload, launching them from + this single trampoline is beneficial as they can execute on different + streaming multiprocesses in parallel. + """ + + pid = tl.program_id(0) + + if pid < blocks1a: + _expt_data_memset(ExpertHist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2, + SIZES, BLOCK_A) + elif pid == n_expts_tot + blocks1a: + _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N) + elif pid < n_expts_tot + blocks1a: + _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a) + else: + offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK) + mask = offs < size + tl.store(Indx + offs, sentinel, mask=mask) diff --git a/torch-ext/triton_kernels/specialize.py b/torch-ext/triton_kernels/specialize.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e95eebc0a4f2f9e25555bb1bd38a3dee1a96d4 --- /dev/null +++ b/torch-ext/triton_kernels/specialize.py @@ -0,0 +1,132 @@ +import inspect +import re +import textwrap +import types +import triton + + +def cacheable(f): + """ + A decorator that allow you to write something of the form: + + @cacheable + def my_kernel(): return (expression dynamically defining a kernel) + + such that it interacts gracefully with triton cache and preload. + """ + + g = f() + g.fn.__name__ = f.__name__ + g.fn.__module__ = f.__module__ + g.fn.__qualname__ = f.__qualname__ + g._fn_name = f"{f.__module__}.{f.__qualname__}" + return g + + +def define_kernel(src, module, attrs=None, **extra_globals): + """ + Dynamically create a Triton function or kernel from a src string, + linking any symbols in the kernel to objects specified by extra_globals. + """ + + # create templace function + def _empty_fn(): + pass + + gdict = dict(**(_empty_fn.__globals__)) + gdict.update(extra_globals) + f = types.FunctionType(_empty_fn.__code__, gdict) + f.__module__ = module.__name__ + + src = textwrap.dedent(src) + src = src[src.find("def "):] + + stored_functions = [] + function_name = src[4:].split("(")[0].strip() + + exec_globals = gdict + exec_globals.update({"stored_functions": stored_functions}) + exec(src + "\n\nstored_functions.append(" + function_name + ")\n", exec_globals) + + f.__signature__ = inspect.signature(stored_functions[0]) + f.__name__ = function_name + f.__doc__ = stored_functions[0].__doc__ + + if attrs is None: + attrs = dict() + f = triton.JITFunction(f, **attrs) + f._unsafe_update_src(src) + return f + + +def specialize(fn, module, constants, tuples, name=None, do_not_specialize=tuple()): + assert isinstance(fn, triton.runtime.jit.JITFunction) + if name is None: + name = f"{fn.__name__}" + # Get original source code + src = inspect.getsource(fn.fn) + src = textwrap.dedent(src) + lines = src.split("\n") + # Skip decorator and def line + def_idx = next(i for i, line in enumerate(lines) if line.strip().startswith("def")) + # separate header vs body LOC + header_end = def_idx + while not lines[header_end].rstrip().endswith(":"): + header_end += 1 + body_lines = lines[header_end + 1:] + header_lines = lines[def_idx:header_end + 1] + # clean-up header + header_clean = [ + l.split("#", 1)[0].strip() # keep code, discard comment + for l in header_lines + if l.split("#", 1)[0].strip() # skip blank‑after‑comment lines + ] + # decompose arguments + header_src = " ".join(header_clean) # turn it into a single line + m = re.search(r"\((.*)\)\s*:", header_src) + if not m: + raise ValueError("Could not parse function header") + args_str = m.group(1) + args = [arg.strip() for arg in args_str.split(",") if arg.strip()] + non_specialized_args = [] + for arg in args: + arg_key = arg.split(":")[0].split("=")[0].strip() + new_args = tuples.get(arg_key, [arg]) + if arg_key not in constants: + non_specialized_args += new_args + # add global symbols + spec_fns = {v.__name__: v for k, v in constants.items() if isinstance(v, triton.runtime.jit.JITFunction)} + globals = spec_fns | fn.get_capture_scope() + # build new source code and define kernel dynamically + new_signature = f"def {name}({', '.join(non_specialized_args)}):" + constexpr_lines = [ + f" {key}: tl.constexpr = {value.__name__ if callable(value) else value}" for key, value in constants.items() + ] + tuple_lines = [ + f" {key} = {'(' + ','.join(value) + (',' if len(value)>=1 else '') + ')'}" for key, value in tuples.items() + ] + new_src = "\n".join(["@triton.jit", new_signature] + constexpr_lines + tuple_lines + body_lines) + # find function parameters + sig = inspect.signature(triton.runtime.jit.JITFunction.__init__) + params = list(sig.parameters.values())[2:] + attrs = {param.name: getattr(fn, param.name, param.default) for param in params} + + # make a new repr which appends the repr of the specialized functions. + base_repr = attrs["repr"] + + def new_repr(specialization): + ret = base_repr(specialization) + for spec_fn in spec_fns.values(): + spec_repr = spec_fn.repr(None) + if spec_repr: + spec_repr = spec_repr.strip("_") + if spec_repr: + ret += f"_{spec_repr}" + return ret + + attrs["repr"] = new_repr + + if do_not_specialize: + attrs["do_not_specialize"] = do_not_specialize + ret = define_kernel(new_src, module, attrs, **globals) + return ret diff --git a/torch-ext/triton_kernels/swiglu.py b/torch-ext/triton_kernels/swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..606a03500a39d6c9063cd6655bf3d513bd47cf23 --- /dev/null +++ b/torch-ext/triton_kernels/swiglu.py @@ -0,0 +1,100 @@ +from dataclasses import dataclass +from triton_kernels.numerics import InFlexData, OutFlexData +import torch +import triton +from .swiglu_details._swiglu import _swiglu, _swiglu_fn +from triton_kernels import target_info + + +@dataclass(frozen=True) +class FlexCtx: + out_data: OutFlexData = OutFlexData() + inp_data: InFlexData = InFlexData() + saturate_inf: bool = False + + +@dataclass(frozen=True) +class PrecisionConfig: + limit: float + flex_ctx: FlexCtx = FlexCtx() + + +swiglu_fn = _swiglu_fn + + +class SwiGLU(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, alpha, precision_config, routing_data): + N = a.shape[-1] + M = a.numel() // N + assert a.stride()[-1] == 1 + assert a.shape[-1] % 2 == 0 + out = torch.empty(size=(M, N // 2), dtype=a.dtype, device=a.device) + flex_ctx = precision_config.flex_ctx + # optimization hyperparameters + BLOCK_M, BLOCK_N = 32 // a.itemsize, 128 + num_warps = 4 + kwargs = {'maxnreg': 64} if not target_info.is_hip() else {} + # launch semi-persistent kernel + N_BLOCKS = triton.cdiv(N // 2, BLOCK_N) + num_sms = target_info.num_sms() + if routing_data is not None: + waves_per_sm = 32 if target_info.is_hip() else 128 + num_pid = num_sms * (waves_per_sm // num_warps) + M_BLOCKS = max(1, triton.cdiv(num_pid, N_BLOCKS)) + grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), ) + else: + M_BLOCKS = triton.cdiv(M, BLOCK_M) + if M_BLOCKS * N_BLOCKS >= 8 * num_sms: + grid = (8 * num_sms, ) + else: + grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), ) + n_tokens = None + if routing_data is not None: + n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot] + _swiglu[grid]( + flex_ctx.out_data.reinterpret(out), + flex_ctx.out_data.expected_scale, + flex_ctx.out_data.actual_scale, + flex_ctx.out_data.checksum_scale, + flex_ctx.inp_data.reinterpret(a), + flex_ctx.inp_data.scale, + alpha, + M, + N // 2, + a.shape[-1], + 1, + out.shape[-1], + 1, + precision_config.limit, + n_tokens, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + EVEN_N=(N // 2) % BLOCK_N == 0, + M_BLOCKS=M_BLOCKS, + N_BLOCKS=N_BLOCKS, + flexpoint_saturate_inf=flex_ctx.saturate_inf, + num_warps=num_warps, + **kwargs, + ) + out = out.view(a.shape[:-1] + out.shape[-1:]) + return out + + +def swiglu(a, alpha, precision_config, routing_data=None): + return SwiGLU.apply(a, alpha, precision_config, routing_data) + + +def swiglu_torch(a, alpha, precision_config): + limit = precision_config.limit + a_gelu = a[..., ::2] + if limit is not None: + a_gelu = a_gelu.clamp(max=limit) + a_linear = a[..., 1::2] + if limit is not None: + a_linear = a_linear.clamp(min=-limit, max=limit) + + out_gelu = a_gelu * torch.sigmoid(alpha * a_gelu) + out = out_gelu * (a_linear + 1) + return out diff --git a/torch-ext/triton_kernels/swiglu_details/__pycache__/_swiglu.cpython-310.pyc b/torch-ext/triton_kernels/swiglu_details/__pycache__/_swiglu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7558eced4bc8d3c593c81ae3f97a99ced64c52a0 Binary files /dev/null and b/torch-ext/triton_kernels/swiglu_details/__pycache__/_swiglu.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/swiglu_details/_swiglu.py b/torch-ext/triton_kernels/swiglu_details/_swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..fbcea076f7dc7d1975d5a961ebed60bd21a53790 --- /dev/null +++ b/torch-ext/triton_kernels/swiglu_details/_swiglu.py @@ -0,0 +1,102 @@ +from triton_kernels.numerics_details.flexpoint import load_scale, float_to_flex, update_scale +import triton +import triton.language as tl + + +@triton.jit +def clip(x, limit, clip_lower: tl.constexpr): + res = tl.minimum(x, limit) + if clip_lower: + res = tl.maximum(-limit, res) + return res + + +@triton.jit +def thread_local_absmax(x, BLOCK_SIZE: tl.constexpr, NUM_THREADS: tl.constexpr): + return tl.max(tl.reshape(tl.abs(x), [NUM_THREADS, BLOCK_SIZE // NUM_THREADS], can_reorder=True), axis=1) + + +def swiglu_repr(specialization): + signature = specialization.signature + constants = specialization.constants + convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype + dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in ["Out", "A"]]) + blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N"]]) + return f"_swiglu_{dtypes}_{blocks}" + + +def swiglu_launch_metadata(grid, kernel, args): + M, N = args["M"], args["N"] + ret = dict() + ret["name"] = f"{kernel.name} [M = {M}, N = {N}]" + A, Out = args["A"], args["Out"] + ret["bytes"] = Out.numel() * Out.element_size() + A.numel() * A.element_size() + return ret + + +@triton.jit +def compute_swiglu(gelu, linear, scale, alpha, limit): + gelu = gelu.to(tl.float32) * scale + if limit is not None: + gelu = clip(gelu, limit, clip_lower=False) + linear = linear.to(tl.float32) * scale + if limit is not None: + linear = clip(linear, limit, clip_lower=True) + s = gelu / (1 + tl.exp(-alpha * gelu)) + return tl.fma(s, linear, s) # (s * (linear + 1)) + + +@triton.jit(repr=lambda _: "_swiglu") +def _swiglu_fn(input, alpha, limit): + gelu, linear = tl.split(tl.reshape(input, (input.shape[0], input.shape[1] // 2, 2))) + return compute_swiglu(gelu, linear, 1.0, alpha, limit) + + +@triton.jit(repr=swiglu_repr, launch_metadata=swiglu_launch_metadata) +def _swiglu(Out, OutExpectedScale, OutActualScale, OutChecksumScale, A, AScale, alpha, M, N, stride_am, stride_an, + stride_outm, stride_outn, limit: tl.constexpr, NTokens, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, M_BLOCKS, N_BLOCKS, flexpoint_saturate_inf: tl.constexpr): + if NTokens is not None: + M = tl.load(NTokens) + M_BLOCKS = (M + BLOCK_M - 1) // BLOCK_M + + local_max = tl.full([tl.extra.cuda.num_threads()], 0.0, tl.float32) + + a_scale = load_scale(AScale) + out_expected_scale = load_scale(OutExpectedScale) + + for pid in tl.range(tl.program_id(0), M_BLOCKS * N_BLOCKS, tl.num_programs(0), num_stages=2): + pid_m = (pid // N_BLOCKS) + pid_n = (pid % N_BLOCKS) + off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_m = off_m < M + mask_n = off_n < N + packed_off_n = pid_n * BLOCK_N + tl.arange(0, 2 * BLOCK_N) // 2 + packed_mask_n = packed_off_n < N + packed_mask_n = tl.max_constancy(packed_mask_n, [16]) + # load a + packed_off_n = pid_n * 2 * BLOCK_N + tl.arange(0, 2 * BLOCK_N) + packed_offs = off_m[:, None] * stride_am + packed_off_n[None, :] * stride_an + if EVEN_N: + a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.) + else: + if pid_n * BLOCK_N + BLOCK_N <= N: + a_packed = tl.load(A + packed_offs, mask=mask_m[:, None], other=0.) + else: + packed_mask = mask_m[:, None] & packed_mask_n[None, :] + a_packed = tl.load(A + packed_offs, mask=packed_mask, other=0.) + a_gelu, a_linear = tl.split(tl.reshape(a_packed, (BLOCK_M, BLOCK_N, 2))) + out = compute_swiglu(a_gelu, a_linear, a_scale, alpha, limit) + # update flexpoint stats and divide by scale + # we don't need masking because of the `other` when loading `A` + if OutActualScale is not None: + absmax = thread_local_absmax(out, out.numel, tl.extra.cuda.num_threads()) + local_max = tl.maximum(local_max, absmax) + out = float_to_flex(out, out_expected_scale, + None, # ActualScale: local absmax is tracked and updated after the loop + OutChecksumScale, None, Out, flexpoint_saturate_inf) + mask = mask_m[:, None] if EVEN_N else mask_m[:, None] & mask_n[None, :] + tl.store(Out + off_m[:, None] * stride_outm + off_n[None, :] * stride_outn, out, mask) + + update_scale(local_max, OutActualScale, Out) diff --git a/torch-ext/triton_kernels/target_info.py b/torch-ext/triton_kernels/target_info.py new file mode 100644 index 0000000000000000000000000000000000000000..9beae7108e8d94e1bf04747e36c7d01a5650e5ca --- /dev/null +++ b/torch-ext/triton_kernels/target_info.py @@ -0,0 +1,77 @@ +import torch +import triton + +cached_capabilities = {} + + +def is_cuda(): + if "is_cuda" not in cached_capabilities: + target = triton.runtime.driver.active.get_current_target() + cached_capabilities["is_cuda"] = False if target is None else target.backend == "cuda" + return cached_capabilities["is_cuda"] + + +def is_hip(): + if "is_hip" not in cached_capabilities: + cached_capabilities["is_hip"] = torch.cuda.is_available() and bool(torch.version.hip) + return cached_capabilities["is_hip"] + + +def is_hip_cdna3(): + if "is_hip_cdna3" not in cached_capabilities: + target = triton.runtime.driver.active.get_current_target() + cached_capabilities["is_hip_cdna3"] = (target is not None and target.backend == 'hip' + and target.arch == 'gfx942') + return cached_capabilities["is_hip_cdna3"] + + +def is_hip_cdna4(): + if "is_hip_cdna4" not in cached_capabilities: + target = triton.runtime.driver.active.get_current_target() + cached_capabilities["is_hip_cdna4"] = (target is not None and target.backend == 'hip' + and target.arch == 'gfx950') + return cached_capabilities["is_hip_cdna4"] + + +def cuda_capability_geq(major, minor=0): + """ + Determines whether we have compute capability >= (major, minor) and + returns this as a constexpr boolean. This can be used for guarding + inline asm implementations that require a certain compute capability. + """ + if is_hip(): + return False + if "cuda" not in cached_capabilities: + if torch.cuda.is_available(): + cached_capabilities["cuda"] = torch.cuda.get_device_capability() + else: + cached_capabilities["cuda"] = (0, 0) + return cached_capabilities["cuda"] >= (major, minor) + + +def get_cdna_version(): + """ + Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently + only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD + hardware or unsupported architecture + """ + target = triton.runtime.driver.active.get_current_target() + if target.backend != 'hip': + return -1 + if target.arch == 'gfx942': + return 3 + if target.arch == 'gfx950': + return 4 + return -1 + + +def has_tma_gather(): + return cuda_capability_geq(10, 0) + + +def has_native_mxfp(): + return cuda_capability_geq(10, 0) + + +def num_sms(): + return torch.cuda.get_device_properties(0).multi_processor_count diff --git a/torch-ext/triton_kernels/tensor.py b/torch-ext/triton_kernels/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6d9d9bdbfa7fd20596a0e9fcc95013295769cd --- /dev/null +++ b/torch-ext/triton_kernels/tensor.py @@ -0,0 +1,211 @@ +from dataclasses import dataclass, fields +from typing import Type + +import torch +from triton.tools.tensor_descriptor import TensorDescriptor + +from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows +from .target_info import cuda_capability_geq +from .tensor_details.layout import Layout, StridedLayout + + +@dataclass +class Storage: + data: torch.Tensor + layout: Layout = None + + def __post_init__(self): + assert isinstance(self.data, torch.Tensor) + if self.layout is None: + self.layout = StridedLayout(self.data.shape) + + @property + def device(self): + return self.data.device + + def is_tma_compliant(self): + # TMAs didn't exist until Hopper + if not cuda_capability_geq(9, 0): + return False + # TMAs only exist for 2D, 3D, 5D inputs + if len(self.data.shape) not in [2, 3, 5]: + return False + # TMAs need at most one stride equal to 1 + # and all other strides divisble by 16 + strides = list(self.data.stride()) + try: + major_dim = strides.index(1) + except ValueError: + major_dim = -1 + ndim = self.data.ndim + bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8 + compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim] + return all(compliant) + + def make_tma(self, block_shape, transpose=False): + strides = list(self.data.stride()) + shape = list(self.data.shape) + # TODO + # there is an issue w/ column-major TMA; we transpose instead + transpose = self.data.stride()[-1] != 1 + if transpose: + block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]] + shape = shape[:-2] + [shape[-1], shape[-2]] + strides = strides[:-2] + [strides[-1], strides[-2]] + if self.data.dtype == torch.uint8 and self.layout.name is None: + # physical block size is half logical block size along packed dimension + indx = strides.index(1) + block_shape[indx] = block_shape[indx] // 2 + # Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses + # CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B. + pad = 128 + shape[-1] = (shape[-1] + pad - 1) // pad * pad + block_shape = self.layout.swizzle_block_shape(block_shape) + return TensorDescriptor(self.data, shape, strides, block_shape) + + +@dataclass +class IntegerType: + bitwidth: int + + +@dataclass +class FloatType: + bitwidth_exponent: int + bitwidth_mantissa: int + is_signed: bool + + def __post_init__(self): + self.bitwidth = int(self.is_signed) + self.bitwidth_exponent + self.bitwidth_mantissa + + +BIT = IntegerType(1) +FP4 = FloatType(bitwidth_exponent=2, bitwidth_mantissa=1, is_signed=True) + + +def bitwidth(type: IntegerType | FloatType | torch.dtype): + if isinstance(type, torch.dtype): + return type.itemsize * 8 + return type.bitwidth + + +@dataclass +class Tensor: + storage: Storage | torch.Tensor + dtype: IntegerType | FloatType | torch.dtype = None + shape: list[int] | None = None + shape_max: list[int] | None = None + + def __post_init__(self): + # set storage + if isinstance(self.storage, torch.Tensor): + self.storage = Storage(self.storage) + # initialize dtype + if self.dtype is None: + self.dtype = self.storage.data.dtype + if bitwidth(self.dtype) < 8 and self.shape is None: + raise ValueError("shape must be provided for sub-byte types") + # initialize shape + if self.shape is None: + self.shape = list(self.storage.data.shape) + # validate shape: all elements must be `int` or numel-1 `torch.Tensor` + is_int = lambda s: isinstance(s, int) + is_item = lambda s: hasattr(s, "numel") and s.numel() == 1 + assert all(map(lambda s: is_int(s) or is_item(s), self.shape)) + # initialize shape_max + if self.shape_max is None: + self.shape_max = [None] * len(self.shape) + for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)): + if smax is not None and not is_int(smax): + raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}") + if smax is None: + self.shape_max[i] = s + # validate shape_max: all elements must be `int` + assert all(map(is_int, self.shape_max)) + + # torch compatibility layer + @property + def ndim(self): + return len(self.shape) + + @property + def device(self): + return self.storage.device + + def stride(self, i=None): + return self.storage.data.stride() if i is None else self.storage.data.stride(i) + + def data_ptr(self): + return self.storage.data.data_ptr() + + def numel(self): + return self.storage.data.numel() + + def element_size(self): + return bitwidth(self.dtype) // 8 + + @property + def data(self): + t = self.storage + return t.data if isinstance(t, Storage) else t + + def dim(self): + return self.ndim + + def size(self, i=None): + if i is None: + return self.shape + return self.shape[i] + + +@dataclass +class Bitmatrix(Tensor): + """ + Represents a boolean matrix in a packed format where each element occupies + a single bit of memory. + + _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along + with the actual bitmatrix to avoid having to launch a separate memset + kernel when we call Bitmatrix::sum(). + """ + + scratchpad: torch.Tensor = None + + def __init__(self, storage, shape, shape_max=None, scratchpad=None): + super().__init__(storage, dtype=BIT, shape=shape, shape_max=shape_max) + self.scratchpad = scratchpad + + def sum(self, partials_block_size): + _, n_cols = self.shape + dev = self.device + if self.scratchpad is None: + self.scratchpad = clear_sums(n_cols, dev) + out_ret = self.scratchpad[:n_cols] + self.scratchpad = None # throw error if we try to sum again + return sum_bitmatrix_rows(self, out_ret, partials_block_size) + + +def get_layout(tensor: torch.Tensor | Tensor | None): + if tensor is None: + return None + if isinstance(tensor, Tensor): + return tensor.storage.layout + return StridedLayout + + +def wrap_torch_tensor(torch_tensor, dtype=None): + if dtype is None: + dtype = torch_tensor.dtype + shape = list(torch_tensor.shape) + shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype) + return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape) + + +def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs): + assert isinstance(tensor, Tensor) + old_storage = tensor.storage + old_data = old_storage.layout.unswizzle_data(old_storage.data) + new_layout = layout_cls(old_data.shape, **layout_kwargs) + new_data = new_layout.swizzle_data(old_data) + attrs = {k.name: getattr(tensor, k.name) for k in fields(tensor) if k.name != "storage"} + return Tensor(Storage(new_data, new_layout), **attrs) diff --git a/torch-ext/triton_kernels/tensor_details/layout.py b/torch-ext/triton_kernels/tensor_details/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..2f311d1357128272cef13ed1dc8ce1381b9605c0 --- /dev/null +++ b/torch-ext/triton_kernels/tensor_details/layout.py @@ -0,0 +1,32 @@ +from .layout_details.base import Layout +from .layout_details.blackwell_scale import BlackwellMXScaleLayout +from .layout_details.hopper_scale import HopperMXScaleLayout +from .layout_details.hopper_value import HopperMXValueLayout +from .layout_details.strided import StridedLayout +from ..target_info import cuda_capability_geq + +__all__ = [ + "Layout", + "BlackwellMXScaleLayout", + "HopperMXScaleLayout", + "HopperMXValueLayout", + "StridedLayout", +] + + +def make_default_matmul_mxfp4_w_layout(mx_axis: int): + if cuda_capability_geq(10): + return StridedLayout, dict() + elif cuda_capability_geq(9): + return HopperMXValueLayout, {"mx_axis": mx_axis} + else: + return StridedLayout, dict() + + +def make_default_matmul_mxfp4_w_scale_layout(mx_axis: int, num_warps: int = 8): + if cuda_capability_geq(10): + return BlackwellMXScaleLayout, dict() + elif cuda_capability_geq(9): + return HopperMXScaleLayout, {"mx_axis": mx_axis, "num_warps": num_warps} + else: + return StridedLayout, dict() diff --git a/torch-ext/triton_kernels/tensor_details/layout_details/base.py b/torch-ext/triton_kernels/tensor_details/layout_details/base.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae34ab5440874d5b2ade904edbc9c3dc5b237f4 --- /dev/null +++ b/torch-ext/triton_kernels/tensor_details/layout_details/base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + + +class Layout(ABC): + + def __init__(self, shape) -> None: + self.initial_shape = shape + + @abstractmethod + def swizzle_data(self, data): + pass + + @abstractmethod + def unswizzle_data(self, data): + pass + + @abstractmethod + def swizzle_block_shape(self, block_shape): + pass diff --git a/torch-ext/triton_kernels/tensor_details/layout_details/blackwell_scale.py b/torch-ext/triton_kernels/tensor_details/layout_details/blackwell_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..3817a3c34a97bf780c498a7817e1a8727fcdf42c --- /dev/null +++ b/torch-ext/triton_kernels/tensor_details/layout_details/blackwell_scale.py @@ -0,0 +1,58 @@ +import math +import triton +import triton.language as tl +import torch +from .base import Layout + +SWIZZLE_ALIGN_INNER = 8 +SWIZZLE_SIZE_INNER = 4 +SWIZZLE_SIZE_OUTER = 128 + + +class BlackwellMXScaleLayout(Layout): + name: str = "BLACKWELL_SCALE" + + def __init__(self, shape) -> None: + super().__init__(shape) + *self.leading_shape, self.K, self.N, = shape + self.B = math.prod(self.leading_shape) + self.ALIGN_K = 8 + self.ALIGN_N = 128 + self.SWIZZLE_K = 4 + self.K_pad = (self.K + self.ALIGN_K - 1) // self.ALIGN_K * self.ALIGN_K + self.N_pad = (self.N + self.ALIGN_N - 1) // self.ALIGN_N * self.ALIGN_N + + def swizzle_data(self, data): + data = torch.nn.functional.pad(data, (0, self.N_pad - self.N, 0, self.K_pad - self.K)) + data = data.transpose(-1, -2).contiguous() + data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.ALIGN_N // 32, 32, self.K_pad // self.SWIZZLE_K, + self.SWIZZLE_K) + data = data.transpose(2, 4).contiguous() + data = data.view(1, self.B * self.N_pad // 128, self.K_pad // 4, 2, 256) + return data + + def unswizzle_data(self, data): + data = data.reshape(self.B, self.N_pad // self.ALIGN_N, self.K_pad // self.SWIZZLE_K, 32, self.ALIGN_N // 32, + self.SWIZZLE_K) + data = data.transpose(2, 4) + data = data.reshape(*self.leading_shape, self.N_pad, self.K_pad) + data = data.transpose(-1, -2) + return data[..., :self.K, :self.N] + + def swizzle_block_shape(self, block_shape): + MX_PACK_DIVISOR = 32 + MX_SCALE_BLOCK_K = block_shape[1] // MX_PACK_DIVISOR + return [1, block_shape[0] // 128, MX_SCALE_BLOCK_K // 4, 2, 256] + + +@triton.jit +def unswizzle_mx_scale_bw(x, SIZE_OUTER: tl.constexpr = SWIZZLE_SIZE_OUTER, + SIZE_INNER: tl.constexpr = SWIZZLE_SIZE_INNER, + ALIGN_INNER: tl.constexpr = SWIZZLE_ALIGN_INNER): + shape_0: tl.constexpr = x.shape[0] + shape_1: tl.constexpr = x.shape[1] + tl.static_assert(shape_1 % SIZE_OUTER == 0) + tl.static_assert(shape_1 // SIZE_OUTER <= ALIGN_INNER) + x = x.reshape(shape_0, (shape_1 // SIZE_OUTER) // SIZE_INNER, 32, SIZE_OUTER // 32, SIZE_INNER) + x = x.trans(0, 3, 2, 1, 4).reshape(shape_0 * SIZE_OUTER, shape_1 // SIZE_OUTER) + return x diff --git a/torch-ext/triton_kernels/tensor_details/layout_details/hopper_scale.py b/torch-ext/triton_kernels/tensor_details/layout_details/hopper_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..7211468faa727714e64053de4b64f9ad38a7abdf --- /dev/null +++ b/torch-ext/triton_kernels/tensor_details/layout_details/hopper_scale.py @@ -0,0 +1,80 @@ +import torch +import triton +import triton.language as tl +from .base import Layout + + +class HopperMXScaleLayout(Layout): + name: str = "HOPPER_SCALE" + + def __init__(self, shape, mx_axis, num_warps=8) -> None: + assert num_warps & (num_warps - 1) == 0, "warps_n must be a power of 2" + super().__init__(shape) + self.mx_axis = mx_axis + self.num_warps = num_warps + *self.leading_shape, _, _ = shape + + def _maybe_mT(self, data): + if self.mx_axis == len(self.leading_shape): + return data.contiguous().mT + return data + + def swizzle_data(self, data): + data = self._maybe_mT(data).contiguous() + *batch, M, K = data.shape + SWIZZLE_ALIGN_M = 2 * self.num_warps * 2 * 8 + SWIZZLE_ALIGN_K = 2 + pad_m = (SWIZZLE_ALIGN_M - (M % SWIZZLE_ALIGN_M)) % SWIZZLE_ALIGN_M + pad_k = (SWIZZLE_ALIGN_K - (K % SWIZZLE_ALIGN_K)) % SWIZZLE_ALIGN_K + data = torch.nn.functional.pad(data, (0, pad_k, 0, pad_m)) + *batch, M, K = data.shape + assert data.is_contiguous() + assert M % ( + 2 * self.num_warps * 2 * + 8) == 0 and K % 2 == 0, f"Input tensor must have a subtile of shape (..., {2 * self.num_warps * 2 * 8}, 2)" + b = len(batch) + data = data.reshape(*batch, M // (2 * self.num_warps * 2 * 8), 2, self.num_warps, 2, 8, K // 2, 2) + perm = [0, 2, 5, 1, 4, 6, 3] + perm = list(range(b)) + [b + p for p in perm] + data = data.permute(*perm) + data = data.flatten(-5, -1) + data = data.flatten(-3, -2) + assert data.shape[-2] == M // 32 + assert data.shape[-1] == K * 32 + data = self._maybe_mT(data) + return data + + def unswizzle_data(self, data): + data = self._maybe_mT(data) + *batch, M, K = data.shape + b = len(batch) + data = data.reshape(*batch, M // self.num_warps, self.num_warps, K // 64, 2, 8, 2, 2) + perm = [0, 3, 1, 6, 4, 2, 5] + perm = list(range(b)) + [b + p for p in perm] + data = data.permute(*perm) + data = data.reshape(*batch, M * 32, K // 32) + data = self._maybe_mT(data) + return data + + def swizzle_block_shape(self, block_shape): + return block_shape + + +@triton.jit +def unswizzle_mxfp4_scale_hopper(x, mx_axis: tl.constexpr, num_warps: tl.constexpr): + """ + Triton inverse of swizzle_mxfp4_scale_hopper + """ + tl.static_assert(len(x.shape) == 2, "NYI") + # implementation assumes mxfp data is packed along the last dimension + x = x.trans() if mx_axis == 0 else x + M: tl.constexpr = x.shape[0] + K: tl.constexpr = x.shape[1] + tl.static_assert(M % num_warps == 0, f"M must be divisible by {num_warps}. Got {M}") + tl.static_assert(K % 64 == 0, f"K must be divisible by 64. Got {K}") + x = x.reshape(M // num_warps, num_warps, K // 64, 2, 8, 2, 2) + x = x.trans(0, 3, 1, 6, 4, 2, 5) + x = x.reshape(M * 32, K // 32) + # implementation assumed mxfp data is packed along the last dimension + x = x.trans() if mx_axis == 0 else x + return x diff --git a/torch-ext/triton_kernels/tensor_details/layout_details/hopper_value.py b/torch-ext/triton_kernels/tensor_details/layout_details/hopper_value.py new file mode 100644 index 0000000000000000000000000000000000000000..ede4995e46418445a2c49bd37fc48fe47e1fbd31 --- /dev/null +++ b/torch-ext/triton_kernels/tensor_details/layout_details/hopper_value.py @@ -0,0 +1,323 @@ +import torch +import triton +import triton.language as tl +from .base import Layout + + +def right_shift_unsigned(x, shift): + return (x >> shift) & ((1 << (32 - shift)) - 1) + + +# ----------------------------------------------------------------------- +# Interleave the bits of four consecutive fp4 values (i.e. 16-bits) as: +# 1000000111000000 (first fp4) +# 1000000111000000 (second fp4) +# 1000000111000000 (third fp4) +# 0110110000000000 (fourth fp4) +# This is done so that dequantization can be done in 14 SASS instructions +# ----------------------------------------------------------------------- + + +def _compress_fp4(x): + x = x.to(torch.int32) + return ((x & 0x8) << 12) | ((x & 0x7) << 6) + + +def _compress_fourth(x): + x = x.to(torch.int32) + return ((x & 0x8) << 11) | ((x & 0x6) << 9) | ((x & 0x1) << 13) + + +def _pack_bits(x: torch.Tensor, mx_axis: int): + x = x.contiguous() + assert x.shape[-1] % 4 == 0, "Input tensor must have a last dimension divisible by 4" + x = x.reshape(x.shape[:-1] + (x.shape[-1] // 4, 4)) + first = _compress_fp4(x[..., 0]) | (_compress_fp4(x[..., 0] >> 4) << 16) + second = _compress_fp4(x[..., 1]) | (_compress_fp4(x[..., 1] >> 4) << 16) + third = _compress_fp4(x[..., 2]) | (_compress_fp4(x[..., 2] >> 4) << 16) + fourth = _compress_fourth(x[..., 3]) | (_compress_fourth(x[..., 3] >> 4) << 16) + x = first | right_shift_unsigned(second, 3) | right_shift_unsigned(third, 6) | fourth + assert x.is_contiguous() + x = x.view(torch.uint8) + return x + + +# ----------------------------------------------------------------------- +# inverse operation of _pack_bits +# ----------------------------------------------------------------------- + + +def _bf16_to_fp4e2m1(x): + # 0bAxxxxxxBCDxxxxxx (int16) -> 0b0000ABCD (uint8) + assert x.dtype == torch.int16 + s = (right_shift_unsigned(x, 15) & 0x1) << 3 + em = right_shift_unsigned(x, 6) & 0x7 + return (s | em).to(torch.uint8) + + +def _bf16x2_to_fp4e2m1x2(x): + # 0bAxxxxxxBCDxxxxxx_0bExxxxxxFGHxxxxxx (int32) -> 0bABCD_EFGH (uint8) + assert x.dtype == torch.int32 + lo = (x & 0xFFFF).to(torch.int16) + hi = (right_shift_unsigned(x, 16) & 0xFFFF).to(torch.int16) + ret_lo = _bf16_to_fp4e2m1(lo) + ret_hi = _bf16_to_fp4e2m1(hi) + return ret_lo | (ret_hi << 4) + + +def _unpack_bits(x, mx_axis: int): + x = x.view(torch.int32) + m = 0b10000001110000001000000111000000 + a = (x << 1) & 0b10000000000000001000000000000000 + b = right_shift_unsigned(x, 3) & 0b00000001100000000000000110000000 + c = right_shift_unsigned(x, 7) & 0b00000000010000000000000001000000 + unpacked = [x & m, (x << 3) & m, (x << 6) & m, (a | b) | c] + x = torch.stack(unpacked, dim=-1) + x = x.flatten(-2, -1) + x = _bf16x2_to_fp4e2m1x2(x) + return x + + +# ----------------------------------------------------------------------- + + +class HopperMXValueLayout(Layout): + name: str = "HOPPER_VALUE" + + def __init__(self, shape, mx_axis, mma_version=3): + super().__init__(shape) + assert mx_axis in range(len(shape)) + self.mx_axis = mx_axis + self.mma_version = mma_version + *self.leading_shape, self.K, self.N, = shape + + def _maybe_mT(self, data): + if self.mx_axis == len(self.leading_shape): + return data.mT + return data + + def swizzle_data(self, data): + """ + Given a uint8 tensor of shape (*, M, K), returns a tensor of shape + (*, M // 4, K * 4) such that: + + 1) Groups contiguously all the elements owned by the same thread of 4 + mma tiles along the K axis. The following animation shows a similar + grouping for 2 tiles along M and 2 tiles along K rather than 4 along K + as done here: + https://neuralmagic.com/wp-content/uploads/2024/10/animation_4.gif + + 2) Moves the elements belonging to thread 4-7 to be contiguous with those + from thread 0-3. This is done to get a full cache line when loading them + from HBM. + + mx_axis selects the lhs or rhs of the matmul. + + WARNING: Assumes that the matmul will be done in bf16 or fp16! + Implementing it for fp8 is as easy as making the tile size (8, 8) + """ + batch = data.ndim - 2 + assert batch >= 0 + assert self.mma_version in (2, 3) + data = self._maybe_mT(data) + init_shape = data.shape + + # We are loading 8 bf16 elements per thread to use ld.global.v4 + # Every u8 represents 2 mxfp4 elements + u8_kwidth = 8 // 2 if self.mma_version == 2 else 1 + + # Pack the 4 // u8_kwidth subtiles of an mma into a u4x8 + contig = (1, u8_kwidth) + scott_trick = (2, 1) + threads = (4, 4) + warp_tile = (2, 2) + k_tile = (1, 4 // u8_kwidth) + + sizes = list(data.shape[:-2]) + pads = [] + # [rest, K, tile, threads] per dimension + for i, (a, b, c, s, d) in enumerate(zip(k_tile, warp_tile, threads, scott_trick, contig)): + pack = a * b * c * s * d + size = data.shape[batch + i] + pad = (pack - size % pack) % pack + pads += [(0, pad)] + sizes.append((size + pad) // pack) + sizes += [a, b, c, s, d] + + pads = tuple(x for t in pads[::-1] for x in t) + data = torch.nn.functional.pad(data, pads) + init_shape = data.shape + # 0: rest[0] + # 1: k_tile[0] + # 2: warp_tile[0] + # 3: threads[0] + # 4: scott_trick[0] + # 5: contig[0] + # 6: rest[1] + # 7: k_tile[1] + # 8: warp_tile[1] + # 9: threads[1] + # 10: scott_trick[1] + # 11: contig[1] + data = data.view(*sizes) + # Want [rest[0], threads[0], rest[1], scott_trick[0], scott_trick[0], threads[1], contig[1], contig[0], k_tile[1], k_tile[0], warp_tile[1], warp_tile[0]] + perm = [0, 3, 6, 10, 4, 9, 7, 1, 8, 2, 5, 11] + perm = list(range(batch)) + [batch + p for p in perm] + data = data.permute(*perm).contiguous() + # These are views + data = data.flatten(-10, -1) + data = data.flatten(-3, -2) + assert data.is_contiguous() + assert data.shape[-2] == init_shape[-2] // 4 + assert data.shape[-1] == init_shape[-1] * 4 + # twiddle the bits + data = _pack_bits(data, self.mx_axis) + data = self._maybe_mT(data) + return data + + def unswizzle_data(self, data): + data = self._maybe_mT(data) + data = _unpack_bits(data, self.mx_axis) + *batch, M, K = data.shape + # We have two times the elements if we already upcasted to bfloat16 + mult = 2 if data.dtype == torch.bfloat16 else 1 + assert M % 4 == 0, "M must be divisible by 4" + assert K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}" + # We are loading 8 bf16 elements per thread to use ld.global.v4 + # Every u8 represents 2 mxfp4 elements + u8_kwidth = 8 // 2 if self.mma_version == 2 else 1 + data = data.reshape(*batch, M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult) + b = len(batch) + perm = [0, 6, 1, 3, 2, 5, 4, 7] + perm = list(range(b)) + [b + p for p in perm] + data = data.permute(*perm) + data = data.reshape(*batch, M * 4, K // 4) + data = self._maybe_mT(data) + return data[..., :self.K, :self.N] + + def swizzle_block_shape(self, block_shape): + return block_shape + + +@triton.jit +def _unshuffle_triton(x, mma_version: tl.constexpr): + """ + Triton inverse of swizzle_mxfp4_value_hopper + """ + tl.static_assert(mma_version == 2 or mma_version == 3, "mma_version must be 2 or 3") + # if mx_axis == 0: + # x = x.trans() + + # We have two times the elements if we already upcasted to bfloat16 + mult: tl.constexpr = 2 if x.dtype == tl.bfloat16 else 1 + M: tl.constexpr = x.shape[0] + K: tl.constexpr = x.shape[1] + tl.static_assert(M % 4 == 0, "M must be divisible by 4") + tl.static_assert(K % (4 * 8 * 2 * 2 * mult) == 0, f"K must be divisible by {4 * 8 * 2 * 2 * mult}") + + # We are loading 8 bf16 elements per thread to use ld.global.v4 + # Every u8 represents 2 mxfp4 elements + u8_kwidth: tl.constexpr = 8 // 2 if mma_version == 2 else 1 + x = x.reshape(M // 4, 4, K // (4 * 8 * 2 * 2 * mult), 2, 4, 8 // u8_kwidth, 2, u8_kwidth * mult) + x = x.trans(0, 6, 1, 3, 2, 5, 4, 7) + x = x.reshape(M * 4, K // 4) + # if mx_axis == 0: + # x = x.trans() + return x + + +@triton.jit +def _unpack_fp4_to_bf16_triton(x): + # For now we implement just H100 support (mul.bf16x2) + # A100 support is possible via fma + r0, r1 = tl.inline_asm_elementwise( + r""" + { + .reg .b32 b, c, d<7>, scale; + .reg .b32 bias; + mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2) + // We add the missing bias to the scale directly + and.b32 $0, $4, 0b10000001110000001000000111000000; + mul.bf16x2 $0, $0, bias; + shl.b32 b, $4, 3; + and.b32 $1, b, 0b10000001110000001000000111000000; + mul.bf16x2 $1, $1, bias; + shl.b32 c, $4, 6; + and.b32 $2, c, 0b10000001110000001000000111000000; + mul.bf16x2 $2, $2, bias; + // Unpack last two elements + shl.b32 d0, $4, 1; + and.b32 d1, d0, 0b10000000000000001000000000000000; + shr.b32 d2, $4, 3; + and.b32 d3, d2, 0b00000001100000000000000110000000; + or.b32 d4, d1, d3; + shr.b32 d5, $4, 7; + and.b32 d6, d5, 0b00000000010000000000000001000000; + or.b32 $3, d4, d6; + mul.bf16x2 $3, $3, bias; + } + """, + constraints="=r,=r,=r,=r,r", + args=[x], + dtype=(tl.bfloat16, tl.bfloat16), + is_pure=True, + pack=4, + ) + # Concat each pack of 4 + x = tl.join(r0, r1) + x = x.reshape(x.shape[0], x.shape[1] // 4, 4, x.shape[2]) + x = x.trans(0, 1, 3, 2) + x = x.reshape(x.shape[0], x.shape[1] * x.shape[2] * x.shape[3]) + return x + + +@triton.jit +def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr): + """ + Implements the bit-untwiddling of a 32-bit integer (8 mxfp4 elements): + (x << 0) & 0b1000000111000000 + (x << 3) & 0b1000000111000000 + (x << 6) & 0b1000000111000000 + ((x << 1) & 0b1000000000000000) | ((x >> 3) & 0b0000000110000000) | ((x >> 7) & 0b0000000001000000) + """ + # upcast values to bfloat16 + tl.static_assert(len(x.shape) == 2) + tl.static_assert(mx_axis == 0 or mx_axis == 1, "mx_axis must be 0 or 1") + tl.static_assert(x.shape[1] % 4 == 0) + tl.static_assert(x.dtype == tl.uint8) + if mx_axis == 0: + x = x.trans() + x = _unpack_fp4_to_bf16_triton(x) + x = _unshuffle_triton(x, mma_version=3) + if mx_axis == 0: + x = x.trans() + + # upcast scale to bfloat16 + # Add bias missing from the bf16 upcasting sequence + # triton / LLVM generates terrible code for this sequence + # scale = scale.to(tl.uint16) + # scale = scale << 7 + # scale = scale.to(tl.bfloat16, bitcast=True) + scale = tl.inline_asm_elementwise( + r""" + { + prmt.b32 $0, $2, 0, 0x5140; + shl.b32 $0, $0, 7; + prmt.b32 $1, $2, 0, 0x7362; + shl.b32 $1, $1, 7; + } + """, + constraints="=r,=r,r", + args=[scale], + dtype=tl.bfloat16, + is_pure=True, + pack=4, + ) + # Broadcast scale + scale = scale.expand_dims(mx_axis + 1) + scale = scale.broadcast_to(scale.shape[:mx_axis + 1] + [32] + scale.shape[mx_axis + 2:]) + scale = scale.reshape(x.shape) + + # Combine scale and x + x = x * scale + return x diff --git a/torch-ext/triton_kernels/tensor_details/layout_details/strided.py b/torch-ext/triton_kernels/tensor_details/layout_details/strided.py new file mode 100644 index 0000000000000000000000000000000000000000..cbfd9248fca219eb94dae358cafd7fac6e082cd1 --- /dev/null +++ b/torch-ext/triton_kernels/tensor_details/layout_details/strided.py @@ -0,0 +1,17 @@ +from .base import Layout + + +class StridedLayout(Layout): + name: str = None + + def __init__(self, shape) -> None: + super().__init__(shape) + + def swizzle_data(self, data): + return data + + def unswizzle_data(self, data): + return data + + def swizzle_block_shape(self, block_shape): + return block_shape diff --git a/torch-ext/triton_kernels/testing.py b/torch-ext/triton_kernels/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..d905725132def1089dbd2f378019cd54b78c3f3a --- /dev/null +++ b/torch-ext/triton_kernels/testing.py @@ -0,0 +1,192 @@ +import enum +import functools +import os +import subprocess +import sys +import torch +from triton_kernels.numerics import MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 + + +def assert_equal(ref, tri): + if isinstance(ref, torch.Tensor): + assert torch.all(ref == tri) + else: + assert ref == tri + + +def assert_close(ref, tri, maxtol=None, rmstol=None, description="--", verbose=True): + if tri.dtype.itemsize == 1: + ref_as_type = ref.to(tri.dtype) + if ref.dtype == tri.dtype: + assert torch.all(ref_as_type == tri) + return + ref = ref_as_type + + if maxtol is None: + maxtol = 2e-2 + if rmstol is None: + rmstol = 4e-3 + """ + Compare reference values against obtained values. + """ + + # cast to float32: + ref = ref.to(torch.float32).detach() + tri = tri.to(torch.float32).detach() + assert ref.shape == tri.shape, f"Tensors must have same size {ref.shape=} {tri.shape=}" + + # deal with infinite elements: + inf_mask_ref = torch.isinf(ref) + inf_mask_tri = torch.isinf(tri) + assert torch.equal(inf_mask_ref, inf_mask_tri), "Tensor must have same infinite elements" + refn = torch.where(inf_mask_ref, 0, ref) + trin = torch.where(inf_mask_tri, 0, tri) + + # normalise so that RMS calculation doesn't overflow: + eps = 1.0e-30 + multiplier = 1.0 / (torch.max(torch.abs(refn)) + eps) + refn *= multiplier + trin *= multiplier + + ref_rms = torch.sqrt(torch.square(refn).mean()) + eps + + rel_err = torch.abs(refn - trin) / torch.maximum(ref_rms, torch.abs(refn)) + max_err = torch.max(rel_err).item() + rms_err = torch.sqrt(torch.square(rel_err).mean()).item() + + if verbose: + print("%s maximum relative error = %s (threshold = %s)" % (description, max_err, maxtol)) + print("%s RMS relative error = %s (threshold = %s)" % (description, rms_err, rmstol)) + + if max_err > maxtol: + bad_idxs = torch.nonzero(rel_err > maxtol) + num_nonzero = bad_idxs.size(0) + bad_idxs = bad_idxs[:1000] + print("%d / %d mismatched elements (shape = %s) at coords %s" % + (num_nonzero, rel_err.numel(), tuple(rel_err.shape), bad_idxs.tolist())) + + bad_idxs = bad_idxs.unbind(-1) + print("ref values: ", ref[tuple(bad_idxs)].cpu()) + print("tri values: ", tri[tuple(bad_idxs)].cpu()) + + assert max_err <= maxtol + assert rms_err <= rmstol + + +class ComputeSanitizerTool(enum.Enum): + MEMCHECK = "memcheck" + RACECHECK = "racecheck" + SYNCCHECK = "synccheck" + INITCHECK = "initcheck" + + +def compute_sanitizer(**target_kwargs): + """ + Decorator to run a test with compute sanitizer enabled and pytorch caching allocator disabled, + to expose potential memory access errors. + This decorator requires the `request` fixture to be present. + If `run_sanitizer` argument is present and set to False, the sanitizer is not run. + Running tests under compute sanitizer requires launching subprocess and is slow, + so use sparingly + """ + + def decorator(test_fn): + + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + if os.environ.get("SKIP_COMPUTE_SANITIZER") == "1": + test_fn(*args, **kwargs) + return + + import psutil + + if target_kwargs.pop("clear_torch_cache", False): + # If we don't pop clear_torch_cache, it won't pass + # target_kwargs.items() <= kwargs.items() condition below. + torch.cuda.empty_cache() + tools_to_check = target_kwargs.pop("tools_to_check", [ComputeSanitizerTool.MEMCHECK]) + assert isinstance(tools_to_check, list), f"{tools_to_check=}" + assert all(tool in ComputeSanitizerTool for tool in tools_to_check), ( + f"{(tool for tool in tools_to_check if tool not in ComputeSanitizerTool)=}") + + ppid_name = psutil.Process(os.getppid()).exe() + run_compute_sanitizer = target_kwargs.items() <= kwargs.items() + if "run_sanitizer" in kwargs: + run_compute_sanitizer &= kwargs["run_sanitizer"] + if run_compute_sanitizer and "compute-sanitizer" not in ppid_name: + for tool in tools_to_check: + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = { + "PATH": os.environ["PATH"], + "PYTORCH_NO_CUDA_MEMORY_CACHING": "1", + "TORCH_SHOW_CPP_STACKTRACES": "1", + "CUDA_LAUNCH_BLOCKING": "1", + } + if "CUDA_VISIBLE_DEVICES" in os.environ: + env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] + assert "request_fixture" in kwargs, ( + "memcheck'ed test must have a (possibly unused) `request` fixture") + test_id = kwargs["request_fixture"].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + cmd = [ + "compute-sanitizer", + "--target-processes=application-only", + "--destroy-on-device-error=context", + f"--tool={tool.value}", + sys.executable, + "-m", + "pytest", + "-vsx", + cmd, + ] + for opt in ["--update_checksum", "--ignore_checksum_error"]: + if opt in sys.argv: + cmd.append(opt) + out = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + sanitizer_ok = "ERROR SUMMARY: 0 errors" in str( + out.stdout) or "RACECHECK SUMMARY: 0 hazards displayed" in str(out.stdout) + test_output = out.stdout + if type(test_output) is bytes: + test_output = test_output.decode() + + fail = False + if not sanitizer_ok: + print("compute-sanitizer returned an error") + fail = True + elif out.returncode != 0: + print( + "The test failed due to some other reason: consider running without compute-sanitizer to verify." + ) + print(f"{out.returncode=}") + fail = True + + if fail: + print("*****************************************************") + print("******************** TEST OUTPUT ********************") + print("*****************************************************") + print(test_output) + print("*****************************************************") + print("****************** TEST OUTPUT END ******************") + print("*****************************************************") + assert None + else: + test_fn(*args, **kwargs) + + return wrapper + + return decorator + + +def compute_actual_scale(x, dtype): + max_finite = { + torch.float8_e5m2: MAX_FINITE_FLOAT8E5, + torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, + torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, + }[dtype] + return x.abs().max() / max_finite diff --git a/torch-ext/triton_kernels/topk.py b/torch-ext/triton_kernels/topk.py new file mode 100644 index 0000000000000000000000000000000000000000..b82a84464ff0f856faca36c83dcaa7c4541c6faa --- /dev/null +++ b/torch-ext/triton_kernels/topk.py @@ -0,0 +1,92 @@ +import torch +import triton +from triton_kernels.topk_details._topk_forward import _topk_forward +from triton_kernels.topk_details._topk_backward import _topk_backward +from triton_kernels.tensor import Tensor, Bitmatrix + + +def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): + if not isinstance(x, Tensor): + x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]] + x_shape_max = [x.shape[0], x.shape[1]] + x = Tensor(x, shape=x_shape, shape_max=x_shape_max) + cdiv = lambda a, b: (a + b - 1) // b + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_S = 128 + assert len(x.shape) == 2 + assert x.shape_max[-1] < 32768 + assert dim == 1 + assert return_bitmatrix + n_rows, n_cols = x.shape + n_rows_max, _ = x.shape_max + dev = x.device + # scratchpad tensors + # NOTE: these are not returned + y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev) + if y_indx is not None: + use_provided_indx = True + else: + y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev) + use_provided_indx = False + # create bitmatrix in transposed memory layout: + n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N + n_cols_words = n_cols_pad // 32 + bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev) + bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max] + s_blocks = cdiv(n_cols, BLOCK_S) + s_cols = s_blocks * BLOCK_S + scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev) + pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks) + _topk_forward[(pids, )]( + x, x.stride(0), # inputs + y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk] + bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix] + n_rows, n_cols, # shapes + scratchpad, BLOCK_S, s_blocks, # thing to memset to zero + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter + APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants + ) + bitmatrix_shape = [n_rows, n_cols_words * 32] + bitmatrix_shape_max = [n_rows_max, None] + bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad) + return y_vals, y_indx, bitmatrix + + +def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax): + assert dy_vals.shape[-1] == k + n_expts_pad = triton.next_power_of_2(x.shape[-1]) + dx = torch.empty_like(x) + _topk_backward[(dy_vals.shape[0], )]( + y_indx, y_indx.stride(0), dy_vals, dy_vals.stride(0), x, x.stride(0), # inputs + dx, # outputs + dx.stride(0), x.shape[0], n_rows, x.shape[-1], APPLY_SOFTMAX=apply_softmax, N_EXPTS_ACT=k, + N_EXPTS_PAD=n_expts_pad) + return dx + + +class TopK(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows): + y_vals, y_indx, bitmatrix = topk_forward(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows) + ctx.save_for_backward(x, y_indx) + ctx.apply_softmax = apply_softmax + ctx.k = k + ctx.n_rows = n_rows + return y_vals, y_indx, bitmatrix + + @staticmethod + def backward(ctx, dy_vals, _0, _1): + x, y_indx = ctx.saved_tensors + dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax) + return dx, None, None, None, None, None, None + + +def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): + ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows) + return ret + + +# x = torch.randn((32, 32), dtype=torch.float16, device="cuda") +# print(topk(x, 4)) diff --git a/torch-ext/triton_kernels/topk_details/__init__.py b/torch-ext/triton_kernels/topk_details/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/torch-ext/triton_kernels/topk_details/__pycache__/__init__.cpython-310.pyc b/torch-ext/triton_kernels/topk_details/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..514c8da8ade15cd0fe236f884b094858c068c3e6 Binary files /dev/null and b/torch-ext/triton_kernels/topk_details/__pycache__/__init__.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/topk_details/__pycache__/_topk_backward.cpython-310.pyc b/torch-ext/triton_kernels/topk_details/__pycache__/_topk_backward.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fca762b1122183dc31d55cba6b3c08b1d9fde30f Binary files /dev/null and b/torch-ext/triton_kernels/topk_details/__pycache__/_topk_backward.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/topk_details/__pycache__/_topk_forward.cpython-310.pyc b/torch-ext/triton_kernels/topk_details/__pycache__/_topk_forward.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..711f6ef52c7ff3d34ea1c39e87b8c5b55ccc7b8c Binary files /dev/null and b/torch-ext/triton_kernels/topk_details/__pycache__/_topk_forward.cpython-310.pyc differ diff --git a/torch-ext/triton_kernels/topk_details/_topk_backward.py b/torch-ext/triton_kernels/topk_details/_topk_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..eebe481771543a05cfab5741bf1a0c875248f70d --- /dev/null +++ b/torch-ext/triton_kernels/topk_details/_topk_backward.py @@ -0,0 +1,51 @@ +import triton +import triton.language as tl + + +@triton.jit +def _topk_backward( + Yi, + stride_ym, # topk indices + DY, + stride_dym, # output gradient values + X, + stride_xm, # input values + DX, + stride_dxm, # input gradient values + n_rows, + NRows, + n_expts_tot, + APPLY_SOFTMAX: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + N_EXPTS_PAD: tl.constexpr, +): + pid_m = tl.program_id(0) + if NRows is not None: + n_rows = tl.load(NRows) + if pid_m >= n_rows: + return + Yi += pid_m * stride_ym + DY += pid_m * stride_dym + X += pid_m * stride_xm + DX += pid_m * stride_dxm + # -- + offs_xn = tl.arange(0, N_EXPTS_PAD) + offs_yn = tl.arange(0, N_EXPTS_ACT) + mask_xn = offs_xn < n_expts_tot + # recompute softmax + y_indx = tl.load(Yi + offs_yn) + x = tl.load(X + y_indx) + x = x.to(tl.float32) + y = tl.softmax(x) + # compute input-gradient + dy = tl.load(DY + offs_yn) + dy = dy.to(tl.float32) + s = tl.sum(y * dy, 0) + # write-back input gradient + tl.store(DX + offs_xn, 0, mask=mask_xn) + tl.debug_barrier() + if APPLY_SOFTMAX: + dx = y * (dy - s) + else: + dx = dy + tl.store(DX + y_indx, dx) diff --git a/torch-ext/triton_kernels/topk_details/_topk_forward.py b/torch-ext/triton_kernels/topk_details/_topk_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..26c7a9ccea3356ad30957a5ab543cabac2e3fa23 --- /dev/null +++ b/torch-ext/triton_kernels/topk_details/_topk_forward.py @@ -0,0 +1,146 @@ +import triton +import triton.language as tl + + +@triton.jit +def get_topmask_and_fullmask(x): + tl.static_assert(x.dtype.is_int_unsigned(), "floating-point value must be passed as bits") + tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) + fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 + tm_arr = tl.full(x.shape, tm, dtype=x.dtype) + fm_arr = tl.full(x.shape, fm, dtype=x.dtype) + return tm_arr, fm_arr + + +@triton.jit +def fpval_to_key(x): + tm, fm = get_topmask_and_fullmask(x) + return x ^ tl.where((x & tm) != 0, fm, tm) + + +@triton.jit +def key_to_fpval(x): + tm, fm = get_topmask_and_fullmask(x) + return x ^ tl.where((x & tm) == 0, fm, tm) + + +# stable top-k tie-breaks to value with smaller index +@triton.jit +def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr): + return N_EXPTS_PAD - indx + + +@triton.jit +def key_to_indx(indx, N_EXPTS_PAD: tl.constexpr): + return N_EXPTS_PAD - indx + + +@triton.jit +def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, + BLOCK_N: tl.constexpr): + x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth + x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}") + if x_nbits < 16: + # this ensures that we leave at least 16 bits for expert index + # even if the input dtype is smaller than 16 bits: + y_nbits: tl.constexpr = 32 + else: + y_nbits: tl.constexpr = x_nbits * 2 + x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}") + x_dtype: tl.constexpr = X.dtype.element_ty + + # subtract 1 from loop iterations because we peel the first (masked) iteration: + loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1 + offs_x_n = loop_iterations * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_x_n[None, :] < n_expts_tot + + # first iteration: + X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :] + x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf")) + x = fpval_to_key(x.to(x_utype, bitcast=True)) + x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] + acc = tl.topk(x, N_EXPTS_ACT, dim=1) + + # subsequent iterations: + for _i in (tl.static_range if loop_iterations <= 4 else range)(loop_iterations): + acc = tl.bitonic_merge(acc) # ensure sorted ascending for the merge + X_ptrs -= BLOCK_N + offs_x_n -= BLOCK_N + x = tl.load(X_ptrs, mask=mask_m, other=float("-inf")) + x = fpval_to_key(x.to(x_utype, bitcast=True)) + x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] + acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1)) + + # rotate expert index into upper 16 bits: + # 0000vvvvvvvviiii --> iiii0000vvvvvvvv + acc = (acc << (y_nbits - 16)) | (acc >> 16) + # sort in ascending order of expert (descending order of key) + acc = tl.sort(acc, dim=1, descending=True) + # iiii0000vvvvvvvv --> 0000iiii: + y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32) + y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD) + # iiii0000vvvvvvvv --> vvvvvvvv: + y_values_raw = acc.to(x_utype) + y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) + + return y_values, y_indices + + +@triton.jit +def _topk_forward(X, stride_xm, # inputs + Yv, Yi, stride_ym, # topk values/indices + USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, # bitmatrix + n_rows, n_expts_tot, # shape + S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset + APPLY_SOFTMAX: tl.constexpr, # constant + BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr): + + pid = tl.program_id(0) + if isinstance(n_rows, tl.tensor) and n_rows.dtype.is_ptr(): + n_rows = tl.load(n_rows) + + if pid < s_blocks: + tl.store(S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)) + + if pid * BLOCK_M >= n_rows: + # early exit: + return + + tl.static_assert(BLOCK_N % 32 == 0) + tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0) + x_dtype: tl.constexpr = X.dtype.element_ty + + # load logits + offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) + offs_y_n = tl.arange(0, N_EXPTS_ACT) + mask_m = offs_m[:, None] < n_rows + if USE_PROVIDED_INDX: + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :] + y_indices = tl.load(Yi_ptrs, mask=mask_m) + Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices + y_values = tl.load(Xv_ptrs, mask=mask_m) + else: + y_values, y_indices = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, # + N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N) + + # normalize selected values + if APPLY_SOFTMAX: + y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype) + + # write back + Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :] + tl.store(Yv_ptrs, y_values, mask=mask_m) + if not USE_PROVIDED_INDX: + Yi_ptrs = Yi + offs_m[:, None] * stride_ym + offs_y_n[None, :] + tl.store(Yi_ptrs, y_indices, mask=mask_m) + + # pack into bitmatrix + y_div = y_indices // 32 + y_rem = y_indices % 32 + loop_iterations = N_EXPTS_PAD // BLOCK_N + for i in range(loop_iterations): + offs_r_n = tl.arange(0, BLOCK_N // 32) + i * (BLOCK_N // 32) + y2 = tl.where(y_div[:, :, None] == offs_r_n[None, None, :], (1 << y_rem)[:, :, None], 0) + r = tl.reduce_or(y2, axis=1) + BitsPtrs = Bits + offs_m[:, None] * stride_rm + offs_r_n[None, :] * stride_rn + tl.store(BitsPtrs, r, mask=mask_m)