|
"""Fused MoE utilities for GPTQ.""" |
|
|
|
import functools |
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
|
|
from .fused_moe import fused_topk, moe_align_block_size, try_get_optimal_moe_config |
|
from .scalar_type import scalar_types |
|
import moe._custom_ops as ops |
|
|
|
|
|
def get_scalar_type(num_bits: int, has_zp: bool): |
|
if has_zp: |
|
assert num_bits == 4 |
|
return scalar_types.uint4 |
|
else: |
|
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 |
|
|
|
|
|
def single_marlin_moe( |
|
hidden_states: torch.Tensor, |
|
w: torch.Tensor, |
|
scales: torch.Tensor, |
|
gating_output: torch.Tensor, |
|
topk: int, |
|
renormalize: bool, |
|
g_idx: Optional[torch.Tensor] = None, |
|
sort_indices: Optional[torch.Tensor] = None, |
|
w_zeros: Optional[torch.Tensor] = None, |
|
override_config: Optional[Dict[str, Any]] = None, |
|
num_bits: int = 8, |
|
is_k_full: bool = True, |
|
) -> torch.Tensor: |
|
""" |
|
This function computes the multiplication of hidden_states with expert |
|
weights used in Marlin MoE, using weights w and top-k gating mechanism. |
|
Its purpose is testing and debugging the fused MoE kernel. |
|
|
|
Parameters: |
|
- hidden_states (torch.Tensor): The input tensor to the Marlin Mul. |
|
- w (torch.Tensor): The set of expert weights. |
|
- scales (torch.Tensor): The quantization scales. |
|
- gating_output (torch.Tensor): The output of the gating operation |
|
(before softmax). |
|
- g_idx (Optional[torch.Tensor]): Optional act_order indices. |
|
- sort_indices (Optional[torch.Tensor]): Optional act_order input |
|
permutation. |
|
- topk (int): The number of top-k experts to select. |
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. |
|
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. |
|
- override_config (Optional[Dict[str, Any]]): Optional override |
|
for the kernel configuration. |
|
- num_bits (bool): The number of bits in expert weights quantization. |
|
|
|
Returns: |
|
- torch.Tensor: The output tensor after applying the MoE layer. |
|
""" |
|
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" |
|
assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" |
|
assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" |
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" |
|
assert w.is_contiguous(), "Expert weights must be contiguous" |
|
assert hidden_states.dtype == torch.float16 |
|
assert num_bits in [4, 8] |
|
|
|
M, K = hidden_states.shape |
|
E = w.shape[0] |
|
N = w.shape[2] // (num_bits // 2) |
|
|
|
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) |
|
|
|
|
|
get_config_func = functools.partial( |
|
try_get_optimal_moe_config, |
|
w.shape, |
|
w.shape, |
|
topk_ids.shape[1], |
|
None, |
|
override_config=override_config, |
|
is_marlin=True, |
|
) |
|
config = get_config_func(M) |
|
|
|
block_size_m = config["BLOCK_SIZE_M"] |
|
|
|
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) |
|
|
|
max_workspace_size = (N // 64) * 16 |
|
workspace = torch.zeros( |
|
max_workspace_size, |
|
dtype=torch.int, |
|
device=hidden_states.device, |
|
requires_grad=False, |
|
) |
|
|
|
has_zero_point = w_zeros is not None |
|
if w_zeros is None: |
|
w_zeros = torch.empty( |
|
(0, 0), |
|
dtype=hidden_states.dtype, |
|
device=hidden_states.device, |
|
requires_grad=False, |
|
) |
|
|
|
if g_idx is None: |
|
g_idx = torch.empty( |
|
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False |
|
) |
|
|
|
if sort_indices is None: |
|
sort_indices = torch.empty( |
|
(0), dtype=torch.int32, device=hidden_states.device, requires_grad=False |
|
) |
|
|
|
scalar_type = get_scalar_type(num_bits, has_zero_point) |
|
|
|
intermediate_cache = ops.ops.marlin_gemm_moe( |
|
hidden_states, |
|
w, |
|
sorted_token_ids, |
|
topk_weights, |
|
topk_ids, |
|
scales, |
|
w_zeros, |
|
g_idx, |
|
sort_indices, |
|
workspace, |
|
scalar_type.id, |
|
M, |
|
N, |
|
K, |
|
is_k_full, |
|
E, |
|
topk, |
|
block_size_m, |
|
True, |
|
False, |
|
) |
|
|
|
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) |
|
|
|
|
|
def fused_marlin_moe( |
|
hidden_states: torch.Tensor, |
|
w1: torch.Tensor, |
|
w2: torch.Tensor, |
|
w1_scale: torch.Tensor, |
|
w2_scale: torch.Tensor, |
|
gating_output: torch.Tensor, |
|
topk_weights: torch.Tensor, |
|
topk_ids: torch.Tensor, |
|
g_idx1: Optional[torch.Tensor] = None, |
|
g_idx2: Optional[torch.Tensor] = None, |
|
sort_indices1: Optional[torch.Tensor] = None, |
|
sort_indices2: Optional[torch.Tensor] = None, |
|
w1_zeros: Optional[torch.Tensor] = None, |
|
w2_zeros: Optional[torch.Tensor] = None, |
|
override_config: Optional[Dict[str, Any]] = None, |
|
num_bits: int = 8, |
|
is_k_full: bool = True, |
|
) -> torch.Tensor: |
|
""" |
|
This function computes a Mixture of Experts (MoE) layer using two sets of |
|
weights, w1 and w2, and top-k gating mechanism. |
|
|
|
Parameters: |
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer. |
|
- w1 (torch.Tensor): The first set of expert weights. |
|
- w2 (torch.Tensor): The second set of expert weights. |
|
- w1_scale (torch.Tensor): Scale to be used for w1. |
|
- w2_scale (torch.Tensor): Scale to be used for w2. |
|
- gating_output (torch.Tensor): The output of the gating operation |
|
(before softmax). |
|
- g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. |
|
- g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. |
|
- sort_indices1 (Optional[torch.Tensor]): The first act_order input |
|
permutation. |
|
- sort_indices2 (Optional[torch.Tensor]): The second act_order input |
|
permutation. |
|
- topk_weights (torch.Tensor): Top-k weights. |
|
- topk_ids (torch.Tensor): Indices of topk-k elements. |
|
- override_config (Optional[Dict[str, Any]]): Optional override |
|
for the kernel configuration. |
|
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. |
|
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. |
|
- num_bits (bool): The number of bits in expert weights quantization. |
|
|
|
Returns: |
|
- torch.Tensor: The output tensor after applying the MoE layer. |
|
""" |
|
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" |
|
assert hidden_states.shape[1] == w1.shape[1] * 16, "Hidden size mismatch w1" |
|
assert hidden_states.shape[1] == w2.shape[2] // ( |
|
num_bits // 2 |
|
), "Hidden size mismatch w2" |
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" |
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" |
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous" |
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous" |
|
assert hidden_states.dtype == torch.float16 |
|
assert num_bits in [4, 8] |
|
|
|
has_no_act_order = ( |
|
g_idx1 is None |
|
and g_idx2 is None |
|
and sort_indices1 is None |
|
and sort_indices2 is None |
|
) |
|
has_all_act_order = ( |
|
g_idx1 is not None |
|
and g_idx2 is not None |
|
and sort_indices1 is not None |
|
and sort_indices2 is not None |
|
) |
|
assert has_no_act_order or has_all_act_order, ( |
|
"g_idx and sorted_indices " "must be all not None or must be all None" |
|
) |
|
|
|
has_no_zp = w1_zeros is None and w2_zeros is None |
|
has_all_zp = w1_zeros is not None and w2_zeros is not None |
|
assert has_no_zp or has_all_zp, ( |
|
"zero points must be both not None or " "must be both None" |
|
) |
|
|
|
M, K = hidden_states.shape |
|
E = w1.shape[0] |
|
N = w2.shape[1] * 16 |
|
topk = topk_ids.shape[1] |
|
|
|
get_config_func = functools.partial( |
|
try_get_optimal_moe_config, |
|
w1.shape, |
|
w2.shape, |
|
topk_ids.shape[1], |
|
None, |
|
override_config=override_config, |
|
is_marlin=True, |
|
) |
|
config = get_config_func(M) |
|
|
|
block_size_m = config["BLOCK_SIZE_M"] |
|
|
|
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) |
|
|
|
max_workspace_size = (max(2 * N, K) // 64) * 16 |
|
workspace = torch.zeros( |
|
max_workspace_size, dtype=torch.int, device="cuda", requires_grad=False |
|
) |
|
|
|
if has_no_zp: |
|
w1_zeros = torch.empty( |
|
(0, 0), |
|
dtype=hidden_states.dtype, |
|
device=hidden_states.device, |
|
requires_grad=False, |
|
) |
|
w2_zeros = torch.empty( |
|
(0, 0), |
|
dtype=hidden_states.dtype, |
|
device=hidden_states.device, |
|
requires_grad=False, |
|
) |
|
|
|
if has_no_act_order: |
|
g_idx1 = torch.empty( |
|
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False |
|
) |
|
g_idx2 = torch.empty( |
|
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False |
|
) |
|
sort_indices1 = torch.empty( |
|
(0), dtype=torch.int32, device=hidden_states.device, requires_grad=False |
|
) |
|
sort_indices2 = torch.empty( |
|
(0, 0), dtype=torch.int32, device=hidden_states.device, requires_grad=False |
|
) |
|
|
|
scalar_type1 = get_scalar_type(num_bits, has_all_zp) |
|
scalar_type2 = get_scalar_type(num_bits, has_all_zp) |
|
|
|
intermediate_cache2 = torch.empty( |
|
(M * topk_ids.shape[1], N), |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
|
|
intermediate_cache1 = ops.ops.marlin_gemm_moe( |
|
hidden_states, |
|
w1, |
|
sorted_token_ids, |
|
topk_weights, |
|
topk_ids, |
|
w1_scale, |
|
w1_zeros, |
|
g_idx1, |
|
sort_indices1, |
|
workspace, |
|
scalar_type1.id, |
|
M, |
|
2 * N, |
|
K, |
|
is_k_full, |
|
E, |
|
topk, |
|
block_size_m, |
|
True, |
|
False, |
|
) |
|
|
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) |
|
|
|
intermediate_cache3 = ops.ops.marlin_gemm_moe( |
|
intermediate_cache2, |
|
w2, |
|
sorted_token_ids, |
|
topk_weights, |
|
topk_ids, |
|
w2_scale, |
|
w2_zeros, |
|
g_idx2, |
|
sort_indices2, |
|
workspace, |
|
scalar_type2.id, |
|
M, |
|
K, |
|
N, |
|
is_k_full, |
|
E, |
|
topk, |
|
block_size_m, |
|
False, |
|
True, |
|
) |
|
|
|
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) |
|
|