|
"""Fused MoE kernel.""" |
|
|
|
import functools |
|
import json |
|
import os |
|
from typing import Any, Callable, Dict, Optional, Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from .platforms import current_platform |
|
from .fp8 import scaled_fp8_quant |
|
import moe._custom_ops as ops |
|
|
|
VLLM_FUSED_MOE_CHUNK_SIZE = int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")) |
|
|
|
|
|
@triton.jit |
|
def fused_moe_kernel( |
|
|
|
a_ptr, |
|
b_ptr, |
|
c_ptr, |
|
a_scale_ptr, |
|
b_scale_ptr, |
|
topk_weights_ptr, |
|
sorted_token_ids_ptr, |
|
expert_ids_ptr, |
|
num_tokens_post_padded_ptr, |
|
|
|
N, |
|
K, |
|
EM, |
|
num_valid_tokens, |
|
|
|
|
|
|
|
|
|
stride_am, |
|
stride_ak, |
|
stride_be, |
|
stride_bk, |
|
stride_bn, |
|
stride_cm, |
|
stride_cn, |
|
stride_bse, |
|
stride_bsn, |
|
|
|
BLOCK_SIZE_M: tl.constexpr, |
|
BLOCK_SIZE_N: tl.constexpr, |
|
BLOCK_SIZE_K: tl.constexpr, |
|
GROUP_SIZE_M: tl.constexpr, |
|
MUL_ROUTED_WEIGHT: tl.constexpr, |
|
top_k: tl.constexpr, |
|
compute_type: tl.constexpr, |
|
use_fp8_w8a8: tl.constexpr, |
|
use_int8_w8a16: tl.constexpr, |
|
): |
|
""" |
|
Implements the fused computation for a Mixture of Experts (MOE) using |
|
token and expert matrices. |
|
|
|
Key Parameters: |
|
- A: The input tensor representing tokens with shape (*, K), where '*' can |
|
be any shape representing batches and K is the feature dimension of |
|
each token. |
|
- B: The stacked MOE weight tensor with shape (E, N, K), where E is |
|
the number of experts, K is the input feature dimension, and N is |
|
the output feature dimension. |
|
- C: The output cache tensor with shape (M, topk, N), where M is the |
|
total number of tokens post padding, topk is the number of times |
|
each token is repeated, and N is the output feature dimension. |
|
- sorted_token_ids: A tensor containing the sorted indices of tokens, |
|
repeated topk times and arranged by the expert index they are |
|
assigned to. |
|
- expert_ids: A tensor containing the indices of the expert for each |
|
block. It determines which expert matrix from B should be used for |
|
each block in A. |
|
This kernel performs the multiplication of a token by its corresponding |
|
expert matrix as determined by `expert_ids`. The sorting of |
|
`sorted_token_ids` by expert index and padding ensures divisibility by |
|
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix |
|
multiplication across different blocks processed by the same expert. |
|
""" |
|
|
|
|
|
|
|
pid = tl.program_id(axis=0) |
|
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) |
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n |
|
group_id = pid // num_pid_in_group |
|
first_pid_m = group_id * GROUP_SIZE_M |
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
|
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) |
|
pid_n = (pid % num_pid_in_group) // group_size_m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) |
|
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: |
|
return |
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
|
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) |
|
token_mask = offs_token < num_valid_tokens |
|
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N |
|
offs_k = tl.arange(0, BLOCK_SIZE_K) |
|
a_ptrs = a_ptr + ( |
|
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak |
|
) |
|
|
|
off_experts = tl.load(expert_ids_ptr + pid_m) |
|
b_ptrs = ( |
|
b_ptr |
|
+ off_experts * stride_be |
|
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
|
) |
|
if use_int8_w8a16: |
|
b_scale_ptrs = ( |
|
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn |
|
) |
|
b_scale = tl.load(b_scale_ptrs) |
|
|
|
if use_fp8_w8a8: |
|
a_scale = tl.load(a_scale_ptr) |
|
b_scale = tl.load(b_scale_ptr + off_experts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
|
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): |
|
|
|
|
|
a = tl.load( |
|
a_ptrs, |
|
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), |
|
other=0.0, |
|
) |
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) |
|
|
|
if use_int8_w8a16: |
|
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) |
|
elif use_fp8_w8a8: |
|
accumulator = tl.dot(a, b, acc=accumulator) |
|
else: |
|
accumulator += tl.dot(a, b) |
|
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak |
|
b_ptrs += BLOCK_SIZE_K * stride_bk |
|
|
|
if MUL_ROUTED_WEIGHT: |
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) |
|
accumulator = accumulator * moe_weight[:, None] |
|
if use_int8_w8a16: |
|
accumulator = (accumulator * b_scale).to(compute_type) |
|
elif use_fp8_w8a8: |
|
accumulator = (accumulator * a_scale * b_scale).to(compute_type) |
|
else: |
|
accumulator = accumulator.to(compute_type) |
|
|
|
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
|
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] |
|
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) |
|
tl.store(c_ptrs, accumulator, mask=c_mask) |
|
|
|
|
|
def moe_align_block_size( |
|
topk_ids: torch.Tensor, block_size: int, num_experts: int |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
""" |
|
Aligns the token distribution across experts to be compatible with block |
|
size for matrix multiplication. |
|
|
|
Parameters: |
|
- topk_ids: A tensor of shape [total_tokens, top_k] representing the |
|
top-k expert indices for each token. |
|
- block_size: The block size used in block matrix multiplication. |
|
- num_experts: The total number of experts. |
|
|
|
Returns: |
|
- sorted_token_ids: A tensor containing the sorted token indices according |
|
to their allocated expert. |
|
- expert_ids: A tensor indicating the assigned expert index for each block. |
|
- num_tokens_post_padded: The total number of tokens after padding, |
|
ensuring divisibility by block_size. |
|
|
|
This function pads the number of tokens that each expert needs to process |
|
so that it is divisible by block_size. |
|
Padding ensures that during block matrix multiplication, the dimensions |
|
align correctly. |
|
|
|
Example: |
|
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], |
|
block_size = 4, and num_experts = 4: |
|
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, |
|
with each expert needing to process 3 tokens. |
|
- As block_size is 4, we pad 1 token for each expert. |
|
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. |
|
- Then append padding tokens [12, 12, 12, 12] for each block. |
|
- After sorting by expert index, we obtain token_ids |
|
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. |
|
Tokens 12 are non-existent (padding) and are ignored in |
|
the subsequent matrix multiplication. |
|
- The padding ensures that the total number of tokens is now divisible |
|
by block_size for proper block matrix operations. |
|
""" |
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) |
|
sorted_ids = torch.empty( |
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device |
|
) |
|
sorted_ids.fill_(topk_ids.numel()) |
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) |
|
expert_ids = torch.empty( |
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device |
|
) |
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) |
|
ops.moe_align_block_size( |
|
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad |
|
) |
|
return sorted_ids, expert_ids, num_tokens_post_pad |
|
|
|
|
|
def invoke_fused_moe_kernel( |
|
A: torch.Tensor, |
|
B: torch.Tensor, |
|
C: torch.Tensor, |
|
A_scale: Optional[torch.Tensor], |
|
B_scale: Optional[torch.Tensor], |
|
topk_weights: torch.Tensor, |
|
topk_ids: torch.Tensor, |
|
sorted_token_ids: torch.Tensor, |
|
expert_ids: torch.Tensor, |
|
num_tokens_post_padded: torch.Tensor, |
|
mul_routed_weight: bool, |
|
top_k: int, |
|
config: Dict[str, Any], |
|
compute_type: tl.dtype, |
|
use_fp8_w8a8: bool, |
|
use_int8_w8a16: bool, |
|
) -> None: |
|
assert topk_weights.stride(1) == 1 |
|
assert sorted_token_ids.stride(0) == 1 |
|
|
|
if use_fp8_w8a8: |
|
A, A_scale = scaled_fp8_quant(A, A_scale) |
|
assert B_scale is not None |
|
elif use_int8_w8a16: |
|
assert B_scale is not None |
|
else: |
|
assert A_scale is None |
|
assert B_scale is None |
|
|
|
grid = lambda META: ( |
|
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) |
|
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), |
|
) |
|
|
|
fused_moe_kernel[grid]( |
|
A, |
|
B, |
|
C, |
|
A_scale, |
|
B_scale, |
|
topk_weights, |
|
sorted_token_ids, |
|
expert_ids, |
|
num_tokens_post_padded, |
|
B.shape[1], |
|
B.shape[2], |
|
sorted_token_ids.shape[0], |
|
topk_ids.numel(), |
|
A.stride(0), |
|
A.stride(1), |
|
B.stride(0), |
|
B.stride(2), |
|
B.stride(1), |
|
C.stride(1), |
|
C.stride(2), |
|
B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, |
|
B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, |
|
MUL_ROUTED_WEIGHT=mul_routed_weight, |
|
top_k=top_k, |
|
compute_type=compute_type, |
|
use_fp8_w8a8=use_fp8_w8a8, |
|
use_int8_w8a16=use_int8_w8a16, |
|
**config, |
|
) |
|
|
|
|
|
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: |
|
device_name = current_platform.get_device_name().replace(" ", "_") |
|
dtype_selector = "" if not dtype else f",dtype={dtype}" |
|
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" |
|
|
|
|
|
@functools.lru_cache |
|
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: |
|
""" |
|
Return optimized configurations for the fused MoE kernel. |
|
|
|
The return value will be a dictionary that maps an irregular grid of |
|
batch sizes to configurations of the fused_moe kernel. To evaluate the |
|
kernel on a given batch size bs, the closest batch size in the grid should |
|
be picked and the associated configuration chosen to invoke the kernel. |
|
""" |
|
|
|
|
|
|
|
json_file_name = get_config_file_name(E, N, dtype) |
|
|
|
config_file_path = os.path.join( |
|
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name |
|
) |
|
if os.path.exists(config_file_path): |
|
with open(config_file_path) as f: |
|
|
|
return {int(key): val for key, val in json.load(f).items()} |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
def get_default_config( |
|
M: int, |
|
E: int, |
|
N: int, |
|
K: int, |
|
topk: int, |
|
dtype: Optional[str], |
|
is_marlin: bool, |
|
) -> Dict[str, int]: |
|
config = { |
|
"BLOCK_SIZE_M": 64, |
|
"BLOCK_SIZE_N": 64, |
|
"BLOCK_SIZE_K": 32, |
|
"GROUP_SIZE_M": 8, |
|
} |
|
|
|
if M <= E or (is_marlin and M <= 32): |
|
config = { |
|
"BLOCK_SIZE_M": 16, |
|
"BLOCK_SIZE_N": 32, |
|
"BLOCK_SIZE_K": 64, |
|
"GROUP_SIZE_M": 1, |
|
} |
|
return config |
|
|
|
|
|
def try_get_optimal_moe_config( |
|
w1_shape: Tuple[int, ...], |
|
w2_shape: Tuple[int, ...], |
|
top_k: int, |
|
dtype: Optional[str], |
|
M: int, |
|
override_config: Optional[Dict[str, Any]] = None, |
|
is_marlin: bool = False, |
|
): |
|
if override_config: |
|
config = override_config |
|
else: |
|
|
|
E, _, N = w2_shape |
|
configs = get_moe_configs(E, N, dtype) |
|
|
|
if configs: |
|
|
|
|
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] |
|
else: |
|
|
|
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) |
|
return config |
|
|
|
|
|
def fused_topk( |
|
hidden_states: torch.Tensor, |
|
gating_output: torch.Tensor, |
|
topk: int, |
|
renormalize: bool, |
|
): |
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" |
|
|
|
M, _ = hidden_states.shape |
|
|
|
topk_weights = torch.empty( |
|
M, topk, dtype=torch.float32, device=hidden_states.device |
|
) |
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) |
|
token_expert_indicies = torch.empty( |
|
M, topk, dtype=torch.int32, device=hidden_states.device |
|
) |
|
|
|
ops.topk_softmax( |
|
topk_weights, |
|
topk_ids, |
|
token_expert_indicies, |
|
gating_output.float(), |
|
) |
|
del token_expert_indicies |
|
|
|
if renormalize: |
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
|
|
|
return topk_weights, topk_ids |
|
|
|
|
|
|
|
def grouped_topk( |
|
hidden_states: torch.Tensor, |
|
gating_output: torch.Tensor, |
|
topk: int, |
|
renormalize: bool, |
|
num_expert_group: int = 0, |
|
topk_group: int = 0, |
|
): |
|
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" |
|
|
|
scores = torch.softmax(gating_output, dim=-1) |
|
num_token = scores.shape[0] |
|
group_scores = ( |
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values |
|
) |
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ |
|
1 |
|
] |
|
group_mask = torch.zeros_like(group_scores) |
|
group_mask.scatter_(1, group_idx, 1) |
|
score_mask = ( |
|
group_mask.unsqueeze(-1) |
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) |
|
.reshape(num_token, -1) |
|
) |
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) |
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) |
|
|
|
if renormalize: |
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
|
|
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) |
|
|
|
|
|
def get_config_dtype_str( |
|
dtype: torch.dtype, |
|
use_int8_w8a16: Optional[bool] = False, |
|
use_fp8_w8a8: Optional[bool] = False, |
|
): |
|
if use_fp8_w8a8: |
|
return "fp8_w8a8" |
|
elif use_int8_w8a16: |
|
return "int8_w8a16" |
|
elif dtype == torch.float: |
|
|
|
|
|
return "float32" |
|
return None |
|
|
|
|
|
def fused_experts( |
|
hidden_states: torch.Tensor, |
|
w1: torch.Tensor, |
|
w2: torch.Tensor, |
|
topk_weights: torch.Tensor, |
|
topk_ids: torch.Tensor, |
|
inplace: bool = False, |
|
override_config: Optional[Dict[str, Any]] = None, |
|
use_fp8_w8a8: bool = False, |
|
use_int8_w8a16: bool = False, |
|
w1_scale: Optional[torch.Tensor] = None, |
|
w2_scale: Optional[torch.Tensor] = None, |
|
a1_scale: Optional[torch.Tensor] = None, |
|
a2_scale: Optional[torch.Tensor] = None, |
|
): |
|
|
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" |
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" |
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" |
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous" |
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous" |
|
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] |
|
|
|
num_tokens, _ = hidden_states.shape |
|
E, N, _ = w1.shape |
|
|
|
|
|
CHUNK_SIZE = VLLM_FUSED_MOE_CHUNK_SIZE |
|
M = min(num_tokens, CHUNK_SIZE) |
|
config_dtype = get_config_dtype_str( |
|
use_fp8_w8a8=use_fp8_w8a8, |
|
use_int8_w8a16=use_int8_w8a16, |
|
dtype=hidden_states.dtype, |
|
) |
|
|
|
get_config_func = functools.partial( |
|
try_get_optimal_moe_config, |
|
w1.shape, |
|
w2.shape, |
|
topk_ids.shape[1], |
|
config_dtype, |
|
override_config=override_config, |
|
) |
|
|
|
config = get_config_func(M) |
|
|
|
intermediate_cache1 = torch.empty( |
|
(M, topk_ids.shape[1], N), |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
intermediate_cache2 = torch.empty( |
|
(M * topk_ids.shape[1], N // 2), |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
intermediate_cache3 = torch.empty( |
|
(M, topk_ids.shape[1], w2.shape[1]), |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
|
|
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 |
|
|
|
if inplace: |
|
out_hidden_states = hidden_states |
|
else: |
|
out_hidden_states = torch.empty_like(hidden_states) |
|
|
|
for chunk in range((num_tokens // CHUNK_SIZE) + 1): |
|
begin_chunk_idx, end_chunk_idx = ( |
|
chunk * CHUNK_SIZE, |
|
min((chunk + 1) * CHUNK_SIZE, num_tokens), |
|
) |
|
curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] |
|
tokens_in_chunk, _ = curr_hidden_states.shape |
|
|
|
if tokens_in_chunk == 0: |
|
break |
|
|
|
if tokens_in_chunk < CHUNK_SIZE and chunk > 0: |
|
|
|
|
|
|
|
|
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] |
|
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] |
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] |
|
config = get_config_func(tokens_in_chunk) |
|
|
|
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] |
|
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] |
|
|
|
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( |
|
curr_topk_ids, config["BLOCK_SIZE_M"], E |
|
) |
|
|
|
invoke_fused_moe_kernel( |
|
curr_hidden_states, |
|
w1, |
|
intermediate_cache1, |
|
a1_scale, |
|
w1_scale, |
|
curr_topk_weights, |
|
curr_topk_ids, |
|
sorted_token_ids, |
|
expert_ids, |
|
num_tokens_post_padded, |
|
False, |
|
topk_ids.shape[1], |
|
config, |
|
compute_type=compute_type, |
|
use_fp8_w8a8=use_fp8_w8a8, |
|
use_int8_w8a16=use_int8_w8a16, |
|
) |
|
|
|
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) |
|
|
|
invoke_fused_moe_kernel( |
|
intermediate_cache2, |
|
w2, |
|
intermediate_cache3, |
|
a2_scale, |
|
w2_scale, |
|
curr_topk_weights, |
|
curr_topk_ids, |
|
sorted_token_ids, |
|
expert_ids, |
|
num_tokens_post_padded, |
|
True, |
|
1, |
|
config, |
|
compute_type=compute_type, |
|
use_fp8_w8a8=use_fp8_w8a8, |
|
use_int8_w8a16=use_int8_w8a16, |
|
) |
|
|
|
ops.moe_sum( |
|
intermediate_cache3.view(*intermediate_cache3.shape), |
|
out_hidden_states[begin_chunk_idx:end_chunk_idx], |
|
) |
|
return out_hidden_states |
|
|
|
|
|
def fused_moe( |
|
hidden_states: torch.Tensor, |
|
w1: torch.Tensor, |
|
w2: torch.Tensor, |
|
gating_output: torch.Tensor, |
|
topk: int, |
|
renormalize: bool, |
|
inplace: bool = False, |
|
override_config: Optional[Dict[str, Any]] = None, |
|
use_grouped_topk: bool = False, |
|
num_expert_group: Optional[int] = None, |
|
topk_group: Optional[int] = None, |
|
custom_routing_function: Optional[Callable] = None, |
|
use_fp8_w8a8: bool = False, |
|
use_int8_w8a16: bool = False, |
|
w1_scale: Optional[torch.Tensor] = None, |
|
w2_scale: Optional[torch.Tensor] = None, |
|
a1_scale: Optional[torch.Tensor] = None, |
|
a2_scale: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
""" |
|
This function computes a Mixture of Experts (MoE) layer using two sets of |
|
weights, w1 and w2, and top-k gating mechanism. |
|
|
|
Parameters: |
|
- hidden_states (torch.Tensor): The input tensor to the MoE layer. |
|
- w1 (torch.Tensor): The first set of expert weights. |
|
- w2 (torch.Tensor): The second set of expert weights. |
|
- gating_output (torch.Tensor): The output of the gating operation |
|
(before softmax). |
|
- topk (int): The number of top-k experts to select. |
|
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. |
|
- inplace (bool): If True, perform the operation in-place. |
|
Defaults to False. |
|
- override_config (Optional[Dict[str, Any]]): Optional override |
|
for the kernel configuration. |
|
- num_expert_group: Optional[int]: additional parameter for grouped_topk |
|
- topk_group: Optional[int]: additional parameter for grouped_topk |
|
- use_grouped_topk: If True, use grouped_topk instead of fused_topk |
|
note: Deepseekv2 model uses grouped_topk |
|
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner |
|
products for w1 and w2. Defaults to False. |
|
- use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner |
|
products for w1 and w2. Defaults to False. |
|
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for |
|
w1. |
|
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for |
|
w2. |
|
|
|
Returns: |
|
- torch.Tensor: The output tensor after applying the MoE layer. |
|
""" |
|
|
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" |
|
|
|
if use_grouped_topk: |
|
assert num_expert_group is not None and topk_group is not None |
|
topk_weights, topk_ids = grouped_topk( |
|
hidden_states, |
|
gating_output, |
|
topk, |
|
renormalize, |
|
num_expert_group, |
|
topk_group, |
|
) |
|
elif custom_routing_function is None: |
|
topk_weights, topk_ids = fused_topk( |
|
hidden_states, gating_output, topk, renormalize |
|
) |
|
else: |
|
topk_weights, topk_ids = custom_routing_function( |
|
hidden_states, gating_output, topk, renormalize |
|
) |
|
|
|
return fused_experts( |
|
hidden_states, |
|
w1, |
|
w2, |
|
topk_weights, |
|
topk_ids, |
|
inplace=inplace, |
|
override_config=override_config, |
|
use_fp8_w8a8=use_fp8_w8a8, |
|
use_int8_w8a16=use_int8_w8a16, |
|
w1_scale=w1_scale, |
|
w2_scale=w2_scale, |
|
a1_scale=a1_scale, |
|
a2_scale=a2_scale, |
|
) |
|
|