Akos Hadnagy commited on
Commit
dd2b6c2
·
1 Parent(s): 1e1ffe8

Push build

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py +202 -0
  4. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py +10 -0
  5. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py +33 -0
  6. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py +54 -0
  7. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py +101 -0
  8. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py +26 -0
  9. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py +42 -0
  10. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py +337 -0
  11. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py +52 -0
  12. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py +244 -0
  13. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py +103 -0
  14. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py +587 -0
  15. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py +507 -0
  16. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py +94 -0
  17. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py +116 -0
  18. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +32 -0
  19. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so +3 -0
  20. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py +9 -0
  21. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py +2 -0
  22. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py +543 -0
  23. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py +23 -0
  24. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py +35 -0
  25. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
  26. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py +33 -0
  27. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
  28. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py +31 -0
  29. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py +1001 -0
  30. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py +35 -0
  31. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +63 -0
  32. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py +37 -0
  33. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py +59 -0
  34. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py +52 -0
  35. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py +38 -0
  36. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py +27 -0
  37. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py +78 -0
  38. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py +415 -0
  39. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py +55 -0
  40. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py +98 -0
  41. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +66 -0
  42. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py +149 -0
  43. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py +10 -0
  44. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py +36 -0
  45. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py +14 -0
  46. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py +72 -0
  47. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py +38 -0
  48. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py +85 -0
  49. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
  50. build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py +9 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.so filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ **/__pycache__/
