kernel
moe / ext-torch /fused_marlin_moe.py
danieldk's picture
danieldk HF staff
Add MoE kernels from vLLM
29e93ec
raw
history blame
10.7 kB
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Any, Dict, Optional
import torch
from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config
from .scalar_type import scalar_types
import moe._custom_ops as ops
def get_scalar_type(num_bits: int, has_zp: bool):
if has_zp:
assert num_bits == 4
return scalar_types.uint4
else:
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes the multiplication of hidden_states with expert
weights used in Marlin MoE, using weights w and top-k gating mechanism.
Its purpose is testing and debugging the fused MoE kernel.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul.
- w (torch.Tensor): The set of expert weights.
- scales (torch.Tensor): The quantization scales.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx (Optional[torch.Tensor]): Optional act_order indices.
- sort_indices (Optional[torch.Tensor]): Optional act_order input
permutation.
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch"
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w.is_contiguous(), "Expert weights must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
M, K = hidden_states.shape
E = w.shape[0]
N = w.shape[2] // (num_bits // 2)
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize)
# This might not be an optimal config for a single MMM
get_config_func = functools.partial(
try_get_optimal_moe_config,
w.shape,
w.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (N // 64) * 16
workspace = torch.zeros(
max_workspace_size,
dtype=torch.int,
device=hidden_states.device,
requires_grad=False,
)
has_zero_point = w_zeros is not None
if w_zeros is None:
w_zeros = torch.empty(
(0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False,
)
if g_idx is None:
g_idx = torch.empty(
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
)
if sort_indices is None:
sort_indices = torch.empty(
(0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
)
scalar_type = get_scalar_type(num_bits, has_zero_point)
intermediate_cache = ops.ops.marlin_gemm_moe(
hidden_states,
w,
sorted_token_ids,
topk_weights,
topk_ids,
scales,
w_zeros,
g_idx,
sort_indices,
workspace,
scalar_type.id,
M,
N,
K,
is_k_full,
E,
topk,
block_size_m,
True,
False,
)
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- w1_scale (torch.Tensor): Scale to be used for w1.
- w2_scale (torch.Tensor): Scale to be used for w2.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices.
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices.
- sort_indices1 (Optional[torch.Tensor]): The first act_order input
permutation.
- sort_indices2 (Optional[torch.Tensor]): The second act_order input
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[1] == w2.shape[2] // (
num_bits // 2
), "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype == torch.float16
assert num_bits in [4, 8]
has_no_act_order = (
g_idx1 is None
and g_idx2 is None
and sort_indices1 is None
and sort_indices2 is None
)
has_all_act_order = (
g_idx1 is not None
and g_idx2 is not None
and sort_indices1 is not None
and sort_indices2 is not None
)
assert has_no_act_order or has_all_act_order, (
"g_idx and sorted_indices " "must be all not None or must be all None"
)
has_no_zp = w1_zeros is None and w2_zeros is None
has_all_zp = w1_zeros is not None and w2_zeros is not None
assert has_no_zp or has_all_zp, (
"zero points must be both not None or " "must be both None"
)
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
topk = topk_ids.shape[1]
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
block_size_m = config["BLOCK_SIZE_M"]
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(
max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False
)
if has_no_zp:
w1_zeros = torch.empty(
(0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False,
)
w2_zeros = torch.empty(
(0, 0),
dtype=hidden_states.dtype,
device=hidden_states.device,
requires_grad=False,
)
if has_no_act_order:
g_idx1 = torch.empty(
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
)
g_idx2 = torch.empty(
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
)
sort_indices1 = torch.empty(
(0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
)
sort_indices2 = torch.empty(
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False
)
scalar_type1 = get_scalar_type(num_bits, has_all_zp)
scalar_type2 = get_scalar_type(num_bits, has_all_zp)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = ops.ops.marlin_gemm_moe(
hidden_states,
w1,
sorted_token_ids,
topk_weights,
topk_ids,
w1_scale,
w1_zeros,
g_idx1,
sort_indices1,
workspace,
scalar_type1.id,
M,
2 * N,
K,
is_k_full,
E,
topk,
block_size_m,
True,
False,
)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = ops.ops.marlin_gemm_moe(
intermediate_cache2,
w2,
sorted_token_ids,
topk_weights,
topk_ids,
w2_scale,
w2_zeros,
g_idx2,
sort_indices2,
workspace,
scalar_type2.id,
M,
K,
N,
is_k_full,
E,
topk,
block_size_m,
False,
True,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)