diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..0cd58331b2a989b68be4ec5676383437fca8687b 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +*.so filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0d24413ec4118bf0c3fcbc7e6006da17894a4fb4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +**/__pycache__/ +**/*egg-info/ \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..408e8efd26190a8e433de4c3741315f63e830e65 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,202 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +#from .grouped_gemm import backend as gg_backend +#from .grouped_gemm import ops as gg_ops + + +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + "MyReplacementLayer", + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from .moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py new file mode 100755 index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from ..stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py new file mode 100755 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py new file mode 100755 index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py @@ -0,0 +1,101 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py new file mode 100755 index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from .arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py new file mode 100755 index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from . import glu, mlp +from .arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py new file mode 100755 index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py @@ -0,0 +1,337 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) + +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import stk +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py new file mode 100755 index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py new file mode 100755 index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py @@ -0,0 +1,244 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + +import torch + +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py new file mode 100755 index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py @@ -0,0 +1,103 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py @@ -0,0 +1,587 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + +import torch +from packaging import version + +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py new file mode 100755 index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py @@ -0,0 +1,507 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) + + # Un-permute locally to setup for the next series of operations. + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py new file mode 100755 index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py @@ -0,0 +1,94 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py new file mode 100755 index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py @@ -0,0 +1,116 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py new file mode 100755 index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py @@ -0,0 +1,32 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5ca3d5090a6bc39a34640a8a15bfa9aa056bfa16 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a19bba459394ac0d93b6405084772af39eb92d4f280f6b5c586d1beeb1589051 +size 5573536 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..76dc5db49710ad2461c9bb1ba76f3fdb3de9f802 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_20250730102509 +ops = torch.ops._megablocks_20250730102509 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_20250730102509::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py new file mode 100755 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py new file mode 100755 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100755 index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,33 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100755 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100755 index 0000000000000000000000000000000000000000..a6f36b90d362ad6e5e26475e4ab3b3a5f4a1b02d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,31 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + # import grouped_gemm + pass + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +#from .grouped_gemm import backend as ops +#from .grouped_gemm import ops as backend diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py new file mode 100755 index 0000000000000000000000000000000000000000..c22fa16689f648d46c04b1ad39c45adba5f0ea9d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py @@ -0,0 +1,1001 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, +): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +# Shared expert MLP forward pass +def shared_mlp_forward( + x: torch.Tensor, + up_proj_weight: torch.Tensor, + down_proj_weight: torch.Tensor, + up_proj_bias: Optional[torch.Tensor] = None, + down_proj_bias: Optional[torch.Tensor] = None, + activation_fn: Optional[Any] = None, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + # Default activation function + if activation_fn is None: + activation_fn = torch.nn.functional.gelu + + # Scale weights + up_proj_weight = scale_grad(up_proj_weight, gradient_scale) + down_proj_weight = scale_grad(down_proj_weight, gradient_scale) + if up_proj_bias is not None: + up_proj_bias = scale_grad(up_proj_bias, gradient_scale) + if down_proj_bias is not None: + down_proj_bias = scale_grad(down_proj_bias, gradient_scale) + + # Resolve dtensors + up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) + down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) + if up_proj_bias is not None: + up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) + if down_proj_bias is not None: + down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) + + # Up projection + x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) + + # Activation + x = activation_fn(x) + + # Down projection + x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) + + return x + + +# Combine outputs from shared expert and regular experts +def combine_expert_shared_outputs( + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + shared_expert_weighted_sum: bool = False, + moe_top_k: int = 1, +) -> torch.Tensor: + if shared_expert_weighted_sum: + # Weighted sum based on number of experts used + total_experts = moe_top_k + 1 + shared_weight = 1.0 / total_experts + expert_weight = moe_top_k / total_experts + return shared_expert_out * shared_weight + expert_out * expert_weight + else: + # Simple addition + return shared_expert_out + expert_out + + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + mlp_impl: Optional[str] = None, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, + mlp_impl: Optional[str] = "sparse", +): + # Flatten inputs + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + # TODO: remove debugging var + # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 + + with torch.no_grad(): + # Step 1: Local permutation setup + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate sharding parameters + world_size = dist.get_world_size(expert_parallel_group) + hidden_sharding_deg = hidden_sharding_degree( + world_size, num_experts, hidden_size + ) + experts_per_rank_val = experts_per_rank(num_experts, world_size) + + # Replicate token counts for hidden sharding + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, (hidden_sharding_deg,) + ) + + # Exchange token counts across devices + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + + # Ensure CUB knows which device to use + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=expert_parallel_group, + async_op=True, + ) + + # Step 2: Local permutation - group tokens by target device + x = x.view(-1, x.shape[-1]) # [sl * bs, hs] + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Step 3: Compute communication counts and exchange tokens + with torch.no_grad(): + tpe_handle.wait() + + # Reshape for per-device calculations + repeated_tokens_per_expert = repeated_tokens_per_expert.view( + world_size, experts_per_rank_val + ) + parallel_tokens_per_expert = parallel_tokens_per_expert.view( + world_size, experts_per_rank_val + ) + + # Calculate send/recv counts + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() + # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() + tokens_received = sum(recv_counts) + + # Replicate for hidden sharding + x = ops.repeat(x, (hidden_sharding_deg, 1)) + + # Cross-device token exchange + parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( + x, recv_counts, send_counts, expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + # Step 4: Setup for local expert computation + replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + replicate_bins = ( + replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins + ) + + # Create expert indices for received tokens + parallel_top_expert = torch.remainder( + torch.arange( + num_experts * hidden_sharding_deg, + dtype=torch.int32, + device=indices.device, + ), + experts_per_rank_val, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # Sort tokens by expert assignment + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + sort_end_bit, + ) + + # Calculate bins for local experts + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, dtype=torch.int + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = ( + parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins + ) + + # Calculate expert capacity + expert_capacity = expert_capacity_fn( + tokens_received, + top_k, + experts_per_rank_val, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if mlp_impl == "grouped": + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + + # Step 5: Expert computation + parallel_x_handle.wait() + + parallel_x = permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=gradient_scale, + alpha=alpha, + ) + + # Step 6: Reverse communication - send results back + x, _ = _layers.all_to_all.all_to_all( + parallel_x, send_counts, recv_counts, expert_parallel_group + ) + + # Step 7: Reduce across hidden sharding dimension + shape = (hidden_sharding_deg, -1, hidden_size) + x = x.view(shape).sum(dim=0) + + # Step 8: Final local unpermutation + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + return x, tokens_per_expert.flatten() + + +def moe_forward( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, + mlp_impl: str = "grouped", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + "mlp_impl": mlp_impl, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + +def moe_forward_with_shared_expert( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, + mlp_impl: str = "grouped", + # Shared expert parameters + shared_up_proj_weight: Optional[torch.Tensor] = None, + shared_down_proj_weight: Optional[torch.Tensor] = None, + shared_up_proj_bias: Optional[torch.Tensor] = None, + shared_down_proj_bias: Optional[torch.Tensor] = None, + shared_expert_weighted_sum: bool = False, + shared_activation_fn: Optional[Any] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # First, compute regular MoE forward pass + expert_out, expert_weights, router_scores = moe_forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=training, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=moe_expert_model_parallelism, + forward_fn=forward_fn, + hidden_size=hidden_size, + mlp_impl=mlp_impl, + ) + + # If shared expert weights provided, compute shared expert output + if shared_up_proj_weight is not None and shared_down_proj_weight is not None: + shared_expert_out = shared_mlp_forward( + x=x, + up_proj_weight=shared_up_proj_weight, + down_proj_weight=shared_down_proj_weight, + up_proj_bias=shared_up_proj_bias, + down_proj_bias=shared_down_proj_bias, + activation_fn=shared_activation_fn, + gradient_scale=gradient_scale, + ) + + # Combine expert outputs + combined_out = combine_expert_shared_outputs( + shared_expert_out=shared_expert_out, + expert_out=expert_out, + shared_expert_weighted_sum=shared_expert_weighted_sum, + moe_top_k=moe_top_k, + ) + + return combined_out, expert_weights, router_scores + + # Return regular MoE output if no shared expert + return expert_out, expert_weights, router_scores + + +def create_shared_expert_weights( + hidden_size: int, + shared_expert_hidden_size: int, + device: torch.device, + dtype: torch.dtype, + init_method: Any, + output_layer_init_method: Any = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + + if output_layer_init_method is None: + output_layer_init_method = init_method + + # Create weight tensors + up_proj_weight = torch.empty( + shared_expert_hidden_size, + hidden_size, + device=device, + dtype=dtype, + ) + down_proj_weight = torch.empty( + hidden_size, + shared_expert_hidden_size, + device=device, + dtype=dtype, + ) + + # Initialize weights + init_method(up_proj_weight) + output_layer_init_method(down_proj_weight) + + # No bias by default + return up_proj_weight, down_proj_weight, None, None + +# HACK: Extract device_mesh from pre-hook closure - required for transformers integration +# This exists because device_mesh is trapped in hook closures with no model attribute +# Fragile - breaks if hook structure changes or Python internals change +# TODO: Replace with a more robust solution when available +def get_device_mesh(model): + # Extract device_mesh from child's unused pre_hook closure + try: + # Find the pre-hook that contains 'device_mesh' in its closure + hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) + # Extract the device_mesh from the closure + return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents + except Exception: + return None + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + moe_top_k = getattr(self.router, "top_k", 4) + moe_num_experts = getattr(self.experts, "num_experts", 128) + gradient_scale = getattr(self.experts, "gradient_scale", None) + alpha = getattr(self.experts, "alpha", 1.0) + moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) + moe_jitter_eps = getattr(self.experts, "jitter_eps", None) + moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) + uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) + + expert_parallel_group = getattr(self, "expert_parallel_group", None) + if expert_parallel_group is None: + device_mesh = get_device_mesh(self) + expert_parallel_group = device_mesh.get_group() if device_mesh else None + + has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 + forward_fn = parallel_forward_once if has_parallel else forward_once + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + mlp_impl = getattr(self, "mlp_impl", "grouped") + + output, expert_weights_out, *_ = moe_forward( + x=x, + router_weight=self.router.weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=self.training, + w1=self.experts.gate_up_proj, + w2=self.experts.down_proj, + w1_bias=self.experts.gate_up_proj_bias, + w2_bias=self.experts.down_proj_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=has_parallel, + forward_fn=forward_fn, + hidden_size=self.experts.hidden_size, + mlp_impl=mlp_impl, + ) + return output, expert_weights_out + + +class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): + + def __init__(self): + super().__init__() + # Shared expert weights will be set by the user + self.shared_up_proj_weight = None + self.shared_down_proj_weight = None + self.shared_up_proj_bias = None + self.shared_down_proj_bias = None + self.shared_expert_weighted_sum = False + self.shared_activation_fn = None + + def set_shared_expert_weights( + self, + up_proj_weight: torch.Tensor, + down_proj_weight: torch.Tensor, + up_proj_bias: Optional[torch.Tensor] = None, + down_proj_bias: Optional[torch.Tensor] = None, + weighted_sum: bool = False, + activation_fn: Optional[Any] = None, + ): + self.shared_up_proj_weight = up_proj_weight + self.shared_down_proj_weight = down_proj_weight + self.shared_up_proj_bias = up_proj_bias + self.shared_down_proj_bias = down_proj_bias + self.shared_expert_weighted_sum = weighted_sum + self.shared_activation_fn = activation_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + moe_top_k = getattr(self.router, "top_k", 4) + moe_num_experts = getattr(self.experts, "num_experts", 128) + gradient_scale = getattr(self.experts, "gradient_scale", None) + alpha = getattr(self.experts, "alpha", 1.0) + moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) + moe_jitter_eps = getattr(self.experts, "jitter_eps", None) + moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) + uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) + + expert_parallel_group = getattr(self, "expert_parallel_group", None) + if expert_parallel_group is None: + device_mesh = get_device_mesh(self) + expert_parallel_group = device_mesh.get_group() if device_mesh else None + + has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 + forward_fn = parallel_forward_once if has_parallel else forward_once + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + mlp_impl = getattr(self, "mlp_impl", "grouped") + + output, expert_weights_out, *_ = moe_forward_with_shared_expert( + x=x, + router_weight=self.router.weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=self.training, + w1=self.experts.gate_up_proj, + w2=self.experts.down_proj, + w1_bias=self.experts.gate_up_proj_bias, + w2_bias=self.experts.down_proj_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=has_parallel, + forward_fn=forward_fn, + hidden_size=self.experts.hidden_size, + mlp_impl=mlp_impl, + # Shared expert parameters + shared_up_proj_weight=self.shared_up_proj_weight, + shared_down_proj_weight=self.shared_down_proj_weight, + shared_up_proj_bias=self.shared_up_proj_bias, + shared_down_proj_bias=self.shared_down_proj_bias, + shared_expert_weighted_sum=self.shared_expert_weighted_sum, + shared_activation_fn=self.shared_activation_fn, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,63 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from .._layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100755 index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py new file mode 100755 index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py new file mode 100755 index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py new file mode 100755 index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from .. import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,415 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100755 index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py new file mode 100755 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py new file mode 100755 index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py new file mode 100755 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py new file mode 100755 index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from .. import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100755 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py new file mode 100755 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/topology.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/topology.py new file mode 100755 index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100755 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100755 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100755 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/matrix.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/matrix.py new file mode 100755 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..408e8efd26190a8e433de4c3741315f63e830e65 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,202 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +#from .grouped_gemm import backend as gg_backend +#from .grouped_gemm import ops as gg_ops + + +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + "MyReplacementLayer", + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from .moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py new file mode 100755 index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from ..stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py new file mode 100755 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py new file mode 100755 index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py @@ -0,0 +1,101 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py new file mode 100755 index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from .arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py new file mode 100755 index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from . import glu, mlp +from .arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py new file mode 100755 index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py @@ -0,0 +1,337 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) + +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import stk +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py new file mode 100755 index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py new file mode 100755 index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py @@ -0,0 +1,244 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + +import torch + +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py new file mode 100755 index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py @@ -0,0 +1,103 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py @@ -0,0 +1,587 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + +import torch +from packaging import version + +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py new file mode 100755 index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py @@ -0,0 +1,507 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) + + # Un-permute locally to setup for the next series of operations. + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py new file mode 100755 index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py @@ -0,0 +1,94 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py new file mode 100755 index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py @@ -0,0 +1,116 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py new file mode 100755 index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py @@ -0,0 +1,32 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5cd3425b99470e22cab8eb8f1dae40d1f2728033 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d191af003fd8f496122d36d0813cf6847cd1f96d736faf7b2c807bfe08807688 +size 5577856 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..76dc5db49710ad2461c9bb1ba76f3fdb3de9f802 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_20250730102509 +ops = torch.ops._megablocks_20250730102509 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_20250730102509::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py new file mode 100755 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py new file mode 100755 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100755 index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,33 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100755 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100755 index 0000000000000000000000000000000000000000..a6f36b90d362ad6e5e26475e4ab3b3a5f4a1b02d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,31 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + # import grouped_gemm + pass + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +#from .grouped_gemm import backend as ops +#from .grouped_gemm import ops as backend diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/layers.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/layers.py new file mode 100755 index 0000000000000000000000000000000000000000..c22fa16689f648d46c04b1ad39c45adba5f0ea9d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/layers.py @@ -0,0 +1,1001 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, +): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +# Shared expert MLP forward pass +def shared_mlp_forward( + x: torch.Tensor, + up_proj_weight: torch.Tensor, + down_proj_weight: torch.Tensor, + up_proj_bias: Optional[torch.Tensor] = None, + down_proj_bias: Optional[torch.Tensor] = None, + activation_fn: Optional[Any] = None, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + # Default activation function + if activation_fn is None: + activation_fn = torch.nn.functional.gelu + + # Scale weights + up_proj_weight = scale_grad(up_proj_weight, gradient_scale) + down_proj_weight = scale_grad(down_proj_weight, gradient_scale) + if up_proj_bias is not None: + up_proj_bias = scale_grad(up_proj_bias, gradient_scale) + if down_proj_bias is not None: + down_proj_bias = scale_grad(down_proj_bias, gradient_scale) + + # Resolve dtensors + up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) + down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) + if up_proj_bias is not None: + up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) + if down_proj_bias is not None: + down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) + + # Up projection + x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) + + # Activation + x = activation_fn(x) + + # Down projection + x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) + + return x + + +# Combine outputs from shared expert and regular experts +def combine_expert_shared_outputs( + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + shared_expert_weighted_sum: bool = False, + moe_top_k: int = 1, +) -> torch.Tensor: + if shared_expert_weighted_sum: + # Weighted sum based on number of experts used + total_experts = moe_top_k + 1 + shared_weight = 1.0 / total_experts + expert_weight = moe_top_k / total_experts + return shared_expert_out * shared_weight + expert_out * expert_weight + else: + # Simple addition + return shared_expert_out + expert_out + + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + mlp_impl: Optional[str] = None, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, + mlp_impl: Optional[str] = "sparse", +): + # Flatten inputs + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + # TODO: remove debugging var + # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 + + with torch.no_grad(): + # Step 1: Local permutation setup + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate sharding parameters + world_size = dist.get_world_size(expert_parallel_group) + hidden_sharding_deg = hidden_sharding_degree( + world_size, num_experts, hidden_size + ) + experts_per_rank_val = experts_per_rank(num_experts, world_size) + + # Replicate token counts for hidden sharding + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, (hidden_sharding_deg,) + ) + + # Exchange token counts across devices + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + + # Ensure CUB knows which device to use + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=expert_parallel_group, + async_op=True, + ) + + # Step 2: Local permutation - group tokens by target device + x = x.view(-1, x.shape[-1]) # [sl * bs, hs] + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Step 3: Compute communication counts and exchange tokens + with torch.no_grad(): + tpe_handle.wait() + + # Reshape for per-device calculations + repeated_tokens_per_expert = repeated_tokens_per_expert.view( + world_size, experts_per_rank_val + ) + parallel_tokens_per_expert = parallel_tokens_per_expert.view( + world_size, experts_per_rank_val + ) + + # Calculate send/recv counts + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() + # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() + tokens_received = sum(recv_counts) + + # Replicate for hidden sharding + x = ops.repeat(x, (hidden_sharding_deg, 1)) + + # Cross-device token exchange + parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( + x, recv_counts, send_counts, expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + # Step 4: Setup for local expert computation + replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + replicate_bins = ( + replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins + ) + + # Create expert indices for received tokens + parallel_top_expert = torch.remainder( + torch.arange( + num_experts * hidden_sharding_deg, + dtype=torch.int32, + device=indices.device, + ), + experts_per_rank_val, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # Sort tokens by expert assignment + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + sort_end_bit, + ) + + # Calculate bins for local experts + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, dtype=torch.int + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = ( + parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins + ) + + # Calculate expert capacity + expert_capacity = expert_capacity_fn( + tokens_received, + top_k, + experts_per_rank_val, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if mlp_impl == "grouped": + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + + # Step 5: Expert computation + parallel_x_handle.wait() + + parallel_x = permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=gradient_scale, + alpha=alpha, + ) + + # Step 6: Reverse communication - send results back + x, _ = _layers.all_to_all.all_to_all( + parallel_x, send_counts, recv_counts, expert_parallel_group + ) + + # Step 7: Reduce across hidden sharding dimension + shape = (hidden_sharding_deg, -1, hidden_size) + x = x.view(shape).sum(dim=0) + + # Step 8: Final local unpermutation + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + return x, tokens_per_expert.flatten() + + +def moe_forward( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, + mlp_impl: str = "grouped", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + "mlp_impl": mlp_impl, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + +def moe_forward_with_shared_expert( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, + mlp_impl: str = "grouped", + # Shared expert parameters + shared_up_proj_weight: Optional[torch.Tensor] = None, + shared_down_proj_weight: Optional[torch.Tensor] = None, + shared_up_proj_bias: Optional[torch.Tensor] = None, + shared_down_proj_bias: Optional[torch.Tensor] = None, + shared_expert_weighted_sum: bool = False, + shared_activation_fn: Optional[Any] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # First, compute regular MoE forward pass + expert_out, expert_weights, router_scores = moe_forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=training, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=moe_expert_model_parallelism, + forward_fn=forward_fn, + hidden_size=hidden_size, + mlp_impl=mlp_impl, + ) + + # If shared expert weights provided, compute shared expert output + if shared_up_proj_weight is not None and shared_down_proj_weight is not None: + shared_expert_out = shared_mlp_forward( + x=x, + up_proj_weight=shared_up_proj_weight, + down_proj_weight=shared_down_proj_weight, + up_proj_bias=shared_up_proj_bias, + down_proj_bias=shared_down_proj_bias, + activation_fn=shared_activation_fn, + gradient_scale=gradient_scale, + ) + + # Combine expert outputs + combined_out = combine_expert_shared_outputs( + shared_expert_out=shared_expert_out, + expert_out=expert_out, + shared_expert_weighted_sum=shared_expert_weighted_sum, + moe_top_k=moe_top_k, + ) + + return combined_out, expert_weights, router_scores + + # Return regular MoE output if no shared expert + return expert_out, expert_weights, router_scores + + +def create_shared_expert_weights( + hidden_size: int, + shared_expert_hidden_size: int, + device: torch.device, + dtype: torch.dtype, + init_method: Any, + output_layer_init_method: Any = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + + if output_layer_init_method is None: + output_layer_init_method = init_method + + # Create weight tensors + up_proj_weight = torch.empty( + shared_expert_hidden_size, + hidden_size, + device=device, + dtype=dtype, + ) + down_proj_weight = torch.empty( + hidden_size, + shared_expert_hidden_size, + device=device, + dtype=dtype, + ) + + # Initialize weights + init_method(up_proj_weight) + output_layer_init_method(down_proj_weight) + + # No bias by default + return up_proj_weight, down_proj_weight, None, None + +# HACK: Extract device_mesh from pre-hook closure - required for transformers integration +# This exists because device_mesh is trapped in hook closures with no model attribute +# Fragile - breaks if hook structure changes or Python internals change +# TODO: Replace with a more robust solution when available +def get_device_mesh(model): + # Extract device_mesh from child's unused pre_hook closure + try: + # Find the pre-hook that contains 'device_mesh' in its closure + hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) + # Extract the device_mesh from the closure + return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents + except Exception: + return None + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + moe_top_k = getattr(self.router, "top_k", 4) + moe_num_experts = getattr(self.experts, "num_experts", 128) + gradient_scale = getattr(self.experts, "gradient_scale", None) + alpha = getattr(self.experts, "alpha", 1.0) + moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) + moe_jitter_eps = getattr(self.experts, "jitter_eps", None) + moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) + uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) + + expert_parallel_group = getattr(self, "expert_parallel_group", None) + if expert_parallel_group is None: + device_mesh = get_device_mesh(self) + expert_parallel_group = device_mesh.get_group() if device_mesh else None + + has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 + forward_fn = parallel_forward_once if has_parallel else forward_once + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + mlp_impl = getattr(self, "mlp_impl", "grouped") + + output, expert_weights_out, *_ = moe_forward( + x=x, + router_weight=self.router.weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=self.training, + w1=self.experts.gate_up_proj, + w2=self.experts.down_proj, + w1_bias=self.experts.gate_up_proj_bias, + w2_bias=self.experts.down_proj_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=has_parallel, + forward_fn=forward_fn, + hidden_size=self.experts.hidden_size, + mlp_impl=mlp_impl, + ) + return output, expert_weights_out + + +class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): + + def __init__(self): + super().__init__() + # Shared expert weights will be set by the user + self.shared_up_proj_weight = None + self.shared_down_proj_weight = None + self.shared_up_proj_bias = None + self.shared_down_proj_bias = None + self.shared_expert_weighted_sum = False + self.shared_activation_fn = None + + def set_shared_expert_weights( + self, + up_proj_weight: torch.Tensor, + down_proj_weight: torch.Tensor, + up_proj_bias: Optional[torch.Tensor] = None, + down_proj_bias: Optional[torch.Tensor] = None, + weighted_sum: bool = False, + activation_fn: Optional[Any] = None, + ): + self.shared_up_proj_weight = up_proj_weight + self.shared_down_proj_weight = down_proj_weight + self.shared_up_proj_bias = up_proj_bias + self.shared_down_proj_bias = down_proj_bias + self.shared_expert_weighted_sum = weighted_sum + self.shared_activation_fn = activation_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + moe_top_k = getattr(self.router, "top_k", 4) + moe_num_experts = getattr(self.experts, "num_experts", 128) + gradient_scale = getattr(self.experts, "gradient_scale", None) + alpha = getattr(self.experts, "alpha", 1.0) + moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) + moe_jitter_eps = getattr(self.experts, "jitter_eps", None) + moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) + uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) + + expert_parallel_group = getattr(self, "expert_parallel_group", None) + if expert_parallel_group is None: + device_mesh = get_device_mesh(self) + expert_parallel_group = device_mesh.get_group() if device_mesh else None + + has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 + forward_fn = parallel_forward_once if has_parallel else forward_once + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + mlp_impl = getattr(self, "mlp_impl", "grouped") + + output, expert_weights_out, *_ = moe_forward_with_shared_expert( + x=x, + router_weight=self.router.weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=self.training, + w1=self.experts.gate_up_proj, + w2=self.experts.down_proj, + w1_bias=self.experts.gate_up_proj_bias, + w2_bias=self.experts.down_proj_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=has_parallel, + forward_fn=forward_fn, + hidden_size=self.experts.hidden_size, + mlp_impl=mlp_impl, + # Shared expert parameters + shared_up_proj_weight=self.shared_up_proj_weight, + shared_down_proj_weight=self.shared_down_proj_weight, + shared_up_proj_bias=self.shared_up_proj_bias, + shared_down_proj_bias=self.shared_down_proj_bias, + shared_expert_weighted_sum=self.shared_expert_weighted_sum, + shared_activation_fn=self.shared_activation_fn, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,63 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from .._layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100755 index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py new file mode 100755 index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py new file mode 100755 index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py new file mode 100755 index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from .. import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,415 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100755 index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py new file mode 100755 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py new file mode 100755 index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py new file mode 100755 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py new file mode 100755 index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from .. import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100755 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py new file mode 100755 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/topology.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/topology.py new file mode 100755 index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100755 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100755 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100755 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/matrix.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/matrix.py new file mode 100755 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..408e8efd26190a8e433de4c3741315f63e830e65 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/__init__.py @@ -0,0 +1,202 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from ._ops import ops + +#from .grouped_gemm import backend as gg_backend +#from .grouped_gemm import ops as gg_ops + + +from ._layers.arguments import Arguments +from ._layers.dmoe import ParallelDroplessMLP, dMoE +from ._layers.glu import SparseGLU +from ._layers.mlp import MLP, SparseMLP +from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +from . import layers + +# This section contains the direct kernel exports (not inlcuded in the original code) +def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute exclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.exclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor: + """ + Compute inclusive cumulative sum along the specified dimension. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + result = ops.inclusive_cumsum(x, dim) + out.copy_(result) + return out + + +def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor: + """ + Compute histogram of input tensor values. + + Args: + x: Input tensor + num_bins: Number of histogram bins + + Returns: + Histogram tensor with counts for each bin + """ + return ops.histogram(x, num_bins) + + +def indices( + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, +) -> torch.Tensor: + """ + Construct indices from padded bins for sparse operations. + + Args: + padded_bins: Tensor containing bin boundaries + block_size: Size of each block + output_block_rows: Number of rows in output blocks + output_block_columns: Number of columns in output blocks + + Returns: + Tensor containing constructed indices + """ + return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns) + + +def replicate_forward( + x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Forward pass of replicate operation - replicate values according to bin sizes. + + Args: + x: Input tensor with values to replicate + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_forward(x, bins, out) + + +def replicate_backward( + grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor +) -> torch.Tensor: + """ + Backward pass of replicate operation - reduce gradients back to bins. + + Args: + grad: Gradient tensor to reduce + bins: Tensor containing bin sizes + out: Output tensor (modified in-place) + + Returns: + The output tensor + """ + return ops.replicate_backward(grad, bins, out) + + +def sort( + x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor +) -> torch.Tensor: + """ + Radix sort with index tracking. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + x_out: Output tensor for sorted values + iota_out: Output tensor for sorted indices + + Returns: + The sorted values tensor + """ + return ops.sort(x, end_bit, x_out, iota_out) + + +# Convenience functions for common use cases +def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor: + """ + Compute cumulative sum with automatic output allocation. + + Args: + x: Input tensor + dim: Dimension along which to compute cumsum (default: last dimension) + exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum + + Returns: + New tensor containing the cumulative sum + """ + out = torch.empty_like(x) + if exclusive: + return exclusive_cumsum(x, dim, out) + else: + return inclusive_cumsum(x, dim, out) + + +def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]: + """ + Sort tensor and return both sorted values and indices. + + Args: + x: Input tensor to sort + end_bit: Number of bits to consider in sorting + + Returns: + Tuple of (sorted_values, sorted_indices) + """ + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + sort(x, end_bit, x_out, iota_out) + return x_out, iota_out + + +# Export public API +__all__ = [ + "MyReplacementLayer", + # Direct kernel exports + "exclusive_cumsum", + "inclusive_cumsum", + "histogram", + "indices", + "replicate_forward", + "replicate_backward", + "sort", + "cumsum", + "argsort", + # Original exports + "Arguments", + "ParallelDroplessMLP", + "dMoE", + "SparseGLU", + "MLP", + "SparseMLP", + "MoE", + "ParallelMLP", + "get_load_balancing_loss", +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a720e7a2cc4e44636f6e433a2750e945dc38e8b2 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/__init__.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# from megablocks.layers.dmoe import dMoE +from .moe import MoE + +__all__ = [ + 'MoE', + # 'dMoE', +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/activation_fn.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/activation_fn.py new file mode 100755 index 0000000000000000000000000000000000000000..0e1d956704840aa4daf7d1d71d24e051567feab9 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/activation_fn.py @@ -0,0 +1,33 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union + +import torch +from ..stk import Matrix + + +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + if return_grad_fn: + return y, out.backward + return y diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/all_to_all.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/all_to_all.py new file mode 100755 index 0000000000000000000000000000000000000000..5ac7067bcaa34db1d82b340c43550fe3577aa7a3 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/all_to_all.py @@ -0,0 +1,54 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + + +class AllToAllOp(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) + + ctx.input_shape = x.shape + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + ctx.group = group + handle = dist.all_to_all_single( + out, + x, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + async_op=async_op, + ) + return out, handle + + @staticmethod + def backward(ctx, grad, _): + if ctx.needs_input_grad[0]: + out = torch.empty( + ctx.input_shape, + device=grad.device, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, + output_split_sizes=ctx.input_split_sizes, + input_split_sizes=ctx.output_split_sizes, + group=ctx.group, + ) + return out, None, None, None, None + return None, None, None, None, None + + +def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): + return AllToAllOp.apply( + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/arguments.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/arguments.py new file mode 100755 index 0000000000000000000000000000000000000000..4db9b1bd38bc2e2f421625c124f86b85f45c5ae0 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/arguments.py @@ -0,0 +1,101 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import dataclasses +from functools import partial +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn.functional as F + +# import megablocks.grouped_gemm_util as grouped_gemm +from .. import grouped_gemm_util as grouped_gemm + +# Type annotation for in-place Tensor initialization function. +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] + +_ALLOWED_BITWIDTHS = (-1, 4, 8) + +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') + + +@dataclasses.dataclass +class Arguments: + # Model arguments. + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN + + # MoE arguments. + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False + + # Parallelism arguments. + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None + + # Compute arguments. + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' + + # Initialization arguments. + fp16: bool = True + bf16: bool = False + device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device) + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method + + # Benchmarking arguments. + uniform_expert_assignment: bool = False + + # shared expert arguments + shared_expert: bool = False # enable using shared expert + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers + remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) + + # Router Z-loss arguments + moe_zloss_weight: float = 0 # 1e-3 is a reasonable value + moe_zloss_in_fp32: bool = False + + def __post_init__(self): + # Sparse MLP is not supported with triton >=3.2.0 + # TODO: Remove this once sparse is supported with triton >=3.2.0 + if self.__getattribute__('mlp_impl') == 'sparse': + try: + import triton + if triton.__version__ >= '3.2.0': + raise ValueError( + 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.', + ) + except ImportError: + raise ImportError('Triton is required for sparse MLP implementation') + + if self.__getattribute__('mlp_impl') == 'grouped': + grouped_gemm.assert_grouped_gemm_is_available() + + if self.shared_expert_hidden_size is None: + self.shared_expert_hidden_size = self.ffn_hidden_size + + +def from_megatron(megatron_args: Any): + args = Arguments() + for field in dataclasses.fields(args): + if hasattr(megatron_args, field.name): + setattr(args, field.name, getattr(megatron_args, field.name)) + return args diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/common.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/common.py new file mode 100755 index 0000000000000000000000000000000000000000..2d07109702963ba48a3b94ab860807954dfd79c1 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/common.py @@ -0,0 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from .arguments import Arguments + + +def dtype(args: Arguments): + if args.fp16: + return torch.float16 + elif args.bf16: + return torch.bfloat16 + return None + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmlp_registry.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmlp_registry.py new file mode 100755 index 0000000000000000000000000000000000000000..de2ed047042e438c7190ebb139b6f7f30009734c --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmlp_registry.py @@ -0,0 +1,42 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +from . import glu, mlp +from .arguments import Arguments + +MlpType = Union[mlp.SparseMLP, glu.SparseGLU] + +_REGISTRY = { + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, +} + + +def get(args: Arguments) -> MlpType: + """Returns an MLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + MLP instance. This only contains MLPs for use in dMoEs + (ie. only for the dropless versions of MoEs). + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated MLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + if args.mlp_impl not in _REGISTRY[args.mlp_type]: + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) + + return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmoe.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmoe.py new file mode 100755 index 0000000000000000000000000000000000000000..6d0375a4df2f27134c4127e60be04f3b45693050 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/dmoe.py @@ -0,0 +1,337 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.', +# ) + +# import megablocks.ops as ops +# # from megablocks.ops import ops +# from megablocks.layers import common, dmlp_registry, moe, mpu +# from megablocks.layers.arguments import Arguments + +from .. import stk +from .. import ops +from . import common, dmlp_registry, moe, mpu +from .arguments import Arguments + +def promote_scalar(x): + return x.view(1) if not len(x.size()) else x + + +class ParallelDroplessMLP(moe.ParallelMLP): + + def __init__(self, args: Arguments): + super(ParallelDroplessMLP, self).__init__(args) + self.hidden_size = args.hidden_size + self.ffn_hidden_size = mpu.features_per_rank(args) + self.blocking = 128 + self.mlp = dmlp_registry.get(args) + + # Calculate the number of bits needed to represent the column indices + # in the intermediate sparse matrix. + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) + self.transpose_sort_end_bit = max( + int(np.ceil(np.log2(max_column_index))), + 1, + ) + + def sparse_transpose(self, size, row_indices, column_indices, offsets): + block_columns = size[1] // self.blocking + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + # + # NOTE: Our sort operation uses the same width indices as the input values. + # To avoid overflow when we have large activation matrices we cast to + # 32-bit before sorting. + _, gather_indices = ops.sort( + column_indices.int(), + self.transpose_sort_end_bit, + ) + + # There are a constant number of blocks in every row of the sparse matrix. + # A blocks offset is: + # + # row_index * blocks_per_row + column_index % blocks_per_row + # + # Once we have the block offsets ordered for transposition we can divide + # by blocks_per_row to get the transposed column indices. + column_indices_t = row_indices.gather(0, gather_indices.long()) + block_offsets_t = gather_indices.int() + + zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device) + nnz_per_column = ops.histogram(column_indices, block_columns) + nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0) + if nnz_per_column.dim() == 0: + # This addresses an edge case when ffn_hidden_size is equal to self.blocking. + nnz_per_column = nnz_per_column.unsqueeze(0) + offsets_t = torch.cat([zero, nnz_per_column]) + return column_indices_t, offsets_t, block_offsets_t + + def topology(self, x, padded_bins): + padded_tokens, _ = x.size() + assert padded_tokens % self.blocking == 0 + if self.ffn_hidden_size % self.blocking != 0: + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // self.blocking + blocks_per_row = self.ffn_hidden_size // self.blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) + + # TODO(tgale): This is unused. Remove the need for this in stk. + # For now, use meta init to save the device memory. + data = torch.empty( + column_indices.numel(), + self.blocking, + self.blocking, + dtype=common.dtype(self.args), + device='meta', + ) + shape = ( + padded_tokens, + self.ffn_hidden_size * mpu.experts_per_rank(self.args), + ) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) + + def indices_and_padded_bins(self, top_experts): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + top_experts = top_experts.int() + bin_ids, indices = ops.sort(top_experts, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + tokens_per_expert = ops.histogram(top_experts, self.num_experts) + + # Round the token counts up to the block size used in + # the matrix muliplications. Caculate the starting + # position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = promote_scalar(bins) + return indices, bin_ids, bins, padded_bins, tokens_per_expert + + def sparse_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather( + x, + indices, + bin_ids, + bins, + padded_bins, + self.top_k, + ) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + x = ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + self.top_k, + ) + return x, tokens_per_expert + + # For use in the base-class parallel_forward_once. + def sparse_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Round the token counts up to the block size used in the matrix + # multiplication. Calculate the starting position of each bin. + padded_tokens_per_expert = ops.round_up( + tokens_per_expert, + self.blocking, + ) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + padded_bins = promote_scalar(padded_bins) + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + # Create the sparse matrix topology. + with torch.no_grad(): + topo = self.topology(x, padded_bins) + + # Perform the expert computation. + x = self.mlp(x, topo) + + # Un-route the data for the MoE output. + return ops.padded_scatter( + x, + indices, + bin_ids, + expert_weights, + bins, + padded_bins, + top_k, + ) + + def grouped_forward_once(self, x, expert_weights, top_experts): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + out = self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + -1, # unused + self.args.moe_top_k, + ) + return out, tokens_per_expert + + def grouped_permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): + + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Perform the expert computation. + x = self.mlp(x, tokens_per_expert) + + # Un-route the data for the MoE output. + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + def forward_once(self, x, expert_weights, top_experts): + if self.args.mlp_impl == 'sparse': + return self.sparse_forward_once(x, expert_weights, top_experts) + else: + return self.grouped_forward_once(x, expert_weights, top_experts) + + def permute_and_compute( + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): + if self.args.mlp_impl == 'sparse': + return self.sparse_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + else: + return self.grouped_permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ) + + +class dMoE(moe.MoE): + + def _init_experts_mlp(self, args: Arguments): + return ParallelDroplessMLP(args) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/gelu.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/gelu.py new file mode 100755 index 0000000000000000000000000000000000000000..c4c9e6532798615b5c12c96694241a4c18ee8f7b --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/gelu.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.', +# ) + +from .. import stk + +import torch +import torch.nn.functional as F + + +@torch.jit.script +def _gelu_backward_inplace(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) + return g.mul_(ff) + + +def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): + # NOTE: The two sparse matrices must have the same topology. + if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix): + return stk.Matrix( + x.size(), + _gelu_backward_inplace(grad.data, x.data), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) + return _gelu_backward_inplace(grad, x) + + +def gelu(x: stk.Matrix): + assert isinstance(x, stk.Matrix) + return stk.Matrix( + x.size(), + F.gelu(x.data, approximate='tanh'), + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t, + ) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/glu.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/glu.py new file mode 100755 index 0000000000000000000000000000000000000000..5f297a41ff6a1a2a285f5b461951672364b898da --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/glu.py @@ -0,0 +1,244 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# import stk.ops +# try: +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.', +# ) + +from .. import stk + +import torch + +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import Arguments +# from megablocks.layers.mlp import ( +# SharedMLP, +# SparseMLP, +# create_dmoe_expert_weights, +# resolve_dtensor, +# ) + +from .. import grouped_gemm_util as gg +from . import common, mpu +from .activation_fn import act_fn +from .arguments import Arguments +from .mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + + +class SparseGLU(SparseMLP): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + with torch.no_grad(): + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + + mpu.set_expert_model_parallel_attributes( + self.v1, + self._should_set_parallelism_attribute, + ) + + def forward(self, x, topo): + if self.args.memory_optimized_mlp: + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) + + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Compute the GLU. + x1 = stk.ops.sdd(x, w1.t(), topo) + x2 = stk.ops.sdd(x, v1.t(), topo) + + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) + + return stk.ops.dsd(x1, w2) + + +class MemoryOptimizedGroupedGLU(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + v1 = v1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) + + # GeLU. + activation_fn_out = activation_fn(sdd_out) * v1_out + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, v1, w2 = saved_tensors[:3] + batch_sizes = saved_tensors[3] + x = saved_tensors[4] + sdd_out, v1_out = saved_tensors[5:7] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + dv1_out = v1_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dv1. + dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + dx = ddsd_out + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx) + dx += gg.backend.gmm(dv1_out, v1, batch_sizes) + return dx, dw1, dv1, dw2, None, None + + +memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply + + +class GroupedGLU(SparseGLU): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = w1.view(ne, -1, self.args.hidden_size) + v1 = v1.view(ne, -1, self.args.hidden_size) + w2 = w2.view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_glu( + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) + x1 = self.args.activation_fn(x1) * x2 + return gg.ops.gmm(x1, w2, batch_sizes) + + +class SharedGLU(SharedMLP): + """GPU for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class + """ + + def __init__(self, args: Arguments): + super().__init__(args) + self.gate_proj = args.fc_cls( + args.hidden_size, + self.args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.py new file mode 100755 index 0000000000000000000000000000000000000000..74d1166931b712635131985b25a89f4ca23e576d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/memory_test.py @@ -0,0 +1,103 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import gc + +import torch +import torch.distributed as dist + +# from megablocks.layers import arguments, dmoe +from . import arguments, dmoe + +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) + + +def get_tensors(): + ptrs = set() + out = [] + for obj in gc.get_objects(): + if torch.is_tensor(obj): + if not obj.is_contiguous() or obj.data_ptr() in ptrs: + continue + out.append(obj) + ptrs.add(obj.data_ptr()) + return out + + +def test_memory( + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): + args = arguments.Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=num_experts, + moe_top_k=top_k, + moe_expert_model_parallelism=True, + expert_parallel_group=group, + fp16=False, + bf16=True, + device=torch.cuda.current_device(), + ) + layer = dmoe.dMoE(args).cuda() + + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) + torch.cuda.empty_cache() + + # Run forward + backward. + # with torch.autograd.detect_anomaly(): + out, _ = layer(x) + out.mean().backward() + + # Report peak memory. + mem = torch.cuda.max_memory_allocated() + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) + + # Calculate weight and gradient memory usage. + weight_memory = 2 * ( + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) + + def grad_numel(x): + if x.grad is not None: + return x.grad.numel() + return 0 + + grad_memory = 2 * ( + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) + weight_memory += grad_memory + + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) + + # Manually calculate GPU memory usage from the garbage + # collector. + gc.collect() + total = 0 + tensors = get_tensors() + tensors = sorted(tensors, key=lambda x: -x.numel()) + for i, t in enumerate(tensors): + total += t.numel() + print(f'{i}: {t.shape}, {t.numel() * 2}') + del tensors + + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _TESTS: + test_memory(group, *args) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mlp.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mlp.py new file mode 100755 index 0000000000000000000000000000000000000000..c99afb9904c24a8b6a83e79059cd1251dbbfd99e --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mlp.py @@ -0,0 +1,587 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# try: +# import stk +# import stk.backend.triton_kernels +# import stk.ops +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.', +# ) + +from .. import stk + +import torch +from packaging import version + +# from megablocks import grouped_gemm_util as gg +# from megablocks.layers import common, gelu, mpu +# from megablocks.layers.activation_fn import act_fn +# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +from .. import grouped_gemm_util as gg +from . import common, gelu, mpu +from .activation_fn import act_fn +from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn + +class ScaleGradient(torch.autograd.Function): + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): + ctx.scale = scale + return x + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): + return grad * ctx.scale, None + + +scale_gradient = ScaleGradient.apply + + +def resolve_dtensor(weight: torch.Tensor): + if version.parse(torch.__version__) >= version.parse('2.0.0'): + from torch.distributed._tensor import DTensor + if isinstance(weight, DTensor): + return weight.to_local() + return weight + + +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): + # Create the entire weight matrix such that the sampled weights will + # not vary between data parallelism and expert model parallelism for + # the same random seed. + master_weights = torch.empty( + num_experts, + ffn_hidden_size, + hidden_size, + device=args.device, + dtype=common.dtype(args), + ) + init_method(master_weights) + + if not args.moe_expert_model_parallelism: + return master_weights + + # Calculate the amount of sharding in each dimension. + expert_sharding_degree = mpu.expert_sharding_degree(args) + hidden_sharding_degree = mpu.hidden_sharding_degree(args) + + # Calculate the experts per rank. + # + # NOTE: We assign ranks to be expert parallel before going + # tensor parallel. + rank = mpu.get_expert_parallel_rank(args) + expert_rank = rank % expert_sharding_degree + num_experts_per_rank = num_experts // expert_sharding_degree + start_expert = expert_rank * num_experts_per_rank + end_expert = (expert_rank + 1) * num_experts_per_rank + + # Calculate the rows per rank. + row_rank = rank // expert_sharding_degree + num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree + start_row = row_rank * num_rows_per_rank + end_row = (row_rank + 1) * num_rows_per_rank + + # Slice the weight matrix to get the chunk for this rank. + with torch.no_grad(): + weights = master_weights[start_expert:end_expert, start_row:end_row] + return weights + + +class MLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + experts_per_rank = mpu.experts_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + mpu.set_expert_model_parallel_attributes( + self.w1, + args.moe_expert_model_parallelism, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + args.moe_expert_model_parallelism, + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + w1 = create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) + self.w1.copy_(w1.transpose(1, 2).contiguous()) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + x = torch.bmm(x, w1) + x = self.args.activation_fn(x) + return torch.bmm(x, w2) + + +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): + weights = create_moe_expert_weights( + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) + + +class MemoryOptimizedMLP(torch.autograd.Function): + """Sparse MLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, topo, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) + + # Layer 0: x @ w1.t(). + sdd_out = stk.ops.sdd(x, w1.t(), topo) + + # GeLU. + activation_fn_out = act_fn(sdd_out, activation_fn) + + # Layer 1: x @ w2. + dsd_out = stk.ops.dsd(activation_fn_out, w2) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.shape = topo.shape + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.data.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx, ddsd_out): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + topo_tensors = saved_tensors[2:8] + x = saved_tensors[8] + sdd_out_data = saved_tensors[9] + + # rematerialize activation function output + activation_fn = ctx.activation_fn + sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) + + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + stk.backend.triton_kernels.sdd( + ddsd_out, + w2.t(), + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + + # Compute dw1. + dw1 = stk.ops.dsd(dsdd_out.t(), x) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + stk.backend.triton_kernels.dsd( + dsdd_out.shape, + dsdd_out.data, + dsdd_out.offsets, + dsdd_out.row_indices, + dsdd_out.column_indices, + dsdd_out.offsets_t, + dsdd_out.column_indices_t, + dsdd_out.block_offsets_t, + False, + w1, + ddsd_out, + ) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_mlp = MemoryOptimizedMLP.apply + + +class SparseMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + + # Initialize the parameters for the MLP. + # + # NOTE: It is important that we create the weight tensors prior + # to creating the master weights and slicing our the piece for + # this rank. If the master weights are created first the PyTorch + # caching allocator appears to use the same memory block for these + # and the slice which causes large increases in our peak memory + # usage. + with torch.no_grad(): + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism + mpu.set_expert_model_parallel_attributes( + self.w1, + self._should_set_parallelism_attribute, + ) + mpu.set_expert_model_parallel_attributes( + self.w2, + self._should_set_parallelism_attribute, + ) + + self.gradient_scale = None + if self.args.moe_expert_model_parallelism: + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) + + def scale_grad(self, w): + if self.gradient_scale is None: + return w + return scale_gradient(w, self.gradient_scale) + + def forward(self, x, topo): + w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) + w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) + if self.args.memory_optimized_mlp: + return memory_optimized_mlp( + x, + w1, + w2, + topo, + self.args.activation_fn, + ) + + # Compute the MLP. + x = stk.ops.sdd(x, w1.t(), topo) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) + + +class MemoryOptimizedGroupedMLP(torch.autograd.Function): + """GroupedMLP with manually scheduled memory reuse.""" + + @staticmethod + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx, x, w1, w2, batch_sizes, activation_fn): + # Cast inputs using ctx dtype from AMP + if ctx._fwd_used_autocast: + x = x.to(ctx._dtype) + w1 = w1.to(ctx._dtype) + w2 = w2.to(ctx._dtype) + # x: [m, k], w1: [n, k], w2: [n, k] + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): + raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") + + # Layer 0: x @ w1.t(). + assert gg.backend is not None + sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) + + # activation_fn + activation_fn_out = activation_fn(sdd_out) + + # Layer 1: x @ w2. + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) + + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward + # pass in the backward pass to avoid materializing another + # intermediate. + ctx.x_shape = x.shape + ctx.sdd_out_shape = sdd_out.shape + ctx.dtype = x.dtype + ctx.activation_fn = activation_fn + ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) + return dsd_out + + @staticmethod + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') + + # Unpack saved tensors + # dtype = ctx.dtype + saved_tensors = ctx.saved_tensors + w1, w2 = saved_tensors[:2] + batch_sizes = saved_tensors[2] + x = saved_tensors[3] + sdd_out = saved_tensors[4] + + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward + + # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None + dw2 = gg.backend.gmm( + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) + + # Compute dactivation_fn_out. + # + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out + gg.backend.gmm( + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) + + # Compute dsdd_out. + # + # NOTE: This reuses the dactivation_fn_out allocation. + if activation_fn is DEFAULT_ACTIVATION_FN: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) + else: + assert activation_grad_fn is not None + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + + # Compute dw1. + dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) + + # Compute dx. + # + # NOTE: This reuses the ddsd_out allocation. + gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) + dx = ddsd_out + return dx, dw1, dw2, None, None + + +memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply + + +class GroupedMLP(SparseMLP): + + def forward(self, x, tokens_per_expert): + batch_sizes = tokens_per_expert.cpu().to(torch.long) + w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) + + # Re-shape the weights for the grouped GEMMs. + ne = mpu.experts_per_rank(self.args) + w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) + w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) + + if self.args.memory_optimized_mlp: + return memory_optimized_grouped_mlp( + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) + + # Compute the MLP. + assert gg.ops is not None + x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) + x = self.args.activation_fn(x) + return gg.ops.gmm(x, w2, batch_sizes) + + +class SharedMLP(torch.nn.Module): + """MLP for shared expert. + + Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class + """ + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + self.fc_kwargs: dict[str, Any] = { + 'bias': args.bias, + 'device': args.device, + } + self.fc_kwargs.update(args.fc_kwargs) + + self.up_proj = args.fc_cls( + args.hidden_size, + args.shared_expert_hidden_size, + **self.fc_kwargs, + ) + self.act = args.activation_fn + self.down_proj = args.fc_cls( + args.shared_expert_hidden_size, + args.hidden_size, + **self.fc_kwargs, + ) + self.down_proj._is_residual = True # a flag for llm-foundry init + + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: + # Helper function to add expert output to shared expert output + # with optional weighted sum. + if self.args.shared_expert_weighted_sum: + # enable using weighted sum for shared expert output + # wieghted by number of experts used + t_experts = self.args.moe_top_k + 1 + sh_mlp_out = shared_expert_out / t_experts + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) + + return shared_expert_out + expert_out + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.act(self.up_proj(x))) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/moe.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/moe.py new file mode 100755 index 0000000000000000000000000000000000000000..d0a4aeaacc9c86fc70944e730c53f7a55644e05e --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/moe.py @@ -0,0 +1,507 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist + +# import megablocks.ops as ops +# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +# from megablocks.layers.all_to_all import all_to_all +# from megablocks.layers.arguments import Arguments + +from ..ops import ( + sort, + histogram, + inclusive_cumsum, + exclusive_cumsum, + binned_gather, + binned_scatter, + gather, + scatter, + repeat, + replicate, +) + +from . import common, mlp, mpu, router, sharedexpert_registry +from .arguments import Arguments +from .all_to_all import all_to_all + +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args: Arguments): + if args.moe_loss_weight == 0: + return 0.0 + + # tokens_per_expert[i].shape = (num_experts) + # expert_scores[i].shape = (tokens, num_experts) + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) + + tokens = expert_scores[0].shape[0] + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# NOTE: This class defines MoE expert computation, including expert model parallel +# communication. When using FSDP on top of MegaBlocks this is the module that should +# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model +# parallel all2all. +class ParallelMLP(torch.nn.Module): + + def __init__(self, args: Arguments): + super(ParallelMLP, self).__init__() + self.args = args + + # Calculate the number of experts in total and the number of experts + # owned by this rank. + # world_size = mpu.get_expert_parallel_world_size(args) + self.num_experts = args.moe_num_experts + self.top_k = self.args.moe_top_k + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + + # Expert MLP. + self.mlp = mlp.MLP(args) + + self.bias: Optional[torch.Tensor] + if self.args.bias: + # Note that the output bias is not parallelized with expert + # model parallelism. + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + # Select the forward function for the operating mode. + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) + + def expert_capacity(self, tokens: int) -> int: + world_size = mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) + return int(self.args.moe_capacity_factor * tokens_per_expert) + + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + num_experts, = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + # output = ops.sort(top_expert, self.sort_end_bit) + output = sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + # tokens_per_expert = ops.histogram(top_expert, self.num_experts) + tokens_per_expert = histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + # bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None + bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + + return indices, bin_ids, bins, tokens_per_expert + + def permute_and_compute( + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output + + # Perform the expert computation. Note that we don't + # use biases for these linear operations. + x = self.mlp(x) + + # Un-route the data for the MoE output. + # return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return binned_scatter(x, indices, expert_weights, bins, top_k) + + + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + sl, bs, _ = x.size() + expert_capacity = self.expert_capacity(sl * bs) + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = self.permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + self.top_k, + ) + return x, tokens_per_expert + + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so every device gets the counts. + # repeated_tokens_per_expert = ops.repeat( + repeated_tokens_per_expert = repeat( + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=self.args.expert_parallel_group, + async_op=True, + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + # output = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = mpu.get_expert_parallel_world_size(self.args) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # If we're sharding the experts along the hidden dimension + # multiple devices own parts of the same sets of experts. + # Replicate the token counts so devices that share experts + # get all of the tokens assigned to them. + # + # TODO(tgale): Fuse this into the prior, local permutation. + # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1)) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, + recv_counts, + send_counts, + self.args.expert_parallel_group, + async_op=True, + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + # replicate_bins = ops.inclusive_cumsum( + replicate_bins = inclusive_cumsum( + parallel_tokens_per_expert.flatten(), + 0, + ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts * mpu.hidden_sharding_degree(self.args), + dtype=torch.int32, + device=indices.device, + ), + mpu.experts_per_rank(self.args), + ) + # parallel_top_expert = ops.replicate( + parallel_top_expert = replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + # parallel_bin_ids, parallel_indices = ops.sort( + parallel_bin_ids, parallel_indices = sort( + parallel_top_expert, + self.sort_end_bit, + ) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, + dtype=torch.int, + ) + # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if self.args.mlp_impl == 'grouped': + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all( + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) + + # Reduce along the hidden sharding to get the final outputs. + # + # TODO(tgale): Fuse this into the following local permutation. + shape = ( + mpu.hidden_sharding_degree(self.args), + -1, + self.args.hidden_size, + ) + # x = ops.sum(x.view(shape), dim=0) + x = x.view(shape).sum(dim=0) + + # Un-permute locally to setup for the next series of operations. + # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) + return x, tokens_per_expert.flatten() + + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): + in_shape = x.size() + + # Compute the experts. + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + if self.training and self.args.moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, scores)) + x = x.view(in_shape) + if self.bias is not None: + if self.args.return_bias: + return x, self.bias + return x + self.bias + return x + + +class MoE(torch.nn.Module): + + def __init__(self, args: Arguments): + super(MoE, self).__init__() + + # Token router. + self.router = router.LearnedRouter(args) + + # Expert computation helper. + self.experts = self._init_experts_mlp(args) + + self.shared_expert = None + if args.shared_expert: + # SharedExpert computation helper. + self.shared_expert = sharedexpert_registry.get(args) + + def _init_experts_mlp(self, args: Arguments): + return ParallelMLP(args) + + def forward(self, x: torch.Tensor): + # NOTE: If we're going to cast the activations to lower precision + # do it before we permute the tokens to save bandwidth. + x = common.cast_if_autocast_enabled(x) + + # Compute the expert scores and assignments. + scores, expert_weights, top_experts = self.router(x) + + # Compute the experts. + out = self.experts(x, scores, expert_weights, top_experts) + if self.shared_expert is not None: + shared_expert_out = self.shared_expert(x) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) + return out diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mpu.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mpu.py new file mode 100755 index 0000000000000000000000000000000000000000..434e143ab42bf3f83406d69e9dd1f72777716e22 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/mpu.py @@ -0,0 +1,94 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch +import torch.distributed as dist + +# from megablocks.layers.arguments import Arguments +from .arguments import Arguments + + +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') + + +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + + +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + + +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) + + +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) + + +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) + + +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + for i in range(world_size): + dist.barrier(group) + if i == rank: + print(f'rank = {rank}', *x) + + +# Helpers for expert/tensor sharding. +def expert_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = min(world_size, args.moe_num_experts) + + if (args.moe_num_experts % esd) != 0: + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) + return esd + + +def hidden_sharding_degree(args: Arguments) -> int: + world_size = get_expert_parallel_world_size(args) + esd = expert_sharding_degree(args) + hsd = world_size // esd + + if (args.ffn_hidden_size % hsd) != 0: + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) + return hsd + + +def experts_per_rank(args: Arguments) -> int: + return args.moe_num_experts // expert_sharding_degree(args) + + +def features_per_rank(args: Arguments) -> int: + return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/router.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/router.py new file mode 100755 index 0000000000000000000000000000000000000000..37cb2782348d62583376f1a183c7ede83601216d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/router.py @@ -0,0 +1,116 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + +# from megablocks.layers import common +# from megablocks.layers.arguments import Arguments +from . import common +from .arguments import Arguments + +_ROUTER_LOGITS = [] + + +def _save_router_logits(logits: torch.Tensor, args: Arguments): + if args.moe_zloss_weight == 0: + return + global _ROUTER_LOGITS + _ROUTER_LOGITS.append(logits) + + +def clear_router_zloss(): + global _ROUTER_LOGITS + _ROUTER_LOGITS.clear() + + +def batched_router_zloss(args: Arguments): + global _ROUTER_LOGITS + + if args.moe_zloss_weight == 0: + import warnings + warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0') + return 0 + + logits_per_router = _ROUTER_LOGITS + + if args.moe_zloss_in_fp32: + logits_per_router = [logits.float() for logits in logits_per_router] + + unscaled_zloss_per_router = torch.stack([ + torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router + ]) + + return args.moe_zloss_weight * unscaled_zloss_per_router + + +# NOTE: To enable end-to-end benchmarking without convergence we +# support a flag to force the router to assign tokens uniformly +# across the experts. We do this with a custom autograd operation +# so that PyTorch still executes the full set of router operation. +class _UniformExpertAssignment(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, num_experts: int): + out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) + out = torch.remainder(out, num_experts) + return out.view(x.shape) + + +_uniform_expert_assignment = _UniformExpertAssignment.apply + + +class LearnedRouter(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.args = args + + # Learned router parameters. + # + # NOTE: This weight matrix is not parallelized with expert model + # parallelism. Each device needs the entire router weight matrix + # so that it can route its batch of data correctly. + self.layer = torch.nn.Linear( + args.hidden_size, + args.moe_num_experts, + bias=False, + dtype=common.dtype(args), + device=args.device, + ) + args.init_method(self.layer.weight) + + def jitter(self, x: torch.Tensor): + low: float = 1.0 - self.args.moe_jitter_eps + high: float = 1.0 + self.args.moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return low + noise * (high - low) + + def _top_k(self, scores: torch.Tensor): + if self.args.moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, self.args.moe_top_k, dim=-1) + + def forward(self, x: torch.Tensor): + if self.training and self.args.moe_jitter_eps is not None: + x = x * self.jitter(x) + + logits = self.layer(x.view(-1, x.shape[-1])) + _save_router_logits(logits, self.args) + scores = logits.softmax(dim=-1) + expert_weights, expert_indices = self._top_k(scores) + if self.args.moe_normalize_expert_weights: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + + expert_indices = ( + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices + ) + return scores, expert_weights, expert_indices diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/sharedexpert_registry.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/sharedexpert_registry.py new file mode 100755 index 0000000000000000000000000000000000000000..5840862f88f370ace5fd49bd0612fc98d186cc49 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_layers/sharedexpert_registry.py @@ -0,0 +1,32 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +# from megablocks.layers import glu, mlp +# from megablocks.layers.arguments import Arguments +from . import glu, mlp +from .arguments import Arguments + +_REGISTRY = { + 'mlp': mlp.SharedMLP, + 'glu': glu.SharedGLU, +} + + +def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: + """Returns an SharedMLP for use in a dMoE instance. + + Uses the provided arguments to instantiate the appropriate + SharedMLP instance. + + Args: + args: propagated Arguments dataclass. + + Returns: + An instantiated SharedMLP constructed using the input args. + """ + if args.mlp_type not in _REGISTRY: + raise ValueError(f'Unsupported mlp type: {args.mlp_type}') + + return _REGISTRY[args.mlp_type](args) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1f33c6c11677f38f19cedc4e356283ab31dc52b8 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:25626aa99a75712594ff18f13a6b029b0814b5fc59ddf26e250a337047509e66 +size 5578408 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..76dc5db49710ad2461c9bb1ba76f3fdb3de9f802 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _megablocks_20250730102509 +ops = torch.ops._megablocks_20250730102509 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_megablocks_20250730102509::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9d4e43e9b7e6a9a1c3bde2df34914643ca5d8332 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/kernels.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/kernels.py new file mode 100755 index 0000000000000000000000000000000000000000..b584ceede926ca30abef2dec581cb3ff329e8e16 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/backend/kernels.py @@ -0,0 +1,543 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import triton +import triton.language as tl + + +def assert_is_tensor(x, ndim): + if x.ndim != ndim: + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') + + +def assert_is_matrix(x): + assert_is_tensor(x, 2) + + +def assert_is_vector(x): + if x.ndim != 1: + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') + + +def assert_equal(a, b): + if a != b: + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) + + +# a: (tokens, hidden_size), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy( + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Our index into array 'a'. + index_a = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'b'. + index_b = offset_in_bin + if bin_idx > 0: + index_b += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: Because of the padding, the output size is dynamic. + # We load the final padded bin bound to get the output rows. + output_rows = padded_bins[-1].cpu().item() + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def gather(x, indices, bin_ids, weights, bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + assert_equal(bin_ids.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + # NOTE: There is no padding so the output rows equals the + # input rows multiplied by top_k. + output_rows = x.shape[0] * top_k + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + x, + out, + indices, + bin_ids, + weights, + bins, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) + _padded_copy[(indices.shape[0],)]( + out, + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) + + +def scatter(x, indices, bin_ids, weights, bins, top_k): + return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) + + +# x: (tokens, top_k, hidden_size), real +# grad: (tokens, hidden_size), real. +# wgrad: (tokens, top_k), real. +# indices: (tokens * top_k), integer. +# bin_ids: (tokens * top_k), integer. +# bins: (num_experts), integer. +# padded_bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _padded_copy_wgrad( + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Our index into 'tokens * top_k'. + index_out = tl.load(indices + tl.program_id(0)) + + # One threadblock per row in 'a'. Array 'b' has greater or equal + # number of rows since they could be padded. + bin_idx = tl.load(bin_ids + tl.program_id(0)) + + # Now we know what bin we're assigned to, but we need to know how + # many threadblocks were assigned to earlier bins so we can offset + # in our bin properly. + offset_in_bin = tl.program_id(0) + if bin_idx > 0: + offset_in_bin -= tl.load(bins + bin_idx - 1) + + # Load the starting index of our bin in array 'x'. + index_x = offset_in_bin + if bin_idx > 0: + index_x += tl.load(padded_bins + bin_idx - 1) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bin_ids) + assert_is_vector(bins) + assert_is_vector(padded_bins) + assert_equal(indices.shape[0], bin_ids.shape[0]) + assert_equal(bins.size(), padded_bins.size()) + + tokens = indices.shape[0] // top_k + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) + _padded_copy_wgrad[(indices.shape[0],)]( + x, + grad, + out, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS=x.shape[1], + TOP_K=top_k, + ) + return out + + +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): + return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy( + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_b = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_a = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + # + # If we're going from A to B, divide the input index to copy + # the same input repeatedly. If we're going from B to A we + # need to reduce the result. Using atomics is slow, so we + # do the reduce step in a second kernel. + offset = index_a // TOP_K if A_TO_B else index_a + a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + # Load the scale, if requested. + scale = tl.load(weights + index_a) if SCALE else 1 + + # Swap the pointers depending on the direction. + # + # NOTE: We need to zero the output in both directions. + iptr = a if A_TO_B else b + optr = b if A_TO_B else a + + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + x = tl.load(iptr + offsets, mask=mask) + x = x.to(tl.float32) * scale.to(tl.float32) + + tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask) + + offsets += BLOCK_X + + +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): + # Validate the input shapes. + assert_is_matrix(x) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(indices.shape[0], x.shape[0] * top_k) + + if weights is not None: + assert_equal(weights.shape[0], x.shape[0] * top_k) + + num_experts = bins.shape[0] + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + + _binned_copy[(num_experts, expert_capacity)]( + x, + out, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=x.shape[1], + A_TO_B=True, + TOP_K=top_k, + SCALE=weights is not None, + ) + return out + + +def binned_scatter(x, indices, weights, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + if weights is not None: + assert_equal(indices.shape[0], weights.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( + out, + x, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS=hidden_size, + A_TO_B=False, + TOP_K=top_k, + SCALE=weights is not None, + ) + + # Reduce along the top-k dimension, if needed. + return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) + + +# a: (tokens, hidden_size), real. +# b: (num_experts, expert_capacity, num_columns), real. +# indices: (tokens * top_k), integer. +# weights: (tokens * top_k), real. +# bins: (num_experts), integer. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_X': 64}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=2), + triton.Config({'BLOCK_X': 256}, num_warps=2), + triton.Config({'BLOCK_X': 128}, num_warps=4), + triton.Config({'BLOCK_X': 256}, num_warps=4), + ], + key=['NUM_COLUMNS'], +) +@triton.jit +def _binned_copy_wgrad( + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): + # Load our indices into the output. + expert_idx = tl.program_id(0) + entry_idx = tl.program_id(1) + + # Calculate our offset into the output. + index_x = expert_idx * expert_capacity + entry_idx + + # Load the index bounds for our bin and calculate + # the number of tokens assigned to our expert. + start = 0 + if expert_idx > 0: + start = tl.load(bins + expert_idx - 1) + end = tl.load(bins + expert_idx) + num_tokens = end - start + + # Calculate our offset into the input. If we don't + # have an input exit early. + if entry_idx >= num_tokens: + return + index_out = tl.load(indices + start + entry_idx) + + # Offset the input and output pointers. + wgrad += index_out + grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) + + acc = tl.zeros((BLOCK_X,), dtype=tl.float32) + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): + mask = offsets < NUM_COLUMNS + data = tl.load(x + offsets, mask=mask).to(tl.float32) + scale = tl.load(grad + offsets, mask=mask).to(tl.float32) + acc += data * scale + offsets += BLOCK_X + + # Reduce to get the final result and store. + out = tl.sum(acc).to(wgrad.dtype.element_ty) + tl.store(wgrad, out) + + +def binned_scatter_wgrad(x, grad, indices, bins, top_k): + # Validate the input shapes. + assert_is_tensor(x, 3) + assert_is_matrix(grad) + assert_is_vector(indices) + assert_is_vector(bins) + assert_equal(bins.shape[0], x.shape[0]) + + num_experts, expert_capacity, hidden_size = x.shape + tokens = indices.shape[0] // top_k + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) + _binned_copy_wgrad[(num_experts, expert_capacity)]( + x, + grad, + out, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS=hidden_size, + TOP_K=top_k, + ) + return out diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/bak.__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/bak.__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..5217959caf74527e3bf7f80db6f93be21c016963 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/bak.__init__.py @@ -0,0 +1,23 @@ +from megablocks_moe.megablocks import ( + MoE, + dMoE, + get_load_balancing_loss, + ParallelMLP, + ParallelDroplessMLP, + SparseMLP, + MLP, + SparseGLU, + Arguments, +) + +__all__ = [ + "MoE", + "dMoE", + "get_load_balancing_loss", + "ParallelMLP", + "ParallelDroplessMLP", + "SparseMLP", + "MLP", + "SparseGLU", + "Arguments", +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/benchmark_util.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/benchmark_util.py new file mode 100755 index 0000000000000000000000000000000000000000..02612d95e3ead1175a596e2878fa34b5bf85ad6f --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/benchmark_util.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import torch + + +def log_benchmark(name, arguments, time, std): + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) + + +def benchmark_function(fn, iterations=100, warmup=10): + # Warmup iterations. + for _ in range(warmup): + fn() + + times = [] + for i in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + fn() + end.record() + + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + return np.mean(times), np.std(times) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b91c8308f0c24f4c4171b6e4f15b6f76dabf295a --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/__init__.py @@ -0,0 +1,2 @@ +from . import ops +from . import backend diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/backend.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/backend.py new file mode 100755 index 0000000000000000000000000000000000000000..76037d8039cbfc2f0577275c78e4bc0be762592a --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/backend.py @@ -0,0 +1,33 @@ +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# # TODO(tgale): Wrap this in a try-block with better +# # error message and instructions for building the +# # c++ operations. +# import grouped_gemm_backend as backend + +# We import the backend operations from the megablocks package as +# grouped_gemm is vendored in megablocks in this repository. +# from ... import _ops as backend +# from megablocks._ops import ops as backend # type: ignore +from .._ops import ops as backend # type: ignore + +def _allocate_output(a, b, batch_sizes, trans_a, trans_b): + assert not (trans_a and trans_b) + assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes" + assert a.ndim == 2, "Expected 2d tensor for 'a'" + assert b.ndim == (2 if trans_a else 3) + + shape = ( + (batch_sizes.shape[0], a.shape[1], b.shape[1]) + if trans_a else + (a.shape[0], (b.shape[1] if trans_b else b.shape[2])) + ) + return torch.empty(*shape, device=a.device, dtype=a.dtype) + +def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None): + if c is None: + c = _allocate_output(a, b, batch_sizes, trans_a, trans_b) + backend.gmm(a, b, c, batch_sizes, trans_a, trans_b) + return c diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/ops.py new file mode 100755 index 0000000000000000000000000000000000000000..4b30dd14e23837ea3b12334f4e31337ed9ad2b69 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm/ops.py @@ -0,0 +1,33 @@ +from . import backend +import torch + + +class GroupedGemm(torch.autograd.Function): + + @staticmethod + def forward(ctx, a, b, batch_sizes, trans_b): + ctx.save_for_backward(a, b, batch_sizes) + ctx.trans_b = trans_b + return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b) + + @staticmethod + def backward(ctx, grad): + grad = grad.contiguous() + a, b, batch_sizes = ctx.saved_tensors + trans_b = ctx.trans_b + + agrad = None + if ctx.needs_input_grad[0]: + agrad = backend.gmm( + grad, b, batch_sizes, trans_a=False, trans_b=not trans_b) + + bgrad = None + if ctx.needs_input_grad[1]: + lhs, rhs = (grad, a) if trans_b else (a, grad) + bgrad = backend.gmm( + lhs, rhs, batch_sizes, trans_a=True, trans_b=False) + return agrad, bgrad, None, None + + +def gmm(a, b, batch_sizes, trans_b=False): + return GroupedGemm.apply(a, b, batch_sizes, trans_b) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm_util.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm_util.py new file mode 100755 index 0000000000000000000000000000000000000000..a6f36b90d362ad6e5e26475e4ab3b3a5f4a1b02d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/grouped_gemm_util.py @@ -0,0 +1,31 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False +try: + # import grouped_gemm + pass + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + + +def grouped_gemm_is_available(): + return _grouped_gemm_is_available + + +def assert_grouped_gemm_is_available(): + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + + +# backend = grouped_gemm.backend if grouped_gemm_is_available() else None +# ops = grouped_gemm.ops if grouped_gemm_is_available() else None + + +#from .grouped_gemm import backend as ops +#from .grouped_gemm import ops as backend diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/layers.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/layers.py new file mode 100755 index 0000000000000000000000000000000000000000..c22fa16689f648d46c04b1ad39c45adba5f0ea9d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/layers.py @@ -0,0 +1,1001 @@ +import torch +import torch.distributed as dist + +from typing import Optional, Any + +from . import _layers +from . import ops + + +# Set the expert model parallel attributes on a tensor +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, "expert_model_parallel") + setattr(tensor, "expert_model_parallel", is_parallel) + + +# Get the expert model parallel attributes from a tensor +def expert_sharding_degree( + world_size: int, + moe_num_experts: int, +) -> int: + esd = min(world_size, moe_num_experts) + if (moe_num_experts % esd) != 0: + raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.") + return esd + + +# Calculate the hidden sharding degree based on world size and expert sharding degree +def hidden_sharding_degree( + world_size: int, + moe_num_experts: int, + ffn_hidden_size: int, +) -> int: + esd = expert_sharding_degree(world_size, moe_num_experts) + hsd = world_size // esd + if (ffn_hidden_size % hsd) != 0: + raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.") + if (esd * hsd) != world_size: + raise ValueError( + f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})." + ) + return hsd + + +# Calculate the number of experts per rank based on world size and expert sharding degree +def experts_per_rank( + moe_num_experts: int, + world_size: int, +) -> int: + return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts) + + +# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree +def features_per_rank( + ffn_hidden_size: int, world_size: int, moe_num_experts: int +) -> int: + return ffn_hidden_size // hidden_sharding_degree( + world_size, moe_num_experts, ffn_hidden_size + ) + + +# Apply jitter to the input tensor +def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor: + low = 1.0 - moe_jitter_eps + high = 1.0 + moe_jitter_eps + noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) + return x * (low + noise * (high - low)) + + +# Compute the top-k scores from the logits +def compute_top_k(scores: torch.Tensor, moe_top_k: int): + if moe_top_k == 1: + return scores.max(dim=-1, keepdim=True) + return torch.topk(scores, moe_top_k, dim=-1) + + +# Route tokens to experts and compute expert weights and indices +def route_tokens( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if training and moe_jitter_eps is not None: + x = apply_jitter(x, moe_jitter_eps) + + x_flat = x.view(-1, x.shape[-1]) + logits = torch.nn.functional.linear(x_flat, router_weight) + expert_weights, expert_indices = compute_top_k(logits, moe_top_k) + expert_weights = expert_weights.softmax(dim=-1) + if moe_normalize_expert_weights is not None: + expert_weights = expert_weights / torch.norm( + expert_weights, + p=moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) + if uniform_expert_assignment: + expert_indices = _layers.router._uniform_expert_assignment( + expert_indices, + moe_num_experts, + ) + + return logits, expert_weights, expert_indices + + +# Scale the gradient of the weights +def scale_grad( + w: torch.Tensor, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + if gradient_scale is None: + return w + return _layers.mlp.scale_gradient(w, gradient_scale) + + +# Forward pass for the MLP layer +def mlp_forward( + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, +): + # Scale weights + w1 = scale_grad(w1, gradient_scale) + w2 = scale_grad(w2, gradient_scale) + w1_bias = scale_grad(w1_bias, gradient_scale) + w2_bias = scale_grad(w2_bias, gradient_scale) + + # Resolve dtensors + w1 = _layers.mlp.resolve_dtensor(w1) + w2 = _layers.mlp.resolve_dtensor(w2) + w1_bias = _layers.mlp.resolve_dtensor(w1_bias) + w2_bias = _layers.mlp.resolve_dtensor(w2_bias) + + # Forward pass + gate_up = torch.bmm(x, w1) + w1_bias[..., None, :] + gate, up = gate_up.chunk(2, dim=-1) + + glu = gate * torch.sigmoid(gate * alpha) + x = (up + 1) * glu + + return torch.bmm(x, w2) + w2_bias[..., None, :] + + +# Shared expert MLP forward pass +def shared_mlp_forward( + x: torch.Tensor, + up_proj_weight: torch.Tensor, + down_proj_weight: torch.Tensor, + up_proj_bias: Optional[torch.Tensor] = None, + down_proj_bias: Optional[torch.Tensor] = None, + activation_fn: Optional[Any] = None, + gradient_scale: Optional[float] = None, +) -> torch.Tensor: + # Default activation function + if activation_fn is None: + activation_fn = torch.nn.functional.gelu + + # Scale weights + up_proj_weight = scale_grad(up_proj_weight, gradient_scale) + down_proj_weight = scale_grad(down_proj_weight, gradient_scale) + if up_proj_bias is not None: + up_proj_bias = scale_grad(up_proj_bias, gradient_scale) + if down_proj_bias is not None: + down_proj_bias = scale_grad(down_proj_bias, gradient_scale) + + # Resolve dtensors + up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight) + down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight) + if up_proj_bias is not None: + up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias) + if down_proj_bias is not None: + down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias) + + # Up projection + x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias) + + # Activation + x = activation_fn(x) + + # Down projection + x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias) + + return x + + +# Combine outputs from shared expert and regular experts +def combine_expert_shared_outputs( + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + shared_expert_weighted_sum: bool = False, + moe_top_k: int = 1, +) -> torch.Tensor: + if shared_expert_weighted_sum: + # Weighted sum based on number of experts used + total_experts = moe_top_k + 1 + shared_weight = 1.0 / total_experts + expert_weight = moe_top_k / total_experts + return shared_expert_out * shared_weight + expert_out * expert_weight + else: + # Simple addition + return shared_expert_out + expert_out + + +# Global variable to store load balancing loss +_LOAD_BALANCING_LOSS = [] + + +def save_load_balancing_loss(loss): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.append(loss) + + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + + +def batched_load_balancing_loss(args): + if args.moe_loss_weight == 0: + return 0.0 + + tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) + num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size + if args.num_layers_per_virtual_pipeline_stage is not None: + num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + + if len(tokens_per_expert) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} token_per_experts " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + if len(expert_scores) != num_layers_per_pipeline_stage: + raise ValueError( + f"Expected {num_layers_per_pipeline_stage} expert_scores " + f"but found {len(tokens_per_expert)}.\nnum_layers = " + f"{args.num_layers}\npipeline_model_parallel_size = " + f"{args.pipeline_model_parallel_size}\n" + "num_layers_per_virtual_pipeline_stage" + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) + + # Verify the shape of the tokens_per_expert and expert_scores tensors. + assert all( + (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert) + ) + + tokens = expert_scores[0].shape[0] + assert all( + ( + ( + x.ndim == 2 + and x.shape[1] == args.moe_num_experts + and x.shape[0] == tokens + ) + for x in expert_scores + ) + ) + + # Concatenate the contributions of each layer and convert to + # the correct types and formats for the dot product. + expert_scores = torch.cat(expert_scores, dim=1) + if args.moe_lbl_in_fp32: + expert_scores = expert_scores.float() + if tokens != 0: + expert_scores = expert_scores.mean(dim=0) + else: + expert_scores = expert_scores.sum(dim=0) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + + expected_values = num_layers_per_pipeline_stage * args.moe_num_experts + assert tokens_per_expert.numel() == expected_values + assert expert_scores.numel() == expected_values + + # Calculate the total scale across all factors. + # + # loss_weight * num_experts / (num_layers * tokens * top_k) + scale_numerator = args.moe_num_experts * args.moe_loss_weight + scale_denominator = args.num_layers * tokens * args.moe_top_k + scale = scale_numerator / scale_denominator + return scale * torch.dot(tokens_per_expert, expert_scores) + + +# Calculate the expert capacity based on tokens, top_k, number of experts, +# expert parallel group, capacity factor, and whether expert model parallelism is used. +def expert_capacity( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: int, + moe_capacity_factor: float, + moe_expert_model_parallelism: bool, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def load_balancing_loss( + tokens_per_expert: torch.Tensor, + expert_scores: torch.Tensor, + top_k: int, + num_experts: int, +): + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == num_experts + scale = num_experts / (tokens * top_k) + return scale * torch.dot( + tokens_per_expert.to(expert_scores.dtype), + expert_scores.mean(dim=0), + ) + + +def indices_and_bins( + top_expert: torch.Tensor, + sort_end_bit: int, + num_experts: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + top_expert = top_expert.int() + + # Ensure contiguous memory layout + top_expert = top_expert.contiguous() + + # Ensure CUB knows which device to use + with torch.cuda.device(top_expert.device): + output = ops.sort(top_expert, sort_end_bit) + bin_ids, indices = output + tokens_per_expert = ops.histogram(top_expert, num_experts) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + bins = bins.view(1) if not len(bins.size()) else bins + return indices, bin_ids, bins, tokens_per_expert + + +def expert_capacity_fn( + tokens: int, + top_k: int, + num_experts: int, + expert_parallel_group: torch.distributed.ProcessGroup, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, +) -> int: + world_size = ( + dist.get_world_size(expert_parallel_group) + if moe_expert_model_parallelism + else 1 + ) + tokens_per_expert = top_k * tokens * world_size / num_experts + return int(moe_capacity_factor * tokens_per_expert) + + +def permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, +): + # Route tokens to experts + x = x.view(-1, x.shape[-1]) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Expert computation + x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha) + + # Ensure CUB knows which device to use + with torch.cuda.device(x.device): + # Route tokens back + out = ops.binned_scatter(x, indices, expert_weights, bins, top_k) + return out + + +def forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: int = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + mlp_impl: Optional[str] = None, +): + # x: [sl, bs, hs] + # expert_weights: [sl * bs, top-k] + # top_experts: [sl * bs, top-k] + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate expert capacity + sl, bs, _ = x.size() + + expert_capacity = expert_capacity_fn( + sl * bs, + top_k, + num_experts, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + + if expert_capacity == 0: + expert_capacity = torch.max(tokens_per_expert).item() + + x = permute_and_compute( + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capacity, + top_k, + w1, + w2, + w1_bias, + w2_bias, + gradient_scale, + alpha, + ) + return x, tokens_per_expert + + +def parallel_forward_once( + x: torch.Tensor, + expert_weights: torch.Tensor, + top_experts: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_bias: torch.Tensor, + w2_bias: torch.Tensor, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + top_k: int = 4, + num_experts: int = 128, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = True, + hidden_size: int = 1152, + mlp_impl: Optional[str] = "sparse", +): + # Flatten inputs + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + # TODO: remove debugging var + # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0 + + with torch.no_grad(): + # Step 1: Local permutation setup + indices, bin_ids, bins, tokens_per_expert = indices_and_bins( + top_experts, sort_end_bit, num_experts + ) + + # Calculate sharding parameters + world_size = dist.get_world_size(expert_parallel_group) + hidden_sharding_deg = hidden_sharding_degree( + world_size, num_experts, hidden_size + ) + experts_per_rank_val = experts_per_rank(num_experts, world_size) + + # Replicate token counts for hidden sharding + repeated_tokens_per_expert = ops.repeat( + tokens_per_expert, (hidden_sharding_deg,) + ) + + # Exchange token counts across devices + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + + # Ensure CUB knows which device to use + tpe_handle = dist.all_to_all_single( + parallel_tokens_per_expert, + repeated_tokens_per_expert, + group=expert_parallel_group, + async_op=True, + ) + + # Step 2: Local permutation - group tokens by target device + x = x.view(-1, x.shape[-1]) # [sl * bs, hs] + x = ops.gather(x, indices, bin_ids, bins, top_k) + + # Step 3: Compute communication counts and exchange tokens + with torch.no_grad(): + tpe_handle.wait() + + # Reshape for per-device calculations + repeated_tokens_per_expert = repeated_tokens_per_expert.view( + world_size, experts_per_rank_val + ) + parallel_tokens_per_expert = parallel_tokens_per_expert.view( + world_size, experts_per_rank_val + ) + + # Calculate send/recv counts + send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist() + # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist() + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist() + tokens_received = sum(recv_counts) + + # Replicate for hidden sharding + x = ops.repeat(x, (hidden_sharding_deg, 1)) + + # Cross-device token exchange + parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all( + x, recv_counts, send_counts, expert_parallel_group, async_op=True + ) + + with torch.no_grad(): + # Step 4: Setup for local expert computation + replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + replicate_bins = ( + replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins + ) + + # Create expert indices for received tokens + parallel_top_expert = torch.remainder( + torch.arange( + num_experts * hidden_sharding_deg, + dtype=torch.int32, + device=indices.device, + ), + experts_per_rank_val, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, + tokens_received, + ).flatten() + + # Sort tokens by expert assignment + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, + sort_end_bit, + ) + + # Calculate bins for local experts + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, dtype=torch.int + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = ( + parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins + ) + + # Calculate expert capacity + expert_capacity = expert_capacity_fn( + tokens_received, + top_k, + experts_per_rank_val, + expert_parallel_group, + moe_capacity_factor, + moe_expert_model_parallelism, + ) + if expert_capacity == 0: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + if mlp_impl == "grouped": + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU for the prior all_to_all, which avoids an extra + # device synchronization. + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, + dtype=torch.int, + ) + + # Step 5: Expert computation + parallel_x_handle.wait() + + parallel_x = permute_and_compute( + parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=gradient_scale, + alpha=alpha, + ) + + # Step 6: Reverse communication - send results back + x, _ = _layers.all_to_all.all_to_all( + parallel_x, send_counts, recv_counts, expert_parallel_group + ) + + # Step 7: Reduce across hidden sharding dimension + shape = (hidden_sharding_deg, -1, hidden_size) + x = x.view(shape).sum(dim=0) + + # Step 8: Final local unpermutation + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) + + return x, tokens_per_expert.flatten() + + +def moe_forward( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, + mlp_impl: str = "grouped", +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # Route tokens to experts + logits, expert_weights, expert_indices = route_tokens( + x, + router_weight, + moe_top_k, + moe_num_experts, + moe_jitter_eps, + moe_normalize_expert_weights, + uniform_expert_assignment, + training, + ) + + # Create router scores for output + router_scores = ( + torch.zeros_like(logits) + .scatter_(1, expert_indices, expert_weights) + .transpose(0, 1) + ) + + in_shape = x.size() + + # Prepare forward function arguments + forward_args = { + "x": x, + "expert_weights": expert_weights, + "top_experts": expert_indices, + "w1": w1, + "w2": w2, + "w1_bias": w1_bias, + "w2_bias": w2_bias, + "gradient_scale": gradient_scale, + "alpha": alpha, + "sort_end_bit": sort_end_bit, + "top_k": moe_top_k, + "num_experts": moe_num_experts, + "expert_parallel_group": expert_parallel_group, + "moe_capacity_factor": moe_capacity_factor, + "moe_expert_model_parallelism": moe_expert_model_parallelism, + "mlp_impl": mlp_impl, + } + + # Add hidden_size for parallel forward + if moe_expert_model_parallelism and hidden_size is not None: + forward_args["hidden_size"] = hidden_size + elif moe_expert_model_parallelism and hidden_size is None: + # Infer hidden_size from input shape + forward_args["hidden_size"] = x.shape[-1] + + # Compute expert outputs + x, tokens_per_expert = forward_fn(**forward_args) + + # Save load balancing loss if needed + moe_loss_weight = 0.0 # Can be made configurable + if training and moe_loss_weight > 0: + save_load_balancing_loss((tokens_per_expert, logits)) + + # Restore original shape + x = x.view(in_shape) + + return x, expert_weights, router_scores + + +def moe_forward_with_shared_expert( + x: torch.Tensor, + router_weight: torch.Tensor, + moe_top_k: int, + moe_num_experts: int, + moe_jitter_eps: float = None, + moe_normalize_expert_weights: int = None, + uniform_expert_assignment: bool = False, + training: bool = False, + w1: torch.Tensor = None, + w2: torch.Tensor = None, + w1_bias: torch.Tensor = None, + w2_bias: torch.Tensor = None, + gradient_scale: Optional[float] = None, + alpha: float = 1.702, + sort_end_bit: int = 0, + expert_parallel_group: torch.distributed.ProcessGroup = None, + moe_capacity_factor: float = 1.0, + moe_expert_model_parallelism: bool = False, + forward_fn: Any = None, + hidden_size: int = None, + mlp_impl: str = "grouped", + # Shared expert parameters + shared_up_proj_weight: Optional[torch.Tensor] = None, + shared_down_proj_weight: Optional[torch.Tensor] = None, + shared_up_proj_bias: Optional[torch.Tensor] = None, + shared_down_proj_bias: Optional[torch.Tensor] = None, + shared_expert_weighted_sum: bool = False, + shared_activation_fn: Optional[Any] = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + # First, compute regular MoE forward pass + expert_out, expert_weights, router_scores = moe_forward( + x=x, + router_weight=router_weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=training, + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=moe_expert_model_parallelism, + forward_fn=forward_fn, + hidden_size=hidden_size, + mlp_impl=mlp_impl, + ) + + # If shared expert weights provided, compute shared expert output + if shared_up_proj_weight is not None and shared_down_proj_weight is not None: + shared_expert_out = shared_mlp_forward( + x=x, + up_proj_weight=shared_up_proj_weight, + down_proj_weight=shared_down_proj_weight, + up_proj_bias=shared_up_proj_bias, + down_proj_bias=shared_down_proj_bias, + activation_fn=shared_activation_fn, + gradient_scale=gradient_scale, + ) + + # Combine expert outputs + combined_out = combine_expert_shared_outputs( + shared_expert_out=shared_expert_out, + expert_out=expert_out, + shared_expert_weighted_sum=shared_expert_weighted_sum, + moe_top_k=moe_top_k, + ) + + return combined_out, expert_weights, router_scores + + # Return regular MoE output if no shared expert + return expert_out, expert_weights, router_scores + + +def create_shared_expert_weights( + hidden_size: int, + shared_expert_hidden_size: int, + device: torch.device, + dtype: torch.dtype, + init_method: Any, + output_layer_init_method: Any = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + + if output_layer_init_method is None: + output_layer_init_method = init_method + + # Create weight tensors + up_proj_weight = torch.empty( + shared_expert_hidden_size, + hidden_size, + device=device, + dtype=dtype, + ) + down_proj_weight = torch.empty( + hidden_size, + shared_expert_hidden_size, + device=device, + dtype=dtype, + ) + + # Initialize weights + init_method(up_proj_weight) + output_layer_init_method(down_proj_weight) + + # No bias by default + return up_proj_weight, down_proj_weight, None, None + +# HACK: Extract device_mesh from pre-hook closure - required for transformers integration +# This exists because device_mesh is trapped in hook closures with no model attribute +# Fragile - breaks if hook structure changes or Python internals change +# TODO: Replace with a more robust solution when available +def get_device_mesh(model): + # Extract device_mesh from child's unused pre_hook closure + try: + # Find the pre-hook that contains 'device_mesh' in its closure + hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars) + # Extract the device_mesh from the closure + return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents + except Exception: + return None + + +class MegaBlocksMoeMLP(torch.nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + moe_top_k = getattr(self.router, "top_k", 4) + moe_num_experts = getattr(self.experts, "num_experts", 128) + gradient_scale = getattr(self.experts, "gradient_scale", None) + alpha = getattr(self.experts, "alpha", 1.0) + moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) + moe_jitter_eps = getattr(self.experts, "jitter_eps", None) + moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) + uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) + + expert_parallel_group = getattr(self, "expert_parallel_group", None) + if expert_parallel_group is None: + device_mesh = get_device_mesh(self) + expert_parallel_group = device_mesh.get_group() if device_mesh else None + + has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 + forward_fn = parallel_forward_once if has_parallel else forward_once + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + mlp_impl = getattr(self, "mlp_impl", "grouped") + + output, expert_weights_out, *_ = moe_forward( + x=x, + router_weight=self.router.weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=self.training, + w1=self.experts.gate_up_proj, + w2=self.experts.down_proj, + w1_bias=self.experts.gate_up_proj_bias, + w2_bias=self.experts.down_proj_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=has_parallel, + forward_fn=forward_fn, + hidden_size=self.experts.hidden_size, + mlp_impl=mlp_impl, + ) + return output, expert_weights_out + + +class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP): + + def __init__(self): + super().__init__() + # Shared expert weights will be set by the user + self.shared_up_proj_weight = None + self.shared_down_proj_weight = None + self.shared_up_proj_bias = None + self.shared_down_proj_bias = None + self.shared_expert_weighted_sum = False + self.shared_activation_fn = None + + def set_shared_expert_weights( + self, + up_proj_weight: torch.Tensor, + down_proj_weight: torch.Tensor, + up_proj_bias: Optional[torch.Tensor] = None, + down_proj_bias: Optional[torch.Tensor] = None, + weighted_sum: bool = False, + activation_fn: Optional[Any] = None, + ): + self.shared_up_proj_weight = up_proj_weight + self.shared_down_proj_weight = down_proj_weight + self.shared_up_proj_bias = up_proj_bias + self.shared_down_proj_bias = down_proj_bias + self.shared_expert_weighted_sum = weighted_sum + self.shared_activation_fn = activation_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + moe_top_k = getattr(self.router, "top_k", 4) + moe_num_experts = getattr(self.experts, "num_experts", 128) + gradient_scale = getattr(self.experts, "gradient_scale", None) + alpha = getattr(self.experts, "alpha", 1.0) + moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0) + moe_jitter_eps = getattr(self.experts, "jitter_eps", None) + moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None) + uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False) + + expert_parallel_group = getattr(self, "expert_parallel_group", None) + if expert_parallel_group is None: + device_mesh = get_device_mesh(self) + expert_parallel_group = device_mesh.get_group() if device_mesh else None + + has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1 + forward_fn = parallel_forward_once if has_parallel else forward_once + + sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1) + mlp_impl = getattr(self, "mlp_impl", "grouped") + + output, expert_weights_out, *_ = moe_forward_with_shared_expert( + x=x, + router_weight=self.router.weight, + moe_top_k=moe_top_k, + moe_num_experts=moe_num_experts, + moe_jitter_eps=moe_jitter_eps, + moe_normalize_expert_weights=moe_normalize_expert_weights, + uniform_expert_assignment=uniform_expert_assignment, + training=self.training, + w1=self.experts.gate_up_proj, + w2=self.experts.down_proj, + w1_bias=self.experts.gate_up_proj_bias, + w2_bias=self.experts.down_proj_bias, + gradient_scale=gradient_scale, + alpha=alpha, + sort_end_bit=sort_end_bit, + expert_parallel_group=expert_parallel_group, + moe_capacity_factor=moe_capacity_factor, + moe_expert_model_parallelism=has_parallel, + forward_fn=forward_fn, + hidden_size=self.experts.hidden_size, + mlp_impl=mlp_impl, + # Shared expert parameters + shared_up_proj_weight=self.shared_up_proj_weight, + shared_down_proj_weight=self.shared_down_proj_weight, + shared_up_proj_bias=self.shared_up_proj_bias, + shared_down_proj_bias=self.shared_down_proj_bias, + shared_expert_weighted_sum=self.shared_expert_weighted_sum, + shared_activation_fn=self.shared_activation_fn, + ) + return output, expert_weights_out \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b944080df810d0b0cfc571f3009b0098a651f9b7 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from .binned_gather import binned_gather +from .binned_scatter import binned_scatter +from .cumsum import exclusive_cumsum, inclusive_cumsum +from .gather import gather +from .histogram import histogram +from .padded_gather import padded_gather +from .padded_scatter import padded_scatter +from .repeat import repeat +from .replicate import replicate +from .round_up import round_up +from .scatter import scatter +from .sort import sort +from .sum import sum +from .topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..4c939818edca3345f6344bbc7cef07ffe3cd0181 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/all_to_all_benchmark.py @@ -0,0 +1,63 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.distributed as dist + +# from megablocks import benchmark_util +# from megablocks.layers.all_to_all import all_to_all + +from .. import benchmark_util +from .._layers.all_to_all import all_to_all + +_ALL_TO_ALL_BENCHMARK = ( + (8, 1024), + (16, 1024), + (32, 1024), + (64, 1024), + (128, 1024), + (256, 1024), + (512, 1024), + (1024, 1024), + (2 * 1024, 1024), + (4 * 1024, 1024), + (8 * 1024, 1024), + (16 * 1024, 1024), + (32 * 1024, 1024), + (64 * 1024, 1024), + (128 * 1024, 1024), + (256 * 1024, 1024), + (512 * 1024, 1024), + (1024 * 1024, 1024), +) + + +def benchmark_all_to_all(group, sl, hs): + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() + + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } + + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + + time, std = benchmark_util.benchmark_function(benchmark) + + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) + + +if __name__ == '__main__': + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) + torch.cuda.set_device(local_rank) + + for args in _ALL_TO_ALL_BENCHMARK: + benchmark_all_to_all(group, *args) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_gather.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_gather.py new file mode 100755 index 0000000000000000000000000000000000000000..189a7fa3518d660f29ea32e7a04827164af98d60 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_gather.py @@ -0,0 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for binned_gather kernel. +class BinnedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): + ctx.save_for_backward(indices, bins) + ctx.top_k = top_k + return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + indices, bins = ctx.saved_tensors + out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) + return out, None, None, None, None + + +binned_gather = BinnedGatherOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_scatter.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..cb937c0c106662ce8108c1cb926f8f063b163d3d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/binned_scatter.py @@ -0,0 +1,59 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for binned_scatter kernel. +class BinnedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + assert len(x.size()) == 3 + ctx.bin_size = x.size(1) + ctx.top_k = top_k + + # TODO(tgale): Don't save 'x' for backwards if we don't need to + # calculate the gradient w.r.t. 'weights'. + ctx.save_for_backward(x, indices, weights, bins) + return kernels.binned_scatter(x, indices, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + x, indices, weights, bins = ctx.saved_tensors + out = kernels.binned_gather( + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[2]: + wgrad = kernels.binned_scatter_wgrad( + x, + grad, + indices, + bins, + ctx.top_k, + ) + return out, None, wgrad, None, None + + +binned_scatter = BinnedScatterOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/cumsum.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/cumsum.py new file mode 100755 index 0000000000000000000000000000000000000000..e2b7572391e20045d335cf7337246e8a9b9f57ef --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/cumsum.py @@ -0,0 +1,52 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + # import megablocks_ops as ops # type: ignore + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrappers for cumsum kernels. +# NOTE: Does not support gradients. +class ExclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int): + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.exclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.exclusive_cumsum(x, dim, out) + return out + + +exclusive_cumsum = ExclusiveCumsumOp.apply + + +class InclusiveCumsumOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: + if len(x.size()) == 1: + x = x.view([1, -1]) + out = torch.empty_like(x) + ops.inclusive_cumsum(x, 1, out) + return out.squeeze() + out = torch.empty_like(x) + ops.inclusive_cumsum(x, dim, out) + return out + + +inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/gather.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/gather.py new file mode 100755 index 0000000000000000000000000000000000000000..f1f87c1e7bed8d3589dd790805234976e0b05898 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/gather.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for gather kernel. +class GatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins) + ctx.top_k = top_k + return kernels.gather(x, indices, bin_ids, None, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins = ctx.saved_tensors + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) + return out, None, None, None, None, None + + +gather = GatherOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram.py new file mode 100755 index 0000000000000000000000000000000000000000..7b3f058ec373cbba7555704fb5e4212c3cc75d9d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram.py @@ -0,0 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for histogram kernel. +# NOTE: Does not support gradients. +class HistogramOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, max_val: float): + return ops.histogram(x, max_val) + + +histogram = HistogramOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram_benchmark.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..c57b7bf8228e01237236748147368b09ffdf8072 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/histogram_benchmark.py @@ -0,0 +1,78 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from .. import ops + +_HISTOGRAM_TESTS = ( + (16384, torch.int32, 2), + (16384, torch.int32, 4), + (16384, torch.int32, 8), + (16384, torch.int32, 16), + (16384, torch.int32, 32), + (16384, torch.int32, 64), + (16384, torch.int32, 128), + (16384, torch.int32, 256), +) + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class HistogramBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testHistogram(self, n, dtype, max_val): + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_HISTOGRAM_TESTS) + def testTorchHistogram(self, n, dtype, max_val): + x = torch.randint(0, 128, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/matmul_benchmark.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/matmul_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..7ccc5dcec5e9a663794fad944c45285869c4d1c1 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/matmul_benchmark.py @@ -0,0 +1,415 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + + +# import stk + +# try: +# import stk +# except ImportError: +# import warnings +# warnings.warn( +# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.', +# ) + +from .. import stk + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + + +# Calling tensor.t() calls tensor.transpose(0, 1) which calls +# torch.as_strided(...). Circumvent this chain to avoid an overhead +# this adds. +def transpose_view(x): + return torch.as_strided( + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) + + +_MATMUL_TESTS = ( + (64 * 1024, 512, 2048, 64), + (32 * 1024, 768, 3072, 64), + (8 * 1024, 1024, 4096, 64), + (4 * 2048, 4096, 4 * 4096, 4), +) + + +def log_benchmark(name, arguments, time, std, flops): + benchmark_util.log_benchmark(name, arguments, time, std) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) + + +class MatmulBenchmark(parameterized.TestCase): + + def build_sparse_matrix(self, x, padded_bins, fhs, ne): + blocking = 128 + padded_tokens, _ = x.size() + assert padded_tokens % blocking == 0 + assert fhs % blocking == 0 + + # Offsets for the sparse matrix. All rows have the + # same number of nonzero blocks dictated by the + # dimensionality of a single expert. + block_rows = padded_tokens // blocking + blocks_per_row = fhs // blocking + offsets = torch.arange( + 0, + block_rows * blocks_per_row + 1, + blocks_per_row, + dtype=torch.int32, + device=x.device, + ) + + # Indices for the sparse matrix. The indices for + # the intermediate matrix are dynamic depending + # on the mapping of tokens to experts. + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) + data = torch.empty( + column_indices.numel(), + blocking, + blocking, + dtype=torch.float16, + device=x.device, + ) + shape = (padded_tokens, fhs * ne) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) + + def build_input_matrix(self, sl, hs, ne): + x = torch.randn((sl, hs)).cuda().half() + + # Assign tokens to experts uniformly. + top_expert = torch.arange(0, sl).cuda().int() % ne + + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1) + return out, padded_bins + + def build_weight_matrix(self, ne, hs, fhs): + return torch.randn((hs, ne * fhs)).cuda().half() + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(x, w, topo) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(topo, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) + topo = topo.t() + + def benchmark(): + return stk.ops.dsd(topo, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + + def benchmark(): + return stk.ops.dsd(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + w = transpose_view(w) + + def benchmark(): + return stk.ops.sdd(out, w, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): + x, padded_bins = self.build_input_matrix(sl, hs, ne) + w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() + x = self.build_sparse_matrix(x, padded_bins, fhs, ne) + out = stk.ops.dsd(x, w) + x = x.t() + + def benchmark(): + return stk.ops.dsd(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + + w = w.transpose(1, 2).contiguous() + w = w.transpose(1, 2) + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + w = w.transpose(1, 2).contiguous() + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, hs)).cuda().half() + w = torch.randn((ne, hs, fhs)).cuda().half() + out = torch.bmm(x, w) + out = out.transpose(1, 2) + + def benchmark(): + return torch.bmm(out, x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + + def benchmark(): + return torch.bmm(x, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + w = torch.transpose(w, 1, 2) + + def benchmark(): + return torch.bmm(out, w) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + @parameterized.parameters(*_MATMUL_TESTS) + def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): + assert (sl % ne) == 0 + x = torch.randn((ne, sl // ne, fhs)).cuda().half() + w = torch.randn((ne, fhs, hs)).cuda().half() + out = torch.bmm(x, w) + x = torch.transpose(x, 1, 2) + + def benchmark(): + return torch.bmm(x, out) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, + } + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_gather.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_gather.py new file mode 100755 index 0000000000000000000000000000000000000000..c1cf4047c9494394d2a3884ba8830179013db7ff --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_gather.py @@ -0,0 +1,55 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for padded_gather kernel. +class PaddedGatherOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + ctx.save_for_backward(indices, bin_ids, bins, padded_bins) + ctx.top_k = top_k + return kernels.padded_gather( + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + + indices, bin_ids, bins, padded_bins = ctx.saved_tensors + out = kernels.padded_scatter( + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) + return out, None, None, None, None, None + + +padded_gather = PaddedGatherOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..61e021b81497e472cda5d72bdac557a0ca92d262 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter.py @@ -0,0 +1,98 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for padded_scatter kernel. +class PaddedScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward( + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.padded_gather( + grad, + indices, + bin_ids, + weights, + bins, + padded_bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.padded_scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + padded_bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None, None + + +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..c575cfe7487d346ba9ec18bbb7ef17f2eb77ec51 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py @@ -0,0 +1,66 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + +_PADDED_SCATTER_BENCHMARK = ( + # dMoE-Medium, 8-way EMP. + (1024 * 16, 1024, 8, 4), + # dMoE-Medium, post-all-to-all. + (1024 * 16 * 4, 1024, 8, 1), +) + + +class PaddedScatterTest(parameterized.TestCase): + + @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK) + def testPaddedScatter(self, sl, hs, ne, top_k): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + # Sample weights for the scatter reduce. + weights = torch.rand((sl * top_k,)).cuda().half() + + # Gather the data to prepare for backwards. + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) + + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + + time, std = benchmark_util.benchmark_function(benchmark) + benchmark_util.log_benchmark( + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, + time, + std, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/permute_benchmark.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/permute_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..6536eeeae402659a087e5c51ef9840627af56501 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/permute_benchmark.py @@ -0,0 +1,149 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import torch +from absl.testing import parameterized + +from .. import benchmark_util, ops + +_PERMUTE_TESTS = ( + (16384, 768, 2), + (16384, 768, 4), + (16384, 768, 8), + (16384, 768, 16), + (16384, 768, 32), + (16384, 768, 64), + (16384, 768, 128), + (16384 * 8, 768, 2), + (16384 * 8, 768, 4), + (16384 * 8, 768, 8), + (16384 * 8, 768, 16), + (16384 * 8, 768, 32), + (16384 * 8, 768, 64), + (16384 * 8, 768, 128), +) + + +class PermuteBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedGather(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testBinnedScatter(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(indices, ne) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.binned_gather(x, indices, bins, ec) + + def benchmark(): + return ops.binned_scatter(x, indices, bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedGather(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testPaddedScatter(self, sl, hs, ne): + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + + # Randomly assign tokens to experts. + top_expert = torch.randint(0, ne, (sl,)).cuda().int() + bin_ids, indices = ops.sort(top_expert) + tokens_per_expert = ops.histogram(top_expert, ne) + padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) + padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) + + @parameterized.parameters(*_PERMUTE_TESTS) + def testCopy(self, sl, hs, ne): + # NOTE: Capacity factor == 1. + # ec = sl // ne + + # Create the data and indices. + x = torch.randn((sl, hs)).cuda().half() + y = x.clone() + + def benchmark(): + return y.copy_(x) + + mean_t, std_t = benchmark_util.benchmark_function(benchmark) + arguments = { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + } + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/repeat.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/repeat.py new file mode 100755 index 0000000000000000000000000000000000000000..7e9e09de5f857d51cd758ab30b2f3a846d6f9275 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/repeat.py @@ -0,0 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): + return x + return x.repeat(*tiling) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/replicate.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/replicate.py new file mode 100755 index 0000000000000000000000000000000000000000..26daf0eede330603a4b8ea7167faf1411d07ca93 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/replicate.py @@ -0,0 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for replicate kernel. +class ReplicateOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): + ctx.save_for_backward(bins) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) + ops.replicate_forward(x, bins, out) + return out + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor): + bins, = ctx.saved_tensors + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) + ops.replicate_backward(grad, bins, out) + return out, None, None + + +replicate = ReplicateOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/round_up.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/round_up.py new file mode 100755 index 0000000000000000000000000000000000000000..6cf6bc873c9f448c5fa9126ebcfd66e8688002af --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/round_up.py @@ -0,0 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def round_up(x: torch.Tensor, value: int): + assert isinstance(value, int) + assert x.dtype == torch.int32 + + # TODO(tgale): If this becomes and issue + # do this in a custom kernel. We only expect + # to use this on arrays of less than 1k elements. + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/scatter.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/scatter.py new file mode 100755 index 0000000000000000000000000000000000000000..f4605d9b46f387761b070352365f223dbfe69d47 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/scatter.py @@ -0,0 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + +import torch +from .stk_autocast import custom_bwd, custom_fwd + +from ..backend import kernels + + +# Autograd wrapper for scatter kernel. +class ScatterOp(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: + maybe_x = [x] if ctx.needs_input_grad[3] else [] + ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) + ctx.top_k = top_k + ctx.x_shape = x.shape + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) + + @staticmethod + @custom_bwd + def backward(ctx: Any, grad: torch.Tensor): + grad = grad.contiguous() + saved_tensors = ctx.saved_tensors + + indices, bin_ids, weights, bins = saved_tensors[:4] + dgrad = None + if ctx.needs_input_grad[0]: + dgrad = kernels.gather( + grad, + indices, + bin_ids, + weights, + bins, + ctx.top_k, + ) + + wgrad = None + if ctx.needs_input_grad[3]: # need wgrad + x = saved_tensors[-1] + wgrad = kernels.scatter_wgrad( + x, + grad, + indices, + bin_ids, + bins, + ctx.top_k, + ) + return dgrad, None, None, wgrad, None, None, None + + +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: + return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sort.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sort.py new file mode 100755 index 0000000000000000000000000000000000000000..bda3bf64283e39533c2eae3627e76bb2d0262c9f --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sort.py @@ -0,0 +1,38 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + +_BITS_FOR_DTYPE = { + torch.int16: 16, + torch.int32: 32, + torch.int64: 64, +} + + +# Autograd wrapper for sort kernel. +# NOTE: Does not support gradients. +class SortOp(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: + if end_bit is None: + end_bit = _BITS_FOR_DTYPE[x.dtype] + x_out = torch.empty_like(x) + iota_out = torch.empty_like(x) + ops.sort(x, end_bit, x_out, iota_out) + return (x_out, iota_out) + + +sort = SortOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sort_benchmark.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sort_benchmark.py new file mode 100755 index 0000000000000000000000000000000000000000..a92ff957d4c552c6e61d9279a7989795472af7b7 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sort_benchmark.py @@ -0,0 +1,85 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np +import torch +from absl.testing import parameterized + +from .. import ops + +_SORT_TESTS = ( + (16384, torch.int32, None), + (16384, torch.int32, 2), + (16384, torch.int32, 128), +) + +_BASELINE_SORT_TESTS = ((16384,),) + + +def numpy_dtype(dtype): + types = { + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + } + return types[dtype] + + +def benchmark_function(fn, iterations=10): + # Run once to get rid of startup overhead. + fn() + times = [] + for _ in range(iterations): + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + times = np.array(times) + return times.mean(), times.std(), times.max(), times.min() + + +def log_benchmark(arguments, mean_t, std_t): + print('=' * 60) + print('Benchmark Parameters:') + for (key, value) in arguments.items(): + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) + + +class SortBenchmark(parameterized.TestCase): + + @parameterized.parameters(*_SORT_TESTS) + def testSort(self, n, dtype, max_val): + if max_val is None: + max_val = np.iinfo(numpy_dtype(dtype)).max + end_bit = int(np.ceil(np.log2(max_val))) + x = torch.randint(0, max_val, (n,)).cuda().to(dtype) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) + arguments = { + 'n': n, + 'dtype': dtype, + 'max_val': max_val, + } + log_benchmark(arguments, mean_t, std_t) + + @parameterized.parameters(*_BASELINE_SORT_TESTS) + def testTorchSort(self, n): + x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) + + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } + log_benchmark(arguments, mean_t, std_t) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/stk_autocast.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/stk_autocast.py new file mode 100755 index 0000000000000000000000000000000000000000..7a3626e5e0eec51339c95a448bca84be14a2ca93 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/stk_autocast.py @@ -0,0 +1,39 @@ +# vendored from +# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sum.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sum.py new file mode 100755 index 0000000000000000000000000000000000000000..e00c1aa68e1193f5b72f75a2edc37de8d505facc --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/sum.py @@ -0,0 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def sum(x: torch.Tensor, dim: int = 0): + if x.shape[dim] == 1: + return x.squeeze(dim=dim) + return x.sum(dim=dim) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/topology.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/topology.py new file mode 100755 index 0000000000000000000000000000000000000000..76a50d3164db20534b099dcb4d8487a7aef25d15 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/ops/topology.py @@ -0,0 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +# NOTE: Torch needs to be imported before the custom +# extensions. Otherwise libc10.so cannot be found. +import torch + +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + from .._ops import ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + + +# Autograd wrapper for topology kernel. +# NOTE: Does not support gradients. +class TopologyOp(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) + return out + + +topology = TopologyOp.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..73c40c267b1c3f4949e9c957a5d2c9f682dfc1a6 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/__init__.py @@ -0,0 +1,7 @@ +# import stk.random +# import stk.ops +# from stk.matrix import Matrix + +from . import random +from . import ops +from .matrix import Matrix diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/autocast.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/autocast.py new file mode 100755 index 0000000000000000000000000000000000000000..97f6e919a60f3fd579ed0215031008d14111dc96 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/autocast.py @@ -0,0 +1,37 @@ +import functools +import torch + + +def _is_eligible(x): + return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64) + + +def _cast(x, dtype): + if isinstance(x, torch.Tensor) and _is_eligible(x): + return x.to(dtype) + elif isinstance(x, map): + return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()} + elif isinstance(x, list) or isinstance(x, tuple): + return type(x)(map(lambda y: _cast(y, dtype), x)) + return x + + +def custom_fwd(fwd): + """Wrap a custom autograd function that always uses autocast dtype.""" + + @functools.wraps(fwd) + def decorate_fwd(*args, **kwargs): + if torch.is_autocast_enabled(): + with torch.autocast(device_type="cuda", enabled=False): + dtype = torch.get_autocast_gpu_dtype() + return fwd(*_cast(args, dtype), **_cast(kwargs, dtype)) + return fwd(*args, **kwargs) + return decorate_fwd + + +def custom_bwd(bwd): + @functools.wraps(bwd) + def decorate_bwd(*args, **kwargs): + with torch.autocast(device_type="cuda", enabled=False): + return bwd(*args, **kwargs) + return decorate_bwd diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/sputnik.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/sputnik.py new file mode 100755 index 0000000000000000000000000000000000000000..220c947bc1e932e8c77cc30f4069e9930f1aa962 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/sputnik.py @@ -0,0 +1,316 @@ +import torch + +from ..backend import triton_kernels as backend +from ..backend.autocast import custom_bwd, custom_fwd + + +def _standardize_shape(x, transpose): + if transpose: + return torch.Size((x[1], x[0])) + return x + + +def _sparse_transpose(x): + return (torch.Size((x[0][1], x[0][0])), ) + x[1:] + + +def _transpose_helper(x, transpose): + if isinstance(x, torch.Tensor): + return x.t() if transpose else x + if transpose: + x = _sparse_transpose(x) + return x + (transpose,) + + +def _wrap(x): + if isinstance(x, torch.Tensor): + return (x,) + return x + + +def _is_transposed(x): + return (not x.is_contiguous() and + x.stride()[0] == 1 and + x.stride()[1] == x.size()[0]) + + +def _call_helper(op, out, a, b, trans_a, trans_b): + args = (_wrap(_transpose_helper(a, trans_a)) + + _wrap(_transpose_helper(b, trans_b))) + if isinstance(out, tuple): + args = args + out + return op(*args) + + +def _preprocess_inputs(lhs, rhs, dy): + if isinstance(lhs, torch.Tensor) and _is_transposed(lhs): + lhs = lhs.t() + if isinstance(rhs, torch.Tensor) and _is_transposed(rhs): + rhs = rhs.t() + if (isinstance(dy, torch.Tensor) and + not dy.is_contiguous() and + not _is_transposed(dy)): + dy = dy.contiguous() + if isinstance(dy, tuple) and not dy[1].is_contiguous(): + dy = (dy[0], dy[1].contiguous()) + dy[2:] + return lhs, rhs, dy + + +def _postprocess_outputs(x, transpose, grad): + if isinstance(x, torch.Tensor) and transpose: + return grad.t() + return grad + + +def _lhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (rhs, dy) if trans_lhs else (dy, rhs) + trans_a = trans_lhs and trans_rhs + trans_b = trans_lhs or not trans_rhs + out = _call_helper(op, lhs, a, b, trans_a, trans_b) + return _postprocess_outputs(lhs, trans_lhs, out) + + +def _rhs_gradient(op, lhs, rhs, dy, trans_lhs, trans_rhs): + lhs, rhs, dy = _preprocess_inputs(lhs, rhs, dy) + + a, b = (dy, lhs) if trans_rhs else (lhs, dy) + trans_a = not trans_lhs or trans_rhs + trans_b = trans_lhs and trans_rhs + out = _call_helper(op, rhs, a, b, trans_a, trans_b) + return _postprocess_outputs(rhs, trans_rhs, out) + + +class DSD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs): + ctx.save_for_backward(data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + rhs) + ctx.shape = _standardize_shape(shape, transpose_a) + ctx.transpose_a = transpose_a + + out = torch.empty( + (shape[0], rhs.size()[1]), + dtype=rhs.dtype, + device=rhs.device) + + backend.dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = (ctx.shape,) + saved_tensors[:-1] + rhs = saved_tensors[-1] + trans_a = ctx.transpose_a + trans_b = _is_transposed(rhs) + + ddata = None + if ctx.needs_input_grad[1]: + ddata = _lhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[-1]: + op = dds if trans_b else dsd + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return None, ddata, None, None, None, None, None, None, None, drhs + + +dsd = DSD.apply + + +class DDS(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b): + ctx.save_for_backward(lhs, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = _standardize_shape(shape, transpose_b) + ctx.transpose_b = transpose_b + out = torch.empty((lhs.size()[0], shape[1]), + dtype=lhs.dtype, + device=lhs.device) + backend.dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs = saved_tensors[0] + rhs = (ctx.shape,) + saved_tensors[1:] + trans_a = _is_transposed(lhs) + trans_b = ctx.transpose_b + + dlhs = None + if ctx.needs_input_grad[0]: + op = dsd if trans_a else dds + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + ddata = None + if ctx.needs_input_grad[2]: + ddata = _rhs_gradient(sdd, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, None, ddata, None, None, None, None, None, None, None + + +dds = DDS.apply + + +class SDD(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, + lhs, + rhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t): + ctx.save_for_backward( + lhs, + rhs, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t) + ctx.shape = shape + out = torch.empty( + data.shape, + dtype=lhs.dtype, + device=lhs.device) + backend.sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices) + return out + + @staticmethod + @custom_bwd + def backward(ctx, dy): + saved_tensors = ctx.saved_tensors + lhs, rhs = saved_tensors[:2] + dy = (ctx.shape, dy) + saved_tensors[2:] + trans_a = _is_transposed(lhs) + trans_b = _is_transposed(rhs) + + dlhs = None + if ctx.needs_input_grad[0]: + op = dds if trans_a else dsd + dlhs = _lhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + drhs = None + if ctx.needs_input_grad[1]: + op = dsd if trans_b else dds + drhs = _rhs_gradient(op, + lhs, + rhs, + dy, + trans_a, + trans_b) + return dlhs, drhs, None, None, None, None, None, None, None, None + + +sdd = SDD.apply + +class RowIndices(torch.autograd.Function): + + @staticmethod + def forward(ctx, shape, data, offsets, column_indices): + out = torch.empty( + column_indices.shape, + dtype=column_indices.dtype, + device=column_indices.device) + backend.row_indices(shape, data, offsets, column_indices, out) + return out + + +row_indices = RowIndices.apply diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/triton_kernels.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/triton_kernels.py new file mode 100755 index 0000000000000000000000000000000000000000..c535309f3321249f475367164a558f94a4f8eb86 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/backend/triton_kernels.py @@ -0,0 +1,393 @@ +import torch +import triton +import triton.language as tl +from dataclasses import dataclass + +@dataclass +class TritonConfig: + BLOCK_M: int = 128 + BLOCK_N: int = 128 + BLOCK_K: int = 32 + BLOCK_SIZE: int = 128 + NUM_STAGES: int = 4 + NUM_WARPS: int = 4 + +def _validate_matmul_dims(M: int, K: int, N: int): + error_string = "incompatible dimensions: tensor has dim with length: {}, which must be divisible by {}" + assert M % TritonConfig.BLOCK_M == 0, error_string.format(M, TritonConfig.BLOCK_M) + assert K % TritonConfig.BLOCK_K == 0, error_string.format(K, TritonConfig.BLOCK_K) + assert N % TritonConfig.BLOCK_N == 0, error_string.format(N, TritonConfig.BLOCK_N) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _sdd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_m = tl.load(row_indices + pid) + pid_n = tl.load(column_indices + pid) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A) + b = tl.load(B) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + #Store to sparse matrix + acc = acc.to(C.dtype.element_ty) + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + cm = tl.arange(0, BLOCK_M) + cn = tl.arange(0, BLOCK_N) + C = C + pid * BLOCK_ELEMENTS + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dsd_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_m) + end_inx = tl.load(offsets + pid_m + 1) + + # pointers to sparse matrix + rm = tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to dense matrix + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + ak_sub_incr = BLOCK_K * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + bk_block_incr = BLOCK_SIZE * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_A: + ptr_A = A + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + else: + ptr_A = A + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * ak_sub_incr + + ptr_B = B + tl.load(column_indices + start_inx + block_inx) * bk_block_incr + sub_block_inx * bk_sub_incr + + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': TritonConfig.BLOCK_M, + 'BLOCK_N': TritonConfig.BLOCK_N, + 'BLOCK_K': TritonConfig.BLOCK_K, + 'BLOCK_SIZE': TritonConfig.BLOCK_SIZE + }, num_stages=TritonConfig.NUM_STAGES, num_warps=TritonConfig.NUM_WARPS), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def _dds_kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + row_indices, column_indices, offsets, + block_offsets_t, trans_A: tl.constexpr, trans_B: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, GROUP_M: tl.constexpr, ACC_TYPE: tl.constexpr, + ): + + # matrix multiplication + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_M) + + start_inx = tl.load(offsets + pid_n) + end_inx = tl.load(offsets + pid_n + 1) + + # pointers to dense matrix + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rak = tl.arange(0, BLOCK_K) + + A += (rm[:, None] * stride_am + rak[None, :] * stride_ak) + + # pointers to sparse matrix + rn = tl.arange(0, BLOCK_N) + rbk = tl.arange(0, BLOCK_K) + B += (rbk[:, None] * stride_bk + rn[None, :] * stride_bn) + + # do matrix multiplication + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + nsub_blocks = tl.cdiv(BLOCK_SIZE, BLOCK_K) + + BLOCK_ELEMENTS = BLOCK_SIZE * BLOCK_SIZE + + ak_sub_incr = BLOCK_K * stride_ak + ak_block_incr = BLOCK_SIZE * stride_ak + bk_sub_incr = BLOCK_K * stride_bk + + for k in range(nsub_blocks * (end_inx - start_inx)): + sub_block_inx = k % nsub_blocks + block_inx = k // nsub_blocks + + if trans_B: + ptr_B = B + (start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + else: + ptr_B = B + tl.load(block_offsets_t + start_inx + block_inx) * BLOCK_ELEMENTS + sub_block_inx * bk_sub_incr + + ptr_A = A + tl.load(column_indices + start_inx + block_inx) * ak_block_incr + sub_block_inx * ak_sub_incr + a = tl.load(ptr_A) + b = tl.load(ptr_B) + acc += tl.dot(a, b) + + acc = acc.to(C.dtype.element_ty) + cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (cm[:, None] * stride_cm + cn[None, :] * stride_cn) + tl.store(C, acc, mask=True) + +def dsd(shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_a, + rhs, + out + ): + + device = rhs.device + trans_A = transpose_a + trans_B = False + + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if rhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = data.stride(1), data.stride(2) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + a_column_indices = column_indices + a_offsets = offsets + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = data.stride(2), data.stride(1) + a_column_indices, a_offsets = column_indices_t, offsets_t + + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _dsd_kernel[grid]( + data.data, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, a_column_indices, a_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + # return out + +def dds(lhs, + shape, + data, + offsets, + row_indices, + column_indices, + offsets_t, + column_indices_t, + block_offsets_t, + transpose_b, + out + ): + + device = lhs.device + trans_B = transpose_b + trans_A = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + + # checks constraints + assert lhs.shape[1] == shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if lhs.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = data.stride(1), data.stride(2) + b_column_indices = column_indices_t + b_offsets = offsets_t + + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N'])) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = data.stride(2), data.stride(1) + b_column_indices, b_offsets = column_indices, offsets + + _dds_kernel[grid]( + lhs, data, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(0), out.stride(1), + row_indices, b_column_indices, b_offsets, + block_offsets_t, trans_A, trans_B, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +def sdd(lhs, + rhs, + shape, + out, + offsets, + row_indices, + column_indices + ): + + device = out.device + trans_A = False + trans_B = False + + if lhs.stride(0) > 1 and lhs.stride(1) > 1: + trans_A = True + if rhs.stride(0) > 1 and rhs.stride(1) > 1: + trans_B = True + + # checks constraints + assert lhs.shape[1] == rhs.shape[0], "incompatible dimensions" + M, K = lhs.shape + _, N = rhs.shape + + _validate_matmul_dims(M, K, N) + + # accumulator types + ACC_TYPE = tl.float32 if out.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32 + + # launch kernel + nnz_blocks = len(row_indices) + grid = lambda META: (nnz_blocks,) + + stride_am, stride_ak = lhs.stride(0), lhs.stride(1) + stride_bk, stride_bn = rhs.stride(0), rhs.stride(1) + + if trans_A: + stride_am, stride_ak = lhs.stride(1), lhs.stride(0) + if trans_B: + stride_bk, stride_bn = rhs.stride(1), rhs.stride(0) + + _sdd_kernel[grid]( + lhs, rhs, out, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + out.stride(1), out.stride(2), + row_indices, column_indices, + GROUP_M=128, ACC_TYPE=ACC_TYPE + ) + +@triton.jit +def _row_indices_kernel(offsets, out): + pid = tl.program_id(0) + row_offset = tl.load(offsets + pid) + nnz_blocks = tl.load(offsets + pid + 1) - row_offset + for nnz_block in range(nnz_blocks): + tl.store(out + row_offset + nnz_block, pid) + +def row_indices( + shape, data, offsets, column_indices, out +): + block_rows = len(offsets) - 1 + _row_indices_kernel[(block_rows, )](offsets, out) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/matrix.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/matrix.py new file mode 100755 index 0000000000000000000000000000000000000000..80f42263d6aada287adbfa52a61fe950162a9e28 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/matrix.py @@ -0,0 +1,329 @@ +import numpy as np +import torch + +# 1. Add heavyweight (data) validation helper. +# 2. Add construction helpers +# 3. Make indentation consistent +# 4. Replace asserts with descriptive errors. + +## +### Validation helpers. +## + + +def _validate_matrix(shape, data, row_indices, column_indices, offsets): + # Data should be [nnz, block_size, block_size] + if data.dim() == 1: + data = torch.reshape(data, [data.numel(), 1, 1]) + + # Blocks should be square. + if data.shape[-2] != data.shape[-1]: + raise ValueError( + "Expected square blocking in data. " + f"Got block shape {[data.shape[-2], data.shape[-1]]}") + + # Flatten batch dimensions on data - original shape preserved + # in shape argument. + block_size = data.shape[-1] + data = data.view([-1, block_size, block_size]) + + if data.dim() != 3: + raise ValueError( + "Expected 3D shape for data (nnz, block, block). " + f"Got shape {data.dim()}D shape.") + + block_size = data.shape[1] + if shape[-2] % block_size != 0 or shape[-1] % block_size != 0: + raise ValueError( + "Matrix shape must be dividible by blocking. " + f"Got shape {shape} with " + f"{[block_size, block_size]} blocking.") + + if np.prod(shape) < data.numel(): + raise ValueError( + "Invalid matrix. Number of nonzeros exceeds matrix capacity " + f"({data.numel()} v. {np.prod(shape)})") + + if row_indices.dim() != 1: + raise ValueError( + f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.") + + if column_indices.dim() != 1: + raise ValueError( + f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.") + + if offsets.dim() != 1: + raise ValueError( + f"Expected 1D offsets. Got {offsets.dim()}D offsets.") + + if row_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks") + + if column_indices.numel() != data.shape[0]: + raise ValueError( + "Expected 1 index per nonzero block. " + f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks") + + block_rows = np.prod(shape[:-1]) / block_size + if offsets.numel() != block_rows + 1: + raise ValueError( + "Expected one offset per block row plus one. " + f"Got {offsets.numel()} offsets with {block_rows} block rows.") + + is_cuda = (data.is_cuda and + row_indices.is_cuda and + column_indices.is_cuda and + offsets.is_cuda) + is_cpu = (not data.is_cuda and + not row_indices.is_cuda and + not column_indices.is_cuda and + not offsets.is_cuda) + if not (is_cuda or is_cpu): + raise ValueError( + "Expected data & meta-data on common device. " + f"Got data on {data.device}, row_indices on {row_indices.device} " + f"column_indices on {column_indices.device} and " + f"offsets on {offsets.device}.") + + if data.dtype != torch.float16: + raise ValueError( + f"Expected float16 data. Got {data.dtype} data.") + if row_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.") + if column_indices.dtype != torch.int16: + raise ValueError( + f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.") + if offsets.dtype != torch.int32: + raise ValueError( + f"Expected int32 offsets. Got {offsets.dtype} offsets.") + return data + + +def _transpose(size, data, row_indices, column_indices, offsets): + block_columns = size[1] // data.shape[1] + + # Sort row indices by column indices to get the transposed matrix's + # column indices. + gather_indices = column_indices.argsort() + column_indices_t = row_indices.gather(0, gather_indices) + block_offsets_t = gather_indices.int() + + # NOTE: Histogram is not implemented for any integer type on CPU. Do + # the histogram in 32-bit float, which can exactly represent 16-bit + # integers. + column_indices_float = column_indices.float() + + zero = torch.zeros((1,), dtype=torch.int32, device=data.device) + nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns) + nnz_per_column = nnz_per_column.int() + offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)]) + return column_indices_t, offsets_t, block_offsets_t + + +class Matrix(torch.nn.Module): + """A matrix stored in sparse format. + + Underlying format is block compressed sparse row (BCSR). + + TODO(tgale): Make this mirror torch.Tensor API as much as possible. + """ + + def __init__(self, + size, + data, + row_indices, + column_indices, + offsets, + column_indices_t=None, + offsets_t=None, + block_offsets_t=None): + super().__init__() + self._size = size + self._data = data + self._row_indices = row_indices + self._column_indices = column_indices + self._offsets = offsets + + # Produce the transpose meta-data if it is not passed in. + if ((column_indices_t is None) or (offsets_t is None) or + (block_offsets_t is None)): + column_indices_t, offsets_t, block_offsets_t = _transpose( + size, data, row_indices, column_indices, offsets) + self._column_indices_t = column_indices_t + self._offsets_t = offsets_t + self._block_offsets_t = block_offsets_t + + self._transposed = False + + # Validate that our metadata will not overflow. + max_dim = np.iinfo(np.int16).max * self.blocking + if column_indices.dtype == torch.int16: + if size[0] > max_dim or size[1] > max_dim: + raise ValueError( + "Sparse matrix with shape {size} exceeds representable " + "size with 16-bit indices.") + + def validate(self): + _validate_matrix(self._size, + self._data, + self._row_indices, + self._column_indices, + self._offsets) + + # TODO(tgale): Add heavyweight data validation. + + def to(self, device): + # TODO(tgale): Handle type conversions here. We + # need to set the appropriate meta-data type for + # the given floating-point type. + self._data = self._data.to(device) + self._row_indices = self._row_indices.to(device) + self._column_indices = self._column_indices.to(device) + self._offsets = self._offsets.to(device) + self._column_indices_t = self._column_indices_t.to(device) + self._offsets_t = self._offsets_t.to(device) + self._block_offsets_t = self._block_offsets_t.to(device) + return self + + def cuda(self): + return self.to(torch.cuda.current_device()) + + def clone(self): + return Matrix( + self.size(), + self.data.clone(), + self.row_indices.clone(), + self.column_indices.clone(), + self.offsets.clone(), + self.column_indices_t.clone(), + self.offsets_t.clone(), + self.block_offsets_t.clone()) + + def t(self): + if self.dim() != 2: + raise ValueError( + "t() expects a tensor with <= 2 dimensions, " + f"but self is {self.dim()}D.") + out = Matrix(self.size(), + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + out._transposed = not self._transposed + out._size = torch.Size((self._size[1], self._size[0])) + return out + + def contiguous(self): + raise ValueError("Not yet implemented.") + + def is_contiguous(self): + return not self._transposed + + @property + def is_cuda(self): + return self._data.is_cuda + + @property + def device(self): + return self._data.device + + def size(self): + return self._size + + @property + def shape(self): + return self.size() + + def dim(self): + return len(self._size) + + @property + def data(self): + return self._data + + @property + def row_indices(self): + return self._row_indices + + @property + def column_indices(self): + return self._column_indices + + @property + def offsets(self): + return self._offsets + + @property + def offsets_t(self): + return self._offsets_t + + @property + def column_indices_t(self): + return self._column_indices_t + + @property + def block_offsets_t(self): + return self._block_offsets_t + + @property + def dtype(self): + return self.data.dtype + + @property + def nnz(self): + return self.data.numel() + + @property + def blocking(self): + return self.data.shape[1] + + @property + def requires_grad(self): + return self.data.requires_grad + + def requires_grad_(self, x): + self.data.requires_grad_(x) + return self + + def view(self, *shape): + assert self.is_contiguous() + if shape[-1] != self.size()[-1]: + raise ValueError( + "Can't change view on compressed dimension. " + f"{self.size()[-1]} v. {shape[-1]}.") + if np.prod(shape) != np.prod(self.size()): + raise ValueError( + "Mismatch in numel of Matrix and new shape. " + f"{np.prod(self.size())} v. {np.prod(shape)}") + return Matrix(shape, + self.data, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + + @property + def grad(self): + # TODO(tgale): Make sure this mirrors torch.Tensor + # behavior in the case where we ask for the gradient + # of a non-contiguous tensor. + size = self.size() + if not self.is_contiguous(): + size = torch.Size((size[1], size[0])) + out = Matrix(size, + self.data.grad, + self.row_indices, + self.column_indices, + self.offsets, + self.column_indices_t, + self.offsets_t, + self.block_offsets_t) + return out if self.is_contiguous() else out.t() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..fc873b236f4cd4036964c016a4036e3ce5ebf1ac --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/__init__.py @@ -0,0 +1,3 @@ +from .linear_ops import dds, dsd, sdd +from .matrix_ops import ones_like, row_indices, sum, to_dense, to_sparse +from .eltwise_ops import mul diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/eltwise_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/eltwise_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..ba7d7332320250fd01fa60e60528f19de3e8ed03 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/eltwise_ops.py @@ -0,0 +1,28 @@ +from ..matrix import Matrix + +def mul(a, b): + """Performs element-wise multiplication of matrices a and b. + + It is the user's responsibility to make sure that a and b + follow the same matrix topology. This function assumes it is safe + to use the topoplogy of a. + + Args: + a: stk.Matrix. + b: stk.Matrix with a's matrix topology. + + Returns: + stk.Matrix where the entries correspond to torch.mul(a, b). + """ + assert isinstance(a, Matrix) + assert isinstance(b, Matrix) + assert a.size() == b.size() + + return Matrix(a.size(), + a.data * b.data, + a.row_indices, + a.column_indices, + a.offsets, + a.column_indices_t, + a.offsets_t, + a.block_offsets_t) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..66bfd4f6af77042d3c5bdb1fe18d00e457478d46 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/eltwise_ops_test.py @@ -0,0 +1,86 @@ +import unittest +import itertools +import torch +from absl.testing import parameterized + +import stk +from stk.ops.linear_ops_test import allclose, _dense_and_sparse + +_MATRIX_SIZES = ( + (128, 128, 0.0), + (256, 256, 0.5), + (2048, 1024, 0.8), + (512, 128, 0.0), + (128, 512, 0.0), + (1024, 512, 0.0), + (1024, 512, 0.5), + (1024, 512, 0.75), + (512, 1024, 0.0), + (512, 1024, 0.5), + (512, 1024, 0.75), + (1024, 1024, 0.0), + (1024, 1024, 0.5), + (1024, 1024, 0.75), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _DTYPE) + testcases = [(*size, 128, dtype) for + (size, dtype) in testcases] + return testcases + +_ELTWISE_OP_TESTS = _generate_testcases() + +def _dense_and_sparse_like(x, std=0.1): + dense_data = torch.randn_like(x.data, device=x.device) * std + sparse = stk.Matrix(x.size(), + dense_data, + x.row_indices, + x.column_indices, + x.offsets) + dense = stk.ops.to_dense(sparse) + + return (dense.requires_grad_(True), + sparse.requires_grad_(True)) + +@parameterized.parameters(_ELTWISE_OP_TESTS) +class EltwiseOpsTest(parameterized.TestCase): + + def testEltwiseMul(self, m, n, sparsity, blocking, dtype): + + a_dense, a = _dense_and_sparse(m, n, sparsity, blocking, dtype) + b_dense, b = _dense_and_sparse_like(a) + + out = stk.ops.mul(a, b) + expected_out = torch.mul(a_dense, b_dense) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size(), out.size()) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = a_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = b_dense.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size(), grad.size()) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/linear_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/linear_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..9d277c8c07f9e30addc31900a12175c8a1f4d7ad --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/linear_ops.py @@ -0,0 +1,59 @@ +import torch + +from ..backend import sputnik +from ..matrix import Matrix + + +def dsd(a, b): + assert isinstance(a, Matrix) + assert isinstance(b, torch.Tensor) + return sputnik.dsd( + a.size(), + a.data, a.offsets, + a.row_indices, + a.column_indices, + a.offsets_t, + a.column_indices_t, + a.block_offsets_t, + not a.is_contiguous(), + b) + + +def dds(a, b): + assert isinstance(a, torch.Tensor) + assert isinstance(b, Matrix) + return sputnik.dds( + a, + b.size(), + b.data, b.offsets, + b.row_indices, + b.column_indices, + b.offsets_t, + b.column_indices_t, + b.block_offsets_t, + not b.is_contiguous()) + + +def sdd(a, b, topo): + assert isinstance(a, torch.Tensor) + assert isinstance(b, torch.Tensor) + assert isinstance(topo, Matrix) + assert topo.is_contiguous() + out = sputnik.sdd( + a, b, + topo.size(), + topo.data, + topo.offsets, + topo.row_indices, + topo.column_indices, + topo.offsets_t, + topo.column_indices_t, + topo.block_offsets_t) + return Matrix(topo.size(), + out, + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/linear_ops_test.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/linear_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..ced1d782fbc9f9ca16b3449239f1588dc5ff5e00 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/linear_ops_test.py @@ -0,0 +1,216 @@ +import unittest +import itertools +import numpy as np +import torch +from absl.testing import parameterized + +import stk + + +def allclose(x, y, pct=0.25): + mask = torch.isclose(x, y, rtol=5e-2) + pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 + if pct_diff > pct: + print("{:.2f}% of values not close.".format(pct_diff)) + return False + return True + + +# An assortment of problems designed to make sure +# the bindings are operating correctly. +_MATRIX_SIZES = ( + (128, 128, 128, 0.0), + (256, 256, 256, 0.5), + (2048, 1024, 512, 0.8), + (512, 128, 128, 0.0), + (128, 128, 512, 0.0), + (1024, 512, 512, 0.0), + (1024, 512, 512, 0.5), + (1024, 512, 512, 0.75), + (512, 512, 1024, 0.0), + (512, 512, 1024, 0.5), + (512, 512, 1024, 0.75), + (1024, 1024, 1024, 0.0), + (1024, 1024, 1024, 0.5), + (1024, 1024, 1024, 0.75), +) + +_TRANSPOSE = ( + (False, False), + (False, True), + (True, False), + (True, True), +) + +_DTYPE = ( + torch.float16, torch.bfloat16 +) + +def _generate_testcases(): + testcases = itertools.product(_MATRIX_SIZES, _TRANSPOSE, _DTYPE) + testcases = [(*size, *trans, 128, dtype) for + (size, trans, dtype) in testcases] + return testcases + +_LINEAR_OP_TESTS = _generate_testcases() + +def _dense_and_sparse(rows, cols, sparsity, blocking, dtype, std=0.1): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + dense = (torch.randn(rows, cols) * std * mask).type(dtype) + sparse = stk.ops.to_sparse(dense, blocking) + cuda_device = torch.device("cuda") + return (dense.to(cuda_device).requires_grad_(True), + sparse.to(cuda_device).requires_grad_(True)) + + +def _dense(rows, cols, dtype, std=0.1): + cuda_device = torch.device("cuda") + out = (torch.randn(rows, cols) * std).type(dtype) + return out.to(cuda_device).requires_grad_(True) + + +def _dense_2x(rows, cols, dtype): + a = _dense(rows, cols, dtype) + return a, a.detach().requires_grad_(True) + + +def _with_transpose(op, a, b, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b) + + +def _mmm(a, b, topo): + mask = stk.ops.to_dense(stk.ops.ones_like(topo)) + return torch.mm(a, b) * mask + + +def _sparse_out_with_transpose(op, a, b, topo, trans_a, trans_b): + a = a.t() if trans_a else a + b = b.t() if trans_b else b + return op(a, b, topo) + + +def _mask(x, mask): + mask = stk.ops.to_dense(stk.ops.ones_like(mask)) + return x * mask + + +@parameterized.parameters(*_LINEAR_OP_TESTS) +class LinearOpsTest(parameterized.TestCase): + + def testLinearOps_Dsd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a_dense, a = _dense_and_sparse(*a_shape, sparsity, blocking, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dsd, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, a_dense, bcp, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = stk.ops.to_dense(a.grad) + expected_grad = _mask(a_dense.grad, a.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Dds(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b_dense, b = _dense_and_sparse(*b_shape, sparsity, blocking, dtype) + + # Execute the matmul. + out = _with_transpose(stk.ops.dds, a, b, trans_a, trans_b) + expected_out = _with_transpose(torch.mm, acp, b_dense, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + out.sum().backward() + + # Validate the results. + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = stk.ops.to_dense(b.grad) + expected_grad = _mask(b_dense.grad, b.grad) + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + def testLinearOps_Sdd(self, m, k, n, sparsity, trans_a, trans_b, blocking, dtype): + # Construct the operands. + a_shape = (k, m) if trans_a else (m, k) + a, acp = _dense_2x(*a_shape, dtype) + b_shape = (n, k) if trans_b else (k, n) + b, bcp = _dense_2x(*b_shape, dtype) + _, topo = _dense_and_sparse(m, n, sparsity, blocking, dtype) + + # Execute the matmul. + out = _sparse_out_with_transpose(stk.ops.sdd, a, b, topo, trans_a, trans_b) + expected_out = _sparse_out_with_transpose(_mmm, acp, bcp, topo, trans_a, trans_b) + + # Compute the gradients w.r.t. the inputs. + expected_out.sum().backward() + stk.ops.sum(out).backward() + + # Validate the results. + out = stk.ops.to_dense(out) + self.assertEqual(out.dim(), 2) + self.assertEqual(expected_out.size()[0], out.size()[0]) + self.assertEqual(expected_out.size()[1], out.size()[1]) + self.assertTrue(allclose(out, expected_out)) + + # LHS gradient. + grad = a.grad + expected_grad = acp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + + # RHS gradient. + grad = b.grad + expected_grad = bcp.grad + self.assertEqual(grad.dim(), 2) + self.assertEqual(expected_grad.size()[0], grad.size()[0]) + self.assertEqual(expected_grad.size()[1], grad.size()[1]) + self.assertTrue(allclose(grad, expected_grad)) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/matrix_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/matrix_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..447c72dc73439d84f58c917676cc04e64f13e97d --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/matrix_ops.py @@ -0,0 +1,98 @@ +from ..backend import sputnik +from ..matrix import Matrix +import torch +import numpy as np + + +@torch.no_grad() +def row_indices(shape, data, offsets, column_indices): + return sputnik.row_indices(shape, data, offsets, column_indices) + + +# TODO(tgale): Replace this helper with a custom kernel. This operation +# is much simpler to do than how it's currently implemented. +@torch.no_grad() +def _expand_for_blocking(idxs, blocking): + # Duplicate for block column dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, 2]).repeat(1, blocking, 1) + + # Update the column indices. + idxs[:, :, 1] *= blocking + idxs[:, :, 1] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking]) + + # Duplicate for block row dimension. + idxs = torch.reshape(idxs, [idxs.size()[0], 1, blocking, 2]) + idxs = idxs.repeat(1, blocking, 1, 1) + + # Update the row indices. + idxs[:, :, :, 0] *= blocking + idxs[:, :, :, 0] += torch.reshape(torch.arange(blocking, device=idxs.device), [1, blocking, 1]) + idxs = torch.reshape(idxs, [-1, 2]) + return idxs + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_dense(x): + assert isinstance(x, Matrix) + + shape = (np.prod(x.shape[:-1]), x.shape[-1]) + row_idxs = x.row_indices.type(torch.int32) + col_idxs = x.column_indices.type(torch.int32) + indices = _expand_for_blocking(torch.stack([row_idxs, col_idxs], dim=1), x.blocking) + indices = (indices[:, 0] * shape[1] + indices[:, 1]).type(torch.int64) + + out = torch.zeros(shape[0] * shape[1], dtype=x.dtype, device=x.device) + out.scatter_(0, indices, x.data.flatten()) + return out.reshape(x.size()) + + +@torch.no_grad() +def _mask(x, blocking=1): + assert x.dim() == 2 + assert x.size()[0] % blocking == 0 + assert x.size()[1] % blocking == 0 + block_rows = x.size()[0] // blocking + block_cols = x.size()[1] // blocking + x = torch.reshape(x, [block_rows, blocking, block_cols, blocking]) + x = torch.sum(torch.abs(x), dim=(1, 3)) + return x != 0 + + +# TODO(tgale): Add input type checking. +@torch.no_grad() +def to_sparse(x, blocking=1): + m = _mask(x, blocking) + + # TODO(tgale): Set to appropriate type for input matrix. + row_nnzs = torch.sum(m, dim=1).type(torch.int32) + zeros = torch.zeros((1,), dtype=row_nnzs.dtype, device=row_nnzs.device) + offsets = torch.cat([zeros, torch.cumsum(row_nnzs, dim=0)]) + offsets = offsets.type(torch.int32) + + indices = torch.nonzero(m).type(torch.int16) + row_indices = indices[:, 0] + column_indices = indices[:, 1] + + # Nonzero indices in the dense matrix. + nonzero_indices = torch.nonzero(m) + nonzero_indices = _expand_for_blocking(nonzero_indices, blocking) + nonzero_indices = nonzero_indices[:, 0] * x.size()[1] + nonzero_indices[:, 1] + + # Gather the data and construct the sparse matrix. + data = torch.gather(x.flatten(), dim=0, index=nonzero_indices) + data = torch.reshape(data, [-1, blocking, blocking]) + return Matrix(x.size(), data, row_indices, column_indices, offsets) + + +@torch.no_grad() +def ones_like(x): + return Matrix(x.size(), + torch.ones_like(x.data), + x.row_indices, + x.column_indices, x.offsets) + + +def sum(x): + assert isinstance(x, Matrix) + return x.data.sum() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..3af04c0760483e578f93303dc457415948a2a34c --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/ops/matrix_ops_test.py @@ -0,0 +1,62 @@ +import unittest + +from absl.testing import parameterized +import stk +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class MatrixOpsTest(parameterized.TestCase): + + def testMatrixOps_FormatConversion(self, rows, cols, sparsity, blocking): + mask = stk.random.dense_mask(rows, cols, sparsity, blocking) + x = (torch.randn(rows, cols) * mask).type(torch.float16) + + # Convert the matrix to sparse format. + sparse_x = stk.ops.to_sparse(x, blocking) + + # Validate the matrix. + sparse_x.validate() + + # Validate the shape. + self.assertEqual(sparse_x.dim(), 2) + self.assertEqual(sparse_x.size()[0], rows) + self.assertEqual(sparse_x.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(sparse_x.nnz, nnz) + + # Convert back to dense format. + dense_x = stk.ops.to_dense(sparse_x) + + # Validate the shape. + self.assertEqual(dense_x.dim(), 2) + self.assertEqual(dense_x.size()[0], rows) + self.assertEqual(dense_x.size()[1], cols) + + # Validate the sparsity + self.assertEqual(torch.count_nonzero(dense_x).item(), nnz) + + # Validate the output. + self.assertTrue(torch.all(torch.eq(x, dense_x))) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/__init__.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2576d1ca27283f77569a9a620c7c99fa68aaf30e --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/__init__.py @@ -0,0 +1,2 @@ +# from stk.random.random_ops import dense_mask, mask, randn +from .random_ops import dense_mask, mask, randn diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/random_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/random_ops.py new file mode 100755 index 0000000000000000000000000000000000000000..d1b36771e0eb8e7abf46bcb3b136b5fb1d29df93 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/random_ops.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from ..ops import matrix_ops + + +@torch.no_grad() +def dense_mask(rows, cols, sparsity, blocking=1): + assert sparsity >= 0.0 and sparsity <= 1.0 + assert rows % blocking == 0 and cols % blocking == 0 + + block_rows, block_cols = (rows // blocking, cols // blocking) + nnz = round(block_rows * block_cols * (1 - sparsity)) + + out = np.ones(block_rows * block_cols) + mask = np.random.choice(out.size, out.size - nnz, replace=False) + out[mask] = 0.0 + + out = np.tile( + np.reshape(out, [block_rows, 1, block_cols, 1]), + (1, blocking, 1, blocking)) + out = np.reshape(out, [rows, cols]) + return torch.from_numpy(out.astype(np.float32)) + + +@torch.no_grad() +def mask(m, n, sparsity, blocking=1): + out = dense_mask(m, n, sparsity, blocking).type(torch.float16) + return matrix_ops.to_sparse(out, blocking=blocking) + + +@torch.no_grad() +def randn(shape, sparsity, blocking=1): + shape_2d = (np.prod(shape[:-1]), shape[-1]) + out = mask(*shape_2d, sparsity, blocking) + out.data.copy_(torch.randn(*out.data.shape)) + return out.view(*shape) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/random_ops_test.py b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/random_ops_test.py new file mode 100755 index 0000000000000000000000000000000000000000..587b44ec890c861879c6296b8f9028f5d99ab82f --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/megablocks/stk/random/random_ops_test.py @@ -0,0 +1,73 @@ +import unittest + +from absl.testing import parameterized +from . import random +import torch + + +@parameterized.parameters( + (8, 16, 0.0, 1), + (8, 16, 0.5, 1), + (8, 16, .95, 1), + (16, 8, 0.0, 1), + (16, 8, 0.5, 1), + (16, 8, .95, 1), + (8, 16, 0.0, 8), + (8, 16, 0.5, 8), + (8, 16, 1.0, 8), + (16, 8, 0.0, 8), + (16, 8, 0.5, 8), + (16, 8, 1.0, 8), + (128, 256, 0.5, 16), + (256, 128, 0.75, 32), + (512, 512, .875, 128)) +class RandomOpsTest(parameterized.TestCase): + + def testRandomOps_DenseMask(self, rows, cols, sparsity, blocking): + mask = random.dense_mask( + rows, cols, sparsity, blocking) + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual( + torch.count_nonzero(mask).item(), + nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask, 0), + torch.eq(mask, 1)))) + + def testRandomOps_SparseMask(self, rows, cols, sparsity, blocking): + mask = random.mask( + rows, cols, sparsity, blocking) + + # Validate the matrix. + mask.validate() + + # Validate the shape. + self.assertEqual(mask.dim(), 2) + self.assertEqual(mask.size()[0], rows) + self.assertEqual(mask.size()[1], cols) + + # Validate the sparsity. + numblocks = rows // blocking * cols // blocking + nnz = round(numblocks * (1 - sparsity)) * blocking ** 2 + self.assertEqual(mask.nnz, nnz) + + # Check values are zero or one. + self.assertTrue( + torch.all(torch.logical_or( + torch.eq(mask.data, 0), + torch.eq(mask.data, 1)))) + + +if __name__ == '__main__': + unittest.main() diff --git a/flake.lock b/flake.lock index 4c4c82013acdcdf89fea6337081ce317689ed44d..0ae9c9208a0f6916c92ba009ed320cf0b55983c7 100644 --- a/flake.lock +++ b/flake.lock @@ -73,11 +73,11 @@ "nixpkgs": "nixpkgs" }, "locked": { - "lastModified": 1751968576, - "narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=", + "lastModified": 1753867039, + "narHash": "sha256-lROFmwSsxtjNsf7U/mQ/jW4GQbnwUQstu0MJrhPdGhg=", "owner": "huggingface", "repo": "hf-nix", - "rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e", + "rev": "dab75553a43839d0dad876c4e59bbcd8b41acd72", "type": "github" }, "original": { @@ -98,32 +98,33 @@ ] }, "locked": { - "lastModified": 1753256281, - "narHash": "sha256-CfL3Fyf2ih7OtyL7ScZUCwOeCj+gjlRyPykhR6Zbt3I=", + "lastModified": 1753867163, + "narHash": "sha256-x3D5QthCR+buQ7oX9+HRhYt4sqxHzYqLYdK+HHAf0Qc=", "owner": "huggingface", "repo": "kernel-builder", - "rev": "dcbbdf2d3c8e78b27321b205b2c9d67ffce6a706", + "rev": "768759f1423ad96ec1232fb16f936718a6742cc9", "type": "github" }, "original": { "owner": "huggingface", + "ref": "torch-2.8", "repo": "kernel-builder", "type": "github" } }, "nixpkgs": { "locked": { - "lastModified": 1747820358, - "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=", - "owner": "danieldk", + "lastModified": 1752785354, + "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=", + "owner": "nixos", "repo": "nixpkgs", - "rev": "d3c1681180717528068082103bf323147de6ab0b", + "rev": "d38025438a6ee456758dc03188ca6873a415463b", "type": "github" }, "original": { - "owner": "danieldk", - "ref": "cudatoolkit-12.9-kernel-builder", + "owner": "nixos", "repo": "nixpkgs", + "rev": "d38025438a6ee456758dc03188ca6873a415463b", "type": "github" } }, diff --git a/flake.nix b/flake.nix index 8e228480f9c833760932be6e90951fa5570f159c..ee6bd0b89ebcd6fb0b422d4fb39542fdf1d1b70d 100644 --- a/flake.nix +++ b/flake.nix @@ -2,7 +2,7 @@ description = "Flake for megablocks_moe kernel"; inputs = { - kernel-builder.url = "github:huggingface/kernel-builder"; + kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8"; }; outputs = diff --git a/torch-ext/megablocks/_megablocks_g6r5q4zwqmcls.abi3.so b/torch-ext/megablocks/_megablocks_g6r5q4zwqmcls.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..347b646ef98a0402050a3e0cc729b046846d2409 --- /dev/null +++ b/torch-ext/megablocks/_megablocks_g6r5q4zwqmcls.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb79fcc23aae59174b8ff2a9d6433246ddcb44f9006fd9164fa8788190f8606a +size 5197576 diff --git a/torch-ext/megablocks/_megablocks_lboajlxwrdgqk.abi3.so b/torch-ext/megablocks/_megablocks_lboajlxwrdgqk.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5ce82f63f681d14563fa83f4a6b6cb748cf5d480 --- /dev/null +++ b/torch-ext/megablocks/_megablocks_lboajlxwrdgqk.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5f7bf298897957464fff4d0c082f95fb5da5eb1a718a95d1d77082ccb2b891bc +size 5189304 diff --git a/torch-ext/megablocks/_megablocks_twnwbxx53zgza.abi3.so b/torch-ext/megablocks/_megablocks_twnwbxx53zgza.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..40c01fefcba1be4ec9d161f3276075b2c548a342 --- /dev/null +++ b/torch-ext/megablocks/_megablocks_twnwbxx53zgza.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1995b97e8d3124fd380527ae526d518cafb663297c9ee172aa840dd4b54235f +size 5208064