2
+ **/*egg-info/
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ #from .grouped_gemm import backend as gg_backend
9
+ #from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
+
20
+ # This section contains the direct kernel exports (not inlcuded in the original code)
21
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Compute exclusive cumulative sum along the specified dimension.
24
+
25
+ Args:
26
+ x: Input tensor
27
+ dim: Dimension along which to compute cumsum
28
+ out: Output tensor (modified in-place)
29
+
30
+ Returns:
31
+ The output tensor
32
+ """
33
+ result = ops.exclusive_cumsum(x, dim)
34
+ out.copy_(result)
35
+ return out
36
+
37
+
38
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Compute inclusive cumulative sum along the specified dimension.
41
+
42
+ Args:
43
+ x: Input tensor
44
+ dim: Dimension along which to compute cumsum
45
+ out: Output tensor (modified in-place)
46
+
47
+ Returns:
48
+ The output tensor
49
+ """
50
+ result = ops.inclusive_cumsum(x, dim)
51
+ out.copy_(result)
52
+ return out
53
+
54
+
55
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
56
+ """
57
+ Compute histogram of input tensor values.
58
+
59
+ Args:
60
+ x: Input tensor
61
+ num_bins: Number of histogram bins
62
+
63
+ Returns:
64
+ Histogram tensor with counts for each bin
65
+ """
66
+ return ops.histogram(x, num_bins)
67
+
68
+
69
+ def indices(
70
+ padded_bins: torch.Tensor,
71
+ block_size: int,
72
+ output_block_rows: int,
73
+ output_block_columns: int,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Construct indices from padded bins for sparse operations.
77
+
78
+ Args:
79
+ padded_bins: Tensor containing bin boundaries
80
+ block_size: Size of each block
81
+ output_block_rows: Number of rows in output blocks
82
+ output_block_columns: Number of columns in output blocks
83
+
84
+ Returns:
85
+ Tensor containing constructed indices
86
+ """
87
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
88
+
89
+
90
+ def replicate_forward(
91
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ """
94
+ Forward pass of replicate operation - replicate values according to bin sizes.
95
+
96
+ Args:
97
+ x: Input tensor with values to replicate
98
+ bins: Tensor containing bin sizes
99
+ out: Output tensor (modified in-place)
100
+
101
+ Returns:
102
+ The output tensor
103
+ """
104
+ return ops.replicate_forward(x, bins, out)
105
+
106
+
107
+ def replicate_backward(
108
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
109
+ ) -> torch.Tensor:
110
+ """
111
+ Backward pass of replicate operation - reduce gradients back to bins.
112
+
113
+ Args:
114
+ grad: Gradient tensor to reduce
115
+ bins: Tensor containing bin sizes
116
+ out: Output tensor (modified in-place)
117
+
118
+ Returns:
119
+ The output tensor
120
+ """
121
+ return ops.replicate_backward(grad, bins, out)
122
+
123
+
124
+ def sort(
125
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
126
+ ) -> torch.Tensor:
127
+ """
128
+ Radix sort with index tracking.
129
+
130
+ Args:
131
+ x: Input tensor to sort
132
+ end_bit: Number of bits to consider in sorting
133
+ x_out: Output tensor for sorted values
134
+ iota_out: Output tensor for sorted indices
135
+
136
+ Returns:
137
+ The sorted values tensor
138
+ """
139
+ return ops.sort(x, end_bit, x_out, iota_out)
140
+
141
+
142
+ # Convenience functions for common use cases
143
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
144
+ """
145
+ Compute cumulative sum with automatic output allocation.
146
+
147
+ Args:
148
+ x: Input tensor
149
+ dim: Dimension along which to compute cumsum (default: last dimension)
150
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
151
+
152
+ Returns:
153
+ New tensor containing the cumulative sum
154
+ """
155
+ out = torch.empty_like(x)
156
+ if exclusive:
157
+ return exclusive_cumsum(x, dim, out)
158
+ else:
159
+ return inclusive_cumsum(x, dim, out)
160
+
161
+
162
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ Sort tensor and return both sorted values and indices.
165
+
166
+ Args:
167
+ x: Input tensor to sort
168
+ end_bit: Number of bits to consider in sorting
169
+
170
+ Returns:
171
+ Tuple of (sorted_values, sorted_indices)
172
+ """
173
+ x_out = torch.empty_like(x)
174
+ iota_out = torch.empty_like(x)
175
+ sort(x, end_bit, x_out, iota_out)
176
+ return x_out, iota_out
177
+
178
+
179
+ # Export public API
180
+ __all__ = [
181
+ "MyReplacementLayer",
182
+ # Direct kernel exports
183
+ "exclusive_cumsum",
184
+ "inclusive_cumsum",
185
+ "histogram",
186
+ "indices",
187
+ "replicate_forward",
188
+ "replicate_backward",
189
+ "sort",
190
+ "cumsum",
191
+ "argsort",
192
+ # Original exports
193
+ "Arguments",
194
+ "ParallelDroplessMLP",
195
+ "dMoE",
196
+ "SparseGLU",
197
+ "MLP",
198
+ "SparseMLP",
199
+ "MoE",
200
+ "ParallelMLP",
201
+ "get_load_balancing_loss",
202
+ ]
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
+
7
+ __all__ = [
8
+ 'MoE',
9
+ # 'dMoE',
10
+ ]
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Callable, Union
5
+
6
+ import torch
7
+ from ..stk import Matrix
8
+
9
+
10
+ def act_fn(
11
+ x: Matrix,
12
+ function: Callable,
13
+ return_grad_fn: bool = False,
14
+ **kwargs,
15
+ ) -> Union[tuple[Matrix, Any] | Matrix]:
16
+ assert isinstance(x, Matrix)
17
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
+ if return_grad_fn:
19
+ x.data.requires_grad = True
20
+ out = function(x.data, **kwargs)
21
+ y = Matrix(
22
+ x.size(),
23
+ out,
24
+ x.row_indices,
25
+ x.column_indices,
26
+ x.offsets,
27
+ x.column_indices_t,
28
+ x.offsets_t,
29
+ x.block_offsets_t,
30
+ )
31
+ if return_grad_fn:
32
+ return y, out.backward
33
+ return y
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class AllToAllOp(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
+
14
+ ctx.input_shape = x.shape
15
+ ctx.output_split_sizes = output_split_sizes
16
+ ctx.input_split_sizes = input_split_sizes
17
+ ctx.group = group
18
+ handle = dist.all_to_all_single(
19
+ out,
20
+ x,
21
+ output_split_sizes=output_split_sizes,
22
+ input_split_sizes=input_split_sizes,
23
+ group=group,
24
+ async_op=async_op,
25
+ )
26
+ return out, handle
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad, _):
30
+ if ctx.needs_input_grad[0]:
31
+ out = torch.empty(
32
+ ctx.input_shape,
33
+ device=grad.device,
34
+ dtype=grad.dtype,
35
+ )
36
+ dist.all_to_all_single(
37
+ out,
38
+ grad,
39
+ output_split_sizes=ctx.input_split_sizes,
40
+ input_split_sizes=ctx.output_split_sizes,
41
+ group=ctx.group,
42
+ )
43
+ return out, None, None, None, None
44
+ return None, None, None, None, None
45
+
46
+
47
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
+ return AllToAllOp.apply(
49
+ x,
50
+ output_split_sizes,
51
+ input_split_sizes,
52
+ group,
53
+ async_op,
54
+ )
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import dataclasses
5
+ from functools import partial
6
+ from typing import Any, Callable, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import torch.nn.functional as F
11
+
12
+ # import megablocks.grouped_gemm_util as grouped_gemm
13
+ from .. import grouped_gemm_util as grouped_gemm
14
+
15
+ # Type annotation for in-place Tensor initialization function.
16
+ InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
17
+
18
+ _ALLOWED_BITWIDTHS = (-1, 4, 8)
19
+
20
+ DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Arguments:
25
+ # Model arguments.
26
+ hidden_size: int = 1024
27
+ ffn_hidden_size: int = 4096
28
+ num_layers: int = 1
29
+ bias: bool = True
30
+ return_bias: bool = True
31
+ activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
32
+
33
+ # MoE arguments.
34
+ moe_num_experts: int = 1
35
+ moe_top_k: int = 1
36
+ moe_capacity_factor: int = 1
37
+ moe_normalize_expert_weights: Optional[Union[int, float]] = None
38
+ moe_loss_weight: float = 0.1
39
+ moe_jitter_eps: Optional[float] = None
40
+ moe_lbl_in_fp32: bool = False
41
+
42
+ # Parallelism arguments.
43
+ moe_expert_model_parallelism: bool = False
44
+ expert_parallel_group: Optional[dist.ProcessGroup] = None
45
+ pipeline_model_parallel_size: int = 1
46
+ num_layers_per_virtual_pipeline_stage: Optional[int] = None
47
+
48
+ # Compute arguments.
49
+ memory_optimized_mlp: bool = False
50
+ mlp_type: str = 'mlp'
51
+ mlp_impl: str = 'sparse'
52
+
53
+ # Initialization arguments.
54
+ fp16: bool = True
55
+ bf16: bool = False
56
+ device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
57
+ init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
58
+ output_layer_init_method: InitFn = init_method
59
+
60
+ # Benchmarking arguments.
61
+ uniform_expert_assignment: bool = False
62
+
63
+ # shared expert arguments
64
+ shared_expert: bool = False # enable using shared expert
65
+ fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
66
+ fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
67
+ remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
68
+ shared_expert_hidden_size: Optional[
69
+ int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
70
+ shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
71
+
72
+ # Router Z-loss arguments
73
+ moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
74
+ moe_zloss_in_fp32: bool = False
75
+
76
+ def __post_init__(self):
77
+ # Sparse MLP is not supported with triton >=3.2.0
78
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
79
+ if self.__getattribute__('mlp_impl') == 'sparse':
80
+ try:
81
+ import triton
82
+ if triton.__version__ >= '3.2.0':
83
+ raise ValueError(
84
+ 'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
85
+ )
86
+ except ImportError:
87
+ raise ImportError('Triton is required for sparse MLP implementation')
88
+
89
+ if self.__getattribute__('mlp_impl') == 'grouped':
90
+ grouped_gemm.assert_grouped_gemm_is_available()
91
+
92
+ if self.shared_expert_hidden_size is None:
93
+ self.shared_expert_hidden_size = self.ffn_hidden_size
94
+
95
+
96
+ def from_megatron(megatron_args: Any):
97
+ args = Arguments()
98
+ for field in dataclasses.fields(args):
99
+ if hasattr(megatron_args, field.name):
100
+ setattr(args, field.name, getattr(megatron_args, field.name))
101
+ return args
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from .arguments import Arguments
7
+
8
+
9
+ def dtype(args: Arguments):
10
+ if args.fp16:
11
+ return torch.float16
12
+ elif args.bf16:
13
+ return torch.bfloat16
14
+ return None
15
+
16
+
17
+ def cast_if_autocast_enabled(tensor):
18
+ if torch.is_autocast_enabled():
19
+ if tensor.device.type == 'cuda':
20
+ dtype = torch.get_autocast_gpu_dtype()
21
+ elif tensor.device.type == 'cpu':
22
+ dtype = torch.get_autocast_cpu_dtype()
23
+ else:
24
+ raise NotImplementedError()
25
+ return tensor.to(dtype=dtype)
26
+ return tensor
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ from . import glu, mlp
7
+ from .arguments import Arguments
8
+
9
+ MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
10
+
11
+ _REGISTRY = {
12
+ 'mlp': {
13
+ 'grouped': mlp.GroupedMLP,
14
+ 'sparse': mlp.SparseMLP,
15
+ },
16
+ 'glu': {
17
+ 'grouped': glu.GroupedGLU,
18
+ 'sparse': glu.SparseGLU,
19
+ },
20
+ }
21
+
22
+
23
+ def get(args: Arguments) -> MlpType:
24
+ """Returns an MLP for use in a dMoE instance.
25
+
26
+ Uses the provided arguments to instantiate the appropriate
27
+ MLP instance. This only contains MLPs for use in dMoEs
28
+ (ie. only for the dropless versions of MoEs).
29
+
30
+ Args:
31
+ args: propagated Arguments dataclass.
32
+
33
+ Returns:
34
+ An instantiated MLP constructed using the input args.
35
+ """
36
+ if args.mlp_type not in _REGISTRY:
37
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
38
+
39
+ if args.mlp_impl not in _REGISTRY[args.mlp_type]:
40
+ raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
41
+
42
+ return _REGISTRY[args.mlp_type][args.mlp_impl](args)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ # try:
8
+ # import stk.ops
9
+ # except ImportError:
10
+ # import warnings
11
+ # warnings.warn(
12
+ # 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
13
+ # )
14
+
15
+ # import megablocks.ops as ops
16
+ # # from megablocks.ops import ops
17
+ # from megablocks.layers import common, dmlp_registry, moe, mpu
18
+ # from megablocks.layers.arguments import Arguments
19
+
20
+ from .. import stk
21
+ from .. import ops
22
+ from . import common, dmlp_registry, moe, mpu
23
+ from .arguments import Arguments
24
+
25
+ def promote_scalar(x):
26
+ return x.view(1) if not len(x.size()) else x
27
+
28
+
29
+ class ParallelDroplessMLP(moe.ParallelMLP):
30
+
31
+ def __init__(self, args: Arguments):
32
+ super(ParallelDroplessMLP, self).__init__(args)
33
+ self.hidden_size = args.hidden_size
34
+ self.ffn_hidden_size = mpu.features_per_rank(args)
35
+ self.blocking = 128
36
+ self.mlp = dmlp_registry.get(args)
37
+
38
+ # Calculate the number of bits needed to represent the column indices
39
+ # in the intermediate sparse matrix.
40
+ max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
41
+ self.transpose_sort_end_bit = max(
42
+ int(np.ceil(np.log2(max_column_index))),
43
+ 1,
44
+ )
45
+
46
+ def sparse_transpose(self, size, row_indices, column_indices, offsets):
47
+ block_columns = size[1] // self.blocking
48
+
49
+ # Sort row indices by column indices to get the transposed matrix's
50
+ # column indices.
51
+ #
52
+ # NOTE: Our sort operation uses the same width indices as the input values.
53
+ # To avoid overflow when we have large activation matrices we cast to
54
+ # 32-bit before sorting.
55
+ _, gather_indices = ops.sort(
56
+ column_indices.int(),
57
+ self.transpose_sort_end_bit,
58
+ )
59
+
60
+ # There are a constant number of blocks in every row of the sparse matrix.
61
+ # A blocks offset is:
62
+ #
63
+ # row_index * blocks_per_row + column_index % blocks_per_row
64
+ #
65
+ # Once we have the block offsets ordered for transposition we can divide
66
+ # by blocks_per_row to get the transposed column indices.
67
+ column_indices_t = row_indices.gather(0, gather_indices.long())
68
+ block_offsets_t = gather_indices.int()
69
+
70
+ zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
71
+ nnz_per_column = ops.histogram(column_indices, block_columns)
72
+ nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
73
+ if nnz_per_column.dim() == 0:
74
+ # This addresses an edge case when ffn_hidden_size is equal to self.blocking.
75
+ nnz_per_column = nnz_per_column.unsqueeze(0)
76
+ offsets_t = torch.cat([zero, nnz_per_column])
77
+ return column_indices_t, offsets_t, block_offsets_t
78
+
79
+ def topology(self, x, padded_bins):
80
+ padded_tokens, _ = x.size()
81
+ assert padded_tokens % self.blocking == 0
82
+ if self.ffn_hidden_size % self.blocking != 0:
83
+ raise ValueError(
84
+ f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
85
+ f'the block size {self.blocking}. Please update your configuration.',
86
+ )
87
+
88
+ # Offsets for the sparse matrix. All rows have the
89
+ # same number of nonzero blocks dictated by the
90
+ # dimensionality of a single expert.
91
+ block_rows = padded_tokens // self.blocking
92
+ blocks_per_row = self.ffn_hidden_size // self.blocking
93
+ offsets = torch.arange(
94
+ 0,
95
+ block_rows * blocks_per_row + 1,
96
+ blocks_per_row,
97
+ dtype=torch.int32,
98
+ device=x.device,
99
+ )
100
+
101
+ # Indices for the sparse matrix. The indices for
102
+ # the intermediate matrix are dynamic depending
103
+ # on the mapping of tokens to experts.
104
+ column_indices = ops.topology(
105
+ padded_bins,
106
+ self.blocking,
107
+ block_rows,
108
+ blocks_per_row,
109
+ )
110
+
111
+ # TODO(tgale): This is unused. Remove the need for this in stk.
112
+ # For now, use meta init to save the device memory.
113
+ data = torch.empty(
114
+ column_indices.numel(),
115
+ self.blocking,
116
+ self.blocking,
117
+ dtype=common.dtype(self.args),
118
+ device='meta',
119
+ )
120
+ shape = (
121
+ padded_tokens,
122
+ self.ffn_hidden_size * mpu.experts_per_rank(self.args),
123
+ )
124
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
125
+ column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
126
+ shape,
127
+ row_indices,
128
+ column_indices,
129
+ offsets,
130
+ )
131
+ return stk.Matrix(
132
+ shape,
133
+ data,
134
+ row_indices,
135
+ column_indices,
136
+ offsets,
137
+ column_indices_t,
138
+ offsets_t,
139
+ block_offsets_t,
140
+ )
141
+
142
+ def indices_and_padded_bins(self, top_experts):
143
+ # Sort the expert ids to produce the scatter/gather
144
+ # indices for the permutation.
145
+ top_experts = top_experts.int()
146
+ bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
147
+
148
+ # Histogram the expert ids to identify the number of
149
+ # tokens routed to each expert.
150
+ tokens_per_expert = ops.histogram(top_experts, self.num_experts)
151
+
152
+ # Round the token counts up to the block size used in
153
+ # the matrix muliplications. Caculate the starting
154
+ # position of each bin.
155
+ padded_tokens_per_expert = ops.round_up(
156
+ tokens_per_expert,
157
+ self.blocking,
158
+ )
159
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
160
+ padded_bins = promote_scalar(padded_bins)
161
+
162
+ # Calculate the bin bounds for the sorted tokens.
163
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
164
+ bins = promote_scalar(bins)
165
+ return indices, bin_ids, bins, padded_bins, tokens_per_expert
166
+
167
+ def sparse_forward_once(self, x, expert_weights, top_experts):
168
+ # x: [sl, bs, hs]
169
+ # expert_weights: [sl * bs, top-k]
170
+ # top_experts: [sl * bs, top-k]
171
+ expert_weights = expert_weights.flatten()
172
+ top_experts = top_experts.flatten()
173
+ with torch.no_grad():
174
+ indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
175
+
176
+ # Route the tokens for MoE computation.
177
+ x = x.view(-1, x.shape[-1])
178
+ x = ops.padded_gather(
179
+ x,
180
+ indices,
181
+ bin_ids,
182
+ bins,
183
+ padded_bins,
184
+ self.top_k,
185
+ )
186
+
187
+ # Create the sparse matrix topology.
188
+ with torch.no_grad():
189
+ topo = self.topology(x, padded_bins)
190
+
191
+ # Perform the expert computation.
192
+ x = self.mlp(x, topo)
193
+
194
+ # Un-route the data for the MoE output.
195
+ x = ops.padded_scatter(
196
+ x,
197
+ indices,
198
+ bin_ids,
199
+ expert_weights,
200
+ bins,
201
+ padded_bins,
202
+ self.top_k,
203
+ )
204
+ return x, tokens_per_expert
205
+
206
+ # For use in the base-class parallel_forward_once.
207
+ def sparse_permute_and_compute(
208
+ self,
209
+ x,
210
+ tokens_per_expert,
211
+ indices,
212
+ bin_ids,
213
+ expert_weights,
214
+ bins,
215
+ expert_capactiy, # unused
216
+ top_k,
217
+ ):
218
+
219
+ # Round the token counts up to the block size used in the matrix
220
+ # multiplication. Calculate the starting position of each bin.
221
+ padded_tokens_per_expert = ops.round_up(
222
+ tokens_per_expert,
223
+ self.blocking,
224
+ )
225
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
226
+ padded_bins = promote_scalar(padded_bins)
227
+
228
+ # Route the tokens for MoE computation.
229
+ x = x.view(-1, x.shape[-1])
230
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
231
+
232
+ # Create the sparse matrix topology.
233
+ with torch.no_grad():
234
+ topo = self.topology(x, padded_bins)
235
+
236
+ # Perform the expert computation.
237
+ x = self.mlp(x, topo)
238
+
239
+ # Un-route the data for the MoE output.
240
+ return ops.padded_scatter(
241
+ x,
242
+ indices,
243
+ bin_ids,
244
+ expert_weights,
245
+ bins,
246
+ padded_bins,
247
+ top_k,
248
+ )
249
+
250
+ def grouped_forward_once(self, x, expert_weights, top_experts):
251
+ # x: [sl, bs, hs]
252
+ # expert_weights: [sl * bs, top-k]
253
+ # top_experts: [sl * bs, top-k]
254
+ expert_weights = expert_weights.flatten()
255
+ top_experts = top_experts.flatten()
256
+ with torch.no_grad():
257
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
258
+
259
+ out = self.grouped_permute_and_compute(
260
+ x,
261
+ tokens_per_expert,
262
+ indices,
263
+ bin_ids,
264
+ expert_weights,
265
+ bins,
266
+ -1, # unused
267
+ self.args.moe_top_k,
268
+ )
269
+ return out, tokens_per_expert
270
+
271
+ def grouped_permute_and_compute(
272
+ self,
273
+ x,
274
+ tokens_per_expert,
275
+ indices,
276
+ bin_ids,
277
+ expert_weights,
278
+ bins,
279
+ expert_capactiy, # unused
280
+ top_k,
281
+ ):
282
+
283
+ # Route the tokens for MoE computation.
284
+ x = x.view(-1, x.shape[-1])
285
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
286
+
287
+ # Perform the expert computation.
288
+ x = self.mlp(x, tokens_per_expert)
289
+
290
+ # Un-route the data for the MoE output.
291
+ return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
292
+
293
+ def forward_once(self, x, expert_weights, top_experts):
294
+ if self.args.mlp_impl == 'sparse':
295
+ return self.sparse_forward_once(x, expert_weights, top_experts)
296
+ else:
297
+ return self.grouped_forward_once(x, expert_weights, top_experts)
298
+
299
+ def permute_and_compute(
300
+ self,
301
+ x,
302
+ tokens_per_expert,
303
+ indices,
304
+ bin_ids,
305
+ expert_weights,
306
+ bins,
307
+ expert_capactiy,
308
+ top_k,
309
+ ):
310
+ if self.args.mlp_impl == 'sparse':
311
+ return self.sparse_permute_and_compute(
312
+ x,
313
+ tokens_per_expert,
314
+ indices,
315
+ bin_ids,
316
+ expert_weights,
317
+ bins,
318
+ expert_capactiy,
319
+ top_k,
320
+ )
321
+ else:
322
+ return self.grouped_permute_and_compute(
323
+ x,
324
+ tokens_per_expert,
325
+ indices,
326
+ bin_ids,
327
+ expert_weights,
328
+ bins,
329
+ expert_capactiy,
330
+ top_k,
331
+ )
332
+
333
+
334
+ class dMoE(moe.MoE):
335
+
336
+ def _init_experts_mlp(self, args: Arguments):
337
+ return ParallelDroplessMLP(args)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # try:
5
+ # import stk
6
+ # except ImportError:
7
+ # import warnings
8
+ # warnings.warn(
9
+ # 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
10
+ # )
11
+
12
+ from .. import stk
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+
18
+ @torch.jit.script
19
+ def _gelu_backward_inplace(g, x):
20
+ tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
21
+ ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
22
+ return g.mul_(ff)
23
+
24
+
25
+ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
26
+ # NOTE: The two sparse matrices must have the same topology.
27
+ if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
28
+ return stk.Matrix(
29
+ x.size(),
30
+ _gelu_backward_inplace(grad.data, x.data),
31
+ x.row_indices,
32
+ x.column_indices,
33
+ x.offsets,
34
+ x.column_indices_t,
35
+ x.offsets_t,
36
+ x.block_offsets_t,
37
+ )
38
+ return _gelu_backward_inplace(grad, x)
39
+
40
+
41
+ def gelu(x: stk.Matrix):
42
+ assert isinstance(x, stk.Matrix)
43
+ return stk.Matrix(
44
+ x.size(),
45
+ F.gelu(x.data, approximate='tanh'),
46
+ x.row_indices,
47
+ x.column_indices,
48
+ x.offsets,
49
+ x.column_indices_t,
50
+ x.offsets_t,
51
+ x.block_offsets_t,
52
+ )
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # import stk.ops
5
+ # try:
6
+ # import stk.ops
7
+ # except ImportError:
8
+ # import warnings
9
+ # warnings.warn(
10
+ # 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
11
+ # )
12
+
13
+ from .. import stk
14
+
15
+ import torch
16
+
17
+ # from megablocks import grouped_gemm_util as gg
18
+ # from megablocks.layers import common, mpu
19
+ # from megablocks.layers.activation_fn import act_fn
20
+ # from megablocks.layers.arguments import Arguments
21
+ # from megablocks.layers.mlp import (
22
+ # SharedMLP,
23
+ # SparseMLP,
24
+ # create_dmoe_expert_weights,
25
+ # resolve_dtensor,
26
+ # )
27
+
28
+ from .. import grouped_gemm_util as gg
29
+ from . import common, mpu
30
+ from .activation_fn import act_fn
31
+ from .arguments import Arguments
32
+ from .mlp import (
33
+ SharedMLP,
34
+ SparseMLP,
35
+ create_dmoe_expert_weights,
36
+ resolve_dtensor,
37
+ )
38
+
39
+
40
+ class SparseGLU(SparseMLP):
41
+
42
+ def __init__(self, args: Arguments):
43
+ super().__init__(args)
44
+ self.v1 = torch.nn.Parameter(
45
+ torch.empty(
46
+ self._num_rows_per_rank,
47
+ args.hidden_size,
48
+ device=args.device,
49
+ dtype=common.dtype(args),
50
+ ),
51
+ )
52
+ with torch.no_grad():
53
+ self.v1.copy_(
54
+ create_dmoe_expert_weights(
55
+ args,
56
+ args.moe_num_experts,
57
+ args.ffn_hidden_size,
58
+ args.hidden_size,
59
+ args.init_method,
60
+ ),
61
+ )
62
+
63
+ mpu.set_expert_model_parallel_attributes(
64
+ self.v1,
65
+ self._should_set_parallelism_attribute,
66
+ )
67
+
68
+ def forward(self, x, topo):
69
+ if self.args.memory_optimized_mlp:
70
+ raise NotImplementedError(
71
+ 'Memory optimized implementation not yet supported with GLU with sparse kernels.',
72
+ )
73
+
74
+ w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
75
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
76
+
77
+ # Compute the GLU.
78
+ x1 = stk.ops.sdd(x, w1.t(), topo)
79
+ x2 = stk.ops.sdd(x, v1.t(), topo)
80
+
81
+ activation_fn_out = act_fn(x1, self.args.activation_fn)
82
+ x1 = stk.ops.mul(activation_fn_out, x2)
83
+
84
+ return stk.ops.dsd(x1, w2)
85
+
86
+
87
+ class MemoryOptimizedGroupedGLU(torch.autograd.Function):
88
+ """GroupedMLP with manually scheduled memory reuse."""
89
+
90
+ @staticmethod
91
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
92
+ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
93
+ # Cast inputs using ctx dtype from AMP
94
+ if ctx._fwd_used_autocast:
95
+ x = x.to(ctx._dtype)
96
+ w1 = w1.to(ctx._dtype)
97
+ v1 = v1.to(ctx._dtype)
98
+ w2 = w2.to(ctx._dtype)
99
+ # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
100
+ if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
101
+ raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
102
+
103
+ # Layer 0: x @ w1.t().
104
+ assert gg.backend is not None
105
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
106
+ v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
107
+
108
+ # GeLU.
109
+ activation_fn_out = activation_fn(sdd_out) * v1_out
110
+
111
+ # Layer 1: x @ w2.
112
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
113
+
114
+ # NOTE: Save the input to the layer and the activation_fn input for
115
+ # gradient computation. We'll re-compute the activation_fn forward
116
+ # pass in the backward pass to avoid materializing another
117
+ # intermediate.
118
+ ctx.x_shape = x.shape
119
+ ctx.sdd_out_shape = sdd_out.shape
120
+ ctx.dtype = x.dtype
121
+ ctx.activation_fn = activation_fn
122
+ ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
123
+ return dsd_out
124
+
125
+ @staticmethod
126
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
127
+ def backward(ctx, ddsd_out):
128
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
129
+ raise ValueError('Expected all MLP inputs to need grad.')
130
+
131
+ # Unpack saved tensors
132
+ # dtype = ctx.dtype
133
+ saved_tensors = ctx.saved_tensors
134
+ w1, v1, w2 = saved_tensors[:3]
135
+ batch_sizes = saved_tensors[3]
136
+ x = saved_tensors[4]
137
+ sdd_out, v1_out = saved_tensors[5:7]
138
+
139
+ # Rematerialize activation_fn output.
140
+ activation_fn = ctx.activation_fn
141
+ with torch.set_grad_enabled(True):
142
+ sdd_out.requires_grad = True
143
+ v1_out.requires_grad = True
144
+ activation_fn_out = activation_fn(sdd_out) * v1_out
145
+ activation_grad_fn = activation_fn_out.backward
146
+
147
+ # Compute dw2 with recomputed activation_fn output.
148
+ assert gg.backend is not None
149
+ dw2 = gg.backend.gmm(
150
+ activation_fn_out,
151
+ ddsd_out,
152
+ batch_sizes,
153
+ trans_a=True,
154
+ )
155
+
156
+ # Compute dactivation_fn_out.
157
+ #
158
+ # NOTE: We reuse the activation_fn_out allocation.
159
+ dactivation_fn_out = activation_fn_out
160
+ gg.backend.gmm(
161
+ ddsd_out,
162
+ w2,
163
+ batch_sizes,
164
+ trans_b=True,
165
+ c=dactivation_fn_out,
166
+ )
167
+
168
+ # Compute dsdd_out.
169
+ #
170
+ # NOTE: This reuses the dactivation_fn_out allocation.
171
+ assert activation_grad_fn is not None
172
+ activation_grad_fn(dactivation_fn_out)
173
+ dsdd_out = sdd_out.grad
174
+ dv1_out = v1_out.grad
175
+
176
+ # Compute dw1.
177
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
178
+
179
+ # Compute dv1.
180
+ dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
181
+
182
+ # Compute dx.
183
+ #
184
+ # NOTE: This reuses the ddsd_out allocation.
185
+ dx = ddsd_out
186
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
187
+ dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
188
+ return dx, dw1, dv1, dw2, None, None
189
+
190
+
191
+ memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
192
+
193
+
194
+ class GroupedGLU(SparseGLU):
195
+
196
+ def forward(self, x, tokens_per_expert):
197
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
198
+ w1, v1, w2 = (
199
+ self.scale_grad(self.w1),
200
+ self.scale_grad(self.v1),
201
+ self.scale_grad(self.w2),
202
+ )
203
+ w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
204
+
205
+ # Re-shape the weights for the grouped GEMMs.
206
+ ne = mpu.experts_per_rank(self.args)
207
+ w1 = w1.view(ne, -1, self.args.hidden_size)
208
+ v1 = v1.view(ne, -1, self.args.hidden_size)
209
+ w2 = w2.view(ne, -1, self.args.hidden_size)
210
+
211
+ if self.args.memory_optimized_mlp:
212
+ return memory_optimized_grouped_glu(
213
+ x,
214
+ w1,
215
+ v1,
216
+ w2,
217
+ batch_sizes,
218
+ self.args.activation_fn,
219
+ )
220
+
221
+ # Compute the MLP.
222
+ assert gg.ops is not None
223
+ x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
224
+ x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
225
+ x1 = self.args.activation_fn(x1) * x2
226
+ return gg.ops.gmm(x1, w2, batch_sizes)
227
+
228
+
229
+ class SharedGLU(SharedMLP):
230
+ """GPU for shared expert.
231
+
232
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
233
+ """
234
+
235
+ def __init__(self, args: Arguments):
236
+ super().__init__(args)
237
+ self.gate_proj = args.fc_cls(
238
+ args.hidden_size,
239
+ self.args.shared_expert_hidden_size,
240
+ **self.fc_kwargs,
241
+ )
242
+
243
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
244
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # from megablocks.layers import arguments, dmoe
10
+ from . import arguments, dmoe
11
+
12
+ _TESTS = ((8, 2048, 4096, 4096, 32, 4),)
13
+
14
+
15
+ def get_tensors():
16
+ ptrs = set()
17
+ out = []
18
+ for obj in gc.get_objects():
19
+ if torch.is_tensor(obj):
20
+ if not obj.is_contiguous() or obj.data_ptr() in ptrs:
21
+ continue
22
+ out.append(obj)
23
+ ptrs.add(obj.data_ptr())
24
+ return out
25
+
26
+
27
+ def test_memory(
28
+ group,
29
+ batch_size,
30
+ sequence_length,
31
+ hidden_size,
32
+ ffn_hidden_size,
33
+ num_experts,
34
+ top_k,
35
+ ):
36
+ args = arguments.Arguments(
37
+ hidden_size=hidden_size,
38
+ ffn_hidden_size=ffn_hidden_size,
39
+ moe_num_experts=num_experts,
40
+ moe_top_k=top_k,
41
+ moe_expert_model_parallelism=True,
42
+ expert_parallel_group=group,
43
+ fp16=False,
44
+ bf16=True,
45
+ device=torch.cuda.current_device(),
46
+ )
47
+ layer = dmoe.dMoE(args).cuda()
48
+
49
+ x = torch.randn((batch_size, sequence_length, hidden_size),
50
+ device=torch.cuda.current_device(),
51
+ dtype=torch.bfloat16).requires_grad_(True)
52
+ torch.cuda.empty_cache()
53
+
54
+ # Run forward + backward.
55
+ # with torch.autograd.detect_anomaly():
56
+ out, _ = layer(x)
57
+ out.mean().backward()
58
+
59
+ # Report peak memory.
60
+ mem = torch.cuda.max_memory_allocated()
61
+ print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
62
+ print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
63
+
64
+ # Calculate weight and gradient memory usage.
65
+ weight_memory = 2 * (
66
+ layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
67
+ )
68
+
69
+ def grad_numel(x):
70
+ if x.grad is not None:
71
+ return x.grad.numel()
72
+ return 0
73
+
74
+ grad_memory = 2 * (
75
+ grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
76
+ )
77
+ weight_memory += grad_memory
78
+
79
+ print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
80
+ print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
81
+
82
+ # Manually calculate GPU memory usage from the garbage
83
+ # collector.
84
+ gc.collect()
85
+ total = 0
86
+ tensors = get_tensors()
87
+ tensors = sorted(tensors, key=lambda x: -x.numel())
88
+ for i, t in enumerate(tensors):
89
+ total += t.numel()
90
+ print(f'{i}: {t.shape}, {t.numel() * 2}')
91
+ del tensors
92
+
93
+ print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
94
+
95
+
96
+ if __name__ == '__main__':
97
+ assert dist.is_available()
98
+ group = dist.init_process_group(backend='nccl')
99
+ local_rank = dist.get_rank(group)
100
+ torch.cuda.set_device(local_rank)
101
+
102
+ for args in _TESTS:
103
+ test_memory(group, *args)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # try:
7
+ # import stk
8
+ # import stk.backend.triton_kernels
9
+ # import stk.ops
10
+ # except ImportError:
11
+ # import warnings
12
+ # warnings.warn(
13
+ # 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
14
+ # )
15
+
16
+ from .. import stk
17
+
18
+ import torch
19
+ from packaging import version
20
+
21
+ # from megablocks import grouped_gemm_util as gg
22
+ # from megablocks.layers import common, gelu, mpu
23
+ # from megablocks.layers.activation_fn import act_fn
24
+ # from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
25
+
26
+ from .. import grouped_gemm_util as gg
27
+ from . import common, gelu, mpu
28
+ from .activation_fn import act_fn
29
+ from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
30
+
31
+ class ScaleGradient(torch.autograd.Function):
32
+
33
+ @staticmethod
34
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
35
+ def forward(ctx: Any, x: torch.Tensor, scale: float):
36
+ ctx.scale = scale
37
+ return x
38
+
39
+ @staticmethod
40
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
41
+ def backward(ctx: torch.Tensor, grad: torch.Tensor):
42
+ return grad * ctx.scale, None
43
+
44
+
45
+ scale_gradient = ScaleGradient.apply
46
+
47
+
48
+ def resolve_dtensor(weight: torch.Tensor):
49
+ if version.parse(torch.__version__) >= version.parse('2.0.0'):
50
+ from torch.distributed._tensor import DTensor
51
+ if isinstance(weight, DTensor):
52
+ return weight.to_local()
53
+ return weight
54
+
55
+
56
+ def create_moe_expert_weights(
57
+ args: Arguments,
58
+ num_experts: int,
59
+ ffn_hidden_size: int,
60
+ hidden_size: int,
61
+ init_method: InitFn,
62
+ ):
63
+ # Create the entire weight matrix such that the sampled weights will
64
+ # not vary between data parallelism and expert model parallelism for
65
+ # the same random seed.
66
+ master_weights = torch.empty(
67
+ num_experts,
68
+ ffn_hidden_size,
69
+ hidden_size,
70
+ device=args.device,
71
+ dtype=common.dtype(args),
72
+ )
73
+ init_method(master_weights)
74
+
75
+ if not args.moe_expert_model_parallelism:
76
+ return master_weights
77
+
78
+ # Calculate the amount of sharding in each dimension.
79
+ expert_sharding_degree = mpu.expert_sharding_degree(args)
80
+ hidden_sharding_degree = mpu.hidden_sharding_degree(args)
81
+
82
+ # Calculate the experts per rank.
83
+ #
84
+ # NOTE: We assign ranks to be expert parallel before going
85
+ # tensor parallel.
86
+ rank = mpu.get_expert_parallel_rank(args)
87
+ expert_rank = rank % expert_sharding_degree
88
+ num_experts_per_rank = num_experts // expert_sharding_degree
89
+ start_expert = expert_rank * num_experts_per_rank
90
+ end_expert = (expert_rank + 1) * num_experts_per_rank
91
+
92
+ # Calculate the rows per rank.
93
+ row_rank = rank // expert_sharding_degree
94
+ num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
95
+ start_row = row_rank * num_rows_per_rank
96
+ end_row = (row_rank + 1) * num_rows_per_rank
97
+
98
+ # Slice the weight matrix to get the chunk for this rank.
99
+ with torch.no_grad():
100
+ weights = master_weights[start_expert:end_expert, start_row:end_row]
101
+ return weights
102
+
103
+
104
+ class MLP(torch.nn.Module):
105
+
106
+ def __init__(self, args: Arguments):
107
+ super().__init__()
108
+ self.args = args
109
+ # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
110
+ experts_per_rank = mpu.experts_per_rank(args)
111
+
112
+ self.w1 = torch.nn.Parameter(
113
+ torch.empty(
114
+ experts_per_rank,
115
+ args.hidden_size,
116
+ mpu.features_per_rank(args),
117
+ device=args.device,
118
+ dtype=common.dtype(args),
119
+ ),
120
+ )
121
+ self.w2 = torch.nn.Parameter(
122
+ torch.empty(
123
+ experts_per_rank,
124
+ mpu.features_per_rank(args),
125
+ args.hidden_size,
126
+ device=args.device,
127
+ dtype=common.dtype(args),
128
+ ),
129
+ )
130
+ mpu.set_expert_model_parallel_attributes(
131
+ self.w1,
132
+ args.moe_expert_model_parallelism,
133
+ )
134
+ mpu.set_expert_model_parallel_attributes(
135
+ self.w2,
136
+ args.moe_expert_model_parallelism,
137
+ )
138
+
139
+ # Initialize the parameters for the MLP.
140
+ #
141
+ # NOTE: It is important that we create the weight tensors prior
142
+ # to creating the master weights and slicing our the piece for
143
+ # this rank. If the master weights are created first the PyTorch
144
+ # caching allocator appears to use the same memory block for these
145
+ # and the slice which causes large increases in our peak memory
146
+ # usage.
147
+ with torch.no_grad():
148
+ w1 = create_moe_expert_weights(
149
+ args,
150
+ args.moe_num_experts,
151
+ args.ffn_hidden_size,
152
+ args.hidden_size,
153
+ args.init_method,
154
+ )
155
+ self.w1.copy_(w1.transpose(1, 2).contiguous())
156
+ self.w2.copy_(
157
+ create_moe_expert_weights(
158
+ args,
159
+ args.moe_num_experts,
160
+ args.ffn_hidden_size,
161
+ args.hidden_size,
162
+ args.output_layer_init_method,
163
+ ),
164
+ )
165
+
166
+ self.gradient_scale = None
167
+ if self.args.moe_expert_model_parallelism:
168
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
169
+
170
+ def scale_grad(self, w):
171
+ if self.gradient_scale is None:
172
+ return w
173
+ return scale_gradient(w, self.gradient_scale)
174
+
175
+ def forward(self, x):
176
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
177
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
178
+ x = torch.bmm(x, w1)
179
+ x = self.args.activation_fn(x)
180
+ return torch.bmm(x, w2)
181
+
182
+
183
+ def create_dmoe_expert_weights(
184
+ args: Arguments,
185
+ num_experts: int,
186
+ rows: int,
187
+ columns: int,
188
+ init_method: InitFn,
189
+ ):
190
+ weights = create_moe_expert_weights(
191
+ args,
192
+ num_experts,
193
+ rows,
194
+ columns,
195
+ init_method,
196
+ )
197
+ return weights.view([-1, columns])
198
+
199
+
200
+ class MemoryOptimizedMLP(torch.autograd.Function):
201
+ """Sparse MLP with manually scheduled memory reuse."""
202
+
203
+ @staticmethod
204
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
205
+ def forward(ctx, x, w1, w2, topo, activation_fn):
206
+ # Cast inputs using ctx dtype from AMP
207
+ if ctx._fwd_used_autocast:
208
+ x = x.to(ctx._dtype)
209
+ w1 = w1.to(ctx._dtype)
210
+ w2 = w2.to(ctx._dtype)
211
+ # x: [m, k], w1: [n, k], w2: [n, k]
212
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
213
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
214
+
215
+ topo_tensors = (
216
+ topo.row_indices,
217
+ topo.column_indices,
218
+ topo.offsets,
219
+ topo.column_indices_t,
220
+ topo.offsets_t,
221
+ topo.block_offsets_t,
222
+ )
223
+
224
+ # Layer 0: x @ w1.t().
225
+ sdd_out = stk.ops.sdd(x, w1.t(), topo)
226
+
227
+ # GeLU.
228
+ activation_fn_out = act_fn(sdd_out, activation_fn)
229
+
230
+ # Layer 1: x @ w2.
231
+ dsd_out = stk.ops.dsd(activation_fn_out, w2)
232
+
233
+ # NOTE: Save the input to the layer and the activation_fn input for
234
+ # gradient computation. We'll re-compute the activation_fn forward
235
+ # pass in the backward pass to avoid materializing another
236
+ # intermediate.
237
+ ctx.shape = topo.shape
238
+ ctx.x_shape = x.shape
239
+ ctx.sdd_out_shape = sdd_out.data.shape
240
+ ctx.dtype = x.dtype
241
+ ctx.activation_fn = activation_fn
242
+ ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
243
+ return dsd_out
244
+
245
+ @staticmethod
246
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
247
+ def backward(ctx, ddsd_out):
248
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
249
+ raise ValueError('Expected all MLP inputs to need grad.')
250
+
251
+ # unpack saved tensors
252
+ # dtype = ctx.dtype
253
+ saved_tensors = ctx.saved_tensors
254
+ w1, w2 = saved_tensors[:2]
255
+ topo_tensors = saved_tensors[2:8]
256
+ x = saved_tensors[8]
257
+ sdd_out_data = saved_tensors[9]
258
+
259
+ # rematerialize activation function output
260
+ activation_fn = ctx.activation_fn
261
+ sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
262
+ activation_fn_out, activation_grad_fn = act_fn(
263
+ sdd_out,
264
+ activation_fn,
265
+ return_grad_fn=True,
266
+ )
267
+
268
+ # Compute dw2 with recomputed activation_fn output.
269
+ dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
270
+
271
+ # Compute dactivation_fn_out.
272
+ #
273
+ # NOTE: We reuse the activation_fn_out allocation.
274
+ dactivation_fn_out = activation_fn_out
275
+ stk.backend.triton_kernels.sdd(
276
+ ddsd_out,
277
+ w2.t(),
278
+ dactivation_fn_out.shape,
279
+ dactivation_fn_out.data,
280
+ dactivation_fn_out.offsets,
281
+ dactivation_fn_out.row_indices,
282
+ dactivation_fn_out.column_indices,
283
+ )
284
+
285
+ # Compute dsdd_out.
286
+ #
287
+ # NOTE: This reuses the dactivation_fn_out allocation.
288
+ if activation_fn is DEFAULT_ACTIVATION_FN:
289
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
290
+ else:
291
+ assert activation_grad_fn is not None
292
+ activation_grad_fn(dactivation_fn_out.data)
293
+ dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
294
+
295
+ # Compute dw1.
296
+ dw1 = stk.ops.dsd(dsdd_out.t(), x)
297
+
298
+ # Compute dx.
299
+ #
300
+ # NOTE: This reuses the ddsd_out allocation.
301
+ stk.backend.triton_kernels.dsd(
302
+ dsdd_out.shape,
303
+ dsdd_out.data,
304
+ dsdd_out.offsets,
305
+ dsdd_out.row_indices,
306
+ dsdd_out.column_indices,
307
+ dsdd_out.offsets_t,
308
+ dsdd_out.column_indices_t,
309
+ dsdd_out.block_offsets_t,
310
+ False,
311
+ w1,
312
+ ddsd_out,
313
+ )
314
+ dx = ddsd_out
315
+ return dx, dw1, dw2, None, None
316
+
317
+
318
+ memory_optimized_mlp = MemoryOptimizedMLP.apply
319
+
320
+
321
+ class SparseMLP(torch.nn.Module):
322
+
323
+ def __init__(self, args: Arguments):
324
+ super().__init__()
325
+ self.args = args
326
+ self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
327
+
328
+ self.w1 = torch.nn.Parameter(
329
+ torch.empty(
330
+ self._num_rows_per_rank,
331
+ args.hidden_size,
332
+ device=args.device,
333
+ dtype=common.dtype(args),
334
+ ),
335
+ )
336
+ self.w2 = torch.nn.Parameter(
337
+ torch.empty(
338
+ self._num_rows_per_rank,
339
+ args.hidden_size,
340
+ device=args.device,
341
+ dtype=common.dtype(args),
342
+ ),
343
+ )
344
+
345
+ # Initialize the parameters for the MLP.
346
+ #
347
+ # NOTE: It is important that we create the weight tensors prior
348
+ # to creating the master weights and slicing our the piece for
349
+ # this rank. If the master weights are created first the PyTorch
350
+ # caching allocator appears to use the same memory block for these
351
+ # and the slice which causes large increases in our peak memory
352
+ # usage.
353
+ with torch.no_grad():
354
+ self.w1.copy_(
355
+ create_dmoe_expert_weights(
356
+ args,
357
+ args.moe_num_experts,
358
+ args.ffn_hidden_size,
359
+ args.hidden_size,
360
+ args.init_method,
361
+ ),
362
+ )
363
+ self.w2.copy_(
364
+ create_dmoe_expert_weights(
365
+ args,
366
+ args.moe_num_experts,
367
+ args.ffn_hidden_size,
368
+ args.hidden_size,
369
+ args.output_layer_init_method,
370
+ ),
371
+ )
372
+
373
+ self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
374
+ mpu.set_expert_model_parallel_attributes(
375
+ self.w1,
376
+ self._should_set_parallelism_attribute,
377
+ )
378
+ mpu.set_expert_model_parallel_attributes(
379
+ self.w2,
380
+ self._should_set_parallelism_attribute,
381
+ )
382
+
383
+ self.gradient_scale = None
384
+ if self.args.moe_expert_model_parallelism:
385
+ self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
386
+
387
+ def scale_grad(self, w):
388
+ if self.gradient_scale is None:
389
+ return w
390
+ return scale_gradient(w, self.gradient_scale)
391
+
392
+ def forward(self, x, topo):
393
+ w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
394
+ w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
395
+ if self.args.memory_optimized_mlp:
396
+ return memory_optimized_mlp(
397
+ x,
398
+ w1,
399
+ w2,
400
+ topo,
401
+ self.args.activation_fn,
402
+ )
403
+
404
+ # Compute the MLP.
405
+ x = stk.ops.sdd(x, w1.t(), topo)
406
+ activation_fn_out = act_fn(x, self.args.activation_fn)
407
+ return stk.ops.dsd(activation_fn_out, w2)
408
+
409
+
410
+ class MemoryOptimizedGroupedMLP(torch.autograd.Function):
411
+ """GroupedMLP with manually scheduled memory reuse."""
412
+
413
+ @staticmethod
414
+ @torch.amp.autocast_mode.custom_fwd(device_type='cuda')
415
+ def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
416
+ # Cast inputs using ctx dtype from AMP
417
+ if ctx._fwd_used_autocast:
418
+ x = x.to(ctx._dtype)
419
+ w1 = w1.to(ctx._dtype)
420
+ w2 = w2.to(ctx._dtype)
421
+ # x: [m, k], w1: [n, k], w2: [n, k]
422
+ if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
423
+ raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
424
+
425
+ # Layer 0: x @ w1.t().
426
+ assert gg.backend is not None
427
+ sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
428
+
429
+ # activation_fn
430
+ activation_fn_out = activation_fn(sdd_out)
431
+
432
+ # Layer 1: x @ w2.
433
+ dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
434
+
435
+ # NOTE: Save the input to the layer and the activation_fn input for
436
+ # gradient computation. We'll re-compute the activation_fn forward
437
+ # pass in the backward pass to avoid materializing another
438
+ # intermediate.
439
+ ctx.x_shape = x.shape
440
+ ctx.sdd_out_shape = sdd_out.shape
441
+ ctx.dtype = x.dtype
442
+ ctx.activation_fn = activation_fn
443
+ ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
444
+ return dsd_out
445
+
446
+ @staticmethod
447
+ @torch.amp.autocast_mode.custom_bwd(device_type='cuda')
448
+ def backward(ctx: Any, ddsd_out: torch.Tensor):
449
+ if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
450
+ raise ValueError('Expected all MLP inputs to need grad.')
451
+
452
+ # Unpack saved tensors
453
+ # dtype = ctx.dtype
454
+ saved_tensors = ctx.saved_tensors
455
+ w1, w2 = saved_tensors[:2]
456
+ batch_sizes = saved_tensors[2]
457
+ x = saved_tensors[3]
458
+ sdd_out = saved_tensors[4]
459
+
460
+ # Rematerialize activation_fn output.
461
+ activation_fn = ctx.activation_fn
462
+ with torch.set_grad_enabled(True):
463
+ sdd_out.requires_grad = True
464
+ activation_fn_out = activation_fn(sdd_out)
465
+ activation_grad_fn = activation_fn_out.backward
466
+
467
+ # Compute dw2 with recomputed activation_fn output.
468
+ assert gg.backend is not None
469
+ dw2 = gg.backend.gmm(
470
+ activation_fn_out,
471
+ ddsd_out,
472
+ batch_sizes,
473
+ trans_a=True,
474
+ )
475
+
476
+ # Compute dactivation_fn_out.
477
+ #
478
+ # NOTE: We reuse the activation_fn_out allocation.
479
+ dactivation_fn_out = activation_fn_out
480
+ gg.backend.gmm(
481
+ ddsd_out,
482
+ w2,
483
+ batch_sizes,
484
+ trans_b=True,
485
+ c=dactivation_fn_out,
486
+ )
487
+
488
+ # Compute dsdd_out.
489
+ #
490
+ # NOTE: This reuses the dactivation_fn_out allocation.
491
+ if activation_fn is DEFAULT_ACTIVATION_FN:
492
+ dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
493
+ else:
494
+ assert activation_grad_fn is not None
495
+ activation_grad_fn(dactivation_fn_out)
496
+ dsdd_out = sdd_out.grad
497
+
498
+ # Compute dw1.
499
+ dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
500
+
501
+ # Compute dx.
502
+ #
503
+ # NOTE: This reuses the ddsd_out allocation.
504
+ gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
505
+ dx = ddsd_out
506
+ return dx, dw1, dw2, None, None
507
+
508
+
509
+ memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
510
+
511
+
512
+ class GroupedMLP(SparseMLP):
513
+
514
+ def forward(self, x, tokens_per_expert):
515
+ batch_sizes = tokens_per_expert.cpu().to(torch.long)
516
+ w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
517
+
518
+ # Re-shape the weights for the grouped GEMMs.
519
+ ne = mpu.experts_per_rank(self.args)
520
+ w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
521
+ w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
522
+
523
+ if self.args.memory_optimized_mlp:
524
+ return memory_optimized_grouped_mlp(
525
+ x,
526
+ w1,
527
+ w2,
528
+ batch_sizes,
529
+ self.args.activation_fn,
530
+ )
531
+
532
+ # Compute the MLP.
533
+ assert gg.ops is not None
534
+ x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
535
+ x = self.args.activation_fn(x)
536
+ return gg.ops.gmm(x, w2, batch_sizes)
537
+
538
+
539
+ class SharedMLP(torch.nn.Module):
540
+ """MLP for shared expert.
541
+
542
+ Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
543
+ """
544
+
545
+ def __init__(self, args: Arguments):
546
+ super().__init__()
547
+ self.args = args
548
+ self.fc_kwargs: dict[str, Any] = {
549
+ 'bias': args.bias,
550
+ 'device': args.device,
551
+ }
552
+ self.fc_kwargs.update(args.fc_kwargs)
553
+
554
+ self.up_proj = args.fc_cls(
555
+ args.hidden_size,
556
+ args.shared_expert_hidden_size,
557
+ **self.fc_kwargs,
558
+ )
559
+ self.act = args.activation_fn
560
+ self.down_proj = args.fc_cls(
561
+ args.shared_expert_hidden_size,
562
+ args.hidden_size,
563
+ **self.fc_kwargs,
564
+ )
565
+ self.down_proj._is_residual = True # a flag for llm-foundry init
566
+
567
+ def add_experts_sharedexpert(
568
+ self,
569
+ shared_expert_out: torch.Tensor,
570
+ expert_out: torch.Tensor,
571
+ ) -> torch.Tensor:
572
+ # Helper function to add expert output to shared expert output
573
+ # with optional weighted sum.
574
+ if self.args.shared_expert_weighted_sum:
575
+ # enable using weighted sum for shared expert output
576
+ # wieghted by number of experts used
577
+ t_experts = self.args.moe_top_k + 1
578
+ sh_mlp_out = shared_expert_out / t_experts
579
+ return sh_mlp_out.add(
580
+ expert_out,
581
+ alpha=(self.args.moe_top_k / t_experts),
582
+ )
583
+
584
+ return shared_expert_out + expert_out
585
+
586
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
587
+ return self.down_proj(self.act(self.up_proj(x)))
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # import megablocks.ops as ops
10
+ # from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
11
+ # from megablocks.layers.all_to_all import all_to_all
12
+ # from megablocks.layers.arguments import Arguments
13
+
14
+ from ..ops import (
15
+ sort,
16
+ histogram,
17
+ inclusive_cumsum,
18
+ exclusive_cumsum,
19
+ binned_gather,
20
+ binned_scatter,
21
+ gather,
22
+ scatter,
23
+ repeat,
24
+ replicate,
25
+ )
26
+
27
+ from . import common, mlp, mpu, router, sharedexpert_registry
28
+ from .arguments import Arguments
29
+ from .all_to_all import all_to_all
30
+
31
+ _LOAD_BALANCING_LOSS = []
32
+
33
+
34
+ def save_load_balancing_loss(loss):
35
+ global _LOAD_BALANCING_LOSS
36
+ _LOAD_BALANCING_LOSS.append(loss)
37
+
38
+
39
+ def get_load_balancing_loss():
40
+ global _LOAD_BALANCING_LOSS
41
+ return _LOAD_BALANCING_LOSS
42
+
43
+
44
+ def clear_load_balancing_loss():
45
+ global _LOAD_BALANCING_LOSS
46
+ _LOAD_BALANCING_LOSS.clear()
47
+
48
+
49
+ def batched_load_balancing_loss(args: Arguments):
50
+ if args.moe_loss_weight == 0:
51
+ return 0.0
52
+
53
+ # tokens_per_expert[i].shape = (num_experts)
54
+ # expert_scores[i].shape = (tokens, num_experts)
55
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
56
+ num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
57
+ if args.num_layers_per_virtual_pipeline_stage is not None:
58
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
59
+
60
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
61
+ raise ValueError(
62
+ f'Expected {num_layers_per_pipeline_stage} token_per_experts '
63
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
64
+ f'{args.num_layers}\npipeline_model_parallel_size = '
65
+ f'{args.pipeline_model_parallel_size}\n'
66
+ 'num_layers_per_virtual_pipeline_stage'
67
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
68
+ )
69
+ if len(expert_scores) != num_layers_per_pipeline_stage:
70
+ raise ValueError(
71
+ f'Expected {num_layers_per_pipeline_stage} expert_scores '
72
+ f'but found {len(tokens_per_expert)}.\nnum_layers = '
73
+ f'{args.num_layers}\npipeline_model_parallel_size = '
74
+ f'{args.pipeline_model_parallel_size}\n'
75
+ 'num_layers_per_virtual_pipeline_stage'
76
+ f' = {args.num_layers_per_virtual_pipeline_stage}',
77
+ )
78
+
79
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
80
+ assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
81
+
82
+ tokens = expert_scores[0].shape[0]
83
+ assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
84
+
85
+ # Concatenate the contributions of each layer and convert to
86
+ # the correct types and formats for the dot product.
87
+ expert_scores = torch.cat(expert_scores, dim=1)
88
+ if args.moe_lbl_in_fp32:
89
+ expert_scores = expert_scores.float()
90
+ if tokens != 0:
91
+ expert_scores = expert_scores.mean(dim=0)
92
+ else:
93
+ expert_scores = expert_scores.sum(dim=0)
94
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
95
+
96
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
97
+ assert tokens_per_expert.numel() == expected_values
98
+ assert expert_scores.numel() == expected_values
99
+
100
+ # Calculate the total scale across all factors.
101
+ #
102
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
103
+ scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
104
+ scale_denominator = (args.num_layers * tokens * args.moe_top_k)
105
+ scale = scale_numerator / scale_denominator
106
+ return scale * torch.dot(tokens_per_expert, expert_scores)
107
+
108
+
109
+ # NOTE: This class defines MoE expert computation, including expert model parallel
110
+ # communication. When using FSDP on top of MegaBlocks this is the module that should
111
+ # be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
112
+ # parallel all2all.
113
+ class ParallelMLP(torch.nn.Module):
114
+
115
+ def __init__(self, args: Arguments):
116
+ super(ParallelMLP, self).__init__()
117
+ self.args = args
118
+
119
+ # Calculate the number of experts in total and the number of experts
120
+ # owned by this rank.
121
+ # world_size = mpu.get_expert_parallel_world_size(args)
122
+ self.num_experts = args.moe_num_experts
123
+ self.top_k = self.args.moe_top_k
124
+
125
+ # Calculate the number of bits needed to represent the expert indices
126
+ # so that we can pass it to radix sort.
127
+ self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
128
+
129
+ # Expert MLP.
130
+ self.mlp = mlp.MLP(args)
131
+
132
+ self.bias: Optional[torch.Tensor]
133
+ if self.args.bias:
134
+ # Note that the output bias is not parallelized with expert
135
+ # model parallelism.
136
+ self.bias = torch.nn.Parameter(
137
+ torch.empty(
138
+ args.hidden_size,
139
+ device=args.device,
140
+ dtype=common.dtype(args),
141
+ ),
142
+ )
143
+ torch.nn.init.zeros_(self.bias)
144
+ else:
145
+ self.register_parameter('bias', None)
146
+
147
+ # Select the forward function for the operating mode.
148
+ self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
149
+
150
+ def expert_capacity(self, tokens: int) -> int:
151
+ world_size = mpu.get_expert_parallel_world_size(self.args)
152
+ tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
153
+ return int(self.args.moe_capacity_factor * tokens_per_expert)
154
+
155
+ def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
156
+ """Calculate the load balancing loss contribution."""
157
+ assert len(expert_scores.size()) == 2
158
+ tokens, num_experts = expert_scores.size()
159
+ assert num_experts == self.num_experts
160
+ assert len(tokens_per_expert.size()) == 1
161
+ num_experts, = tokens_per_expert.size()
162
+ assert num_experts == self.num_experts
163
+ scale = self.num_experts / (tokens * self.top_k)
164
+ return scale * torch.dot(
165
+ tokens_per_expert.to(expert_scores.dtype),
166
+ expert_scores.mean(dim=0),
167
+ )
168
+
169
+ def indices_and_bins(self,
170
+ top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171
+ # Sort the expert ids to produce the scatter/gather
172
+ # indices for the permutation.
173
+ #
174
+ # TODO(tgale): Is it worth doing this conversion to 32-bit
175
+ # prior? Could we place the `torch.max` operation to return
176
+ # 32-bit expert indices?
177
+ top_expert = top_expert.int()
178
+ # output = ops.sort(top_expert, self.sort_end_bit)
179
+ output = sort(top_expert, self.sort_end_bit)
180
+ assert output is not None
181
+ bin_ids, indices = output
182
+
183
+ # Histogram the expert ids to identify the number of
184
+ # tokens routed to each expert.
185
+ #
186
+ # TODO(tgale): Does the sorted data produce a more favorable
187
+ # data distribution for histogram? Or is the op parallelism
188
+ # worth more?
189
+ # tokens_per_expert = ops.histogram(top_expert, self.num_experts)
190
+ tokens_per_expert = histogram(top_expert, self.num_experts)
191
+
192
+ # Calculate the bin bounds for the sorted tokens.
193
+ # bins = ops.inclusive_cumsum(tokens_per_expert, 0)
194
+ bins = inclusive_cumsum(tokens_per_expert, 0)
195
+ assert bins is not None
196
+ bins = bins.view(1) if not len(bins.size()) else bins
197
+
198
+ assert isinstance(indices, torch.Tensor)
199
+ assert isinstance(bin_ids, torch.Tensor)
200
+ assert isinstance(bins, torch.Tensor)
201
+ assert isinstance(tokens_per_expert, torch.Tensor)
202
+
203
+ return indices, bin_ids, bins, tokens_per_expert
204
+
205
+ def permute_and_compute(
206
+ self,
207
+ x: torch.Tensor,
208
+ tokens_per_expert: int, # unused
209
+ indices: torch.Tensor,
210
+ bin_ids: torch.Tensor, # unused
211
+ expert_weights: torch.Tensor,
212
+ bins: torch.Tensor,
213
+ expert_capacity: int,
214
+ top_k: int,
215
+ ):
216
+ # Route the tokens for MoE computation.
217
+ x = x.view(-1, x.shape[-1])
218
+ # output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
219
+ output = binned_gather(x, indices, bins, expert_capacity, top_k)
220
+ assert output is not None
221
+ x = output
222
+
223
+ # Perform the expert computation. Note that we don't
224
+ # use biases for these linear operations.
225
+ x = self.mlp(x)
226
+
227
+ # Un-route the data for the MoE output.
228
+ # return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
229
+ return binned_scatter(x, indices, expert_weights, bins, top_k)
230
+
231
+
232
+ def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
233
+ # x: [sl, bs, hs]
234
+ # expert_weights: [sl * bs, top-k]
235
+ # top_experts: [sl * bs, top-k]
236
+ expert_weights = expert_weights.flatten()
237
+ top_experts = top_experts.flatten()
238
+ with torch.no_grad():
239
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
240
+
241
+ # If expert_capacity is set to zero, set the number of tokens
242
+ # per expert to the maximum we need to avoid dropping tokens.
243
+ sl, bs, _ = x.size()
244
+ expert_capacity = self.expert_capacity(sl * bs)
245
+ if expert_capacity == 0:
246
+ expert_capacity = torch.max(tokens_per_expert).item()
247
+
248
+ x = self.permute_and_compute(
249
+ x,
250
+ tokens_per_expert,
251
+ indices,
252
+ bin_ids,
253
+ expert_weights,
254
+ bins,
255
+ expert_capacity,
256
+ self.top_k,
257
+ )
258
+ return x, tokens_per_expert
259
+
260
+ def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
261
+ # NOTE: This function implements the same computation as forward_once
262
+ # but with expert model parallelism.
263
+ #
264
+ # 1. Permute the tokens locally so that they are grouped by their
265
+ # expert assignments. This allows us to transfer all of the tokens
266
+ # for a remote device in one communication primitive.
267
+ #
268
+ # 2. Permute the tokens across the expert parallel devices. After
269
+ # this is completed each device has all of the tokens assigned to
270
+ # its set of experts in its local HBM.
271
+ #
272
+ # 3. Permute the tokens locally so that they are grouped by their
273
+ # expert assignement. After the distributed permutation the tokens
274
+ # are grouped by which device they came from. We re-order them
275
+ # locally to allow for efficient computation.
276
+ #
277
+ # After this series of permutations we compute the linear layers
278
+ # and then repeat these three steps in reverse to produce the final
279
+ # output.
280
+ #
281
+ # Compute the mapping of local tokens to experts.
282
+ expert_weights = expert_weights.flatten()
283
+ top_experts = top_experts.flatten()
284
+ with torch.no_grad():
285
+ indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
286
+
287
+ # If we're sharding the experts along the hidden dimension
288
+ # multiple devices own parts of the same sets of experts.
289
+ # Replicate the token counts so every device gets the counts.
290
+ # repeated_tokens_per_expert = ops.repeat(
291
+ repeated_tokens_per_expert = repeat(
292
+ tokens_per_expert,
293
+ (mpu.hidden_sharding_degree(self.args),),
294
+ )
295
+
296
+ # Pass token count information to the device on which the
297
+ # target expert resides.
298
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
299
+ tpe_handle = dist.all_to_all_single(
300
+ parallel_tokens_per_expert,
301
+ repeated_tokens_per_expert,
302
+ group=self.args.expert_parallel_group,
303
+ async_op=True,
304
+ )
305
+
306
+ # Permute locally and without any padding so that tokens for each
307
+ # parallel device are stored contiguously.
308
+ #
309
+ # This view updates the shape of the tensor from [sl, bs, hs] to
310
+ # [sl * bs, hs] prior to the permutation.
311
+ x = x.view(-1, x.shape[-1])
312
+ # output = ops.gather(x, indices, bin_ids, bins, self.top_k)
313
+ output = gather(x, indices, bin_ids, bins, self.top_k)
314
+ assert output is not None
315
+ x = output
316
+
317
+ # Compute the number of tokens that will be received from each
318
+ # device and permute the input data across the devices.
319
+ with torch.no_grad():
320
+ tpe_handle.wait()
321
+ experts_per_rank = mpu.experts_per_rank(self.args)
322
+
323
+ # Reshape to [world_size, num_experts_per_rank].
324
+ world_size = mpu.get_expert_parallel_world_size(self.args)
325
+ repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
326
+ parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
327
+
328
+ # TODO(tgale): It might be faster to do this on the GPU and
329
+ # then communicate the results back to the host.
330
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
331
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
332
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
333
+
334
+ # Convert the send/recv counts to lists.
335
+ send_counts = send_counts.tolist()
336
+ recv_counts = recv_counts.tolist()
337
+ tokens_received = sum(recv_counts)
338
+
339
+ # If we're sharding the experts along the hidden dimension
340
+ # multiple devices own parts of the same sets of experts.
341
+ # Replicate the token counts so devices that share experts
342
+ # get all of the tokens assigned to them.
343
+ #
344
+ # TODO(tgale): Fuse this into the prior, local permutation.
345
+ # x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
346
+ x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
347
+
348
+ # Start the cross-device permutation asynchronously so we can
349
+ # overlap communication with computation.
350
+ parallel_x, parallel_x_handle = all_to_all(
351
+ x,
352
+ recv_counts,
353
+ send_counts,
354
+ self.args.expert_parallel_group,
355
+ async_op=True,
356
+ )
357
+
358
+ with torch.no_grad():
359
+ # After we do the cross-device permutation we have the tokens on the
360
+ # correct device but not yet grouped by expert because we received
361
+ # tokens from each device as contiguous chunks. To group the tokens
362
+ # for expert computation we'll do one more local permutation. The
363
+ # rest of this torch.no_grad() scope sets up the indices and bins
364
+ # for this permutation.
365
+ # replicate_bins = ops.inclusive_cumsum(
366
+ replicate_bins = inclusive_cumsum(
367
+ parallel_tokens_per_expert.flatten(),
368
+ 0,
369
+ )
370
+ replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
371
+
372
+ # Construct the expert indices for the permuted tokens.
373
+ parallel_top_expert = torch.remainder(
374
+ torch.arange(
375
+ self.num_experts * mpu.hidden_sharding_degree(self.args),
376
+ dtype=torch.int32,
377
+ device=indices.device,
378
+ ),
379
+ mpu.experts_per_rank(self.args),
380
+ )
381
+ # parallel_top_expert = ops.replicate(
382
+ parallel_top_expert = replicate(
383
+ parallel_top_expert.unsqueeze(dim=0),
384
+ replicate_bins,
385
+ tokens_received,
386
+ ).flatten()
387
+
388
+ # TODO(tgale): The sort_end_bit here can be reduced.
389
+ # parallel_bin_ids, parallel_indices = ops.sort(
390
+ parallel_bin_ids, parallel_indices = sort(
391
+ parallel_top_expert,
392
+ self.sort_end_bit,
393
+ )
394
+
395
+ # Calculate the bins boundaries from the token counts.
396
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
397
+ dim=0,
398
+ dtype=torch.int,
399
+ )
400
+ # parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
401
+ parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
402
+ parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
403
+
404
+ # If expert_capacity is set to zero, set the number of tokens
405
+ # per expert to the maximum we need to avoid dropping tokens.
406
+ tokens, _ = x.size()
407
+ expert_capacity = self.expert_capacity(tokens)
408
+ if expert_capacity == 0:
409
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
410
+
411
+ # Locally permute the tokens and perform the expert computation.
412
+ # Block to make sure that the cross-device permutation is complete.
413
+ if self.args.mlp_impl == 'grouped':
414
+ # GroupedMLP requires counts on CPU. We can use the tensor already
415
+ # moved to CPU for the prior all_to_all, which avoids an extra
416
+ # device synchronization.
417
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
418
+ dim=0,
419
+ dtype=torch.int,
420
+ )
421
+ parallel_x_handle.wait()
422
+ parallel_x = self.permute_and_compute(
423
+ parallel_x,
424
+ parallel_tokens_per_expert,
425
+ parallel_indices,
426
+ parallel_bin_ids,
427
+ None, # expert_weights
428
+ parallel_bins,
429
+ expert_capacity,
430
+ top_k=1,
431
+ )
432
+
433
+ # Un-permute the tokens across the devices.
434
+ x, _ = all_to_all(
435
+ parallel_x,
436
+ send_counts,
437
+ recv_counts,
438
+ self.args.expert_parallel_group,
439
+ )
440
+
441
+ # Reduce along the hidden sharding to get the final outputs.
442
+ #
443
+ # TODO(tgale): Fuse this into the following local permutation.
444
+ shape = (
445
+ mpu.hidden_sharding_degree(self.args),
446
+ -1,
447
+ self.args.hidden_size,
448
+ )
449
+ # x = ops.sum(x.view(shape), dim=0)
450
+ x = x.view(shape).sum(dim=0)
451
+
452
+ # Un-permute locally to setup for the next series of operations.
453
+ # x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
454
+ x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
455
+ return x, tokens_per_expert.flatten()
456
+
457
+ def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
458
+ in_shape = x.size()
459
+
460
+ # Compute the experts.
461
+ x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
462
+ if self.training and self.args.moe_loss_weight > 0:
463
+ save_load_balancing_loss((tokens_per_expert, scores))
464
+ x = x.view(in_shape)
465
+ if self.bias is not None:
466
+ if self.args.return_bias:
467
+ return x, self.bias
468
+ return x + self.bias
469
+ return x
470
+
471
+
472
+ class MoE(torch.nn.Module):
473
+
474
+ def __init__(self, args: Arguments):
475
+ super(MoE, self).__init__()
476
+
477
+ # Token router.
478
+ self.router = router.LearnedRouter(args)
479
+
480
+ # Expert computation helper.
481
+ self.experts = self._init_experts_mlp(args)
482
+
483
+ self.shared_expert = None
484
+ if args.shared_expert:
485
+ # SharedExpert computation helper.
486
+ self.shared_expert = sharedexpert_registry.get(args)
487
+
488
+ def _init_experts_mlp(self, args: Arguments):
489
+ return ParallelMLP(args)
490
+
491
+ def forward(self, x: torch.Tensor):
492
+ # NOTE: If we're going to cast the activations to lower precision
493
+ # do it before we permute the tokens to save bandwidth.
494
+ x = common.cast_if_autocast_enabled(x)
495
+
496
+ # Compute the expert scores and assignments.
497
+ scores, expert_weights, top_experts = self.router(x)
498
+
499
+ # Compute the experts.
500
+ out = self.experts(x, scores, expert_weights, top_experts)
501
+ if self.shared_expert is not None:
502
+ shared_expert_out = self.shared_expert(x)
503
+ out = self.shared_expert.add_experts_sharedexpert(
504
+ shared_expert_out,
505
+ out,
506
+ )
507
+ return out
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ # from megablocks.layers.arguments import Arguments
10
+ from .arguments import Arguments
11
+
12
+
13
+ class MoeParam(torch.Tensor):
14
+
15
+ def __init__(self):
16
+ super().__init__(self)
17
+ self.expert_model_parallel: bool
18
+
19
+
20
+ def is_moe_param(tensor: torch.Tensor) -> bool:
21
+ return hasattr(tensor, 'expert_model_parallel')
22
+
23
+
24
+ def get_expert_parallel_world_size(args: Arguments) -> int:
25
+ return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
26
+
27
+
28
+ def get_expert_parallel_rank(args: Arguments) -> int:
29
+ return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
30
+
31
+
32
+ def set_expert_model_parallel_attributes(
33
+ tensor: torch.Tensor,
34
+ is_parallel: bool,
35
+ ):
36
+ assert not hasattr(tensor, 'expert_model_parallel')
37
+ setattr(tensor, 'expert_model_parallel', is_parallel)
38
+
39
+
40
+ def param_is_expert_model_parallel(param: MoeParam) -> bool:
41
+ return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
42
+
43
+
44
+ def copy_expert_model_parallel_attributes(
45
+ destination_tensor: torch.Tensor,
46
+ source_tensor: torch.Tensor,
47
+ ):
48
+ if hasattr(source_tensor, 'expert_model_parallel'):
49
+ setattr(
50
+ destination_tensor,
51
+ 'expert_model_parallel',
52
+ getattr(source_tensor, 'expert_model_parallel'),
53
+ )
54
+
55
+
56
+ def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
57
+ world_size = dist.get_world_size(group)
58
+ rank = dist.get_rank(group)
59
+ for i in range(world_size):
60
+ dist.barrier(group)
61
+ if i == rank:
62
+ print(f'rank = {rank}', *x)
63
+
64
+
65
+ # Helpers for expert/tensor sharding.
66
+ def expert_sharding_degree(args: Arguments) -> int:
67
+ world_size = get_expert_parallel_world_size(args)
68
+ esd = min(world_size, args.moe_num_experts)
69
+
70
+ if (args.moe_num_experts % esd) != 0:
71
+ raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
72
+ return esd
73
+
74
+
75
+ def hidden_sharding_degree(args: Arguments) -> int:
76
+ world_size = get_expert_parallel_world_size(args)
77
+ esd = expert_sharding_degree(args)
78
+ hsd = world_size // esd
79
+
80
+ if (args.ffn_hidden_size % hsd) != 0:
81
+ raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
82
+ if (esd * hsd) != world_size:
83
+ raise ValueError(
84
+ f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
85
+ )
86
+ return hsd
87
+
88
+
89
+ def experts_per_rank(args: Arguments) -> int:
90
+ return args.moe_num_experts // expert_sharding_degree(args)
91
+
92
+
93
+ def features_per_rank(args: Arguments) -> int:
94
+ return args.ffn_hidden_size // hidden_sharding_degree(args)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ # from megablocks.layers import common
8
+ # from megablocks.layers.arguments import Arguments
9
+ from . import common
10
+ from .arguments import Arguments
11
+
12
+ _ROUTER_LOGITS = []
13
+
14
+
15
+ def _save_router_logits(logits: torch.Tensor, args: Arguments):
16
+ if args.moe_zloss_weight == 0:
17
+ return
18
+ global _ROUTER_LOGITS
19
+ _ROUTER_LOGITS.append(logits)
20
+
21
+
22
+ def clear_router_zloss():
23
+ global _ROUTER_LOGITS
24
+ _ROUTER_LOGITS.clear()
25
+
26
+
27
+ def batched_router_zloss(args: Arguments):
28
+ global _ROUTER_LOGITS
29
+
30
+ if args.moe_zloss_weight == 0:
31
+ import warnings
32
+ warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
33
+ return 0
34
+
35
+ logits_per_router = _ROUTER_LOGITS
36
+
37
+ if args.moe_zloss_in_fp32:
38
+ logits_per_router = [logits.float() for logits in logits_per_router]
39
+
40
+ unscaled_zloss_per_router = torch.stack([
41
+ torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
42
+ ])
43
+
44
+ return args.moe_zloss_weight * unscaled_zloss_per_router
45
+
46
+
47
+ # NOTE: To enable end-to-end benchmarking without convergence we
48
+ # support a flag to force the router to assign tokens uniformly
49
+ # across the experts. We do this with a custom autograd operation
50
+ # so that PyTorch still executes the full set of router operation.
51
+ class _UniformExpertAssignment(torch.autograd.Function):
52
+
53
+ @staticmethod
54
+ def forward(ctx: Any, x: torch.Tensor, num_experts: int):
55
+ out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
56
+ out = torch.remainder(out, num_experts)
57
+ return out.view(x.shape)
58
+
59
+
60
+ _uniform_expert_assignment = _UniformExpertAssignment.apply
61
+
62
+
63
+ class LearnedRouter(torch.nn.Module):
64
+
65
+ def __init__(self, args: Arguments):
66
+ super().__init__()
67
+ self.args = args
68
+
69
+ # Learned router parameters.
70
+ #
71
+ # NOTE: This weight matrix is not parallelized with expert model
72
+ # parallelism. Each device needs the entire router weight matrix
73
+ # so that it can route its batch of data correctly.
74
+ self.layer = torch.nn.Linear(
75
+ args.hidden_size,
76
+ args.moe_num_experts,
77
+ bias=False,
78
+ dtype=common.dtype(args),
79
+ device=args.device,
80
+ )
81
+ args.init_method(self.layer.weight)
82
+
83
+ def jitter(self, x: torch.Tensor):
84
+ low: float = 1.0 - self.args.moe_jitter_eps
85
+ high: float = 1.0 + self.args.moe_jitter_eps
86
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
87
+ return low + noise * (high - low)
88
+
89
+ def _top_k(self, scores: torch.Tensor):
90
+ if self.args.moe_top_k == 1:
91
+ return scores.max(dim=-1, keepdim=True)
92
+ return torch.topk(scores, self.args.moe_top_k, dim=-1)
93
+
94
+ def forward(self, x: torch.Tensor):
95
+ if self.training and self.args.moe_jitter_eps is not None:
96
+ x = x * self.jitter(x)
97
+
98
+ logits = self.layer(x.view(-1, x.shape[-1]))
99
+ _save_router_logits(logits, self.args)
100
+ scores = logits.softmax(dim=-1)
101
+ expert_weights, expert_indices = self._top_k(scores)
102
+ if self.args.moe_normalize_expert_weights:
103
+ expert_weights = expert_weights / torch.norm(
104
+ expert_weights,
105
+ p=self.args.moe_normalize_expert_weights,
106
+ dim=-1,
107
+ keepdim=True,
108
+ )
109
+
110
+ expert_indices = (
111
+ _uniform_expert_assignment(
112
+ expert_indices,
113
+ self.args.moe_num_experts,
114
+ ) if self.args.uniform_expert_assignment else expert_indices
115
+ )
116
+ return scores, expert_weights, expert_indices
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Union
5
+
6
+ # from megablocks.layers import glu, mlp
7
+ # from megablocks.layers.arguments import Arguments
8
+ from . import glu, mlp
9
+ from .arguments import Arguments
10
+
11
+ _REGISTRY = {
12
+ 'mlp': mlp.SharedMLP,
13
+ 'glu': glu.SharedGLU,
14
+ }
15
+
16
+
17
+ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
18
+ """Returns an SharedMLP for use in a dMoE instance.
19
+
20
+ Uses the provided arguments to instantiate the appropriate
21
+ SharedMLP instance.
22
+
23
+ Args:
24
+ args: propagated Arguments dataclass.
25
+
26
+ Returns:
27
+ An instantiated SharedMLP constructed using the input args.
28
+ """
29
+ if args.mlp_type not in _REGISTRY:
30
+ raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
31
+
32
+ return _REGISTRY[args.mlp_type](args)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a19bba459394ac0d93b6405084772af39eb92d4f280f6b5c586d1beeb1589051
3
+ size 5573536
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import _megablocks_20250730102509
3
+ ops = torch.ops._megablocks_20250730102509
4
+
5
+ def add_op_namespace_prefix(op_name: str):
6
+ """
7
+ Prefix op by namespace.
8
+ """
9
+ return f"_megablocks_20250730102509::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import triton
6
+ import triton.language as tl
7
+
8
+
9
+ def assert_is_tensor(x, ndim):
10
+ if x.ndim != ndim:
11
+ raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
12
+
13
+
14
+ def assert_is_matrix(x):
15
+ assert_is_tensor(x, 2)
16
+
17
+
18
+ def assert_is_vector(x):
19
+ if x.ndim != 1:
20
+ raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
21
+
22
+
23
+ def assert_equal(a, b):
24
+ if a != b:
25
+ raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
26
+
27
+
28
+ # a: (tokens, hidden_size), real.
29
+ # indices: (tokens * top_k), integer.
30
+ # bin_ids: (tokens * top_k), integer.
31
+ # weights: (tokens * top_k), real.
32
+ # bins: (num_experts), integer.
33
+ # padded_bins: (num_experts), integer.
34
+ @triton.autotune(
35
+ configs=[
36
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
37
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
38
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
39
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
40
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
41
+ ],
42
+ key=['NUM_COLUMNS'],
43
+ )
44
+ @triton.jit
45
+ def _padded_copy(
46
+ a,
47
+ b,
48
+ indices,
49
+ bin_ids,
50
+ weights,
51
+ bins,
52
+ padded_bins,
53
+ NUM_COLUMNS: tl.constexpr,
54
+ TOP_K: tl.constexpr,
55
+ BLOCK_X: tl.constexpr,
56
+ A_TO_B: tl.constexpr,
57
+ SCALE: tl.constexpr,
58
+ ):
59
+ # Our index into array 'a'.
60
+ index_a = tl.load(indices + tl.program_id(0))
61
+
62
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
63
+ # number of rows since they could be padded.
64
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
65
+
66
+ # Now we know what bin we're assigned to, but we need to know how
67
+ # many threadblocks were assigned to earlier bins so we can offset
68
+ # in our bin properly.
69
+ offset_in_bin = tl.program_id(0)
70
+ if bin_idx > 0:
71
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
72
+
73
+ # Load the starting index of our bin in array 'b'.
74
+ index_b = offset_in_bin
75
+ if bin_idx > 0:
76
+ index_b += tl.load(padded_bins + bin_idx - 1)
77
+
78
+ # Offset the input and output pointers.
79
+ #
80
+ # If we're going from A to B, divide the input index to copy
81
+ # the same input repeatedly. If we're going from B to A we
82
+ # need to reduce the result. Using atomics is slow, so we
83
+ # do the reduce step in a second kernel.
84
+ offset = index_a // TOP_K if A_TO_B else index_a
85
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
86
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
87
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
88
+
89
+ # Load the scale, if requested.
90
+ scale = tl.load(weights + index_a) if SCALE else 1
91
+
92
+ # Swap the pointers depending on the direction.
93
+ iptr = a if A_TO_B else b
94
+ optr = b if A_TO_B else a
95
+
96
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
97
+ for _ in range(iterations):
98
+ mask = offsets < NUM_COLUMNS
99
+ x = tl.load(iptr + offsets, mask=mask)
100
+ x = x.to(tl.float32) * scale.to(tl.float32)
101
+
102
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
103
+
104
+ offsets += BLOCK_X
105
+
106
+
107
+ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
108
+ # Validate the input shapes.
109
+ assert_is_matrix(x)
110
+ assert_is_vector(indices)
111
+ assert_is_vector(bin_ids)
112
+ assert_is_vector(bins)
113
+ assert_is_vector(padded_bins)
114
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
115
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
116
+ assert_equal(bins.size(), padded_bins.size())
117
+
118
+ if weights is not None:
119
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
120
+
121
+ # NOTE: Because of the padding, the output size is dynamic.
122
+ # We load the final padded bin bound to get the output rows.
123
+ output_rows = padded_bins[-1].cpu().item()
124
+ out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
125
+ _padded_copy[(indices.shape[0],)](
126
+ x,
127
+ out,
128
+ indices,
129
+ bin_ids,
130
+ weights,
131
+ bins,
132
+ padded_bins,
133
+ NUM_COLUMNS=x.shape[1],
134
+ A_TO_B=True,
135
+ TOP_K=top_k,
136
+ SCALE=weights is not None,
137
+ )
138
+ return out
139
+
140
+
141
+ def gather(x, indices, bin_ids, weights, bins, top_k):
142
+ # Validate the input shapes.
143
+ assert_is_matrix(x)
144
+ assert_is_vector(indices)
145
+ assert_is_vector(bin_ids)
146
+ assert_is_vector(bins)
147
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
148
+ assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
149
+
150
+ if weights is not None:
151
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
152
+
153
+ # NOTE: There is no padding so the output rows equals the
154
+ # input rows multiplied by top_k.
155
+ output_rows = x.shape[0] * top_k
156
+ out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
157
+ _padded_copy[(indices.shape[0],)](
158
+ x,
159
+ out,
160
+ indices,
161
+ bin_ids,
162
+ weights,
163
+ bins,
164
+ bins,
165
+ NUM_COLUMNS=x.shape[1],
166
+ A_TO_B=True,
167
+ TOP_K=top_k,
168
+ SCALE=weights is not None,
169
+ )
170
+ return out
171
+
172
+
173
+ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
174
+ # Validate the input shapes.
175
+ assert_is_matrix(x)
176
+ assert_is_vector(indices)
177
+ assert_is_vector(bin_ids)
178
+ assert_is_vector(bins)
179
+ assert_is_vector(padded_bins)
180
+ assert_equal(indices.shape[0], bin_ids.shape[0])
181
+ assert_equal(bins.size(), padded_bins.size())
182
+
183
+ if weights is not None:
184
+ assert_equal(indices.shape[0], weights.shape[0])
185
+
186
+ tokens = indices.shape[0] // top_k
187
+ out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
188
+ _padded_copy[(indices.shape[0],)](
189
+ out,
190
+ x,
191
+ indices,
192
+ bin_ids,
193
+ weights,
194
+ bins,
195
+ padded_bins,
196
+ NUM_COLUMNS=x.shape[1],
197
+ A_TO_B=False,
198
+ TOP_K=top_k,
199
+ SCALE=weights is not None,
200
+ )
201
+
202
+ # Reduce along the top-k dimension, if needed.
203
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
204
+
205
+
206
+ def scatter(x, indices, bin_ids, weights, bins, top_k):
207
+ return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
208
+
209
+
210
+ # x: (tokens, top_k, hidden_size), real
211
+ # grad: (tokens, hidden_size), real.
212
+ # wgrad: (tokens, top_k), real.
213
+ # indices: (tokens * top_k), integer.
214
+ # bin_ids: (tokens * top_k), integer.
215
+ # bins: (num_experts), integer.
216
+ # padded_bins: (num_experts), integer.
217
+ @triton.autotune(
218
+ configs=[
219
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
220
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
221
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
222
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
223
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
224
+ ],
225
+ key=['NUM_COLUMNS'],
226
+ )
227
+ @triton.jit
228
+ def _padded_copy_wgrad(
229
+ x,
230
+ grad,
231
+ wgrad,
232
+ indices,
233
+ bin_ids,
234
+ bins,
235
+ padded_bins,
236
+ NUM_COLUMNS: tl.constexpr,
237
+ TOP_K: tl.constexpr,
238
+ BLOCK_X: tl.constexpr,
239
+ ):
240
+ # Our index into 'tokens * top_k'.
241
+ index_out = tl.load(indices + tl.program_id(0))
242
+
243
+ # One threadblock per row in 'a'. Array 'b' has greater or equal
244
+ # number of rows since they could be padded.
245
+ bin_idx = tl.load(bin_ids + tl.program_id(0))
246
+
247
+ # Now we know what bin we're assigned to, but we need to know how
248
+ # many threadblocks were assigned to earlier bins so we can offset
249
+ # in our bin properly.
250
+ offset_in_bin = tl.program_id(0)
251
+ if bin_idx > 0:
252
+ offset_in_bin -= tl.load(bins + bin_idx - 1)
253
+
254
+ # Load the starting index of our bin in array 'x'.
255
+ index_x = offset_in_bin
256
+ if bin_idx > 0:
257
+ index_x += tl.load(padded_bins + bin_idx - 1)
258
+
259
+ # Offset the input and output pointers.
260
+ wgrad += index_out
261
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
262
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
263
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
264
+
265
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
266
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
267
+ for _ in range(iterations):
268
+ mask = offsets < NUM_COLUMNS
269
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
270
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
271
+ acc += data * scale
272
+ offsets += BLOCK_X
273
+
274
+ # Reduce to get the final result and store.
275
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
276
+ tl.store(wgrad, out)
277
+
278
+
279
+ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
280
+ # Validate the input shapes.
281
+ assert_is_matrix(x)
282
+ assert_is_matrix(grad)
283
+ assert_is_vector(indices)
284
+ assert_is_vector(bin_ids)
285
+ assert_is_vector(bins)
286
+ assert_is_vector(padded_bins)
287
+ assert_equal(indices.shape[0], bin_ids.shape[0])
288
+ assert_equal(bins.size(), padded_bins.size())
289
+
290
+ tokens = indices.shape[0] // top_k
291
+ out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
292
+ _padded_copy_wgrad[(indices.shape[0],)](
293
+ x,
294
+ grad,
295
+ out,
296
+ indices,
297
+ bin_ids,
298
+ bins,
299
+ padded_bins,
300
+ NUM_COLUMNS=x.shape[1],
301
+ TOP_K=top_k,
302
+ )
303
+ return out
304
+
305
+
306
+ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
307
+ return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
308
+
309
+
310
+ # a: (tokens, hidden_size), real.
311
+ # b: (num_experts, expert_capacity, num_columns), real.
312
+ # indices: (tokens * top_k), integer.
313
+ # weights: (tokens * top_k), real.
314
+ # bins: (num_experts), integer.
315
+ @triton.autotune(
316
+ configs=[
317
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
318
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
319
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
320
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
321
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
322
+ ],
323
+ key=['NUM_COLUMNS'],
324
+ )
325
+ @triton.jit
326
+ def _binned_copy(
327
+ a,
328
+ b,
329
+ num_experts,
330
+ expert_capacity,
331
+ indices,
332
+ weights,
333
+ bins,
334
+ NUM_COLUMNS: tl.constexpr,
335
+ TOP_K: tl.constexpr,
336
+ BLOCK_X: tl.constexpr,
337
+ A_TO_B: tl.constexpr,
338
+ SCALE: tl.constexpr,
339
+ ):
340
+ # Load our indices into the output.
341
+ expert_idx = tl.program_id(0)
342
+ entry_idx = tl.program_id(1)
343
+
344
+ # Calculate our offset into the output.
345
+ index_b = expert_idx * expert_capacity + entry_idx
346
+
347
+ # Load the index bounds for our bin and calculate
348
+ # the number of tokens assigned to our expert.
349
+ start = 0
350
+ if expert_idx > 0:
351
+ start = tl.load(bins + expert_idx - 1)
352
+ end = tl.load(bins + expert_idx)
353
+ num_tokens = end - start
354
+
355
+ # Calculate our offset into the input. If we don't
356
+ # have an input exit early.
357
+ if entry_idx >= num_tokens:
358
+ return
359
+ index_a = tl.load(indices + start + entry_idx)
360
+
361
+ # Offset the input and output pointers.
362
+ #
363
+ # If we're going from A to B, divide the input index to copy
364
+ # the same input repeatedly. If we're going from B to A we
365
+ # need to reduce the result. Using atomics is slow, so we
366
+ # do the reduce step in a second kernel.
367
+ offset = index_a // TOP_K if A_TO_B else index_a
368
+ a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
369
+ b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
370
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
371
+
372
+ # Load the scale, if requested.
373
+ scale = tl.load(weights + index_a) if SCALE else 1
374
+
375
+ # Swap the pointers depending on the direction.
376
+ #
377
+ # NOTE: We need to zero the output in both directions.
378
+ iptr = a if A_TO_B else b
379
+ optr = b if A_TO_B else a
380
+
381
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
382
+ for _ in range(iterations):
383
+ mask = offsets < NUM_COLUMNS
384
+ x = tl.load(iptr + offsets, mask=mask)
385
+ x = x.to(tl.float32) * scale.to(tl.float32)
386
+
387
+ tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
388
+
389
+ offsets += BLOCK_X
390
+
391
+
392
+ def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
393
+ # Validate the input shapes.
394
+ assert_is_matrix(x)
395
+ assert_is_vector(indices)
396
+ assert_is_vector(bins)
397
+ assert_equal(indices.shape[0], x.shape[0] * top_k)
398
+
399
+ if weights is not None:
400
+ assert_equal(weights.shape[0], x.shape[0] * top_k)
401
+
402
+ num_experts = bins.shape[0]
403
+ out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
404
+
405
+ _binned_copy[(num_experts, expert_capacity)](
406
+ x,
407
+ out,
408
+ num_experts,
409
+ expert_capacity,
410
+ indices,
411
+ weights,
412
+ bins,
413
+ NUM_COLUMNS=x.shape[1],
414
+ A_TO_B=True,
415
+ TOP_K=top_k,
416
+ SCALE=weights is not None,
417
+ )
418
+ return out
419
+
420
+
421
+ def binned_scatter(x, indices, weights, bins, top_k):
422
+ # Validate the input shapes.
423
+ assert_is_tensor(x, 3)
424
+ assert_is_vector(indices)
425
+ assert_is_vector(bins)
426
+ assert_equal(bins.shape[0], x.shape[0])
427
+
428
+ if weights is not None:
429
+ assert_equal(indices.shape[0], weights.shape[0])
430
+
431
+ num_experts, expert_capacity, hidden_size = x.shape
432
+ tokens = indices.shape[0] // top_k
433
+ out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
434
+ _binned_copy[(num_experts, expert_capacity)](
435
+ out,
436
+ x,
437
+ num_experts,
438
+ expert_capacity,
439
+ indices,
440
+ weights,
441
+ bins,
442
+ NUM_COLUMNS=hidden_size,
443
+ A_TO_B=False,
444
+ TOP_K=top_k,
445
+ SCALE=weights is not None,
446
+ )
447
+
448
+ # Reduce along the top-k dimension, if needed.
449
+ return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
450
+
451
+
452
+ # a: (tokens, hidden_size), real.
453
+ # b: (num_experts, expert_capacity, num_columns), real.
454
+ # indices: (tokens * top_k), integer.
455
+ # weights: (tokens * top_k), real.
456
+ # bins: (num_experts), integer.
457
+ @triton.autotune(
458
+ configs=[
459
+ triton.Config({'BLOCK_X': 64}, num_warps=2),
460
+ triton.Config({'BLOCK_X': 128}, num_warps=2),
461
+ triton.Config({'BLOCK_X': 256}, num_warps=2),
462
+ triton.Config({'BLOCK_X': 128}, num_warps=4),
463
+ triton.Config({'BLOCK_X': 256}, num_warps=4),
464
+ ],
465
+ key=['NUM_COLUMNS'],
466
+ )
467
+ @triton.jit
468
+ def _binned_copy_wgrad(
469
+ x,
470
+ grad,
471
+ wgrad,
472
+ num_experts,
473
+ expert_capacity,
474
+ indices,
475
+ bins,
476
+ NUM_COLUMNS: tl.constexpr,
477
+ TOP_K: tl.constexpr,
478
+ BLOCK_X: tl.constexpr,
479
+ ):
480
+ # Load our indices into the output.
481
+ expert_idx = tl.program_id(0)
482
+ entry_idx = tl.program_id(1)
483
+
484
+ # Calculate our offset into the output.
485
+ index_x = expert_idx * expert_capacity + entry_idx
486
+
487
+ # Load the index bounds for our bin and calculate
488
+ # the number of tokens assigned to our expert.
489
+ start = 0
490
+ if expert_idx > 0:
491
+ start = tl.load(bins + expert_idx - 1)
492
+ end = tl.load(bins + expert_idx)
493
+ num_tokens = end - start
494
+
495
+ # Calculate our offset into the input. If we don't
496
+ # have an input exit early.
497
+ if entry_idx >= num_tokens:
498
+ return
499
+ index_out = tl.load(indices + start + entry_idx)
500
+
501
+ # Offset the input and output pointers.
502
+ wgrad += index_out
503
+ grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
504
+ x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
505
+ offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
506
+
507
+ acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
508
+ iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
509
+ for _ in range(iterations):
510
+ mask = offsets < NUM_COLUMNS
511
+ data = tl.load(x + offsets, mask=mask).to(tl.float32)
512
+ scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
513
+ acc += data * scale
514
+ offsets += BLOCK_X
515
+
516
+ # Reduce to get the final result and store.
517
+ out = tl.sum(acc).to(wgrad.dtype.element_ty)
518
+ tl.store(wgrad, out)
519
+
520
+
521
+ def binned_scatter_wgrad(x, grad, indices, bins, top_k):
522
+ # Validate the input shapes.
523
+ assert_is_tensor(x, 3)
524
+ assert_is_matrix(grad)
525
+ assert_is_vector(indices)
526
+ assert_is_vector(bins)
527
+ assert_equal(bins.shape[0], x.shape[0])
528
+
529
+ num_experts, expert_capacity, hidden_size = x.shape
530
+ tokens = indices.shape[0] // top_k
531
+ out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
532
+ _binned_copy_wgrad[(num_experts, expert_capacity)](
533
+ x,
534
+ grad,
535
+ out,
536
+ num_experts,
537
+ expert_capacity,
538
+ indices,
539
+ bins,
540
+ NUM_COLUMNS=hidden_size,
541
+ TOP_K=top_k,
542
+ )
543
+ return out
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from megablocks_moe.megablocks import (
2
+ MoE,
3
+ dMoE,
4
+ get_load_balancing_loss,
5
+ ParallelMLP,
6
+ ParallelDroplessMLP,
7
+ SparseMLP,
8
+ MLP,
9
+ SparseGLU,
10
+ Arguments,
11
+ )
12
+
13
+ __all__ = [
14
+ "MoE",
15
+ "dMoE",
16
+ "get_load_balancing_loss",
17
+ "ParallelMLP",
18
+ "ParallelDroplessMLP",
19
+ "SparseMLP",
20
+ "MLP",
21
+ "SparseGLU",
22
+ "Arguments",
23
+ ]
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+
8
+ def log_benchmark(name, arguments, time, std):
9
+ print('=' * 60)
10
+ print(f'{name} Benchmark')
11
+ print('Benchmark Parameters:')
12
+ for (key, value) in arguments.items():
13
+ print(f'{key} = {value}')
14
+ print('Results:')
15
+ print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
16
+ print('=' * 60)
17
+
18
+
19
+ def benchmark_function(fn, iterations=100, warmup=10):
20
+ # Warmup iterations.
21
+ for _ in range(warmup):
22
+ fn()
23
+
24
+ times = []
25
+ for i in range(iterations):
26
+ start = torch.cuda.Event(enable_timing=True)
27
+ end = torch.cuda.Event(enable_timing=True)
28
+
29
+ start.record()
30
+ fn()
31
+ end.record()
32
+
33
+ torch.cuda.synchronize()
34
+ times.append(start.elapsed_time(end))
35
+ return np.mean(times), np.std(times)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from . import ops
2
+ from . import backend
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NOTE: Torch needs to be imported before the custom
2
+ # extensions. Otherwise libc10.so cannot be found.
3
+ import torch
4
+
5
+ # # TODO(tgale): Wrap this in a try-block with better
6
+ # # error message and instructions for building the
7
+ # # c++ operations.
8
+ # import grouped_gemm_backend as backend
9
+
10
+ # We import the backend operations from the megablocks package as
11
+ # grouped_gemm is vendored in megablocks in this repository.
12
+ # from ... import _ops as backend
13
+ # from megablocks._ops import ops as backend # type: ignore
14
+ from .._ops import ops as backend # type: ignore
15
+
16
+ def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
17
+ assert not (trans_a and trans_b)
18
+ assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
19
+ assert a.ndim == 2, "Expected 2d tensor for 'a'"
20
+ assert b.ndim == (2 if trans_a else 3)
21
+
22
+ shape = (
23
+ (batch_sizes.shape[0], a.shape[1], b.shape[1])
24
+ if trans_a else
25
+ (a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
26
+ )
27
+ return torch.empty(*shape, device=a.device, dtype=a.dtype)
28
+
29
+ def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
30
+ if c is None:
31
+ c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
32
+ backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
33
+ return c
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import backend
2
+ import torch
3
+
4
+
5
+ class GroupedGemm(torch.autograd.Function):
6
+
7
+ @staticmethod
8
+ def forward(ctx, a, b, batch_sizes, trans_b):
9
+ ctx.save_for_backward(a, b, batch_sizes)
10
+ ctx.trans_b = trans_b
11
+ return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
12
+
13
+ @staticmethod
14
+ def backward(ctx, grad):
15
+ grad = grad.contiguous()
16
+ a, b, batch_sizes = ctx.saved_tensors
17
+ trans_b = ctx.trans_b
18
+
19
+ agrad = None
20
+ if ctx.needs_input_grad[0]:
21
+ agrad = backend.gmm(
22
+ grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
23
+
24
+ bgrad = None
25
+ if ctx.needs_input_grad[1]:
26
+ lhs, rhs = (grad, a) if trans_b else (a, grad)
27
+ bgrad = backend.gmm(
28
+ lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
29
+ return agrad, bgrad, None, None
30
+
31
+
32
+ def gmm(a, b, batch_sizes, trans_b=False):
33
+ return GroupedGemm.apply(a, b, batch_sizes, trans_b)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import warnings
4
+
5
+ _grouped_gemm_is_available: bool = False
6
+ try:
7
+ # import grouped_gemm
8
+ pass
9
+ _grouped_gemm_is_available = True
10
+ except ImportError as error:
11
+ warnings.warn('Grouped GEMM not available.')
12
+
13
+
14
+ def grouped_gemm_is_available():
15
+ return _grouped_gemm_is_available
16
+
17
+
18
+ def assert_grouped_gemm_is_available():
19
+ msg = (
20
+ 'Grouped GEMM not available. Please run '
21
+ '`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
22
+ )
23
+ assert _grouped_gemm_is_available, msg
24
+
25
+
26
+ # backend = grouped_gemm.backend if grouped_gemm_is_available() else None
27
+ # ops = grouped_gemm.ops if grouped_gemm_is_available() else None
28
+
29
+
30
+ #from .grouped_gemm import backend as ops
31
+ #from .grouped_gemm import ops as backend
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+
4
+ from typing import Optional, Any
5
+
6
+ from . import _layers
7
+ from . import ops
8
+
9
+
10
+ # Set the expert model parallel attributes on a tensor
11
+ def set_expert_model_parallel_attributes(
12
+ tensor: torch.Tensor,
13
+ is_parallel: bool,
14
+ ):
15
+ assert not hasattr(tensor, "expert_model_parallel")
16
+ setattr(tensor, "expert_model_parallel", is_parallel)
17
+
18
+
19
+ # Get the expert model parallel attributes from a tensor
20
+ def expert_sharding_degree(
21
+ world_size: int,
22
+ moe_num_experts: int,
23
+ ) -> int:
24
+ esd = min(world_size, moe_num_experts)
25
+ if (moe_num_experts % esd) != 0:
26
+ raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
27
+ return esd
28
+
29
+
30
+ # Calculate the hidden sharding degree based on world size and expert sharding degree
31
+ def hidden_sharding_degree(
32
+ world_size: int,
33
+ moe_num_experts: int,
34
+ ffn_hidden_size: int,
35
+ ) -> int:
36
+ esd = expert_sharding_degree(world_size, moe_num_experts)
37
+ hsd = world_size // esd
38
+ if (ffn_hidden_size % hsd) != 0:
39
+ raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
40
+ if (esd * hsd) != world_size:
41
+ raise ValueError(
42
+ f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
43
+ )
44
+ return hsd
45
+
46
+
47
+ # Calculate the number of experts per rank based on world size and expert sharding degree
48
+ def experts_per_rank(
49
+ moe_num_experts: int,
50
+ world_size: int,
51
+ ) -> int:
52
+ return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
53
+
54
+
55
+ # Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
56
+ def features_per_rank(
57
+ ffn_hidden_size: int, world_size: int, moe_num_experts: int
58
+ ) -> int:
59
+ return ffn_hidden_size // hidden_sharding_degree(
60
+ world_size, moe_num_experts, ffn_hidden_size
61
+ )
62
+
63
+
64
+ # Apply jitter to the input tensor
65
+ def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
66
+ low = 1.0 - moe_jitter_eps
67
+ high = 1.0 + moe_jitter_eps
68
+ noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
69
+ return x * (low + noise * (high - low))
70
+
71
+
72
+ # Compute the top-k scores from the logits
73
+ def compute_top_k(scores: torch.Tensor, moe_top_k: int):
74
+ if moe_top_k == 1:
75
+ return scores.max(dim=-1, keepdim=True)
76
+ return torch.topk(scores, moe_top_k, dim=-1)
77
+
78
+
79
+ # Route tokens to experts and compute expert weights and indices
80
+ def route_tokens(
81
+ x: torch.Tensor,
82
+ router_weight: torch.Tensor,
83
+ moe_top_k: int,
84
+ moe_num_experts: int,
85
+ moe_jitter_eps: float = None,
86
+ moe_normalize_expert_weights: int = None,
87
+ uniform_expert_assignment: bool = False,
88
+ training: bool = False,
89
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
90
+ if training and moe_jitter_eps is not None:
91
+ x = apply_jitter(x, moe_jitter_eps)
92
+
93
+ x_flat = x.view(-1, x.shape[-1])
94
+ logits = torch.nn.functional.linear(x_flat, router_weight)
95
+ expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
96
+ expert_weights = expert_weights.softmax(dim=-1)
97
+ if moe_normalize_expert_weights is not None:
98
+ expert_weights = expert_weights / torch.norm(
99
+ expert_weights,
100
+ p=moe_normalize_expert_weights,
101
+ dim=-1,
102
+ keepdim=True,
103
+ )
104
+ if uniform_expert_assignment:
105
+ expert_indices = _layers.router._uniform_expert_assignment(
106
+ expert_indices,
107
+ moe_num_experts,
108
+ )
109
+
110
+ return logits, expert_weights, expert_indices
111
+
112
+
113
+ # Scale the gradient of the weights
114
+ def scale_grad(
115
+ w: torch.Tensor,
116
+ gradient_scale: Optional[float] = None,
117
+ ) -> torch.Tensor:
118
+ if gradient_scale is None:
119
+ return w
120
+ return _layers.mlp.scale_gradient(w, gradient_scale)
121
+
122
+
123
+ # Forward pass for the MLP layer
124
+ def mlp_forward(
125
+ x: torch.Tensor,
126
+ w1: torch.Tensor,
127
+ w2: torch.Tensor,
128
+ w1_bias: torch.Tensor,
129
+ w2_bias: torch.Tensor,
130
+ gradient_scale: Optional[float] = None,
131
+ alpha: float = 1.702,
132
+ ):
133
+ # Scale weights
134
+ w1 = scale_grad(w1, gradient_scale)
135
+ w2 = scale_grad(w2, gradient_scale)
136
+ w1_bias = scale_grad(w1_bias, gradient_scale)
137
+ w2_bias = scale_grad(w2_bias, gradient_scale)
138
+
139
+ # Resolve dtensors
140
+ w1 = _layers.mlp.resolve_dtensor(w1)
141
+ w2 = _layers.mlp.resolve_dtensor(w2)
142
+ w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
143
+ w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
144
+
145
+ # Forward pass
146
+ gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
147
+ gate, up = gate_up.chunk(2, dim=-1)
148
+
149
+ glu = gate * torch.sigmoid(gate * alpha)
150
+ x = (up + 1) * glu
151
+
152
+ return torch.bmm(x, w2) + w2_bias[..., None, :]
153
+
154
+
155
+ # Shared expert MLP forward pass
156
+ def shared_mlp_forward(
157
+ x: torch.Tensor,
158
+ up_proj_weight: torch.Tensor,
159
+ down_proj_weight: torch.Tensor,
160
+ up_proj_bias: Optional[torch.Tensor] = None,
161
+ down_proj_bias: Optional[torch.Tensor] = None,
162
+ activation_fn: Optional[Any] = None,
163
+ gradient_scale: Optional[float] = None,
164
+ ) -> torch.Tensor:
165
+ # Default activation function
166
+ if activation_fn is None:
167
+ activation_fn = torch.nn.functional.gelu
168
+
169
+ # Scale weights
170
+ up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
171
+ down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
172
+ if up_proj_bias is not None:
173
+ up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
174
+ if down_proj_bias is not None:
175
+ down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
176
+
177
+ # Resolve dtensors
178
+ up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
179
+ down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
180
+ if up_proj_bias is not None:
181
+ up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
182
+ if down_proj_bias is not None:
183
+ down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
184
+
185
+ # Up projection
186
+ x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
187
+
188
+ # Activation
189
+ x = activation_fn(x)
190
+
191
+ # Down projection
192
+ x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
193
+
194
+ return x
195
+
196
+
197
+ # Combine outputs from shared expert and regular experts
198
+ def combine_expert_shared_outputs(
199
+ shared_expert_out: torch.Tensor,
200
+ expert_out: torch.Tensor,
201
+ shared_expert_weighted_sum: bool = False,
202
+ moe_top_k: int = 1,
203
+ ) -> torch.Tensor:
204
+ if shared_expert_weighted_sum:
205
+ # Weighted sum based on number of experts used
206
+ total_experts = moe_top_k + 1
207
+ shared_weight = 1.0 / total_experts
208
+ expert_weight = moe_top_k / total_experts
209
+ return shared_expert_out * shared_weight + expert_out * expert_weight
210
+ else:
211
+ # Simple addition
212
+ return shared_expert_out + expert_out
213
+
214
+
215
+ # Global variable to store load balancing loss
216
+ _LOAD_BALANCING_LOSS = []
217
+
218
+
219
+ def save_load_balancing_loss(loss):
220
+ global _LOAD_BALANCING_LOSS
221
+ _LOAD_BALANCING_LOSS.append(loss)
222
+
223
+
224
+ def get_load_balancing_loss():
225
+ global _LOAD_BALANCING_LOSS
226
+ return _LOAD_BALANCING_LOSS
227
+
228
+
229
+ def clear_load_balancing_loss():
230
+ global _LOAD_BALANCING_LOSS
231
+ _LOAD_BALANCING_LOSS.clear()
232
+
233
+
234
+ def batched_load_balancing_loss(args):
235
+ if args.moe_loss_weight == 0:
236
+ return 0.0
237
+
238
+ tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
239
+ num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
240
+ if args.num_layers_per_virtual_pipeline_stage is not None:
241
+ num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
242
+
243
+ if len(tokens_per_expert) != num_layers_per_pipeline_stage:
244
+ raise ValueError(
245
+ f"Expected {num_layers_per_pipeline_stage} token_per_experts "
246
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
247
+ f"{args.num_layers}\npipeline_model_parallel_size = "
248
+ f"{args.pipeline_model_parallel_size}\n"
249
+ "num_layers_per_virtual_pipeline_stage"
250
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
251
+ )
252
+ if len(expert_scores) != num_layers_per_pipeline_stage:
253
+ raise ValueError(
254
+ f"Expected {num_layers_per_pipeline_stage} expert_scores "
255
+ f"but found {len(tokens_per_expert)}.\nnum_layers = "
256
+ f"{args.num_layers}\npipeline_model_parallel_size = "
257
+ f"{args.pipeline_model_parallel_size}\n"
258
+ "num_layers_per_virtual_pipeline_stage"
259
+ f" = {args.num_layers_per_virtual_pipeline_stage}",
260
+ )
261
+
262
+ # Verify the shape of the tokens_per_expert and expert_scores tensors.
263
+ assert all(
264
+ (x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
265
+ )
266
+
267
+ tokens = expert_scores[0].shape[0]
268
+ assert all(
269
+ (
270
+ (
271
+ x.ndim == 2
272
+ and x.shape[1] == args.moe_num_experts
273
+ and x.shape[0] == tokens
274
+ )
275
+ for x in expert_scores
276
+ )
277
+ )
278
+
279
+ # Concatenate the contributions of each layer and convert to
280
+ # the correct types and formats for the dot product.
281
+ expert_scores = torch.cat(expert_scores, dim=1)
282
+ if args.moe_lbl_in_fp32:
283
+ expert_scores = expert_scores.float()
284
+ if tokens != 0:
285
+ expert_scores = expert_scores.mean(dim=0)
286
+ else:
287
+ expert_scores = expert_scores.sum(dim=0)
288
+ tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
289
+
290
+ expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
291
+ assert tokens_per_expert.numel() == expected_values
292
+ assert expert_scores.numel() == expected_values
293
+
294
+ # Calculate the total scale across all factors.
295
+ #
296
+ # loss_weight * num_experts / (num_layers * tokens * top_k)
297
+ scale_numerator = args.moe_num_experts * args.moe_loss_weight
298
+ scale_denominator = args.num_layers * tokens * args.moe_top_k
299
+ scale = scale_numerator / scale_denominator
300
+ return scale * torch.dot(tokens_per_expert, expert_scores)
301
+
302
+
303
+ # Calculate the expert capacity based on tokens, top_k, number of experts,
304
+ # expert parallel group, capacity factor, and whether expert model parallelism is used.
305
+ def expert_capacity(
306
+ tokens: int,
307
+ top_k: int,
308
+ num_experts: int,
309
+ expert_parallel_group: int,
310
+ moe_capacity_factor: float,
311
+ moe_expert_model_parallelism: bool,
312
+ ) -> int:
313
+ world_size = (
314
+ dist.get_world_size(expert_parallel_group)
315
+ if moe_expert_model_parallelism
316
+ else 1
317
+ )
318
+
319
+ tokens_per_expert = top_k * tokens * world_size / num_experts
320
+ return int(moe_capacity_factor * tokens_per_expert)
321
+
322
+
323
+ def load_balancing_loss(
324
+ tokens_per_expert: torch.Tensor,
325
+ expert_scores: torch.Tensor,
326
+ top_k: int,
327
+ num_experts: int,
328
+ ):
329
+ assert len(expert_scores.size()) == 2
330
+ tokens, num_experts = expert_scores.size()
331
+ assert num_experts == num_experts
332
+ assert len(tokens_per_expert.size()) == 1
333
+ (num_experts,) = tokens_per_expert.size()
334
+ assert num_experts == num_experts
335
+ scale = num_experts / (tokens * top_k)
336
+ return scale * torch.dot(
337
+ tokens_per_expert.to(expert_scores.dtype),
338
+ expert_scores.mean(dim=0),
339
+ )
340
+
341
+
342
+ def indices_and_bins(
343
+ top_expert: torch.Tensor,
344
+ sort_end_bit: int,
345
+ num_experts: int,
346
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
347
+ top_expert = top_expert.int()
348
+
349
+ # Ensure contiguous memory layout
350
+ top_expert = top_expert.contiguous()
351
+
352
+ # Ensure CUB knows which device to use
353
+ with torch.cuda.device(top_expert.device):
354
+ output = ops.sort(top_expert, sort_end_bit)
355
+ bin_ids, indices = output
356
+ tokens_per_expert = ops.histogram(top_expert, num_experts)
357
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
358
+
359
+ bins = bins.view(1) if not len(bins.size()) else bins
360
+ return indices, bin_ids, bins, tokens_per_expert
361
+
362
+
363
+ def expert_capacity_fn(
364
+ tokens: int,
365
+ top_k: int,
366
+ num_experts: int,
367
+ expert_parallel_group: torch.distributed.ProcessGroup,
368
+ moe_capacity_factor: float = 1.0,
369
+ moe_expert_model_parallelism: bool = False,
370
+ ) -> int:
371
+ world_size = (
372
+ dist.get_world_size(expert_parallel_group)
373
+ if moe_expert_model_parallelism
374
+ else 1
375
+ )
376
+ tokens_per_expert = top_k * tokens * world_size / num_experts
377
+ return int(moe_capacity_factor * tokens_per_expert)
378
+
379
+
380
+ def permute_and_compute(
381
+ x,
382
+ tokens_per_expert,
383
+ indices,
384
+ bin_ids,
385
+ expert_weights,
386
+ bins,
387
+ expert_capacity,
388
+ top_k,
389
+ w1,
390
+ w2,
391
+ w1_bias,
392
+ w2_bias,
393
+ gradient_scale,
394
+ alpha,
395
+ ):
396
+ # Route tokens to experts
397
+ x = x.view(-1, x.shape[-1])
398
+
399
+ # Ensure CUB knows which device to use
400
+ with torch.cuda.device(x.device):
401
+ x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
402
+
403
+ # Expert computation
404
+ x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
405
+
406
+ # Ensure CUB knows which device to use
407
+ with torch.cuda.device(x.device):
408
+ # Route tokens back
409
+ out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
410
+ return out
411
+
412
+
413
+ def forward_once(
414
+ x: torch.Tensor,
415
+ expert_weights: torch.Tensor,
416
+ top_experts: torch.Tensor,
417
+ w1: torch.Tensor,
418
+ w2: torch.Tensor,
419
+ w1_bias: torch.Tensor,
420
+ w2_bias: torch.Tensor,
421
+ gradient_scale: Optional[float] = None,
422
+ alpha: float = 1.702,
423
+ sort_end_bit: int = 0,
424
+ top_k: int = 4,
425
+ num_experts: int = 128,
426
+ expert_parallel_group: int = None,
427
+ moe_capacity_factor: float = 1.0,
428
+ moe_expert_model_parallelism: bool = False,
429
+ mlp_impl: Optional[str] = None,
430
+ ):
431
+ # x: [sl, bs, hs]
432
+ # expert_weights: [sl * bs, top-k]
433
+ # top_experts: [sl * bs, top-k]
434
+ expert_weights = expert_weights.flatten()
435
+ top_experts = top_experts.flatten()
436
+
437
+ with torch.no_grad():
438
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
439
+ top_experts, sort_end_bit, num_experts
440
+ )
441
+
442
+ # Calculate expert capacity
443
+ sl, bs, _ = x.size()
444
+
445
+ expert_capacity = expert_capacity_fn(
446
+ sl * bs,
447
+ top_k,
448
+ num_experts,
449
+ expert_parallel_group,
450
+ moe_capacity_factor,
451
+ moe_expert_model_parallelism,
452
+ )
453
+
454
+ if expert_capacity == 0:
455
+ expert_capacity = torch.max(tokens_per_expert).item()
456
+
457
+ x = permute_and_compute(
458
+ x,
459
+ tokens_per_expert,
460
+ indices,
461
+ bin_ids,
462
+ expert_weights,
463
+ bins,
464
+ expert_capacity,
465
+ top_k,
466
+ w1,
467
+ w2,
468
+ w1_bias,
469
+ w2_bias,
470
+ gradient_scale,
471
+ alpha,
472
+ )
473
+ return x, tokens_per_expert
474
+
475
+
476
+ def parallel_forward_once(
477
+ x: torch.Tensor,
478
+ expert_weights: torch.Tensor,
479
+ top_experts: torch.Tensor,
480
+ w1: torch.Tensor,
481
+ w2: torch.Tensor,
482
+ w1_bias: torch.Tensor,
483
+ w2_bias: torch.Tensor,
484
+ gradient_scale: Optional[float] = None,
485
+ alpha: float = 1.702,
486
+ sort_end_bit: int = 0,
487
+ top_k: int = 4,
488
+ num_experts: int = 128,
489
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
490
+ moe_capacity_factor: float = 1.0,
491
+ moe_expert_model_parallelism: bool = True,
492
+ hidden_size: int = 1152,
493
+ mlp_impl: Optional[str] = "sparse",
494
+ ):
495
+ # Flatten inputs
496
+ expert_weights = expert_weights.flatten()
497
+ top_experts = top_experts.flatten()
498
+
499
+ # TODO: remove debugging var
500
+ # my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
501
+
502
+ with torch.no_grad():
503
+ # Step 1: Local permutation setup
504
+ indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
505
+ top_experts, sort_end_bit, num_experts
506
+ )
507
+
508
+ # Calculate sharding parameters
509
+ world_size = dist.get_world_size(expert_parallel_group)
510
+ hidden_sharding_deg = hidden_sharding_degree(
511
+ world_size, num_experts, hidden_size
512
+ )
513
+ experts_per_rank_val = experts_per_rank(num_experts, world_size)
514
+
515
+ # Replicate token counts for hidden sharding
516
+ repeated_tokens_per_expert = ops.repeat(
517
+ tokens_per_expert, (hidden_sharding_deg,)
518
+ )
519
+
520
+ # Exchange token counts across devices
521
+ parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
522
+
523
+ # Ensure CUB knows which device to use
524
+ tpe_handle = dist.all_to_all_single(
525
+ parallel_tokens_per_expert,
526
+ repeated_tokens_per_expert,
527
+ group=expert_parallel_group,
528
+ async_op=True,
529
+ )
530
+
531
+ # Step 2: Local permutation - group tokens by target device
532
+ x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
533
+ x = ops.gather(x, indices, bin_ids, bins, top_k)
534
+
535
+ # Step 3: Compute communication counts and exchange tokens
536
+ with torch.no_grad():
537
+ tpe_handle.wait()
538
+
539
+ # Reshape for per-device calculations
540
+ repeated_tokens_per_expert = repeated_tokens_per_expert.view(
541
+ world_size, experts_per_rank_val
542
+ )
543
+ parallel_tokens_per_expert = parallel_tokens_per_expert.view(
544
+ world_size, experts_per_rank_val
545
+ )
546
+
547
+ # Calculate send/recv counts
548
+ send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
549
+ # recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
550
+ parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
551
+ recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
552
+ tokens_received = sum(recv_counts)
553
+
554
+ # Replicate for hidden sharding
555
+ x = ops.repeat(x, (hidden_sharding_deg, 1))
556
+
557
+ # Cross-device token exchange
558
+ parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
559
+ x, recv_counts, send_counts, expert_parallel_group, async_op=True
560
+ )
561
+
562
+ with torch.no_grad():
563
+ # Step 4: Setup for local expert computation
564
+ replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
565
+ replicate_bins = (
566
+ replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
567
+ )
568
+
569
+ # Create expert indices for received tokens
570
+ parallel_top_expert = torch.remainder(
571
+ torch.arange(
572
+ num_experts * hidden_sharding_deg,
573
+ dtype=torch.int32,
574
+ device=indices.device,
575
+ ),
576
+ experts_per_rank_val,
577
+ )
578
+ parallel_top_expert = ops.replicate(
579
+ parallel_top_expert.unsqueeze(dim=0),
580
+ replicate_bins,
581
+ tokens_received,
582
+ ).flatten()
583
+
584
+ # Sort tokens by expert assignment
585
+ parallel_bin_ids, parallel_indices = ops.sort(
586
+ parallel_top_expert,
587
+ sort_end_bit,
588
+ )
589
+
590
+ # Calculate bins for local experts
591
+ parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
592
+ dim=0, dtype=torch.int
593
+ )
594
+ parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
595
+ parallel_bins = (
596
+ parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
597
+ )
598
+
599
+ # Calculate expert capacity
600
+ expert_capacity = expert_capacity_fn(
601
+ tokens_received,
602
+ top_k,
603
+ experts_per_rank_val,
604
+ expert_parallel_group,
605
+ moe_capacity_factor,
606
+ moe_expert_model_parallelism,
607
+ )
608
+ if expert_capacity == 0:
609
+ expert_capacity = torch.max(parallel_tokens_per_expert).item()
610
+
611
+ # Locally permute the tokens and perform the expert computation.
612
+ # Block to make sure that the cross-device permutation is complete.
613
+ if mlp_impl == "grouped":
614
+ # GroupedMLP requires counts on CPU. We can use the tensor already
615
+ # moved to CPU for the prior all_to_all, which avoids an extra
616
+ # device synchronization.
617
+ parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
618
+ dim=0,
619
+ dtype=torch.int,
620
+ )
621
+
622
+ # Step 5: Expert computation
623
+ parallel_x_handle.wait()
624
+
625
+ parallel_x = permute_and_compute(
626
+ parallel_x,
627
+ parallel_tokens_per_expert,
628
+ parallel_indices,
629
+ parallel_bin_ids,
630
+ None, # expert_weights
631
+ parallel_bins,
632
+ expert_capacity,
633
+ top_k=1,
634
+ w1=w1,
635
+ w2=w2,
636
+ w1_bias=w1_bias,
637
+ w2_bias=w2_bias,
638
+ gradient_scale=gradient_scale,
639
+ alpha=alpha,
640
+ )
641
+
642
+ # Step 6: Reverse communication - send results back
643
+ x, _ = _layers.all_to_all.all_to_all(
644
+ parallel_x, send_counts, recv_counts, expert_parallel_group
645
+ )
646
+
647
+ # Step 7: Reduce across hidden sharding dimension
648
+ shape = (hidden_sharding_deg, -1, hidden_size)
649
+ x = x.view(shape).sum(dim=0)
650
+
651
+ # Step 8: Final local unpermutation
652
+ x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
653
+
654
+ return x, tokens_per_expert.flatten()
655
+
656
+
657
+ def moe_forward(
658
+ x: torch.Tensor,
659
+ router_weight: torch.Tensor,
660
+ moe_top_k: int,
661
+ moe_num_experts: int,
662
+ moe_jitter_eps: float = None,
663
+ moe_normalize_expert_weights: int = None,
664
+ uniform_expert_assignment: bool = False,
665
+ training: bool = False,
666
+ w1: torch.Tensor = None,
667
+ w2: torch.Tensor = None,
668
+ w1_bias: torch.Tensor = None,
669
+ w2_bias: torch.Tensor = None,
670
+ gradient_scale: Optional[float] = None,
671
+ alpha: float = 1.702,
672
+ sort_end_bit: int = 0,
673
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
674
+ moe_capacity_factor: float = 1.0,
675
+ moe_expert_model_parallelism: bool = False,
676
+ forward_fn: Any = None,
677
+ hidden_size: int = None,
678
+ mlp_impl: str = "grouped",
679
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
680
+
681
+ # Route tokens to experts
682
+ logits, expert_weights, expert_indices = route_tokens(
683
+ x,
684
+ router_weight,
685
+ moe_top_k,
686
+ moe_num_experts,
687
+ moe_jitter_eps,
688
+ moe_normalize_expert_weights,
689
+ uniform_expert_assignment,
690
+ training,
691
+ )
692
+
693
+ # Create router scores for output
694
+ router_scores = (
695
+ torch.zeros_like(logits)
696
+ .scatter_(1, expert_indices, expert_weights)
697
+ .transpose(0, 1)
698
+ )
699
+
700
+ in_shape = x.size()
701
+
702
+ # Prepare forward function arguments
703
+ forward_args = {
704
+ "x": x,
705
+ "expert_weights": expert_weights,
706
+ "top_experts": expert_indices,
707
+ "w1": w1,
708
+ "w2": w2,
709
+ "w1_bias": w1_bias,
710
+ "w2_bias": w2_bias,
711
+ "gradient_scale": gradient_scale,
712
+ "alpha": alpha,
713
+ "sort_end_bit": sort_end_bit,
714
+ "top_k": moe_top_k,
715
+ "num_experts": moe_num_experts,
716
+ "expert_parallel_group": expert_parallel_group,
717
+ "moe_capacity_factor": moe_capacity_factor,
718
+ "moe_expert_model_parallelism": moe_expert_model_parallelism,
719
+ "mlp_impl": mlp_impl,
720
+ }
721
+
722
+ # Add hidden_size for parallel forward
723
+ if moe_expert_model_parallelism and hidden_size is not None:
724
+ forward_args["hidden_size"] = hidden_size
725
+ elif moe_expert_model_parallelism and hidden_size is None:
726
+ # Infer hidden_size from input shape
727
+ forward_args["hidden_size"] = x.shape[-1]
728
+
729
+ # Compute expert outputs
730
+ x, tokens_per_expert = forward_fn(**forward_args)
731
+
732
+ # Save load balancing loss if needed
733
+ moe_loss_weight = 0.0 # Can be made configurable
734
+ if training and moe_loss_weight > 0:
735
+ save_load_balancing_loss((tokens_per_expert, logits))
736
+
737
+ # Restore original shape
738
+ x = x.view(in_shape)
739
+
740
+ return x, expert_weights, router_scores
741
+
742
+
743
+ def moe_forward_with_shared_expert(
744
+ x: torch.Tensor,
745
+ router_weight: torch.Tensor,
746
+ moe_top_k: int,
747
+ moe_num_experts: int,
748
+ moe_jitter_eps: float = None,
749
+ moe_normalize_expert_weights: int = None,
750
+ uniform_expert_assignment: bool = False,
751
+ training: bool = False,
752
+ w1: torch.Tensor = None,
753
+ w2: torch.Tensor = None,
754
+ w1_bias: torch.Tensor = None,
755
+ w2_bias: torch.Tensor = None,
756
+ gradient_scale: Optional[float] = None,
757
+ alpha: float = 1.702,
758
+ sort_end_bit: int = 0,
759
+ expert_parallel_group: torch.distributed.ProcessGroup = None,
760
+ moe_capacity_factor: float = 1.0,
761
+ moe_expert_model_parallelism: bool = False,
762
+ forward_fn: Any = None,
763
+ hidden_size: int = None,
764
+ mlp_impl: str = "grouped",
765
+ # Shared expert parameters
766
+ shared_up_proj_weight: Optional[torch.Tensor] = None,
767
+ shared_down_proj_weight: Optional[torch.Tensor] = None,
768
+ shared_up_proj_bias: Optional[torch.Tensor] = None,
769
+ shared_down_proj_bias: Optional[torch.Tensor] = None,
770
+ shared_expert_weighted_sum: bool = False,
771
+ shared_activation_fn: Optional[Any] = None,
772
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
773
+
774
+ # First, compute regular MoE forward pass
775
+ expert_out, expert_weights, router_scores = moe_forward(
776
+ x=x,
777
+ router_weight=router_weight,
778
+ moe_top_k=moe_top_k,
779
+ moe_num_experts=moe_num_experts,
780
+ moe_jitter_eps=moe_jitter_eps,
781
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
782
+ uniform_expert_assignment=uniform_expert_assignment,
783
+ training=training,
784
+ w1=w1,
785
+ w2=w2,
786
+ w1_bias=w1_bias,
787
+ w2_bias=w2_bias,
788
+ gradient_scale=gradient_scale,
789
+ alpha=alpha,
790
+ sort_end_bit=sort_end_bit,
791
+ expert_parallel_group=expert_parallel_group,
792
+ moe_capacity_factor=moe_capacity_factor,
793
+ moe_expert_model_parallelism=moe_expert_model_parallelism,
794
+ forward_fn=forward_fn,
795
+ hidden_size=hidden_size,
796
+ mlp_impl=mlp_impl,
797
+ )
798
+
799
+ # If shared expert weights provided, compute shared expert output
800
+ if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
801
+ shared_expert_out = shared_mlp_forward(
802
+ x=x,
803
+ up_proj_weight=shared_up_proj_weight,
804
+ down_proj_weight=shared_down_proj_weight,
805
+ up_proj_bias=shared_up_proj_bias,
806
+ down_proj_bias=shared_down_proj_bias,
807
+ activation_fn=shared_activation_fn,
808
+ gradient_scale=gradient_scale,
809
+ )
810
+
811
+ # Combine expert outputs
812
+ combined_out = combine_expert_shared_outputs(
813
+ shared_expert_out=shared_expert_out,
814
+ expert_out=expert_out,
815
+ shared_expert_weighted_sum=shared_expert_weighted_sum,
816
+ moe_top_k=moe_top_k,
817
+ )
818
+
819
+ return combined_out, expert_weights, router_scores
820
+
821
+ # Return regular MoE output if no shared expert
822
+ return expert_out, expert_weights, router_scores
823
+
824
+
825
+ def create_shared_expert_weights(
826
+ hidden_size: int,
827
+ shared_expert_hidden_size: int,
828
+ device: torch.device,
829
+ dtype: torch.dtype,
830
+ init_method: Any,
831
+ output_layer_init_method: Any = None,
832
+ ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
833
+
834
+ if output_layer_init_method is None:
835
+ output_layer_init_method = init_method
836
+
837
+ # Create weight tensors
838
+ up_proj_weight = torch.empty(
839
+ shared_expert_hidden_size,
840
+ hidden_size,
841
+ device=device,
842
+ dtype=dtype,
843
+ )
844
+ down_proj_weight = torch.empty(
845
+ hidden_size,
846
+ shared_expert_hidden_size,
847
+ device=device,
848
+ dtype=dtype,
849
+ )
850
+
851
+ # Initialize weights
852
+ init_method(up_proj_weight)
853
+ output_layer_init_method(down_proj_weight)
854
+
855
+ # No bias by default
856
+ return up_proj_weight, down_proj_weight, None, None
857
+
858
+ # HACK: Extract device_mesh from pre-hook closure - required for transformers integration
859
+ # This exists because device_mesh is trapped in hook closures with no model attribute
860
+ # Fragile - breaks if hook structure changes or Python internals change
861
+ # TODO: Replace with a more robust solution when available
862
+ def get_device_mesh(model):
863
+ # Extract device_mesh from child's unused pre_hook closure
864
+ try:
865
+ # Find the pre-hook that contains 'device_mesh' in its closure
866
+ hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
867
+ # Extract the device_mesh from the closure
868
+ return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
869
+ except Exception:
870
+ return None
871
+
872
+
873
+ class MegaBlocksMoeMLP(torch.nn.Module):
874
+
875
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
876
+ moe_top_k = getattr(self.router, "top_k", 4)
877
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
878
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
879
+ alpha = getattr(self.experts, "alpha", 1.0)
880
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
881
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
882
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
883
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
884
+
885
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
886
+ if expert_parallel_group is None:
887
+ device_mesh = get_device_mesh(self)
888
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
889
+
890
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
891
+ forward_fn = parallel_forward_once if has_parallel else forward_once
892
+
893
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
894
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
895
+
896
+ output, expert_weights_out, *_ = moe_forward(
897
+ x=x,
898
+ router_weight=self.router.weight,
899
+ moe_top_k=moe_top_k,
900
+ moe_num_experts=moe_num_experts,
901
+ moe_jitter_eps=moe_jitter_eps,
902
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
903
+ uniform_expert_assignment=uniform_expert_assignment,
904
+ training=self.training,
905
+ w1=self.experts.gate_up_proj,
906
+ w2=self.experts.down_proj,
907
+ w1_bias=self.experts.gate_up_proj_bias,
908
+ w2_bias=self.experts.down_proj_bias,
909
+ gradient_scale=gradient_scale,
910
+ alpha=alpha,
911
+ sort_end_bit=sort_end_bit,
912
+ expert_parallel_group=expert_parallel_group,
913
+ moe_capacity_factor=moe_capacity_factor,
914
+ moe_expert_model_parallelism=has_parallel,
915
+ forward_fn=forward_fn,
916
+ hidden_size=self.experts.hidden_size,
917
+ mlp_impl=mlp_impl,
918
+ )
919
+ return output, expert_weights_out
920
+
921
+
922
+ class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
923
+
924
+ def __init__(self):
925
+ super().__init__()
926
+ # Shared expert weights will be set by the user
927
+ self.shared_up_proj_weight = None
928
+ self.shared_down_proj_weight = None
929
+ self.shared_up_proj_bias = None
930
+ self.shared_down_proj_bias = None
931
+ self.shared_expert_weighted_sum = False
932
+ self.shared_activation_fn = None
933
+
934
+ def set_shared_expert_weights(
935
+ self,
936
+ up_proj_weight: torch.Tensor,
937
+ down_proj_weight: torch.Tensor,
938
+ up_proj_bias: Optional[torch.Tensor] = None,
939
+ down_proj_bias: Optional[torch.Tensor] = None,
940
+ weighted_sum: bool = False,
941
+ activation_fn: Optional[Any] = None,
942
+ ):
943
+ self.shared_up_proj_weight = up_proj_weight
944
+ self.shared_down_proj_weight = down_proj_weight
945
+ self.shared_up_proj_bias = up_proj_bias
946
+ self.shared_down_proj_bias = down_proj_bias
947
+ self.shared_expert_weighted_sum = weighted_sum
948
+ self.shared_activation_fn = activation_fn
949
+
950
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
951
+ moe_top_k = getattr(self.router, "top_k", 4)
952
+ moe_num_experts = getattr(self.experts, "num_experts", 128)
953
+ gradient_scale = getattr(self.experts, "gradient_scale", None)
954
+ alpha = getattr(self.experts, "alpha", 1.0)
955
+ moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
956
+ moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
957
+ moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
958
+ uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
959
+
960
+ expert_parallel_group = getattr(self, "expert_parallel_group", None)
961
+ if expert_parallel_group is None:
962
+ device_mesh = get_device_mesh(self)
963
+ expert_parallel_group = device_mesh.get_group() if device_mesh else None
964
+
965
+ has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
966
+ forward_fn = parallel_forward_once if has_parallel else forward_once
967
+
968
+ sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
969
+ mlp_impl = getattr(self, "mlp_impl", "grouped")
970
+
971
+ output, expert_weights_out, *_ = moe_forward_with_shared_expert(
972
+ x=x,
973
+ router_weight=self.router.weight,
974
+ moe_top_k=moe_top_k,
975
+ moe_num_experts=moe_num_experts,
976
+ moe_jitter_eps=moe_jitter_eps,
977
+ moe_normalize_expert_weights=moe_normalize_expert_weights,
978
+ uniform_expert_assignment=uniform_expert_assignment,
979
+ training=self.training,
980
+ w1=self.experts.gate_up_proj,
981
+ w2=self.experts.down_proj,
982
+ w1_bias=self.experts.gate_up_proj_bias,
983
+ w2_bias=self.experts.down_proj_bias,
984
+ gradient_scale=gradient_scale,
985
+ alpha=alpha,
986
+ sort_end_bit=sort_end_bit,
987
+ expert_parallel_group=expert_parallel_group,
988
+ moe_capacity_factor=moe_capacity_factor,
989
+ moe_expert_model_parallelism=has_parallel,
990
+ forward_fn=forward_fn,
991
+ hidden_size=self.experts.hidden_size,
992
+ mlp_impl=mlp_impl,
993
+ # Shared expert parameters
994
+ shared_up_proj_weight=self.shared_up_proj_weight,
995
+ shared_down_proj_weight=self.shared_down_proj_weight,
996
+ shared_up_proj_bias=self.shared_up_proj_bias,
997
+ shared_down_proj_bias=self.shared_down_proj_bias,
998
+ shared_expert_weighted_sum=self.shared_expert_weighted_sum,
999
+ shared_activation_fn=self.shared_activation_fn,
1000
+ )
1001
+ return output, expert_weights_out
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from .binned_gather import binned_gather
5
+ from .binned_scatter import binned_scatter
6
+ from .cumsum import exclusive_cumsum, inclusive_cumsum
7
+ from .gather import gather
8
+ from .histogram import histogram
9
+ from .padded_gather import padded_gather
10
+ from .padded_scatter import padded_scatter
11
+ from .repeat import repeat
12
+ from .replicate import replicate
13
+ from .round_up import round_up
14
+ from .scatter import scatter
15
+ from .sort import sort
16
+ from .sum import sum
17
+ from .topology import topology
18
+
19
+ __all__ = [
20
+ 'binned_gather',
21
+ 'binned_scatter',
22
+ 'exclusive_cumsum',
23
+ 'inclusive_cumsum',
24
+ 'gather',
25
+ 'histogram',
26
+ 'padded_gather',
27
+ 'padded_scatter',
28
+ 'repeat',
29
+ 'replicate',
30
+ 'round_up',
31
+ 'scatter',
32
+ 'sort',
33
+ 'sum',
34
+ 'topology',
35
+ ]
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ # from megablocks import benchmark_util
8
+ # from megablocks.layers.all_to_all import all_to_all
9
+
10
+ from .. import benchmark_util
11
+ from .._layers.all_to_all import all_to_all
12
+
13
+ _ALL_TO_ALL_BENCHMARK = (
14
+ (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (256, 1024),
20
+ (512, 1024),
21
+ (1024, 1024),
22
+ (2 * 1024, 1024),
23
+ (4 * 1024, 1024),
24
+ (8 * 1024, 1024),
25
+ (16 * 1024, 1024),
26
+ (32 * 1024, 1024),
27
+ (64 * 1024, 1024),
28
+ (128 * 1024, 1024),
29
+ (256 * 1024, 1024),
30
+ (512 * 1024, 1024),
31
+ (1024 * 1024, 1024),
32
+ )
33
+
34
+
35
+ def benchmark_all_to_all(group, sl, hs):
36
+ world_size = dist.get_world_size(group)
37
+ assert (sl % world_size) == 0
38
+ send_recv_sizes = [sl // world_size] * world_size
39
+
40
+ x = torch.randn((sl, hs)).cuda().half()
41
+
42
+ details = {
43
+ 'world_size': world_size,
44
+ 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
45
+ }
46
+
47
+ def benchmark():
48
+ return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
49
+
50
+ time, std = benchmark_util.benchmark_function(benchmark)
51
+
52
+ if dist.get_rank(group) == 0:
53
+ benchmark_util.log_benchmark('All-To-All', details, time, std)
54
+
55
+
56
+ if __name__ == '__main__':
57
+ assert dist.is_available()
58
+ group = dist.init_process_group(backend='nccl')
59
+ local_rank = dist.get_rank(group)
60
+ torch.cuda.set_device(local_rank)
61
+
62
+ for args in _ALL_TO_ALL_BENCHMARK:
63
+ benchmark_all_to_all(group, *args)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for binned_gather kernel.
12
+ class BinnedGatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bins: torch.Tensor,
21
+ bin_size: int,
22
+ top_k: int,
23
+ ):
24
+ ctx.save_for_backward(indices, bins)
25
+ ctx.top_k = top_k
26
+ return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
27
+
28
+ @staticmethod
29
+ @custom_bwd
30
+ def backward(ctx: Any, grad: torch.Tensor):
31
+ grad = grad.contiguous()
32
+ indices, bins = ctx.saved_tensors
33
+ out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
34
+ return out, None, None, None, None
35
+
36
+
37
+ binned_gather = BinnedGatherOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for binned_scatter kernel.
12
+ class BinnedScatterOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ weights: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ top_k: int,
23
+ ):
24
+ assert len(x.size()) == 3
25
+ ctx.bin_size = x.size(1)
26
+ ctx.top_k = top_k
27
+
28
+ # TODO(tgale): Don't save 'x' for backwards if we don't need to
29
+ # calculate the gradient w.r.t. 'weights'.
30
+ ctx.save_for_backward(x, indices, weights, bins)
31
+ return kernels.binned_scatter(x, indices, weights, bins, top_k)
32
+
33
+ @staticmethod
34
+ @custom_bwd
35
+ def backward(ctx: Any, grad: torch.Tensor):
36
+ grad = grad.contiguous()
37
+ x, indices, weights, bins = ctx.saved_tensors
38
+ out = kernels.binned_gather(
39
+ grad,
40
+ indices,
41
+ weights,
42
+ bins,
43
+ ctx.bin_size,
44
+ ctx.top_k,
45
+ )
46
+
47
+ wgrad = None
48
+ if ctx.needs_input_grad[2]:
49
+ wgrad = kernels.binned_scatter_wgrad(
50
+ x,
51
+ grad,
52
+ indices,
53
+ bins,
54
+ ctx.top_k,
55
+ )
56
+ return out, None, wgrad, None, None
57
+
58
+
59
+ binned_scatter = BinnedScatterOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ # import megablocks_ops as ops # type: ignore
14
+ from .._ops import ops # type: ignore
15
+ except ModuleNotFoundError as e:
16
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
17
+
18
+
19
+ # Autograd wrappers for cumsum kernels.
20
+ # NOTE: Does not support gradients.
21
+ class ExclusiveCumsumOp(torch.autograd.Function):
22
+
23
+ @staticmethod
24
+ def forward(ctx: Any, x: torch.Tensor, dim: int):
25
+ if len(x.size()) == 1:
26
+ x = x.view([1, -1])
27
+ out = torch.empty_like(x)
28
+ ops.exclusive_cumsum(x, 1, out)
29
+ return out.squeeze()
30
+ out = torch.empty_like(x)
31
+ ops.exclusive_cumsum(x, dim, out)
32
+ return out
33
+
34
+
35
+ exclusive_cumsum = ExclusiveCumsumOp.apply
36
+
37
+
38
+ class InclusiveCumsumOp(torch.autograd.Function):
39
+
40
+ @staticmethod
41
+ def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
42
+ if len(x.size()) == 1:
43
+ x = x.view([1, -1])
44
+ out = torch.empty_like(x)
45
+ ops.inclusive_cumsum(x, 1, out)
46
+ return out.squeeze()
47
+ out = torch.empty_like(x)
48
+ ops.inclusive_cumsum(x, dim, out)
49
+ return out
50
+
51
+
52
+ inclusive_cumsum = InclusiveCumsumOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for gather kernel.
12
+ class GatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ top_k: int,
23
+ ):
24
+ ctx.save_for_backward(indices, bin_ids, bins)
25
+ ctx.top_k = top_k
26
+ return kernels.gather(x, indices, bin_ids, None, bins, top_k)
27
+
28
+ @staticmethod
29
+ @custom_bwd
30
+ def backward(ctx: Any, grad: torch.Tensor):
31
+ grad = grad.contiguous()
32
+
33
+ indices, bin_ids, bins = ctx.saved_tensors
34
+ out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
35
+ return out, None, None, None, None, None
36
+
37
+
38
+ gather = GatherOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from .._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for histogram kernel.
19
+ # NOTE: Does not support gradients.
20
+ class HistogramOp(torch.autograd.Function):
21
+
22
+ @staticmethod
23
+ def forward(ctx: Any, x: torch.Tensor, max_val: float):
24
+ return ops.histogram(x, max_val)
25
+
26
+
27
+ histogram = HistogramOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import numpy as np
7
+ import torch
8
+ from absl.testing import parameterized
9
+
10
+ from .. import ops
11
+
12
+ _HISTOGRAM_TESTS = (
13
+ (16384, torch.int32, 2),
14
+ (16384, torch.int32, 4),
15
+ (16384, torch.int32, 8),
16
+ (16384, torch.int32, 16),
17
+ (16384, torch.int32, 32),
18
+ (16384, torch.int32, 64),
19
+ (16384, torch.int32, 128),
20
+ (16384, torch.int32, 256),
21
+ )
22
+
23
+
24
+ def benchmark_function(fn, iterations=10):
25
+ # Run once to get rid of startup overhead.
26
+ fn()
27
+ times = []
28
+ for _ in range(iterations):
29
+ start = torch.cuda.Event(enable_timing=True)
30
+ end = torch.cuda.Event(enable_timing=True)
31
+ start.record()
32
+ fn()
33
+ end.record()
34
+ torch.cuda.synchronize()
35
+ times.append(start.elapsed_time(end))
36
+ times = np.array(times)
37
+ return times.mean(), times.std(), times.max(), times.min()
38
+
39
+
40
+ def log_benchmark(arguments, mean_t, std_t):
41
+ print('=' * 60)
42
+ print('Benchmark Parameters:')
43
+ for (key, value) in arguments.items():
44
+ print(f'{key} = {value}')
45
+ print('Results:')
46
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
47
+ print('=' * 60)
48
+
49
+
50
+ class HistogramBenchmark(parameterized.TestCase):
51
+
52
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
53
+ def testHistogram(self, n, dtype, max_val):
54
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
55
+
56
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
57
+ arguments = {
58
+ 'n': n,
59
+ 'dtype': dtype,
60
+ 'max_val': max_val,
61
+ }
62
+ log_benchmark(arguments, mean_t, std_t)
63
+
64
+ @parameterized.parameters(*_HISTOGRAM_TESTS)
65
+ def testTorchHistogram(self, n, dtype, max_val):
66
+ x = torch.randint(0, 128, (n,)).cuda().to(dtype)
67
+
68
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
69
+ arguments = {
70
+ 'n': n,
71
+ 'dtype': dtype,
72
+ 'max_val': max_val,
73
+ }
74
+ log_benchmark(arguments, mean_t, std_t)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ unittest.main()
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+
7
+ # import stk
8
+
9
+ # try:
10
+ # import stk
11
+ # except ImportError:
12
+ # import warnings
13
+ # warnings.warn(
14
+ # 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
15
+ # )
16
+
17
+ from .. import stk
18
+
19
+ import torch
20
+ from absl.testing import parameterized
21
+
22
+ from .. import benchmark_util, ops
23
+
24
+
25
+ # Calling tensor.t() calls tensor.transpose(0, 1) which calls
26
+ # torch.as_strided(...). Circumvent this chain to avoid an overhead
27
+ # this adds.
28
+ def transpose_view(x):
29
+ return torch.as_strided(
30
+ x,
31
+ (x.shape[1], x.shape[0]),
32
+ (x.stride()[1], x.stride()[0]),
33
+ )
34
+
35
+
36
+ _MATMUL_TESTS = (
37
+ (64 * 1024, 512, 2048, 64),
38
+ (32 * 1024, 768, 3072, 64),
39
+ (8 * 1024, 1024, 4096, 64),
40
+ (4 * 2048, 4096, 4 * 4096, 4),
41
+ )
42
+
43
+
44
+ def log_benchmark(name, arguments, time, std, flops):
45
+ benchmark_util.log_benchmark(name, arguments, time, std)
46
+ print('flops = {:.2f}B'.format(flops / 1e9))
47
+ print('throughput = {:.2f}T'.format(flops / 1e9 / time))
48
+ print('=' * 60)
49
+
50
+
51
+ class MatmulBenchmark(parameterized.TestCase):
52
+
53
+ def build_sparse_matrix(self, x, padded_bins, fhs, ne):
54
+ blocking = 128
55
+ padded_tokens, _ = x.size()
56
+ assert padded_tokens % blocking == 0
57
+ assert fhs % blocking == 0
58
+
59
+ # Offsets for the sparse matrix. All rows have the
60
+ # same number of nonzero blocks dictated by the
61
+ # dimensionality of a single expert.
62
+ block_rows = padded_tokens // blocking
63
+ blocks_per_row = fhs // blocking
64
+ offsets = torch.arange(
65
+ 0,
66
+ block_rows * blocks_per_row + 1,
67
+ blocks_per_row,
68
+ dtype=torch.int32,
69
+ device=x.device,
70
+ )
71
+
72
+ # Indices for the sparse matrix. The indices for
73
+ # the intermediate matrix are dynamic depending
74
+ # on the mapping of tokens to experts.
75
+ column_indices = ops.topology(
76
+ padded_bins,
77
+ blocking,
78
+ block_rows,
79
+ blocks_per_row,
80
+ )
81
+ data = torch.empty(
82
+ column_indices.numel(),
83
+ blocking,
84
+ blocking,
85
+ dtype=torch.float16,
86
+ device=x.device,
87
+ )
88
+ shape = (padded_tokens, fhs * ne)
89
+ row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
90
+ return stk.Matrix(shape, data, row_indices, column_indices, offsets)
91
+
92
+ def build_input_matrix(self, sl, hs, ne):
93
+ x = torch.randn((sl, hs)).cuda().half()
94
+
95
+ # Assign tokens to experts uniformly.
96
+ top_expert = torch.arange(0, sl).cuda().int() % ne
97
+
98
+ bin_ids, indices = ops.sort(top_expert)
99
+ tokens_per_expert = ops.histogram(top_expert, ne)
100
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
101
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
102
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
103
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
104
+ return out, padded_bins
105
+
106
+ def build_weight_matrix(self, ne, hs, fhs):
107
+ return torch.randn((hs, ne * fhs)).cuda().half()
108
+
109
+ @parameterized.parameters(*_MATMUL_TESTS)
110
+ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
111
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
112
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
113
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
114
+ w = transpose_view(w)
115
+
116
+ def benchmark():
117
+ return stk.ops.sdd(x, w, topo)
118
+
119
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ arguments = {
121
+ 'sequence_length': sl,
122
+ 'hidden_size': hs,
123
+ 'ffn_hidden_size': fhs,
124
+ 'num_experts': ne,
125
+ }
126
+ log_benchmark(
127
+ '0::Fwd::SDD::NT',
128
+ arguments,
129
+ mean_t,
130
+ std_t,
131
+ x.numel() * fhs * 2,
132
+ )
133
+
134
+ @parameterized.parameters(*_MATMUL_TESTS)
135
+ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
136
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
137
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
138
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
139
+
140
+ def benchmark():
141
+ return stk.ops.dsd(topo, w)
142
+
143
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
144
+ arguments = {
145
+ 'sequence_length': sl,
146
+ 'hidden_size': hs,
147
+ 'ffn_hidden_size': fhs,
148
+ 'num_experts': ne,
149
+ }
150
+ log_benchmark(
151
+ '0::GradX::DSD::NN',
152
+ arguments,
153
+ mean_t,
154
+ std_t,
155
+ x.numel() * fhs * 2,
156
+ )
157
+
158
+ @parameterized.parameters(*_MATMUL_TESTS)
159
+ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
160
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
161
+ topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
162
+ topo = topo.t()
163
+
164
+ def benchmark():
165
+ return stk.ops.dsd(topo, x)
166
+
167
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
168
+ arguments = {
169
+ 'sequence_length': sl,
170
+ 'hidden_size': hs,
171
+ 'ffn_hidden_size': fhs,
172
+ 'num_experts': ne,
173
+ }
174
+ log_benchmark(
175
+ '0::GradW::DSD::TN',
176
+ arguments,
177
+ mean_t,
178
+ std_t,
179
+ x.numel() * fhs * 2,
180
+ )
181
+
182
+ @parameterized.parameters(*_MATMUL_TESTS)
183
+ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
184
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
185
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
186
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
187
+
188
+ def benchmark():
189
+ return stk.ops.dsd(x, w)
190
+
191
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
192
+ arguments = {
193
+ 'sequence_length': sl,
194
+ 'hidden_size': hs,
195
+ 'ffn_hidden_size': fhs,
196
+ 'num_experts': ne,
197
+ }
198
+ log_benchmark(
199
+ '1::Fwd::DSD::NN',
200
+ arguments,
201
+ mean_t,
202
+ std_t,
203
+ x.nnz * hs * 2,
204
+ )
205
+
206
+ @parameterized.parameters(*_MATMUL_TESTS)
207
+ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
208
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
209
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
210
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
211
+ out = stk.ops.dsd(x, w)
212
+ w = transpose_view(w)
213
+
214
+ def benchmark():
215
+ return stk.ops.sdd(out, w, x)
216
+
217
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
218
+ arguments = {
219
+ 'sequence_length': sl,
220
+ 'hidden_size': hs,
221
+ 'ffn_hidden_size': fhs,
222
+ 'num_experts': ne,
223
+ }
224
+ log_benchmark(
225
+ '1::GradX::SDD::NT',
226
+ arguments,
227
+ mean_t,
228
+ std_t,
229
+ x.nnz * hs * 2,
230
+ )
231
+
232
+ @parameterized.parameters(*_MATMUL_TESTS)
233
+ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
234
+ x, padded_bins = self.build_input_matrix(sl, hs, ne)
235
+ w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
236
+ x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
237
+ out = stk.ops.dsd(x, w)
238
+ x = x.t()
239
+
240
+ def benchmark():
241
+ return stk.ops.dsd(x, out)
242
+
243
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
244
+ arguments = {
245
+ 'sequence_length': sl,
246
+ 'hidden_size': hs,
247
+ 'ffn_hidden_size': fhs,
248
+ 'num_experts': ne,
249
+ }
250
+ log_benchmark(
251
+ '1::GradW::DSD::TN',
252
+ arguments,
253
+ mean_t,
254
+ std_t,
255
+ x.nnz * hs * 2,
256
+ )
257
+
258
+ @parameterized.parameters(*_MATMUL_TESTS)
259
+ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
260
+ assert (sl % ne) == 0
261
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
262
+ w = torch.randn((ne, hs, fhs)).cuda().half()
263
+
264
+ w = w.transpose(1, 2).contiguous()
265
+ w = w.transpose(1, 2)
266
+
267
+ def benchmark():
268
+ return torch.bmm(x, w)
269
+
270
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
271
+ arguments = {
272
+ 'sequence_length': sl,
273
+ 'hidden_size': hs,
274
+ 'ffn_hidden_size': fhs,
275
+ 'num_experts': ne,
276
+ }
277
+ log_benchmark(
278
+ '0::Fwd:DDD::NT',
279
+ arguments,
280
+ mean_t,
281
+ std_t,
282
+ x.numel() * fhs * 2,
283
+ )
284
+
285
+ @parameterized.parameters(*_MATMUL_TESTS)
286
+ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
287
+ assert (sl % ne) == 0
288
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
289
+ w = torch.randn((ne, hs, fhs)).cuda().half()
290
+ out = torch.bmm(x, w)
291
+ w = w.transpose(1, 2).contiguous()
292
+
293
+ def benchmark():
294
+ return torch.bmm(out, w)
295
+
296
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
297
+ arguments = {
298
+ 'sequence_length': sl,
299
+ 'hidden_size': hs,
300
+ 'ffn_hidden_size': fhs,
301
+ 'num_experts': ne,
302
+ }
303
+ log_benchmark(
304
+ '0:GradX:DDD::NN',
305
+ arguments,
306
+ mean_t,
307
+ std_t,
308
+ x.numel() * fhs * 2,
309
+ )
310
+
311
+ @parameterized.parameters(*_MATMUL_TESTS)
312
+ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
313
+ assert (sl % ne) == 0
314
+ x = torch.randn((ne, sl // ne, hs)).cuda().half()
315
+ w = torch.randn((ne, hs, fhs)).cuda().half()
316
+ out = torch.bmm(x, w)
317
+ out = out.transpose(1, 2)
318
+
319
+ def benchmark():
320
+ return torch.bmm(out, x)
321
+
322
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
323
+ arguments = {
324
+ 'sequence_length': sl,
325
+ 'hidden_size': hs,
326
+ 'ffn_hidden_size': fhs,
327
+ 'num_experts': ne,
328
+ }
329
+ log_benchmark(
330
+ '0:GradW:DDD::TN',
331
+ arguments,
332
+ mean_t,
333
+ std_t,
334
+ x.numel() * fhs * 2,
335
+ )
336
+
337
+ @parameterized.parameters(*_MATMUL_TESTS)
338
+ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
339
+ assert (sl % ne) == 0
340
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
341
+ w = torch.randn((ne, fhs, hs)).cuda().half()
342
+
343
+ def benchmark():
344
+ return torch.bmm(x, w)
345
+
346
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
347
+ arguments = {
348
+ 'sequence_length': sl,
349
+ 'hidden_size': hs,
350
+ 'ffn_hidden_size': fhs,
351
+ 'num_experts': ne,
352
+ }
353
+ log_benchmark(
354
+ '1::Fwd::DDD::NN',
355
+ arguments,
356
+ mean_t,
357
+ std_t,
358
+ x.numel() * hs * 2,
359
+ )
360
+
361
+ @parameterized.parameters(*_MATMUL_TESTS)
362
+ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
363
+ assert (sl % ne) == 0
364
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
365
+ w = torch.randn((ne, fhs, hs)).cuda().half()
366
+ out = torch.bmm(x, w)
367
+ w = torch.transpose(w, 1, 2)
368
+
369
+ def benchmark():
370
+ return torch.bmm(out, w)
371
+
372
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
373
+ arguments = {
374
+ 'sequence_length': sl,
375
+ 'hidden_size': hs,
376
+ 'ffn_hidden_size': fhs,
377
+ 'num_experts': ne,
378
+ }
379
+ log_benchmark(
380
+ '1::GradX::DDD::NT',
381
+ arguments,
382
+ mean_t,
383
+ std_t,
384
+ x.numel() * hs * 2,
385
+ )
386
+
387
+ @parameterized.parameters(*_MATMUL_TESTS)
388
+ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
389
+ assert (sl % ne) == 0
390
+ x = torch.randn((ne, sl // ne, fhs)).cuda().half()
391
+ w = torch.randn((ne, fhs, hs)).cuda().half()
392
+ out = torch.bmm(x, w)
393
+ x = torch.transpose(x, 1, 2)
394
+
395
+ def benchmark():
396
+ return torch.bmm(x, out)
397
+
398
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
399
+ arguments = {
400
+ 'sequence_length': sl,
401
+ 'hidden_size': hs,
402
+ 'ffn_hidden_size': fhs,
403
+ 'num_experts': ne,
404
+ }
405
+ log_benchmark(
406
+ '1::GradW::DDD::TN',
407
+ arguments,
408
+ mean_t,
409
+ std_t,
410
+ x.numel() * hs * 2,
411
+ )
412
+
413
+
414
+ if __name__ == '__main__':
415
+ unittest.main()
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for padded_gather kernel.
12
+ class PaddedGatherOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ bins: torch.Tensor,
22
+ padded_bins: torch.Tensor,
23
+ top_k: int,
24
+ ):
25
+ ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
26
+ ctx.top_k = top_k
27
+ return kernels.padded_gather(
28
+ x,
29
+ indices,
30
+ bin_ids,
31
+ None,
32
+ bins,
33
+ padded_bins,
34
+ top_k,
35
+ )
36
+
37
+ @staticmethod
38
+ @custom_bwd
39
+ def backward(ctx: Any, grad: torch.Tensor):
40
+ grad = grad.contiguous()
41
+
42
+ indices, bin_ids, bins, padded_bins = ctx.saved_tensors
43
+ out = kernels.padded_scatter(
44
+ grad,
45
+ indices,
46
+ bin_ids,
47
+ None,
48
+ bins,
49
+ padded_bins,
50
+ ctx.top_k,
51
+ )
52
+ return out, None, None, None, None, None
53
+
54
+
55
+ padded_gather = PaddedGatherOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from typing import Any
4
+
5
+ import torch
6
+ from .stk_autocast import custom_bwd, custom_fwd
7
+
8
+ from ..backend import kernels
9
+
10
+
11
+ # Autograd wrapper for padded_scatter kernel.
12
+ class PaddedScatterOp(torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ @custom_fwd
16
+ def forward(
17
+ ctx: Any,
18
+ x: torch.Tensor,
19
+ indices: torch.Tensor,
20
+ bin_ids: torch.Tensor,
21
+ weights: torch.Tensor,
22
+ bins: torch.Tensor,
23
+ padded_bins: torch.Tensor,
24
+ top_k: int,
25
+ ):
26
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
27
+ ctx.save_for_backward(
28
+ indices,
29
+ bin_ids,
30
+ weights,
31
+ bins,
32
+ padded_bins,
33
+ *maybe_x,
34
+ )
35
+ ctx.top_k = top_k
36
+ ctx.x_shape = x.shape
37
+ return kernels.padded_scatter(
38
+ x,
39
+ indices,
40
+ bin_ids,
41
+ weights,
42
+ bins,
43
+ padded_bins,
44
+ top_k,
45
+ )
46
+
47
+ @staticmethod
48
+ @custom_bwd
49
+ def backward(ctx: Any, grad: torch.Tensor):
50
+ grad = grad.contiguous()
51
+ saved_tensors = ctx.saved_tensors
52
+
53
+ indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
54
+ dgrad = None
55
+ if ctx.needs_input_grad[0]:
56
+ dgrad = kernels.padded_gather(
57
+ grad,
58
+ indices,
59
+ bin_ids,
60
+ weights,
61
+ bins,
62
+ padded_bins,
63
+ ctx.top_k,
64
+ )
65
+
66
+ wgrad = None
67
+ if ctx.needs_input_grad[3]: # need wgrad
68
+ x = saved_tensors[-1]
69
+ wgrad = kernels.padded_scatter_wgrad(
70
+ x,
71
+ grad,
72
+ indices,
73
+ bin_ids,
74
+ bins,
75
+ padded_bins,
76
+ ctx.top_k,
77
+ )
78
+ return dgrad, None, None, wgrad, None, None, None, None
79
+
80
+
81
+ def padded_scatter(
82
+ x: torch.Tensor,
83
+ indices: torch.Tensor,
84
+ bin_ids: torch.Tensor,
85
+ weights: torch.Tensor,
86
+ bins: torch.Tensor,
87
+ padded_bins: torch.Tensor,
88
+ top_k: int,
89
+ ):
90
+ return PaddedScatterOp.apply(
91
+ x,
92
+ indices,
93
+ bin_ids,
94
+ weights,
95
+ bins,
96
+ padded_bins,
97
+ top_k,
98
+ )
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import torch
7
+ from absl.testing import parameterized
8
+
9
+ from .. import benchmark_util, ops
10
+
11
+ _PADDED_SCATTER_BENCHMARK = (
12
+ # dMoE-Medium, 8-way EMP.
13
+ (1024 * 16, 1024, 8, 4),
14
+ # dMoE-Medium, post-all-to-all.
15
+ (1024 * 16 * 4, 1024, 8, 1),
16
+ )
17
+
18
+
19
+ class PaddedScatterTest(parameterized.TestCase):
20
+
21
+ @parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
22
+ def testPaddedScatter(self, sl, hs, ne, top_k):
23
+ # Create the data and indices.
24
+ x = torch.randn((sl, hs)).cuda().half()
25
+
26
+ # Randomly assign tokens to experts.
27
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
28
+ bin_ids, indices = ops.sort(top_expert)
29
+ tokens_per_expert = ops.histogram(top_expert, ne)
30
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
31
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
32
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
33
+
34
+ # Sample weights for the scatter reduce.
35
+ weights = torch.rand((sl * top_k,)).cuda().half()
36
+
37
+ # Gather the data to prepare for backwards.
38
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
39
+
40
+ def benchmark():
41
+ return ops.padded_scatter(
42
+ x,
43
+ indices,
44
+ bin_ids,
45
+ weights,
46
+ bins,
47
+ padded_bins,
48
+ top_k,
49
+ )
50
+
51
+ time, std = benchmark_util.benchmark_function(benchmark)
52
+ benchmark_util.log_benchmark(
53
+ 'Padded Scatter',
54
+ {
55
+ 'sequence_length': sl,
56
+ 'hidden_size': hs,
57
+ 'num_experts': ne,
58
+ 'top_k': top_k,
59
+ },
60
+ time,
61
+ std,
62
+ )
63
+
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import torch
7
+ from absl.testing import parameterized
8
+
9
+ from .. import benchmark_util, ops
10
+
11
+ _PERMUTE_TESTS = (
12
+ (16384, 768, 2),
13
+ (16384, 768, 4),
14
+ (16384, 768, 8),
15
+ (16384, 768, 16),
16
+ (16384, 768, 32),
17
+ (16384, 768, 64),
18
+ (16384, 768, 128),
19
+ (16384 * 8, 768, 2),
20
+ (16384 * 8, 768, 4),
21
+ (16384 * 8, 768, 8),
22
+ (16384 * 8, 768, 16),
23
+ (16384 * 8, 768, 32),
24
+ (16384 * 8, 768, 64),
25
+ (16384 * 8, 768, 128),
26
+ )
27
+
28
+
29
+ class PermuteBenchmark(parameterized.TestCase):
30
+
31
+ @parameterized.parameters(*_PERMUTE_TESTS)
32
+ def testBinnedGather(self, sl, hs, ne):
33
+ # NOTE: Capacity factor == 1.
34
+ ec = sl // ne
35
+
36
+ # Create the data and indices.
37
+ x = torch.randn((sl, hs)).cuda().half()
38
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
39
+ bin_ids, indices = ops.sort(top_expert)
40
+ tokens_per_expert = ops.histogram(indices, ne)
41
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
42
+
43
+ def benchmark():
44
+ return ops.binned_gather(x, indices, bins, ec)
45
+
46
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
47
+ arguments = {
48
+ 'sequence_length': sl,
49
+ 'hidden_size': hs,
50
+ 'num_experts': ne,
51
+ }
52
+ benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
53
+
54
+ @parameterized.parameters(*_PERMUTE_TESTS)
55
+ def testBinnedScatter(self, sl, hs, ne):
56
+ # NOTE: Capacity factor == 1.
57
+ ec = sl // ne
58
+
59
+ # Create the data and indices.
60
+ x = torch.randn((sl, hs)).cuda().half()
61
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
62
+ bin_ids, indices = ops.sort(top_expert)
63
+ tokens_per_expert = ops.histogram(indices, ne)
64
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+ x = ops.binned_gather(x, indices, bins, ec)
66
+
67
+ def benchmark():
68
+ return ops.binned_scatter(x, indices, bins)
69
+
70
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
71
+ arguments = {
72
+ 'sequence_length': sl,
73
+ 'hidden_size': hs,
74
+ 'num_experts': ne,
75
+ }
76
+ benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
77
+
78
+ @parameterized.parameters(*_PERMUTE_TESTS)
79
+ def testPaddedGather(self, sl, hs, ne):
80
+ # Create the data and indices.
81
+ x = torch.randn((sl, hs)).cuda().half()
82
+
83
+ # Randomly assign tokens to experts.
84
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
85
+ bin_ids, indices = ops.sort(top_expert)
86
+ tokens_per_expert = ops.histogram(top_expert, ne)
87
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+
91
+ def benchmark():
92
+ return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
93
+
94
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
95
+ arguments = {
96
+ 'sequence_length': sl,
97
+ 'hidden_size': hs,
98
+ 'num_experts': ne,
99
+ }
100
+ benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
101
+
102
+ @parameterized.parameters(*_PERMUTE_TESTS)
103
+ def testPaddedScatter(self, sl, hs, ne):
104
+ # Create the data and indices.
105
+ x = torch.randn((sl, hs)).cuda().half()
106
+
107
+ # Randomly assign tokens to experts.
108
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
109
+ bin_ids, indices = ops.sort(top_expert)
110
+ tokens_per_expert = ops.histogram(top_expert, ne)
111
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
112
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
113
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
114
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
115
+
116
+ def benchmark():
117
+ return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
118
+
119
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
120
+ arguments = {
121
+ 'sequence_length': sl,
122
+ 'hidden_size': hs,
123
+ 'num_experts': ne,
124
+ }
125
+ benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
126
+
127
+ @parameterized.parameters(*_PERMUTE_TESTS)
128
+ def testCopy(self, sl, hs, ne):
129
+ # NOTE: Capacity factor == 1.
130
+ # ec = sl // ne
131
+
132
+ # Create the data and indices.
133
+ x = torch.randn((sl, hs)).cuda().half()
134
+ y = x.clone()
135
+
136
+ def benchmark():
137
+ return y.copy_(x)
138
+
139
+ mean_t, std_t = benchmark_util.benchmark_function(benchmark)
140
+ arguments = {
141
+ 'sequence_length': sl,
142
+ 'hidden_size': hs,
143
+ 'num_experts': ne,
144
+ }
145
+ benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
146
+
147
+
148
+ if __name__ == '__main__':
149
+ unittest.main()
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+
7
+ def repeat(x: torch.Tensor, tiling: torch.Size):
8
+ if all((t == 1 for t in tiling)):
9
+ return x
10
+ return x.repeat(*tiling)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from .._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+
18
+ # Autograd wrapper for replicate kernel.
19
+ class ReplicateOp(torch.autograd.Function):
20
+
21
+ @staticmethod
22
+ def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
23
+ ctx.save_for_backward(bins)
24
+ out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
25
+ ops.replicate_forward(x, bins, out)
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx: Any, grad: torch.Tensor):
30
+ bins, = ctx.saved_tensors
31
+ out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
32
+ ops.replicate_backward(grad, bins, out)
33
+ return out, None, None
34
+
35
+
36
+ replicate = ReplicateOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+
7
+ def round_up(x: torch.Tensor, value: int):
8
+ assert isinstance(value, int)
9
+ assert x.dtype == torch.int32
10
+
11
+ # TODO(tgale): If this becomes and issue
12
+ # do this in a custom kernel. We only expect
13
+ # to use this on arrays of less than 1k elements.
14
+ return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Optional
5
+
6
+ import torch
7
+ from .stk_autocast import custom_bwd, custom_fwd
8
+
9
+ from ..backend import kernels
10
+
11
+
12
+ # Autograd wrapper for scatter kernel.
13
+ class ScatterOp(torch.autograd.Function):
14
+
15
+ @staticmethod
16
+ @custom_fwd
17
+ def forward(
18
+ ctx: Any,
19
+ x: torch.Tensor,
20
+ indices: torch.Tensor,
21
+ bin_ids: torch.Tensor,
22
+ weights: torch.Tensor,
23
+ bins: torch.Tensor,
24
+ top_k: int,
25
+ ) -> torch.Tensor:
26
+ maybe_x = [x] if ctx.needs_input_grad[3] else []
27
+ ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
28
+ ctx.top_k = top_k
29
+ ctx.x_shape = x.shape
30
+ return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
31
+
32
+ @staticmethod
33
+ @custom_bwd
34
+ def backward(ctx: Any, grad: torch.Tensor):
35
+ grad = grad.contiguous()
36
+ saved_tensors = ctx.saved_tensors
37
+
38
+ indices, bin_ids, weights, bins = saved_tensors[:4]
39
+ dgrad = None
40
+ if ctx.needs_input_grad[0]:
41
+ dgrad = kernels.gather(
42
+ grad,
43
+ indices,
44
+ bin_ids,
45
+ weights,
46
+ bins,
47
+ ctx.top_k,
48
+ )
49
+
50
+ wgrad = None
51
+ if ctx.needs_input_grad[3]: # need wgrad
52
+ x = saved_tensors[-1]
53
+ wgrad = kernels.scatter_wgrad(
54
+ x,
55
+ grad,
56
+ indices,
57
+ bin_ids,
58
+ bins,
59
+ ctx.top_k,
60
+ )
61
+ return dgrad, None, None, wgrad, None, None, None
62
+
63
+
64
+ def scatter(
65
+ x: torch.Tensor,
66
+ indices: torch.Tensor,
67
+ bin_ids: torch.Tensor,
68
+ weights: torch.Tensor,
69
+ bins: torch.Tensor,
70
+ top_k: int,
71
+ ) -> Optional[torch.Tensor]:
72
+ return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Optional, Tuple
5
+
6
+ # NOTE: Torch needs to be imported before the custom
7
+ # extensions. Otherwise libc10.so cannot be found.
8
+ import torch
9
+
10
+ # Wrap this in a try-block with better error message and
11
+ # instructions for building the c++ operations.
12
+ try:
13
+ from .._ops import ops # type: ignore
14
+ except ModuleNotFoundError as e:
15
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
16
+
17
+ _BITS_FOR_DTYPE = {
18
+ torch.int16: 16,
19
+ torch.int32: 32,
20
+ torch.int64: 64,
21
+ }
22
+
23
+
24
+ # Autograd wrapper for sort kernel.
25
+ # NOTE: Does not support gradients.
26
+ class SortOp(torch.autograd.Function):
27
+
28
+ @staticmethod
29
+ def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
30
+ if end_bit is None:
31
+ end_bit = _BITS_FOR_DTYPE[x.dtype]
32
+ x_out = torch.empty_like(x)
33
+ iota_out = torch.empty_like(x)
34
+ ops.sort(x, end_bit, x_out, iota_out)
35
+ return (x_out, iota_out)
36
+
37
+
38
+ sort = SortOp.apply
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import unittest
5
+
6
+ import numpy as np
7
+ import torch
8
+ from absl.testing import parameterized
9
+
10
+ from .. import ops
11
+
12
+ _SORT_TESTS = (
13
+ (16384, torch.int32, None),
14
+ (16384, torch.int32, 2),
15
+ (16384, torch.int32, 128),
16
+ )
17
+
18
+ _BASELINE_SORT_TESTS = ((16384,),)
19
+
20
+
21
+ def numpy_dtype(dtype):
22
+ types = {
23
+ torch.int16: np.int16,
24
+ torch.int32: np.int32,
25
+ torch.int64: np.int64,
26
+ }
27
+ return types[dtype]
28
+
29
+
30
+ def benchmark_function(fn, iterations=10):
31
+ # Run once to get rid of startup overhead.
32
+ fn()
33
+ times = []
34
+ for _ in range(iterations):
35
+ start = torch.cuda.Event(enable_timing=True)
36
+ end = torch.cuda.Event(enable_timing=True)
37
+ start.record()
38
+ fn()
39
+ end.record()
40
+ torch.cuda.synchronize()
41
+ times.append(start.elapsed_time(end))
42
+ times = np.array(times)
43
+ return times.mean(), times.std(), times.max(), times.min()
44
+
45
+
46
+ def log_benchmark(arguments, mean_t, std_t):
47
+ print('=' * 60)
48
+ print('Benchmark Parameters:')
49
+ for (key, value) in arguments.items():
50
+ print(f'{key} = {value}')
51
+ print('Results:')
52
+ print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
53
+ print('=' * 60)
54
+
55
+
56
+ class SortBenchmark(parameterized.TestCase):
57
+
58
+ @parameterized.parameters(*_SORT_TESTS)
59
+ def testSort(self, n, dtype, max_val):
60
+ if max_val is None:
61
+ max_val = np.iinfo(numpy_dtype(dtype)).max
62
+ end_bit = int(np.ceil(np.log2(max_val)))
63
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
64
+
65
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
66
+ arguments = {
67
+ 'n': n,
68
+ 'dtype': dtype,
69
+ 'max_val': max_val,
70
+ }
71
+ log_benchmark(arguments, mean_t, std_t)
72
+
73
+ @parameterized.parameters(*_BASELINE_SORT_TESTS)
74
+ def testTorchSort(self, n):
75
+ x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
76
+
77
+ mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
78
+ arguments = {
79
+ 'n': n,
80
+ }
81
+ log_benchmark(arguments, mean_t, std_t)
82
+
83
+
84
+ if __name__ == '__main__':
85
+ unittest.main()
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # vendored from
2
+ # https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
3
+ import functools
4
+ import torch
5
+
6
+
7
+ def _is_eligible(x):
8
+ return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
9
+
10
+
11
+ def _cast(x, dtype):
12
+ if isinstance(x, torch.Tensor) and _is_eligible(x):
13
+ return x.to(dtype)
14
+ elif isinstance(x, map):
15
+ return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
16
+ elif isinstance(x, list) or isinstance(x, tuple):
17
+ return type(x)(map(lambda y: _cast(y, dtype), x))
18
+ return x
19
+
20
+
21
+ def custom_fwd(fwd):
22
+ """Wrap a custom autograd function that always uses autocast dtype."""
23
+
24
+ @functools.wraps(fwd)
25
+ def decorate_fwd(*args, **kwargs):
26
+ if torch.is_autocast_enabled():
27
+ with torch.autocast(device_type="cuda", enabled=False):
28
+ dtype = torch.get_autocast_gpu_dtype()
29
+ return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
30
+ return fwd(*args, **kwargs)
31
+ return decorate_fwd
32
+
33
+
34
+ def custom_bwd(bwd):
35
+ @functools.wraps(bwd)
36
+ def decorate_bwd(*args, **kwargs):
37
+ with torch.autocast(device_type="cuda", enabled=False):
38
+ return bwd(*args, **kwargs)
39
+ return decorate_bwd
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ import torch
4
+
5
+
6
+ def sum(x: torch.Tensor, dim: int = 0):
7
+ if x.shape[dim] == 1:
8
+ return x.squeeze(dim=dim)
9
+ return x.sum(dim=dim)