Akos Hadnagy
commited on
Commit
·
dd2b6c2
1
Parent(s):
1e1ffe8
Push build
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +2 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py +202 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py +10 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py +33 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py +54 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py +101 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py +26 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py +42 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py +337 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py +52 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py +244 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py +103 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py +587 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py +507 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py +94 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py +116 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py +32 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so +3 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py +9 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py +2 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py +543 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py +23 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py +35 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py +2 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py +33 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py +33 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py +31 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py +1001 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py +35 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py +63 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py +37 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py +59 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py +52 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py +38 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py +27 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py +78 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py +415 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py +55 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py +98 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py +66 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py +149 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py +10 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py +36 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py +14 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py +72 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py +38 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py +85 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py +39 -0
- build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py +9 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.so filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
**/__pycache__/
|
2 |
+
**/*egg-info/
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/__init__.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ._ops import ops
|
7 |
+
|
8 |
+
#from .grouped_gemm import backend as gg_backend
|
9 |
+
#from .grouped_gemm import ops as gg_ops
|
10 |
+
|
11 |
+
|
12 |
+
from ._layers.arguments import Arguments
|
13 |
+
from ._layers.dmoe import ParallelDroplessMLP, dMoE
|
14 |
+
from ._layers.glu import SparseGLU
|
15 |
+
from ._layers.mlp import MLP, SparseMLP
|
16 |
+
from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
17 |
+
|
18 |
+
from . import layers
|
19 |
+
|
20 |
+
# This section contains the direct kernel exports (not inlcuded in the original code)
|
21 |
+
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Compute exclusive cumulative sum along the specified dimension.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
x: Input tensor
|
27 |
+
dim: Dimension along which to compute cumsum
|
28 |
+
out: Output tensor (modified in-place)
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
The output tensor
|
32 |
+
"""
|
33 |
+
result = ops.exclusive_cumsum(x, dim)
|
34 |
+
out.copy_(result)
|
35 |
+
return out
|
36 |
+
|
37 |
+
|
38 |
+
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
Compute inclusive cumulative sum along the specified dimension.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
x: Input tensor
|
44 |
+
dim: Dimension along which to compute cumsum
|
45 |
+
out: Output tensor (modified in-place)
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
The output tensor
|
49 |
+
"""
|
50 |
+
result = ops.inclusive_cumsum(x, dim)
|
51 |
+
out.copy_(result)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
56 |
+
"""
|
57 |
+
Compute histogram of input tensor values.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
x: Input tensor
|
61 |
+
num_bins: Number of histogram bins
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Histogram tensor with counts for each bin
|
65 |
+
"""
|
66 |
+
return ops.histogram(x, num_bins)
|
67 |
+
|
68 |
+
|
69 |
+
def indices(
|
70 |
+
padded_bins: torch.Tensor,
|
71 |
+
block_size: int,
|
72 |
+
output_block_rows: int,
|
73 |
+
output_block_columns: int,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Construct indices from padded bins for sparse operations.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
padded_bins: Tensor containing bin boundaries
|
80 |
+
block_size: Size of each block
|
81 |
+
output_block_rows: Number of rows in output blocks
|
82 |
+
output_block_columns: Number of columns in output blocks
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Tensor containing constructed indices
|
86 |
+
"""
|
87 |
+
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
|
88 |
+
|
89 |
+
|
90 |
+
def replicate_forward(
|
91 |
+
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
92 |
+
) -> torch.Tensor:
|
93 |
+
"""
|
94 |
+
Forward pass of replicate operation - replicate values according to bin sizes.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
x: Input tensor with values to replicate
|
98 |
+
bins: Tensor containing bin sizes
|
99 |
+
out: Output tensor (modified in-place)
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
The output tensor
|
103 |
+
"""
|
104 |
+
return ops.replicate_forward(x, bins, out)
|
105 |
+
|
106 |
+
|
107 |
+
def replicate_backward(
|
108 |
+
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
109 |
+
) -> torch.Tensor:
|
110 |
+
"""
|
111 |
+
Backward pass of replicate operation - reduce gradients back to bins.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
grad: Gradient tensor to reduce
|
115 |
+
bins: Tensor containing bin sizes
|
116 |
+
out: Output tensor (modified in-place)
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
The output tensor
|
120 |
+
"""
|
121 |
+
return ops.replicate_backward(grad, bins, out)
|
122 |
+
|
123 |
+
|
124 |
+
def sort(
|
125 |
+
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
|
126 |
+
) -> torch.Tensor:
|
127 |
+
"""
|
128 |
+
Radix sort with index tracking.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
x: Input tensor to sort
|
132 |
+
end_bit: Number of bits to consider in sorting
|
133 |
+
x_out: Output tensor for sorted values
|
134 |
+
iota_out: Output tensor for sorted indices
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
The sorted values tensor
|
138 |
+
"""
|
139 |
+
return ops.sort(x, end_bit, x_out, iota_out)
|
140 |
+
|
141 |
+
|
142 |
+
# Convenience functions for common use cases
|
143 |
+
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
|
144 |
+
"""
|
145 |
+
Compute cumulative sum with automatic output allocation.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
x: Input tensor
|
149 |
+
dim: Dimension along which to compute cumsum (default: last dimension)
|
150 |
+
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
New tensor containing the cumulative sum
|
154 |
+
"""
|
155 |
+
out = torch.empty_like(x)
|
156 |
+
if exclusive:
|
157 |
+
return exclusive_cumsum(x, dim, out)
|
158 |
+
else:
|
159 |
+
return inclusive_cumsum(x, dim, out)
|
160 |
+
|
161 |
+
|
162 |
+
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
|
163 |
+
"""
|
164 |
+
Sort tensor and return both sorted values and indices.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
x: Input tensor to sort
|
168 |
+
end_bit: Number of bits to consider in sorting
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Tuple of (sorted_values, sorted_indices)
|
172 |
+
"""
|
173 |
+
x_out = torch.empty_like(x)
|
174 |
+
iota_out = torch.empty_like(x)
|
175 |
+
sort(x, end_bit, x_out, iota_out)
|
176 |
+
return x_out, iota_out
|
177 |
+
|
178 |
+
|
179 |
+
# Export public API
|
180 |
+
__all__ = [
|
181 |
+
"MyReplacementLayer",
|
182 |
+
# Direct kernel exports
|
183 |
+
"exclusive_cumsum",
|
184 |
+
"inclusive_cumsum",
|
185 |
+
"histogram",
|
186 |
+
"indices",
|
187 |
+
"replicate_forward",
|
188 |
+
"replicate_backward",
|
189 |
+
"sort",
|
190 |
+
"cumsum",
|
191 |
+
"argsort",
|
192 |
+
# Original exports
|
193 |
+
"Arguments",
|
194 |
+
"ParallelDroplessMLP",
|
195 |
+
"dMoE",
|
196 |
+
"SparseGLU",
|
197 |
+
"MLP",
|
198 |
+
"SparseMLP",
|
199 |
+
"MoE",
|
200 |
+
"ParallelMLP",
|
201 |
+
"get_load_balancing_loss",
|
202 |
+
]
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# from megablocks.layers.dmoe import dMoE
|
5 |
+
from .moe import MoE
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'MoE',
|
9 |
+
# 'dMoE',
|
10 |
+
]
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/activation_fn.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Callable, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from ..stk import Matrix
|
8 |
+
|
9 |
+
|
10 |
+
def act_fn(
|
11 |
+
x: Matrix,
|
12 |
+
function: Callable,
|
13 |
+
return_grad_fn: bool = False,
|
14 |
+
**kwargs,
|
15 |
+
) -> Union[tuple[Matrix, Any] | Matrix]:
|
16 |
+
assert isinstance(x, Matrix)
|
17 |
+
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
|
18 |
+
if return_grad_fn:
|
19 |
+
x.data.requires_grad = True
|
20 |
+
out = function(x.data, **kwargs)
|
21 |
+
y = Matrix(
|
22 |
+
x.size(),
|
23 |
+
out,
|
24 |
+
x.row_indices,
|
25 |
+
x.column_indices,
|
26 |
+
x.offsets,
|
27 |
+
x.column_indices_t,
|
28 |
+
x.offsets_t,
|
29 |
+
x.block_offsets_t,
|
30 |
+
)
|
31 |
+
if return_grad_fn:
|
32 |
+
return y, out.backward
|
33 |
+
return y
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/all_to_all.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class AllToAllOp(torch.autograd.Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
|
12 |
+
out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
|
13 |
+
|
14 |
+
ctx.input_shape = x.shape
|
15 |
+
ctx.output_split_sizes = output_split_sizes
|
16 |
+
ctx.input_split_sizes = input_split_sizes
|
17 |
+
ctx.group = group
|
18 |
+
handle = dist.all_to_all_single(
|
19 |
+
out,
|
20 |
+
x,
|
21 |
+
output_split_sizes=output_split_sizes,
|
22 |
+
input_split_sizes=input_split_sizes,
|
23 |
+
group=group,
|
24 |
+
async_op=async_op,
|
25 |
+
)
|
26 |
+
return out, handle
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, grad, _):
|
30 |
+
if ctx.needs_input_grad[0]:
|
31 |
+
out = torch.empty(
|
32 |
+
ctx.input_shape,
|
33 |
+
device=grad.device,
|
34 |
+
dtype=grad.dtype,
|
35 |
+
)
|
36 |
+
dist.all_to_all_single(
|
37 |
+
out,
|
38 |
+
grad,
|
39 |
+
output_split_sizes=ctx.input_split_sizes,
|
40 |
+
input_split_sizes=ctx.output_split_sizes,
|
41 |
+
group=ctx.group,
|
42 |
+
)
|
43 |
+
return out, None, None, None, None
|
44 |
+
return None, None, None, None, None
|
45 |
+
|
46 |
+
|
47 |
+
def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
|
48 |
+
return AllToAllOp.apply(
|
49 |
+
x,
|
50 |
+
output_split_sizes,
|
51 |
+
input_split_sizes,
|
52 |
+
group,
|
53 |
+
async_op,
|
54 |
+
)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/arguments.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import dataclasses
|
5 |
+
from functools import partial
|
6 |
+
from typing import Any, Callable, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# import megablocks.grouped_gemm_util as grouped_gemm
|
13 |
+
from .. import grouped_gemm_util as grouped_gemm
|
14 |
+
|
15 |
+
# Type annotation for in-place Tensor initialization function.
|
16 |
+
InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]]
|
17 |
+
|
18 |
+
_ALLOWED_BITWIDTHS = (-1, 4, 8)
|
19 |
+
|
20 |
+
DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh')
|
21 |
+
|
22 |
+
|
23 |
+
@dataclasses.dataclass
|
24 |
+
class Arguments:
|
25 |
+
# Model arguments.
|
26 |
+
hidden_size: int = 1024
|
27 |
+
ffn_hidden_size: int = 4096
|
28 |
+
num_layers: int = 1
|
29 |
+
bias: bool = True
|
30 |
+
return_bias: bool = True
|
31 |
+
activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN
|
32 |
+
|
33 |
+
# MoE arguments.
|
34 |
+
moe_num_experts: int = 1
|
35 |
+
moe_top_k: int = 1
|
36 |
+
moe_capacity_factor: int = 1
|
37 |
+
moe_normalize_expert_weights: Optional[Union[int, float]] = None
|
38 |
+
moe_loss_weight: float = 0.1
|
39 |
+
moe_jitter_eps: Optional[float] = None
|
40 |
+
moe_lbl_in_fp32: bool = False
|
41 |
+
|
42 |
+
# Parallelism arguments.
|
43 |
+
moe_expert_model_parallelism: bool = False
|
44 |
+
expert_parallel_group: Optional[dist.ProcessGroup] = None
|
45 |
+
pipeline_model_parallel_size: int = 1
|
46 |
+
num_layers_per_virtual_pipeline_stage: Optional[int] = None
|
47 |
+
|
48 |
+
# Compute arguments.
|
49 |
+
memory_optimized_mlp: bool = False
|
50 |
+
mlp_type: str = 'mlp'
|
51 |
+
mlp_impl: str = 'sparse'
|
52 |
+
|
53 |
+
# Initialization arguments.
|
54 |
+
fp16: bool = True
|
55 |
+
bf16: bool = False
|
56 |
+
device: Union[int, torch.device] = dataclasses.field(default_factory=torch.cuda.current_device)
|
57 |
+
init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02)
|
58 |
+
output_layer_init_method: InitFn = init_method
|
59 |
+
|
60 |
+
# Benchmarking arguments.
|
61 |
+
uniform_expert_assignment: bool = False
|
62 |
+
|
63 |
+
# shared expert arguments
|
64 |
+
shared_expert: bool = False # enable using shared expert
|
65 |
+
fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8))
|
66 |
+
fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers
|
67 |
+
remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored
|
68 |
+
shared_expert_hidden_size: Optional[
|
69 |
+
int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size
|
70 |
+
shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used)
|
71 |
+
|
72 |
+
# Router Z-loss arguments
|
73 |
+
moe_zloss_weight: float = 0 # 1e-3 is a reasonable value
|
74 |
+
moe_zloss_in_fp32: bool = False
|
75 |
+
|
76 |
+
def __post_init__(self):
|
77 |
+
# Sparse MLP is not supported with triton >=3.2.0
|
78 |
+
# TODO: Remove this once sparse is supported with triton >=3.2.0
|
79 |
+
if self.__getattribute__('mlp_impl') == 'sparse':
|
80 |
+
try:
|
81 |
+
import triton
|
82 |
+
if triton.__version__ >= '3.2.0':
|
83 |
+
raise ValueError(
|
84 |
+
'Sparse MLP is not supported with triton >=3.2.0. Please use mlp_impl="grouped" instead.',
|
85 |
+
)
|
86 |
+
except ImportError:
|
87 |
+
raise ImportError('Triton is required for sparse MLP implementation')
|
88 |
+
|
89 |
+
if self.__getattribute__('mlp_impl') == 'grouped':
|
90 |
+
grouped_gemm.assert_grouped_gemm_is_available()
|
91 |
+
|
92 |
+
if self.shared_expert_hidden_size is None:
|
93 |
+
self.shared_expert_hidden_size = self.ffn_hidden_size
|
94 |
+
|
95 |
+
|
96 |
+
def from_megatron(megatron_args: Any):
|
97 |
+
args = Arguments()
|
98 |
+
for field in dataclasses.fields(args):
|
99 |
+
if hasattr(megatron_args, field.name):
|
100 |
+
setattr(args, field.name, getattr(megatron_args, field.name))
|
101 |
+
return args
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/common.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .arguments import Arguments
|
7 |
+
|
8 |
+
|
9 |
+
def dtype(args: Arguments):
|
10 |
+
if args.fp16:
|
11 |
+
return torch.float16
|
12 |
+
elif args.bf16:
|
13 |
+
return torch.bfloat16
|
14 |
+
return None
|
15 |
+
|
16 |
+
|
17 |
+
def cast_if_autocast_enabled(tensor):
|
18 |
+
if torch.is_autocast_enabled():
|
19 |
+
if tensor.device.type == 'cuda':
|
20 |
+
dtype = torch.get_autocast_gpu_dtype()
|
21 |
+
elif tensor.device.type == 'cpu':
|
22 |
+
dtype = torch.get_autocast_cpu_dtype()
|
23 |
+
else:
|
24 |
+
raise NotImplementedError()
|
25 |
+
return tensor.to(dtype=dtype)
|
26 |
+
return tensor
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmlp_registry.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
from . import glu, mlp
|
7 |
+
from .arguments import Arguments
|
8 |
+
|
9 |
+
MlpType = Union[mlp.SparseMLP, glu.SparseGLU]
|
10 |
+
|
11 |
+
_REGISTRY = {
|
12 |
+
'mlp': {
|
13 |
+
'grouped': mlp.GroupedMLP,
|
14 |
+
'sparse': mlp.SparseMLP,
|
15 |
+
},
|
16 |
+
'glu': {
|
17 |
+
'grouped': glu.GroupedGLU,
|
18 |
+
'sparse': glu.SparseGLU,
|
19 |
+
},
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def get(args: Arguments) -> MlpType:
|
24 |
+
"""Returns an MLP for use in a dMoE instance.
|
25 |
+
|
26 |
+
Uses the provided arguments to instantiate the appropriate
|
27 |
+
MLP instance. This only contains MLPs for use in dMoEs
|
28 |
+
(ie. only for the dropless versions of MoEs).
|
29 |
+
|
30 |
+
Args:
|
31 |
+
args: propagated Arguments dataclass.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
An instantiated MLP constructed using the input args.
|
35 |
+
"""
|
36 |
+
if args.mlp_type not in _REGISTRY:
|
37 |
+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
|
38 |
+
|
39 |
+
if args.mlp_impl not in _REGISTRY[args.mlp_type]:
|
40 |
+
raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',)
|
41 |
+
|
42 |
+
return _REGISTRY[args.mlp_type][args.mlp_impl](args)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/dmoe.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# try:
|
8 |
+
# import stk.ops
|
9 |
+
# except ImportError:
|
10 |
+
# import warnings
|
11 |
+
# warnings.warn(
|
12 |
+
# 'Please add `stanford-stk` if megablocks/_layers/dmoe.py is needed.',
|
13 |
+
# )
|
14 |
+
|
15 |
+
# import megablocks.ops as ops
|
16 |
+
# # from megablocks.ops import ops
|
17 |
+
# from megablocks.layers import common, dmlp_registry, moe, mpu
|
18 |
+
# from megablocks.layers.arguments import Arguments
|
19 |
+
|
20 |
+
from .. import stk
|
21 |
+
from .. import ops
|
22 |
+
from . import common, dmlp_registry, moe, mpu
|
23 |
+
from .arguments import Arguments
|
24 |
+
|
25 |
+
def promote_scalar(x):
|
26 |
+
return x.view(1) if not len(x.size()) else x
|
27 |
+
|
28 |
+
|
29 |
+
class ParallelDroplessMLP(moe.ParallelMLP):
|
30 |
+
|
31 |
+
def __init__(self, args: Arguments):
|
32 |
+
super(ParallelDroplessMLP, self).__init__(args)
|
33 |
+
self.hidden_size = args.hidden_size
|
34 |
+
self.ffn_hidden_size = mpu.features_per_rank(args)
|
35 |
+
self.blocking = 128
|
36 |
+
self.mlp = dmlp_registry.get(args)
|
37 |
+
|
38 |
+
# Calculate the number of bits needed to represent the column indices
|
39 |
+
# in the intermediate sparse matrix.
|
40 |
+
max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking)
|
41 |
+
self.transpose_sort_end_bit = max(
|
42 |
+
int(np.ceil(np.log2(max_column_index))),
|
43 |
+
1,
|
44 |
+
)
|
45 |
+
|
46 |
+
def sparse_transpose(self, size, row_indices, column_indices, offsets):
|
47 |
+
block_columns = size[1] // self.blocking
|
48 |
+
|
49 |
+
# Sort row indices by column indices to get the transposed matrix's
|
50 |
+
# column indices.
|
51 |
+
#
|
52 |
+
# NOTE: Our sort operation uses the same width indices as the input values.
|
53 |
+
# To avoid overflow when we have large activation matrices we cast to
|
54 |
+
# 32-bit before sorting.
|
55 |
+
_, gather_indices = ops.sort(
|
56 |
+
column_indices.int(),
|
57 |
+
self.transpose_sort_end_bit,
|
58 |
+
)
|
59 |
+
|
60 |
+
# There are a constant number of blocks in every row of the sparse matrix.
|
61 |
+
# A blocks offset is:
|
62 |
+
#
|
63 |
+
# row_index * blocks_per_row + column_index % blocks_per_row
|
64 |
+
#
|
65 |
+
# Once we have the block offsets ordered for transposition we can divide
|
66 |
+
# by blocks_per_row to get the transposed column indices.
|
67 |
+
column_indices_t = row_indices.gather(0, gather_indices.long())
|
68 |
+
block_offsets_t = gather_indices.int()
|
69 |
+
|
70 |
+
zero = torch.zeros((1,), dtype=torch.int32, device=row_indices.device)
|
71 |
+
nnz_per_column = ops.histogram(column_indices, block_columns)
|
72 |
+
nnz_per_column = ops.inclusive_cumsum(nnz_per_column, 0)
|
73 |
+
if nnz_per_column.dim() == 0:
|
74 |
+
# This addresses an edge case when ffn_hidden_size is equal to self.blocking.
|
75 |
+
nnz_per_column = nnz_per_column.unsqueeze(0)
|
76 |
+
offsets_t = torch.cat([zero, nnz_per_column])
|
77 |
+
return column_indices_t, offsets_t, block_offsets_t
|
78 |
+
|
79 |
+
def topology(self, x, padded_bins):
|
80 |
+
padded_tokens, _ = x.size()
|
81 |
+
assert padded_tokens % self.blocking == 0
|
82 |
+
if self.ffn_hidden_size % self.blocking != 0:
|
83 |
+
raise ValueError(
|
84 |
+
f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' +
|
85 |
+
f'the block size {self.blocking}. Please update your configuration.',
|
86 |
+
)
|
87 |
+
|
88 |
+
# Offsets for the sparse matrix. All rows have the
|
89 |
+
# same number of nonzero blocks dictated by the
|
90 |
+
# dimensionality of a single expert.
|
91 |
+
block_rows = padded_tokens // self.blocking
|
92 |
+
blocks_per_row = self.ffn_hidden_size // self.blocking
|
93 |
+
offsets = torch.arange(
|
94 |
+
0,
|
95 |
+
block_rows * blocks_per_row + 1,
|
96 |
+
blocks_per_row,
|
97 |
+
dtype=torch.int32,
|
98 |
+
device=x.device,
|
99 |
+
)
|
100 |
+
|
101 |
+
# Indices for the sparse matrix. The indices for
|
102 |
+
# the intermediate matrix are dynamic depending
|
103 |
+
# on the mapping of tokens to experts.
|
104 |
+
column_indices = ops.topology(
|
105 |
+
padded_bins,
|
106 |
+
self.blocking,
|
107 |
+
block_rows,
|
108 |
+
blocks_per_row,
|
109 |
+
)
|
110 |
+
|
111 |
+
# TODO(tgale): This is unused. Remove the need for this in stk.
|
112 |
+
# For now, use meta init to save the device memory.
|
113 |
+
data = torch.empty(
|
114 |
+
column_indices.numel(),
|
115 |
+
self.blocking,
|
116 |
+
self.blocking,
|
117 |
+
dtype=common.dtype(self.args),
|
118 |
+
device='meta',
|
119 |
+
)
|
120 |
+
shape = (
|
121 |
+
padded_tokens,
|
122 |
+
self.ffn_hidden_size * mpu.experts_per_rank(self.args),
|
123 |
+
)
|
124 |
+
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
125 |
+
column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose(
|
126 |
+
shape,
|
127 |
+
row_indices,
|
128 |
+
column_indices,
|
129 |
+
offsets,
|
130 |
+
)
|
131 |
+
return stk.Matrix(
|
132 |
+
shape,
|
133 |
+
data,
|
134 |
+
row_indices,
|
135 |
+
column_indices,
|
136 |
+
offsets,
|
137 |
+
column_indices_t,
|
138 |
+
offsets_t,
|
139 |
+
block_offsets_t,
|
140 |
+
)
|
141 |
+
|
142 |
+
def indices_and_padded_bins(self, top_experts):
|
143 |
+
# Sort the expert ids to produce the scatter/gather
|
144 |
+
# indices for the permutation.
|
145 |
+
top_experts = top_experts.int()
|
146 |
+
bin_ids, indices = ops.sort(top_experts, self.sort_end_bit)
|
147 |
+
|
148 |
+
# Histogram the expert ids to identify the number of
|
149 |
+
# tokens routed to each expert.
|
150 |
+
tokens_per_expert = ops.histogram(top_experts, self.num_experts)
|
151 |
+
|
152 |
+
# Round the token counts up to the block size used in
|
153 |
+
# the matrix muliplications. Caculate the starting
|
154 |
+
# position of each bin.
|
155 |
+
padded_tokens_per_expert = ops.round_up(
|
156 |
+
tokens_per_expert,
|
157 |
+
self.blocking,
|
158 |
+
)
|
159 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
160 |
+
padded_bins = promote_scalar(padded_bins)
|
161 |
+
|
162 |
+
# Calculate the bin bounds for the sorted tokens.
|
163 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
164 |
+
bins = promote_scalar(bins)
|
165 |
+
return indices, bin_ids, bins, padded_bins, tokens_per_expert
|
166 |
+
|
167 |
+
def sparse_forward_once(self, x, expert_weights, top_experts):
|
168 |
+
# x: [sl, bs, hs]
|
169 |
+
# expert_weights: [sl * bs, top-k]
|
170 |
+
# top_experts: [sl * bs, top-k]
|
171 |
+
expert_weights = expert_weights.flatten()
|
172 |
+
top_experts = top_experts.flatten()
|
173 |
+
with torch.no_grad():
|
174 |
+
indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts))
|
175 |
+
|
176 |
+
# Route the tokens for MoE computation.
|
177 |
+
x = x.view(-1, x.shape[-1])
|
178 |
+
x = ops.padded_gather(
|
179 |
+
x,
|
180 |
+
indices,
|
181 |
+
bin_ids,
|
182 |
+
bins,
|
183 |
+
padded_bins,
|
184 |
+
self.top_k,
|
185 |
+
)
|
186 |
+
|
187 |
+
# Create the sparse matrix topology.
|
188 |
+
with torch.no_grad():
|
189 |
+
topo = self.topology(x, padded_bins)
|
190 |
+
|
191 |
+
# Perform the expert computation.
|
192 |
+
x = self.mlp(x, topo)
|
193 |
+
|
194 |
+
# Un-route the data for the MoE output.
|
195 |
+
x = ops.padded_scatter(
|
196 |
+
x,
|
197 |
+
indices,
|
198 |
+
bin_ids,
|
199 |
+
expert_weights,
|
200 |
+
bins,
|
201 |
+
padded_bins,
|
202 |
+
self.top_k,
|
203 |
+
)
|
204 |
+
return x, tokens_per_expert
|
205 |
+
|
206 |
+
# For use in the base-class parallel_forward_once.
|
207 |
+
def sparse_permute_and_compute(
|
208 |
+
self,
|
209 |
+
x,
|
210 |
+
tokens_per_expert,
|
211 |
+
indices,
|
212 |
+
bin_ids,
|
213 |
+
expert_weights,
|
214 |
+
bins,
|
215 |
+
expert_capactiy, # unused
|
216 |
+
top_k,
|
217 |
+
):
|
218 |
+
|
219 |
+
# Round the token counts up to the block size used in the matrix
|
220 |
+
# multiplication. Calculate the starting position of each bin.
|
221 |
+
padded_tokens_per_expert = ops.round_up(
|
222 |
+
tokens_per_expert,
|
223 |
+
self.blocking,
|
224 |
+
)
|
225 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
226 |
+
padded_bins = promote_scalar(padded_bins)
|
227 |
+
|
228 |
+
# Route the tokens for MoE computation.
|
229 |
+
x = x.view(-1, x.shape[-1])
|
230 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
231 |
+
|
232 |
+
# Create the sparse matrix topology.
|
233 |
+
with torch.no_grad():
|
234 |
+
topo = self.topology(x, padded_bins)
|
235 |
+
|
236 |
+
# Perform the expert computation.
|
237 |
+
x = self.mlp(x, topo)
|
238 |
+
|
239 |
+
# Un-route the data for the MoE output.
|
240 |
+
return ops.padded_scatter(
|
241 |
+
x,
|
242 |
+
indices,
|
243 |
+
bin_ids,
|
244 |
+
expert_weights,
|
245 |
+
bins,
|
246 |
+
padded_bins,
|
247 |
+
top_k,
|
248 |
+
)
|
249 |
+
|
250 |
+
def grouped_forward_once(self, x, expert_weights, top_experts):
|
251 |
+
# x: [sl, bs, hs]
|
252 |
+
# expert_weights: [sl * bs, top-k]
|
253 |
+
# top_experts: [sl * bs, top-k]
|
254 |
+
expert_weights = expert_weights.flatten()
|
255 |
+
top_experts = top_experts.flatten()
|
256 |
+
with torch.no_grad():
|
257 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
258 |
+
|
259 |
+
out = self.grouped_permute_and_compute(
|
260 |
+
x,
|
261 |
+
tokens_per_expert,
|
262 |
+
indices,
|
263 |
+
bin_ids,
|
264 |
+
expert_weights,
|
265 |
+
bins,
|
266 |
+
-1, # unused
|
267 |
+
self.args.moe_top_k,
|
268 |
+
)
|
269 |
+
return out, tokens_per_expert
|
270 |
+
|
271 |
+
def grouped_permute_and_compute(
|
272 |
+
self,
|
273 |
+
x,
|
274 |
+
tokens_per_expert,
|
275 |
+
indices,
|
276 |
+
bin_ids,
|
277 |
+
expert_weights,
|
278 |
+
bins,
|
279 |
+
expert_capactiy, # unused
|
280 |
+
top_k,
|
281 |
+
):
|
282 |
+
|
283 |
+
# Route the tokens for MoE computation.
|
284 |
+
x = x.view(-1, x.shape[-1])
|
285 |
+
x = ops.gather(x, indices, bin_ids, bins, top_k)
|
286 |
+
|
287 |
+
# Perform the expert computation.
|
288 |
+
x = self.mlp(x, tokens_per_expert)
|
289 |
+
|
290 |
+
# Un-route the data for the MoE output.
|
291 |
+
return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
|
292 |
+
|
293 |
+
def forward_once(self, x, expert_weights, top_experts):
|
294 |
+
if self.args.mlp_impl == 'sparse':
|
295 |
+
return self.sparse_forward_once(x, expert_weights, top_experts)
|
296 |
+
else:
|
297 |
+
return self.grouped_forward_once(x, expert_weights, top_experts)
|
298 |
+
|
299 |
+
def permute_and_compute(
|
300 |
+
self,
|
301 |
+
x,
|
302 |
+
tokens_per_expert,
|
303 |
+
indices,
|
304 |
+
bin_ids,
|
305 |
+
expert_weights,
|
306 |
+
bins,
|
307 |
+
expert_capactiy,
|
308 |
+
top_k,
|
309 |
+
):
|
310 |
+
if self.args.mlp_impl == 'sparse':
|
311 |
+
return self.sparse_permute_and_compute(
|
312 |
+
x,
|
313 |
+
tokens_per_expert,
|
314 |
+
indices,
|
315 |
+
bin_ids,
|
316 |
+
expert_weights,
|
317 |
+
bins,
|
318 |
+
expert_capactiy,
|
319 |
+
top_k,
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
return self.grouped_permute_and_compute(
|
323 |
+
x,
|
324 |
+
tokens_per_expert,
|
325 |
+
indices,
|
326 |
+
bin_ids,
|
327 |
+
expert_weights,
|
328 |
+
bins,
|
329 |
+
expert_capactiy,
|
330 |
+
top_k,
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
class dMoE(moe.MoE):
|
335 |
+
|
336 |
+
def _init_experts_mlp(self, args: Arguments):
|
337 |
+
return ParallelDroplessMLP(args)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/gelu.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# try:
|
5 |
+
# import stk
|
6 |
+
# except ImportError:
|
7 |
+
# import warnings
|
8 |
+
# warnings.warn(
|
9 |
+
# 'Please add `stanford-stk` if megablocks/_layers/gelu.py is needed.',
|
10 |
+
# )
|
11 |
+
|
12 |
+
from .. import stk
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
|
18 |
+
@torch.jit.script
|
19 |
+
def _gelu_backward_inplace(g, x):
|
20 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
21 |
+
ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out))
|
22 |
+
return g.mul_(ff)
|
23 |
+
|
24 |
+
|
25 |
+
def gelu_backward_(grad: stk.Matrix, x: stk.Matrix):
|
26 |
+
# NOTE: The two sparse matrices must have the same topology.
|
27 |
+
if isinstance(grad, stk.Matrix) and isinstance(x, stk.Matrix):
|
28 |
+
return stk.Matrix(
|
29 |
+
x.size(),
|
30 |
+
_gelu_backward_inplace(grad.data, x.data),
|
31 |
+
x.row_indices,
|
32 |
+
x.column_indices,
|
33 |
+
x.offsets,
|
34 |
+
x.column_indices_t,
|
35 |
+
x.offsets_t,
|
36 |
+
x.block_offsets_t,
|
37 |
+
)
|
38 |
+
return _gelu_backward_inplace(grad, x)
|
39 |
+
|
40 |
+
|
41 |
+
def gelu(x: stk.Matrix):
|
42 |
+
assert isinstance(x, stk.Matrix)
|
43 |
+
return stk.Matrix(
|
44 |
+
x.size(),
|
45 |
+
F.gelu(x.data, approximate='tanh'),
|
46 |
+
x.row_indices,
|
47 |
+
x.column_indices,
|
48 |
+
x.offsets,
|
49 |
+
x.column_indices_t,
|
50 |
+
x.offsets_t,
|
51 |
+
x.block_offsets_t,
|
52 |
+
)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/glu.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# import stk.ops
|
5 |
+
# try:
|
6 |
+
# import stk.ops
|
7 |
+
# except ImportError:
|
8 |
+
# import warnings
|
9 |
+
# warnings.warn(
|
10 |
+
# 'Please add `stanford-stk` if megablocks/_layers/glu.py is needed.',
|
11 |
+
# )
|
12 |
+
|
13 |
+
from .. import stk
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
# from megablocks import grouped_gemm_util as gg
|
18 |
+
# from megablocks.layers import common, mpu
|
19 |
+
# from megablocks.layers.activation_fn import act_fn
|
20 |
+
# from megablocks.layers.arguments import Arguments
|
21 |
+
# from megablocks.layers.mlp import (
|
22 |
+
# SharedMLP,
|
23 |
+
# SparseMLP,
|
24 |
+
# create_dmoe_expert_weights,
|
25 |
+
# resolve_dtensor,
|
26 |
+
# )
|
27 |
+
|
28 |
+
from .. import grouped_gemm_util as gg
|
29 |
+
from . import common, mpu
|
30 |
+
from .activation_fn import act_fn
|
31 |
+
from .arguments import Arguments
|
32 |
+
from .mlp import (
|
33 |
+
SharedMLP,
|
34 |
+
SparseMLP,
|
35 |
+
create_dmoe_expert_weights,
|
36 |
+
resolve_dtensor,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
class SparseGLU(SparseMLP):
|
41 |
+
|
42 |
+
def __init__(self, args: Arguments):
|
43 |
+
super().__init__(args)
|
44 |
+
self.v1 = torch.nn.Parameter(
|
45 |
+
torch.empty(
|
46 |
+
self._num_rows_per_rank,
|
47 |
+
args.hidden_size,
|
48 |
+
device=args.device,
|
49 |
+
dtype=common.dtype(args),
|
50 |
+
),
|
51 |
+
)
|
52 |
+
with torch.no_grad():
|
53 |
+
self.v1.copy_(
|
54 |
+
create_dmoe_expert_weights(
|
55 |
+
args,
|
56 |
+
args.moe_num_experts,
|
57 |
+
args.ffn_hidden_size,
|
58 |
+
args.hidden_size,
|
59 |
+
args.init_method,
|
60 |
+
),
|
61 |
+
)
|
62 |
+
|
63 |
+
mpu.set_expert_model_parallel_attributes(
|
64 |
+
self.v1,
|
65 |
+
self._should_set_parallelism_attribute,
|
66 |
+
)
|
67 |
+
|
68 |
+
def forward(self, x, topo):
|
69 |
+
if self.args.memory_optimized_mlp:
|
70 |
+
raise NotImplementedError(
|
71 |
+
'Memory optimized implementation not yet supported with GLU with sparse kernels.',
|
72 |
+
)
|
73 |
+
|
74 |
+
w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2)
|
75 |
+
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
|
76 |
+
|
77 |
+
# Compute the GLU.
|
78 |
+
x1 = stk.ops.sdd(x, w1.t(), topo)
|
79 |
+
x2 = stk.ops.sdd(x, v1.t(), topo)
|
80 |
+
|
81 |
+
activation_fn_out = act_fn(x1, self.args.activation_fn)
|
82 |
+
x1 = stk.ops.mul(activation_fn_out, x2)
|
83 |
+
|
84 |
+
return stk.ops.dsd(x1, w2)
|
85 |
+
|
86 |
+
|
87 |
+
class MemoryOptimizedGroupedGLU(torch.autograd.Function):
|
88 |
+
"""GroupedMLP with manually scheduled memory reuse."""
|
89 |
+
|
90 |
+
@staticmethod
|
91 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
92 |
+
def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn):
|
93 |
+
# Cast inputs using ctx dtype from AMP
|
94 |
+
if ctx._fwd_used_autocast:
|
95 |
+
x = x.to(ctx._dtype)
|
96 |
+
w1 = w1.to(ctx._dtype)
|
97 |
+
v1 = v1.to(ctx._dtype)
|
98 |
+
w2 = w2.to(ctx._dtype)
|
99 |
+
# x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k]
|
100 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()):
|
101 |
+
raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.")
|
102 |
+
|
103 |
+
# Layer 0: x @ w1.t().
|
104 |
+
assert gg.backend is not None
|
105 |
+
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
|
106 |
+
v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True)
|
107 |
+
|
108 |
+
# GeLU.
|
109 |
+
activation_fn_out = activation_fn(sdd_out) * v1_out
|
110 |
+
|
111 |
+
# Layer 1: x @ w2.
|
112 |
+
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
|
113 |
+
|
114 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
115 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
116 |
+
# pass in the backward pass to avoid materializing another
|
117 |
+
# intermediate.
|
118 |
+
ctx.x_shape = x.shape
|
119 |
+
ctx.sdd_out_shape = sdd_out.shape
|
120 |
+
ctx.dtype = x.dtype
|
121 |
+
ctx.activation_fn = activation_fn
|
122 |
+
ctx.save_for_backward(w1, v1, w2, batch_sizes, x, sdd_out, v1_out)
|
123 |
+
return dsd_out
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
127 |
+
def backward(ctx, ddsd_out):
|
128 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
129 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
130 |
+
|
131 |
+
# Unpack saved tensors
|
132 |
+
# dtype = ctx.dtype
|
133 |
+
saved_tensors = ctx.saved_tensors
|
134 |
+
w1, v1, w2 = saved_tensors[:3]
|
135 |
+
batch_sizes = saved_tensors[3]
|
136 |
+
x = saved_tensors[4]
|
137 |
+
sdd_out, v1_out = saved_tensors[5:7]
|
138 |
+
|
139 |
+
# Rematerialize activation_fn output.
|
140 |
+
activation_fn = ctx.activation_fn
|
141 |
+
with torch.set_grad_enabled(True):
|
142 |
+
sdd_out.requires_grad = True
|
143 |
+
v1_out.requires_grad = True
|
144 |
+
activation_fn_out = activation_fn(sdd_out) * v1_out
|
145 |
+
activation_grad_fn = activation_fn_out.backward
|
146 |
+
|
147 |
+
# Compute dw2 with recomputed activation_fn output.
|
148 |
+
assert gg.backend is not None
|
149 |
+
dw2 = gg.backend.gmm(
|
150 |
+
activation_fn_out,
|
151 |
+
ddsd_out,
|
152 |
+
batch_sizes,
|
153 |
+
trans_a=True,
|
154 |
+
)
|
155 |
+
|
156 |
+
# Compute dactivation_fn_out.
|
157 |
+
#
|
158 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
159 |
+
dactivation_fn_out = activation_fn_out
|
160 |
+
gg.backend.gmm(
|
161 |
+
ddsd_out,
|
162 |
+
w2,
|
163 |
+
batch_sizes,
|
164 |
+
trans_b=True,
|
165 |
+
c=dactivation_fn_out,
|
166 |
+
)
|
167 |
+
|
168 |
+
# Compute dsdd_out.
|
169 |
+
#
|
170 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
171 |
+
assert activation_grad_fn is not None
|
172 |
+
activation_grad_fn(dactivation_fn_out)
|
173 |
+
dsdd_out = sdd_out.grad
|
174 |
+
dv1_out = v1_out.grad
|
175 |
+
|
176 |
+
# Compute dw1.
|
177 |
+
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
|
178 |
+
|
179 |
+
# Compute dv1.
|
180 |
+
dv1 = gg.backend.gmm(dv1_out, x, batch_sizes, trans_a=True)
|
181 |
+
|
182 |
+
# Compute dx.
|
183 |
+
#
|
184 |
+
# NOTE: This reuses the ddsd_out allocation.
|
185 |
+
dx = ddsd_out
|
186 |
+
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=dx)
|
187 |
+
dx += gg.backend.gmm(dv1_out, v1, batch_sizes)
|
188 |
+
return dx, dw1, dv1, dw2, None, None
|
189 |
+
|
190 |
+
|
191 |
+
memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply
|
192 |
+
|
193 |
+
|
194 |
+
class GroupedGLU(SparseGLU):
|
195 |
+
|
196 |
+
def forward(self, x, tokens_per_expert):
|
197 |
+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
|
198 |
+
w1, v1, w2 = (
|
199 |
+
self.scale_grad(self.w1),
|
200 |
+
self.scale_grad(self.v1),
|
201 |
+
self.scale_grad(self.w2),
|
202 |
+
)
|
203 |
+
w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2)
|
204 |
+
|
205 |
+
# Re-shape the weights for the grouped GEMMs.
|
206 |
+
ne = mpu.experts_per_rank(self.args)
|
207 |
+
w1 = w1.view(ne, -1, self.args.hidden_size)
|
208 |
+
v1 = v1.view(ne, -1, self.args.hidden_size)
|
209 |
+
w2 = w2.view(ne, -1, self.args.hidden_size)
|
210 |
+
|
211 |
+
if self.args.memory_optimized_mlp:
|
212 |
+
return memory_optimized_grouped_glu(
|
213 |
+
x,
|
214 |
+
w1,
|
215 |
+
v1,
|
216 |
+
w2,
|
217 |
+
batch_sizes,
|
218 |
+
self.args.activation_fn,
|
219 |
+
)
|
220 |
+
|
221 |
+
# Compute the MLP.
|
222 |
+
assert gg.ops is not None
|
223 |
+
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
|
224 |
+
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
|
225 |
+
x1 = self.args.activation_fn(x1) * x2
|
226 |
+
return gg.ops.gmm(x1, w2, batch_sizes)
|
227 |
+
|
228 |
+
|
229 |
+
class SharedGLU(SharedMLP):
|
230 |
+
"""GPU for shared expert.
|
231 |
+
|
232 |
+
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class
|
233 |
+
"""
|
234 |
+
|
235 |
+
def __init__(self, args: Arguments):
|
236 |
+
super().__init__(args)
|
237 |
+
self.gate_proj = args.fc_cls(
|
238 |
+
args.hidden_size,
|
239 |
+
self.args.shared_expert_hidden_size,
|
240 |
+
**self.fc_kwargs,
|
241 |
+
)
|
242 |
+
|
243 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
244 |
+
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/memory_test.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import gc
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
# from megablocks.layers import arguments, dmoe
|
10 |
+
from . import arguments, dmoe
|
11 |
+
|
12 |
+
_TESTS = ((8, 2048, 4096, 4096, 32, 4),)
|
13 |
+
|
14 |
+
|
15 |
+
def get_tensors():
|
16 |
+
ptrs = set()
|
17 |
+
out = []
|
18 |
+
for obj in gc.get_objects():
|
19 |
+
if torch.is_tensor(obj):
|
20 |
+
if not obj.is_contiguous() or obj.data_ptr() in ptrs:
|
21 |
+
continue
|
22 |
+
out.append(obj)
|
23 |
+
ptrs.add(obj.data_ptr())
|
24 |
+
return out
|
25 |
+
|
26 |
+
|
27 |
+
def test_memory(
|
28 |
+
group,
|
29 |
+
batch_size,
|
30 |
+
sequence_length,
|
31 |
+
hidden_size,
|
32 |
+
ffn_hidden_size,
|
33 |
+
num_experts,
|
34 |
+
top_k,
|
35 |
+
):
|
36 |
+
args = arguments.Arguments(
|
37 |
+
hidden_size=hidden_size,
|
38 |
+
ffn_hidden_size=ffn_hidden_size,
|
39 |
+
moe_num_experts=num_experts,
|
40 |
+
moe_top_k=top_k,
|
41 |
+
moe_expert_model_parallelism=True,
|
42 |
+
expert_parallel_group=group,
|
43 |
+
fp16=False,
|
44 |
+
bf16=True,
|
45 |
+
device=torch.cuda.current_device(),
|
46 |
+
)
|
47 |
+
layer = dmoe.dMoE(args).cuda()
|
48 |
+
|
49 |
+
x = torch.randn((batch_size, sequence_length, hidden_size),
|
50 |
+
device=torch.cuda.current_device(),
|
51 |
+
dtype=torch.bfloat16).requires_grad_(True)
|
52 |
+
torch.cuda.empty_cache()
|
53 |
+
|
54 |
+
# Run forward + backward.
|
55 |
+
# with torch.autograd.detect_anomaly():
|
56 |
+
out, _ = layer(x)
|
57 |
+
out.mean().backward()
|
58 |
+
|
59 |
+
# Report peak memory.
|
60 |
+
mem = torch.cuda.max_memory_allocated()
|
61 |
+
print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6))
|
62 |
+
print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),)
|
63 |
+
|
64 |
+
# Calculate weight and gradient memory usage.
|
65 |
+
weight_memory = 2 * (
|
66 |
+
layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel()
|
67 |
+
)
|
68 |
+
|
69 |
+
def grad_numel(x):
|
70 |
+
if x.grad is not None:
|
71 |
+
return x.grad.numel()
|
72 |
+
return 0
|
73 |
+
|
74 |
+
grad_memory = 2 * (
|
75 |
+
grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2)
|
76 |
+
)
|
77 |
+
weight_memory += grad_memory
|
78 |
+
|
79 |
+
print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6))
|
80 |
+
print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),)
|
81 |
+
|
82 |
+
# Manually calculate GPU memory usage from the garbage
|
83 |
+
# collector.
|
84 |
+
gc.collect()
|
85 |
+
total = 0
|
86 |
+
tensors = get_tensors()
|
87 |
+
tensors = sorted(tensors, key=lambda x: -x.numel())
|
88 |
+
for i, t in enumerate(tensors):
|
89 |
+
total += t.numel()
|
90 |
+
print(f'{i}: {t.shape}, {t.numel() * 2}')
|
91 |
+
del tensors
|
92 |
+
|
93 |
+
print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6))
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
assert dist.is_available()
|
98 |
+
group = dist.init_process_group(backend='nccl')
|
99 |
+
local_rank = dist.get_rank(group)
|
100 |
+
torch.cuda.set_device(local_rank)
|
101 |
+
|
102 |
+
for args in _TESTS:
|
103 |
+
test_memory(group, *args)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mlp.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# try:
|
7 |
+
# import stk
|
8 |
+
# import stk.backend.triton_kernels
|
9 |
+
# import stk.ops
|
10 |
+
# except ImportError:
|
11 |
+
# import warnings
|
12 |
+
# warnings.warn(
|
13 |
+
# 'Please add `stanford-stk` if megablocks/_layers/mlp.py is needed.',
|
14 |
+
# )
|
15 |
+
|
16 |
+
from .. import stk
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from packaging import version
|
20 |
+
|
21 |
+
# from megablocks import grouped_gemm_util as gg
|
22 |
+
# from megablocks.layers import common, gelu, mpu
|
23 |
+
# from megablocks.layers.activation_fn import act_fn
|
24 |
+
# from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
25 |
+
|
26 |
+
from .. import grouped_gemm_util as gg
|
27 |
+
from . import common, gelu, mpu
|
28 |
+
from .activation_fn import act_fn
|
29 |
+
from .arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn
|
30 |
+
|
31 |
+
class ScaleGradient(torch.autograd.Function):
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
35 |
+
def forward(ctx: Any, x: torch.Tensor, scale: float):
|
36 |
+
ctx.scale = scale
|
37 |
+
return x
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
41 |
+
def backward(ctx: torch.Tensor, grad: torch.Tensor):
|
42 |
+
return grad * ctx.scale, None
|
43 |
+
|
44 |
+
|
45 |
+
scale_gradient = ScaleGradient.apply
|
46 |
+
|
47 |
+
|
48 |
+
def resolve_dtensor(weight: torch.Tensor):
|
49 |
+
if version.parse(torch.__version__) >= version.parse('2.0.0'):
|
50 |
+
from torch.distributed._tensor import DTensor
|
51 |
+
if isinstance(weight, DTensor):
|
52 |
+
return weight.to_local()
|
53 |
+
return weight
|
54 |
+
|
55 |
+
|
56 |
+
def create_moe_expert_weights(
|
57 |
+
args: Arguments,
|
58 |
+
num_experts: int,
|
59 |
+
ffn_hidden_size: int,
|
60 |
+
hidden_size: int,
|
61 |
+
init_method: InitFn,
|
62 |
+
):
|
63 |
+
# Create the entire weight matrix such that the sampled weights will
|
64 |
+
# not vary between data parallelism and expert model parallelism for
|
65 |
+
# the same random seed.
|
66 |
+
master_weights = torch.empty(
|
67 |
+
num_experts,
|
68 |
+
ffn_hidden_size,
|
69 |
+
hidden_size,
|
70 |
+
device=args.device,
|
71 |
+
dtype=common.dtype(args),
|
72 |
+
)
|
73 |
+
init_method(master_weights)
|
74 |
+
|
75 |
+
if not args.moe_expert_model_parallelism:
|
76 |
+
return master_weights
|
77 |
+
|
78 |
+
# Calculate the amount of sharding in each dimension.
|
79 |
+
expert_sharding_degree = mpu.expert_sharding_degree(args)
|
80 |
+
hidden_sharding_degree = mpu.hidden_sharding_degree(args)
|
81 |
+
|
82 |
+
# Calculate the experts per rank.
|
83 |
+
#
|
84 |
+
# NOTE: We assign ranks to be expert parallel before going
|
85 |
+
# tensor parallel.
|
86 |
+
rank = mpu.get_expert_parallel_rank(args)
|
87 |
+
expert_rank = rank % expert_sharding_degree
|
88 |
+
num_experts_per_rank = num_experts // expert_sharding_degree
|
89 |
+
start_expert = expert_rank * num_experts_per_rank
|
90 |
+
end_expert = (expert_rank + 1) * num_experts_per_rank
|
91 |
+
|
92 |
+
# Calculate the rows per rank.
|
93 |
+
row_rank = rank // expert_sharding_degree
|
94 |
+
num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree
|
95 |
+
start_row = row_rank * num_rows_per_rank
|
96 |
+
end_row = (row_rank + 1) * num_rows_per_rank
|
97 |
+
|
98 |
+
# Slice the weight matrix to get the chunk for this rank.
|
99 |
+
with torch.no_grad():
|
100 |
+
weights = master_weights[start_expert:end_expert, start_row:end_row]
|
101 |
+
return weights
|
102 |
+
|
103 |
+
|
104 |
+
class MLP(torch.nn.Module):
|
105 |
+
|
106 |
+
def __init__(self, args: Arguments):
|
107 |
+
super().__init__()
|
108 |
+
self.args = args
|
109 |
+
# expert_parallel_world_size = mpu.get_expert_parallel_world_size(args)
|
110 |
+
experts_per_rank = mpu.experts_per_rank(args)
|
111 |
+
|
112 |
+
self.w1 = torch.nn.Parameter(
|
113 |
+
torch.empty(
|
114 |
+
experts_per_rank,
|
115 |
+
args.hidden_size,
|
116 |
+
mpu.features_per_rank(args),
|
117 |
+
device=args.device,
|
118 |
+
dtype=common.dtype(args),
|
119 |
+
),
|
120 |
+
)
|
121 |
+
self.w2 = torch.nn.Parameter(
|
122 |
+
torch.empty(
|
123 |
+
experts_per_rank,
|
124 |
+
mpu.features_per_rank(args),
|
125 |
+
args.hidden_size,
|
126 |
+
device=args.device,
|
127 |
+
dtype=common.dtype(args),
|
128 |
+
),
|
129 |
+
)
|
130 |
+
mpu.set_expert_model_parallel_attributes(
|
131 |
+
self.w1,
|
132 |
+
args.moe_expert_model_parallelism,
|
133 |
+
)
|
134 |
+
mpu.set_expert_model_parallel_attributes(
|
135 |
+
self.w2,
|
136 |
+
args.moe_expert_model_parallelism,
|
137 |
+
)
|
138 |
+
|
139 |
+
# Initialize the parameters for the MLP.
|
140 |
+
#
|
141 |
+
# NOTE: It is important that we create the weight tensors prior
|
142 |
+
# to creating the master weights and slicing our the piece for
|
143 |
+
# this rank. If the master weights are created first the PyTorch
|
144 |
+
# caching allocator appears to use the same memory block for these
|
145 |
+
# and the slice which causes large increases in our peak memory
|
146 |
+
# usage.
|
147 |
+
with torch.no_grad():
|
148 |
+
w1 = create_moe_expert_weights(
|
149 |
+
args,
|
150 |
+
args.moe_num_experts,
|
151 |
+
args.ffn_hidden_size,
|
152 |
+
args.hidden_size,
|
153 |
+
args.init_method,
|
154 |
+
)
|
155 |
+
self.w1.copy_(w1.transpose(1, 2).contiguous())
|
156 |
+
self.w2.copy_(
|
157 |
+
create_moe_expert_weights(
|
158 |
+
args,
|
159 |
+
args.moe_num_experts,
|
160 |
+
args.ffn_hidden_size,
|
161 |
+
args.hidden_size,
|
162 |
+
args.output_layer_init_method,
|
163 |
+
),
|
164 |
+
)
|
165 |
+
|
166 |
+
self.gradient_scale = None
|
167 |
+
if self.args.moe_expert_model_parallelism:
|
168 |
+
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
|
169 |
+
|
170 |
+
def scale_grad(self, w):
|
171 |
+
if self.gradient_scale is None:
|
172 |
+
return w
|
173 |
+
return scale_gradient(w, self.gradient_scale)
|
174 |
+
|
175 |
+
def forward(self, x):
|
176 |
+
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
|
177 |
+
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
|
178 |
+
x = torch.bmm(x, w1)
|
179 |
+
x = self.args.activation_fn(x)
|
180 |
+
return torch.bmm(x, w2)
|
181 |
+
|
182 |
+
|
183 |
+
def create_dmoe_expert_weights(
|
184 |
+
args: Arguments,
|
185 |
+
num_experts: int,
|
186 |
+
rows: int,
|
187 |
+
columns: int,
|
188 |
+
init_method: InitFn,
|
189 |
+
):
|
190 |
+
weights = create_moe_expert_weights(
|
191 |
+
args,
|
192 |
+
num_experts,
|
193 |
+
rows,
|
194 |
+
columns,
|
195 |
+
init_method,
|
196 |
+
)
|
197 |
+
return weights.view([-1, columns])
|
198 |
+
|
199 |
+
|
200 |
+
class MemoryOptimizedMLP(torch.autograd.Function):
|
201 |
+
"""Sparse MLP with manually scheduled memory reuse."""
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
205 |
+
def forward(ctx, x, w1, w2, topo, activation_fn):
|
206 |
+
# Cast inputs using ctx dtype from AMP
|
207 |
+
if ctx._fwd_used_autocast:
|
208 |
+
x = x.to(ctx._dtype)
|
209 |
+
w1 = w1.to(ctx._dtype)
|
210 |
+
w2 = w2.to(ctx._dtype)
|
211 |
+
# x: [m, k], w1: [n, k], w2: [n, k]
|
212 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
|
213 |
+
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
|
214 |
+
|
215 |
+
topo_tensors = (
|
216 |
+
topo.row_indices,
|
217 |
+
topo.column_indices,
|
218 |
+
topo.offsets,
|
219 |
+
topo.column_indices_t,
|
220 |
+
topo.offsets_t,
|
221 |
+
topo.block_offsets_t,
|
222 |
+
)
|
223 |
+
|
224 |
+
# Layer 0: x @ w1.t().
|
225 |
+
sdd_out = stk.ops.sdd(x, w1.t(), topo)
|
226 |
+
|
227 |
+
# GeLU.
|
228 |
+
activation_fn_out = act_fn(sdd_out, activation_fn)
|
229 |
+
|
230 |
+
# Layer 1: x @ w2.
|
231 |
+
dsd_out = stk.ops.dsd(activation_fn_out, w2)
|
232 |
+
|
233 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
234 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
235 |
+
# pass in the backward pass to avoid materializing another
|
236 |
+
# intermediate.
|
237 |
+
ctx.shape = topo.shape
|
238 |
+
ctx.x_shape = x.shape
|
239 |
+
ctx.sdd_out_shape = sdd_out.data.shape
|
240 |
+
ctx.dtype = x.dtype
|
241 |
+
ctx.activation_fn = activation_fn
|
242 |
+
ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data)
|
243 |
+
return dsd_out
|
244 |
+
|
245 |
+
@staticmethod
|
246 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
247 |
+
def backward(ctx, ddsd_out):
|
248 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
249 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
250 |
+
|
251 |
+
# unpack saved tensors
|
252 |
+
# dtype = ctx.dtype
|
253 |
+
saved_tensors = ctx.saved_tensors
|
254 |
+
w1, w2 = saved_tensors[:2]
|
255 |
+
topo_tensors = saved_tensors[2:8]
|
256 |
+
x = saved_tensors[8]
|
257 |
+
sdd_out_data = saved_tensors[9]
|
258 |
+
|
259 |
+
# rematerialize activation function output
|
260 |
+
activation_fn = ctx.activation_fn
|
261 |
+
sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors)
|
262 |
+
activation_fn_out, activation_grad_fn = act_fn(
|
263 |
+
sdd_out,
|
264 |
+
activation_fn,
|
265 |
+
return_grad_fn=True,
|
266 |
+
)
|
267 |
+
|
268 |
+
# Compute dw2 with recomputed activation_fn output.
|
269 |
+
dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out)
|
270 |
+
|
271 |
+
# Compute dactivation_fn_out.
|
272 |
+
#
|
273 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
274 |
+
dactivation_fn_out = activation_fn_out
|
275 |
+
stk.backend.triton_kernels.sdd(
|
276 |
+
ddsd_out,
|
277 |
+
w2.t(),
|
278 |
+
dactivation_fn_out.shape,
|
279 |
+
dactivation_fn_out.data,
|
280 |
+
dactivation_fn_out.offsets,
|
281 |
+
dactivation_fn_out.row_indices,
|
282 |
+
dactivation_fn_out.column_indices,
|
283 |
+
)
|
284 |
+
|
285 |
+
# Compute dsdd_out.
|
286 |
+
#
|
287 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
288 |
+
if activation_fn is DEFAULT_ACTIVATION_FN:
|
289 |
+
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
|
290 |
+
else:
|
291 |
+
assert activation_grad_fn is not None
|
292 |
+
activation_grad_fn(dactivation_fn_out.data)
|
293 |
+
dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors)
|
294 |
+
|
295 |
+
# Compute dw1.
|
296 |
+
dw1 = stk.ops.dsd(dsdd_out.t(), x)
|
297 |
+
|
298 |
+
# Compute dx.
|
299 |
+
#
|
300 |
+
# NOTE: This reuses the ddsd_out allocation.
|
301 |
+
stk.backend.triton_kernels.dsd(
|
302 |
+
dsdd_out.shape,
|
303 |
+
dsdd_out.data,
|
304 |
+
dsdd_out.offsets,
|
305 |
+
dsdd_out.row_indices,
|
306 |
+
dsdd_out.column_indices,
|
307 |
+
dsdd_out.offsets_t,
|
308 |
+
dsdd_out.column_indices_t,
|
309 |
+
dsdd_out.block_offsets_t,
|
310 |
+
False,
|
311 |
+
w1,
|
312 |
+
ddsd_out,
|
313 |
+
)
|
314 |
+
dx = ddsd_out
|
315 |
+
return dx, dw1, dw2, None, None
|
316 |
+
|
317 |
+
|
318 |
+
memory_optimized_mlp = MemoryOptimizedMLP.apply
|
319 |
+
|
320 |
+
|
321 |
+
class SparseMLP(torch.nn.Module):
|
322 |
+
|
323 |
+
def __init__(self, args: Arguments):
|
324 |
+
super().__init__()
|
325 |
+
self.args = args
|
326 |
+
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args)
|
327 |
+
|
328 |
+
self.w1 = torch.nn.Parameter(
|
329 |
+
torch.empty(
|
330 |
+
self._num_rows_per_rank,
|
331 |
+
args.hidden_size,
|
332 |
+
device=args.device,
|
333 |
+
dtype=common.dtype(args),
|
334 |
+
),
|
335 |
+
)
|
336 |
+
self.w2 = torch.nn.Parameter(
|
337 |
+
torch.empty(
|
338 |
+
self._num_rows_per_rank,
|
339 |
+
args.hidden_size,
|
340 |
+
device=args.device,
|
341 |
+
dtype=common.dtype(args),
|
342 |
+
),
|
343 |
+
)
|
344 |
+
|
345 |
+
# Initialize the parameters for the MLP.
|
346 |
+
#
|
347 |
+
# NOTE: It is important that we create the weight tensors prior
|
348 |
+
# to creating the master weights and slicing our the piece for
|
349 |
+
# this rank. If the master weights are created first the PyTorch
|
350 |
+
# caching allocator appears to use the same memory block for these
|
351 |
+
# and the slice which causes large increases in our peak memory
|
352 |
+
# usage.
|
353 |
+
with torch.no_grad():
|
354 |
+
self.w1.copy_(
|
355 |
+
create_dmoe_expert_weights(
|
356 |
+
args,
|
357 |
+
args.moe_num_experts,
|
358 |
+
args.ffn_hidden_size,
|
359 |
+
args.hidden_size,
|
360 |
+
args.init_method,
|
361 |
+
),
|
362 |
+
)
|
363 |
+
self.w2.copy_(
|
364 |
+
create_dmoe_expert_weights(
|
365 |
+
args,
|
366 |
+
args.moe_num_experts,
|
367 |
+
args.ffn_hidden_size,
|
368 |
+
args.hidden_size,
|
369 |
+
args.output_layer_init_method,
|
370 |
+
),
|
371 |
+
)
|
372 |
+
|
373 |
+
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism
|
374 |
+
mpu.set_expert_model_parallel_attributes(
|
375 |
+
self.w1,
|
376 |
+
self._should_set_parallelism_attribute,
|
377 |
+
)
|
378 |
+
mpu.set_expert_model_parallel_attributes(
|
379 |
+
self.w2,
|
380 |
+
self._should_set_parallelism_attribute,
|
381 |
+
)
|
382 |
+
|
383 |
+
self.gradient_scale = None
|
384 |
+
if self.args.moe_expert_model_parallelism:
|
385 |
+
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,)
|
386 |
+
|
387 |
+
def scale_grad(self, w):
|
388 |
+
if self.gradient_scale is None:
|
389 |
+
return w
|
390 |
+
return scale_gradient(w, self.gradient_scale)
|
391 |
+
|
392 |
+
def forward(self, x, topo):
|
393 |
+
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2)
|
394 |
+
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2)
|
395 |
+
if self.args.memory_optimized_mlp:
|
396 |
+
return memory_optimized_mlp(
|
397 |
+
x,
|
398 |
+
w1,
|
399 |
+
w2,
|
400 |
+
topo,
|
401 |
+
self.args.activation_fn,
|
402 |
+
)
|
403 |
+
|
404 |
+
# Compute the MLP.
|
405 |
+
x = stk.ops.sdd(x, w1.t(), topo)
|
406 |
+
activation_fn_out = act_fn(x, self.args.activation_fn)
|
407 |
+
return stk.ops.dsd(activation_fn_out, w2)
|
408 |
+
|
409 |
+
|
410 |
+
class MemoryOptimizedGroupedMLP(torch.autograd.Function):
|
411 |
+
"""GroupedMLP with manually scheduled memory reuse."""
|
412 |
+
|
413 |
+
@staticmethod
|
414 |
+
@torch.amp.autocast_mode.custom_fwd(device_type='cuda')
|
415 |
+
def forward(ctx, x, w1, w2, batch_sizes, activation_fn):
|
416 |
+
# Cast inputs using ctx dtype from AMP
|
417 |
+
if ctx._fwd_used_autocast:
|
418 |
+
x = x.to(ctx._dtype)
|
419 |
+
w1 = w1.to(ctx._dtype)
|
420 |
+
w2 = w2.to(ctx._dtype)
|
421 |
+
# x: [m, k], w1: [n, k], w2: [n, k]
|
422 |
+
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()):
|
423 |
+
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.")
|
424 |
+
|
425 |
+
# Layer 0: x @ w1.t().
|
426 |
+
assert gg.backend is not None
|
427 |
+
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True)
|
428 |
+
|
429 |
+
# activation_fn
|
430 |
+
activation_fn_out = activation_fn(sdd_out)
|
431 |
+
|
432 |
+
# Layer 1: x @ w2.
|
433 |
+
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes)
|
434 |
+
|
435 |
+
# NOTE: Save the input to the layer and the activation_fn input for
|
436 |
+
# gradient computation. We'll re-compute the activation_fn forward
|
437 |
+
# pass in the backward pass to avoid materializing another
|
438 |
+
# intermediate.
|
439 |
+
ctx.x_shape = x.shape
|
440 |
+
ctx.sdd_out_shape = sdd_out.shape
|
441 |
+
ctx.dtype = x.dtype
|
442 |
+
ctx.activation_fn = activation_fn
|
443 |
+
ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out)
|
444 |
+
return dsd_out
|
445 |
+
|
446 |
+
@staticmethod
|
447 |
+
@torch.amp.autocast_mode.custom_bwd(device_type='cuda')
|
448 |
+
def backward(ctx: Any, ddsd_out: torch.Tensor):
|
449 |
+
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]):
|
450 |
+
raise ValueError('Expected all MLP inputs to need grad.')
|
451 |
+
|
452 |
+
# Unpack saved tensors
|
453 |
+
# dtype = ctx.dtype
|
454 |
+
saved_tensors = ctx.saved_tensors
|
455 |
+
w1, w2 = saved_tensors[:2]
|
456 |
+
batch_sizes = saved_tensors[2]
|
457 |
+
x = saved_tensors[3]
|
458 |
+
sdd_out = saved_tensors[4]
|
459 |
+
|
460 |
+
# Rematerialize activation_fn output.
|
461 |
+
activation_fn = ctx.activation_fn
|
462 |
+
with torch.set_grad_enabled(True):
|
463 |
+
sdd_out.requires_grad = True
|
464 |
+
activation_fn_out = activation_fn(sdd_out)
|
465 |
+
activation_grad_fn = activation_fn_out.backward
|
466 |
+
|
467 |
+
# Compute dw2 with recomputed activation_fn output.
|
468 |
+
assert gg.backend is not None
|
469 |
+
dw2 = gg.backend.gmm(
|
470 |
+
activation_fn_out,
|
471 |
+
ddsd_out,
|
472 |
+
batch_sizes,
|
473 |
+
trans_a=True,
|
474 |
+
)
|
475 |
+
|
476 |
+
# Compute dactivation_fn_out.
|
477 |
+
#
|
478 |
+
# NOTE: We reuse the activation_fn_out allocation.
|
479 |
+
dactivation_fn_out = activation_fn_out
|
480 |
+
gg.backend.gmm(
|
481 |
+
ddsd_out,
|
482 |
+
w2,
|
483 |
+
batch_sizes,
|
484 |
+
trans_b=True,
|
485 |
+
c=dactivation_fn_out,
|
486 |
+
)
|
487 |
+
|
488 |
+
# Compute dsdd_out.
|
489 |
+
#
|
490 |
+
# NOTE: This reuses the dactivation_fn_out allocation.
|
491 |
+
if activation_fn is DEFAULT_ACTIVATION_FN:
|
492 |
+
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out)
|
493 |
+
else:
|
494 |
+
assert activation_grad_fn is not None
|
495 |
+
activation_grad_fn(dactivation_fn_out)
|
496 |
+
dsdd_out = sdd_out.grad
|
497 |
+
|
498 |
+
# Compute dw1.
|
499 |
+
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True)
|
500 |
+
|
501 |
+
# Compute dx.
|
502 |
+
#
|
503 |
+
# NOTE: This reuses the ddsd_out allocation.
|
504 |
+
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out)
|
505 |
+
dx = ddsd_out
|
506 |
+
return dx, dw1, dw2, None, None
|
507 |
+
|
508 |
+
|
509 |
+
memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply
|
510 |
+
|
511 |
+
|
512 |
+
class GroupedMLP(SparseMLP):
|
513 |
+
|
514 |
+
def forward(self, x, tokens_per_expert):
|
515 |
+
batch_sizes = tokens_per_expert.cpu().to(torch.long)
|
516 |
+
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2))
|
517 |
+
|
518 |
+
# Re-shape the weights for the grouped GEMMs.
|
519 |
+
ne = mpu.experts_per_rank(self.args)
|
520 |
+
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size)
|
521 |
+
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size)
|
522 |
+
|
523 |
+
if self.args.memory_optimized_mlp:
|
524 |
+
return memory_optimized_grouped_mlp(
|
525 |
+
x,
|
526 |
+
w1,
|
527 |
+
w2,
|
528 |
+
batch_sizes,
|
529 |
+
self.args.activation_fn,
|
530 |
+
)
|
531 |
+
|
532 |
+
# Compute the MLP.
|
533 |
+
assert gg.ops is not None
|
534 |
+
x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
|
535 |
+
x = self.args.activation_fn(x)
|
536 |
+
return gg.ops.gmm(x, w2, batch_sizes)
|
537 |
+
|
538 |
+
|
539 |
+
class SharedMLP(torch.nn.Module):
|
540 |
+
"""MLP for shared expert.
|
541 |
+
|
542 |
+
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class
|
543 |
+
"""
|
544 |
+
|
545 |
+
def __init__(self, args: Arguments):
|
546 |
+
super().__init__()
|
547 |
+
self.args = args
|
548 |
+
self.fc_kwargs: dict[str, Any] = {
|
549 |
+
'bias': args.bias,
|
550 |
+
'device': args.device,
|
551 |
+
}
|
552 |
+
self.fc_kwargs.update(args.fc_kwargs)
|
553 |
+
|
554 |
+
self.up_proj = args.fc_cls(
|
555 |
+
args.hidden_size,
|
556 |
+
args.shared_expert_hidden_size,
|
557 |
+
**self.fc_kwargs,
|
558 |
+
)
|
559 |
+
self.act = args.activation_fn
|
560 |
+
self.down_proj = args.fc_cls(
|
561 |
+
args.shared_expert_hidden_size,
|
562 |
+
args.hidden_size,
|
563 |
+
**self.fc_kwargs,
|
564 |
+
)
|
565 |
+
self.down_proj._is_residual = True # a flag for llm-foundry init
|
566 |
+
|
567 |
+
def add_experts_sharedexpert(
|
568 |
+
self,
|
569 |
+
shared_expert_out: torch.Tensor,
|
570 |
+
expert_out: torch.Tensor,
|
571 |
+
) -> torch.Tensor:
|
572 |
+
# Helper function to add expert output to shared expert output
|
573 |
+
# with optional weighted sum.
|
574 |
+
if self.args.shared_expert_weighted_sum:
|
575 |
+
# enable using weighted sum for shared expert output
|
576 |
+
# wieghted by number of experts used
|
577 |
+
t_experts = self.args.moe_top_k + 1
|
578 |
+
sh_mlp_out = shared_expert_out / t_experts
|
579 |
+
return sh_mlp_out.add(
|
580 |
+
expert_out,
|
581 |
+
alpha=(self.args.moe_top_k / t_experts),
|
582 |
+
)
|
583 |
+
|
584 |
+
return shared_expert_out + expert_out
|
585 |
+
|
586 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
587 |
+
return self.down_proj(self.act(self.up_proj(x)))
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/moe.py
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
# import megablocks.ops as ops
|
10 |
+
# from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry
|
11 |
+
# from megablocks.layers.all_to_all import all_to_all
|
12 |
+
# from megablocks.layers.arguments import Arguments
|
13 |
+
|
14 |
+
from ..ops import (
|
15 |
+
sort,
|
16 |
+
histogram,
|
17 |
+
inclusive_cumsum,
|
18 |
+
exclusive_cumsum,
|
19 |
+
binned_gather,
|
20 |
+
binned_scatter,
|
21 |
+
gather,
|
22 |
+
scatter,
|
23 |
+
repeat,
|
24 |
+
replicate,
|
25 |
+
)
|
26 |
+
|
27 |
+
from . import common, mlp, mpu, router, sharedexpert_registry
|
28 |
+
from .arguments import Arguments
|
29 |
+
from .all_to_all import all_to_all
|
30 |
+
|
31 |
+
_LOAD_BALANCING_LOSS = []
|
32 |
+
|
33 |
+
|
34 |
+
def save_load_balancing_loss(loss):
|
35 |
+
global _LOAD_BALANCING_LOSS
|
36 |
+
_LOAD_BALANCING_LOSS.append(loss)
|
37 |
+
|
38 |
+
|
39 |
+
def get_load_balancing_loss():
|
40 |
+
global _LOAD_BALANCING_LOSS
|
41 |
+
return _LOAD_BALANCING_LOSS
|
42 |
+
|
43 |
+
|
44 |
+
def clear_load_balancing_loss():
|
45 |
+
global _LOAD_BALANCING_LOSS
|
46 |
+
_LOAD_BALANCING_LOSS.clear()
|
47 |
+
|
48 |
+
|
49 |
+
def batched_load_balancing_loss(args: Arguments):
|
50 |
+
if args.moe_loss_weight == 0:
|
51 |
+
return 0.0
|
52 |
+
|
53 |
+
# tokens_per_expert[i].shape = (num_experts)
|
54 |
+
# expert_scores[i].shape = (tokens, num_experts)
|
55 |
+
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
|
56 |
+
num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size)
|
57 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
58 |
+
num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
|
59 |
+
|
60 |
+
if len(tokens_per_expert) != num_layers_per_pipeline_stage:
|
61 |
+
raise ValueError(
|
62 |
+
f'Expected {num_layers_per_pipeline_stage} token_per_experts '
|
63 |
+
f'but found {len(tokens_per_expert)}.\nnum_layers = '
|
64 |
+
f'{args.num_layers}\npipeline_model_parallel_size = '
|
65 |
+
f'{args.pipeline_model_parallel_size}\n'
|
66 |
+
'num_layers_per_virtual_pipeline_stage'
|
67 |
+
f' = {args.num_layers_per_virtual_pipeline_stage}',
|
68 |
+
)
|
69 |
+
if len(expert_scores) != num_layers_per_pipeline_stage:
|
70 |
+
raise ValueError(
|
71 |
+
f'Expected {num_layers_per_pipeline_stage} expert_scores '
|
72 |
+
f'but found {len(tokens_per_expert)}.\nnum_layers = '
|
73 |
+
f'{args.num_layers}\npipeline_model_parallel_size = '
|
74 |
+
f'{args.pipeline_model_parallel_size}\n'
|
75 |
+
'num_layers_per_virtual_pipeline_stage'
|
76 |
+
f' = {args.num_layers_per_virtual_pipeline_stage}',
|
77 |
+
)
|
78 |
+
|
79 |
+
# Verify the shape of the tokens_per_expert and expert_scores tensors.
|
80 |
+
assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert))
|
81 |
+
|
82 |
+
tokens = expert_scores[0].shape[0]
|
83 |
+
assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores))
|
84 |
+
|
85 |
+
# Concatenate the contributions of each layer and convert to
|
86 |
+
# the correct types and formats for the dot product.
|
87 |
+
expert_scores = torch.cat(expert_scores, dim=1)
|
88 |
+
if args.moe_lbl_in_fp32:
|
89 |
+
expert_scores = expert_scores.float()
|
90 |
+
if tokens != 0:
|
91 |
+
expert_scores = expert_scores.mean(dim=0)
|
92 |
+
else:
|
93 |
+
expert_scores = expert_scores.sum(dim=0)
|
94 |
+
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
|
95 |
+
|
96 |
+
expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
|
97 |
+
assert tokens_per_expert.numel() == expected_values
|
98 |
+
assert expert_scores.numel() == expected_values
|
99 |
+
|
100 |
+
# Calculate the total scale across all factors.
|
101 |
+
#
|
102 |
+
# loss_weight * num_experts / (num_layers * tokens * top_k)
|
103 |
+
scale_numerator = (args.moe_num_experts * args.moe_loss_weight)
|
104 |
+
scale_denominator = (args.num_layers * tokens * args.moe_top_k)
|
105 |
+
scale = scale_numerator / scale_denominator
|
106 |
+
return scale * torch.dot(tokens_per_expert, expert_scores)
|
107 |
+
|
108 |
+
|
109 |
+
# NOTE: This class defines MoE expert computation, including expert model parallel
|
110 |
+
# communication. When using FSDP on top of MegaBlocks this is the module that should
|
111 |
+
# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
|
112 |
+
# parallel all2all.
|
113 |
+
class ParallelMLP(torch.nn.Module):
|
114 |
+
|
115 |
+
def __init__(self, args: Arguments):
|
116 |
+
super(ParallelMLP, self).__init__()
|
117 |
+
self.args = args
|
118 |
+
|
119 |
+
# Calculate the number of experts in total and the number of experts
|
120 |
+
# owned by this rank.
|
121 |
+
# world_size = mpu.get_expert_parallel_world_size(args)
|
122 |
+
self.num_experts = args.moe_num_experts
|
123 |
+
self.top_k = self.args.moe_top_k
|
124 |
+
|
125 |
+
# Calculate the number of bits needed to represent the expert indices
|
126 |
+
# so that we can pass it to radix sort.
|
127 |
+
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)
|
128 |
+
|
129 |
+
# Expert MLP.
|
130 |
+
self.mlp = mlp.MLP(args)
|
131 |
+
|
132 |
+
self.bias: Optional[torch.Tensor]
|
133 |
+
if self.args.bias:
|
134 |
+
# Note that the output bias is not parallelized with expert
|
135 |
+
# model parallelism.
|
136 |
+
self.bias = torch.nn.Parameter(
|
137 |
+
torch.empty(
|
138 |
+
args.hidden_size,
|
139 |
+
device=args.device,
|
140 |
+
dtype=common.dtype(args),
|
141 |
+
),
|
142 |
+
)
|
143 |
+
torch.nn.init.zeros_(self.bias)
|
144 |
+
else:
|
145 |
+
self.register_parameter('bias', None)
|
146 |
+
|
147 |
+
# Select the forward function for the operating mode.
|
148 |
+
self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once)
|
149 |
+
|
150 |
+
def expert_capacity(self, tokens: int) -> int:
|
151 |
+
world_size = mpu.get_expert_parallel_world_size(self.args)
|
152 |
+
tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts)
|
153 |
+
return int(self.args.moe_capacity_factor * tokens_per_expert)
|
154 |
+
|
155 |
+
def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor):
|
156 |
+
"""Calculate the load balancing loss contribution."""
|
157 |
+
assert len(expert_scores.size()) == 2
|
158 |
+
tokens, num_experts = expert_scores.size()
|
159 |
+
assert num_experts == self.num_experts
|
160 |
+
assert len(tokens_per_expert.size()) == 1
|
161 |
+
num_experts, = tokens_per_expert.size()
|
162 |
+
assert num_experts == self.num_experts
|
163 |
+
scale = self.num_experts / (tokens * self.top_k)
|
164 |
+
return scale * torch.dot(
|
165 |
+
tokens_per_expert.to(expert_scores.dtype),
|
166 |
+
expert_scores.mean(dim=0),
|
167 |
+
)
|
168 |
+
|
169 |
+
def indices_and_bins(self,
|
170 |
+
top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
171 |
+
# Sort the expert ids to produce the scatter/gather
|
172 |
+
# indices for the permutation.
|
173 |
+
#
|
174 |
+
# TODO(tgale): Is it worth doing this conversion to 32-bit
|
175 |
+
# prior? Could we place the `torch.max` operation to return
|
176 |
+
# 32-bit expert indices?
|
177 |
+
top_expert = top_expert.int()
|
178 |
+
# output = ops.sort(top_expert, self.sort_end_bit)
|
179 |
+
output = sort(top_expert, self.sort_end_bit)
|
180 |
+
assert output is not None
|
181 |
+
bin_ids, indices = output
|
182 |
+
|
183 |
+
# Histogram the expert ids to identify the number of
|
184 |
+
# tokens routed to each expert.
|
185 |
+
#
|
186 |
+
# TODO(tgale): Does the sorted data produce a more favorable
|
187 |
+
# data distribution for histogram? Or is the op parallelism
|
188 |
+
# worth more?
|
189 |
+
# tokens_per_expert = ops.histogram(top_expert, self.num_experts)
|
190 |
+
tokens_per_expert = histogram(top_expert, self.num_experts)
|
191 |
+
|
192 |
+
# Calculate the bin bounds for the sorted tokens.
|
193 |
+
# bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
194 |
+
bins = inclusive_cumsum(tokens_per_expert, 0)
|
195 |
+
assert bins is not None
|
196 |
+
bins = bins.view(1) if not len(bins.size()) else bins
|
197 |
+
|
198 |
+
assert isinstance(indices, torch.Tensor)
|
199 |
+
assert isinstance(bin_ids, torch.Tensor)
|
200 |
+
assert isinstance(bins, torch.Tensor)
|
201 |
+
assert isinstance(tokens_per_expert, torch.Tensor)
|
202 |
+
|
203 |
+
return indices, bin_ids, bins, tokens_per_expert
|
204 |
+
|
205 |
+
def permute_and_compute(
|
206 |
+
self,
|
207 |
+
x: torch.Tensor,
|
208 |
+
tokens_per_expert: int, # unused
|
209 |
+
indices: torch.Tensor,
|
210 |
+
bin_ids: torch.Tensor, # unused
|
211 |
+
expert_weights: torch.Tensor,
|
212 |
+
bins: torch.Tensor,
|
213 |
+
expert_capacity: int,
|
214 |
+
top_k: int,
|
215 |
+
):
|
216 |
+
# Route the tokens for MoE computation.
|
217 |
+
x = x.view(-1, x.shape[-1])
|
218 |
+
# output = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
219 |
+
output = binned_gather(x, indices, bins, expert_capacity, top_k)
|
220 |
+
assert output is not None
|
221 |
+
x = output
|
222 |
+
|
223 |
+
# Perform the expert computation. Note that we don't
|
224 |
+
# use biases for these linear operations.
|
225 |
+
x = self.mlp(x)
|
226 |
+
|
227 |
+
# Un-route the data for the MoE output.
|
228 |
+
# return ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
229 |
+
return binned_scatter(x, indices, expert_weights, bins, top_k)
|
230 |
+
|
231 |
+
|
232 |
+
def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
233 |
+
# x: [sl, bs, hs]
|
234 |
+
# expert_weights: [sl * bs, top-k]
|
235 |
+
# top_experts: [sl * bs, top-k]
|
236 |
+
expert_weights = expert_weights.flatten()
|
237 |
+
top_experts = top_experts.flatten()
|
238 |
+
with torch.no_grad():
|
239 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
240 |
+
|
241 |
+
# If expert_capacity is set to zero, set the number of tokens
|
242 |
+
# per expert to the maximum we need to avoid dropping tokens.
|
243 |
+
sl, bs, _ = x.size()
|
244 |
+
expert_capacity = self.expert_capacity(sl * bs)
|
245 |
+
if expert_capacity == 0:
|
246 |
+
expert_capacity = torch.max(tokens_per_expert).item()
|
247 |
+
|
248 |
+
x = self.permute_and_compute(
|
249 |
+
x,
|
250 |
+
tokens_per_expert,
|
251 |
+
indices,
|
252 |
+
bin_ids,
|
253 |
+
expert_weights,
|
254 |
+
bins,
|
255 |
+
expert_capacity,
|
256 |
+
self.top_k,
|
257 |
+
)
|
258 |
+
return x, tokens_per_expert
|
259 |
+
|
260 |
+
def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
261 |
+
# NOTE: This function implements the same computation as forward_once
|
262 |
+
# but with expert model parallelism.
|
263 |
+
#
|
264 |
+
# 1. Permute the tokens locally so that they are grouped by their
|
265 |
+
# expert assignments. This allows us to transfer all of the tokens
|
266 |
+
# for a remote device in one communication primitive.
|
267 |
+
#
|
268 |
+
# 2. Permute the tokens across the expert parallel devices. After
|
269 |
+
# this is completed each device has all of the tokens assigned to
|
270 |
+
# its set of experts in its local HBM.
|
271 |
+
#
|
272 |
+
# 3. Permute the tokens locally so that they are grouped by their
|
273 |
+
# expert assignement. After the distributed permutation the tokens
|
274 |
+
# are grouped by which device they came from. We re-order them
|
275 |
+
# locally to allow for efficient computation.
|
276 |
+
#
|
277 |
+
# After this series of permutations we compute the linear layers
|
278 |
+
# and then repeat these three steps in reverse to produce the final
|
279 |
+
# output.
|
280 |
+
#
|
281 |
+
# Compute the mapping of local tokens to experts.
|
282 |
+
expert_weights = expert_weights.flatten()
|
283 |
+
top_experts = top_experts.flatten()
|
284 |
+
with torch.no_grad():
|
285 |
+
indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts))
|
286 |
+
|
287 |
+
# If we're sharding the experts along the hidden dimension
|
288 |
+
# multiple devices own parts of the same sets of experts.
|
289 |
+
# Replicate the token counts so every device gets the counts.
|
290 |
+
# repeated_tokens_per_expert = ops.repeat(
|
291 |
+
repeated_tokens_per_expert = repeat(
|
292 |
+
tokens_per_expert,
|
293 |
+
(mpu.hidden_sharding_degree(self.args),),
|
294 |
+
)
|
295 |
+
|
296 |
+
# Pass token count information to the device on which the
|
297 |
+
# target expert resides.
|
298 |
+
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,)
|
299 |
+
tpe_handle = dist.all_to_all_single(
|
300 |
+
parallel_tokens_per_expert,
|
301 |
+
repeated_tokens_per_expert,
|
302 |
+
group=self.args.expert_parallel_group,
|
303 |
+
async_op=True,
|
304 |
+
)
|
305 |
+
|
306 |
+
# Permute locally and without any padding so that tokens for each
|
307 |
+
# parallel device are stored contiguously.
|
308 |
+
#
|
309 |
+
# This view updates the shape of the tensor from [sl, bs, hs] to
|
310 |
+
# [sl * bs, hs] prior to the permutation.
|
311 |
+
x = x.view(-1, x.shape[-1])
|
312 |
+
# output = ops.gather(x, indices, bin_ids, bins, self.top_k)
|
313 |
+
output = gather(x, indices, bin_ids, bins, self.top_k)
|
314 |
+
assert output is not None
|
315 |
+
x = output
|
316 |
+
|
317 |
+
# Compute the number of tokens that will be received from each
|
318 |
+
# device and permute the input data across the devices.
|
319 |
+
with torch.no_grad():
|
320 |
+
tpe_handle.wait()
|
321 |
+
experts_per_rank = mpu.experts_per_rank(self.args)
|
322 |
+
|
323 |
+
# Reshape to [world_size, num_experts_per_rank].
|
324 |
+
world_size = mpu.get_expert_parallel_world_size(self.args)
|
325 |
+
repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank))
|
326 |
+
parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank))
|
327 |
+
|
328 |
+
# TODO(tgale): It might be faster to do this on the GPU and
|
329 |
+
# then communicate the results back to the host.
|
330 |
+
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1)
|
331 |
+
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
|
332 |
+
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1)
|
333 |
+
|
334 |
+
# Convert the send/recv counts to lists.
|
335 |
+
send_counts = send_counts.tolist()
|
336 |
+
recv_counts = recv_counts.tolist()
|
337 |
+
tokens_received = sum(recv_counts)
|
338 |
+
|
339 |
+
# If we're sharding the experts along the hidden dimension
|
340 |
+
# multiple devices own parts of the same sets of experts.
|
341 |
+
# Replicate the token counts so devices that share experts
|
342 |
+
# get all of the tokens assigned to them.
|
343 |
+
#
|
344 |
+
# TODO(tgale): Fuse this into the prior, local permutation.
|
345 |
+
# x = ops.repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
346 |
+
x = repeat(x, (mpu.hidden_sharding_degree(self.args), 1))
|
347 |
+
|
348 |
+
# Start the cross-device permutation asynchronously so we can
|
349 |
+
# overlap communication with computation.
|
350 |
+
parallel_x, parallel_x_handle = all_to_all(
|
351 |
+
x,
|
352 |
+
recv_counts,
|
353 |
+
send_counts,
|
354 |
+
self.args.expert_parallel_group,
|
355 |
+
async_op=True,
|
356 |
+
)
|
357 |
+
|
358 |
+
with torch.no_grad():
|
359 |
+
# After we do the cross-device permutation we have the tokens on the
|
360 |
+
# correct device but not yet grouped by expert because we received
|
361 |
+
# tokens from each device as contiguous chunks. To group the tokens
|
362 |
+
# for expert computation we'll do one more local permutation. The
|
363 |
+
# rest of this torch.no_grad() scope sets up the indices and bins
|
364 |
+
# for this permutation.
|
365 |
+
# replicate_bins = ops.inclusive_cumsum(
|
366 |
+
replicate_bins = inclusive_cumsum(
|
367 |
+
parallel_tokens_per_expert.flatten(),
|
368 |
+
0,
|
369 |
+
)
|
370 |
+
replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins)
|
371 |
+
|
372 |
+
# Construct the expert indices for the permuted tokens.
|
373 |
+
parallel_top_expert = torch.remainder(
|
374 |
+
torch.arange(
|
375 |
+
self.num_experts * mpu.hidden_sharding_degree(self.args),
|
376 |
+
dtype=torch.int32,
|
377 |
+
device=indices.device,
|
378 |
+
),
|
379 |
+
mpu.experts_per_rank(self.args),
|
380 |
+
)
|
381 |
+
# parallel_top_expert = ops.replicate(
|
382 |
+
parallel_top_expert = replicate(
|
383 |
+
parallel_top_expert.unsqueeze(dim=0),
|
384 |
+
replicate_bins,
|
385 |
+
tokens_received,
|
386 |
+
).flatten()
|
387 |
+
|
388 |
+
# TODO(tgale): The sort_end_bit here can be reduced.
|
389 |
+
# parallel_bin_ids, parallel_indices = ops.sort(
|
390 |
+
parallel_bin_ids, parallel_indices = sort(
|
391 |
+
parallel_top_expert,
|
392 |
+
self.sort_end_bit,
|
393 |
+
)
|
394 |
+
|
395 |
+
# Calculate the bins boundaries from the token counts.
|
396 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
|
397 |
+
dim=0,
|
398 |
+
dtype=torch.int,
|
399 |
+
)
|
400 |
+
# parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
401 |
+
parallel_bins = inclusive_cumsum(parallel_tokens_per_expert, 0)
|
402 |
+
parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins)
|
403 |
+
|
404 |
+
# If expert_capacity is set to zero, set the number of tokens
|
405 |
+
# per expert to the maximum we need to avoid dropping tokens.
|
406 |
+
tokens, _ = x.size()
|
407 |
+
expert_capacity = self.expert_capacity(tokens)
|
408 |
+
if expert_capacity == 0:
|
409 |
+
expert_capacity = torch.max(parallel_tokens_per_expert).item()
|
410 |
+
|
411 |
+
# Locally permute the tokens and perform the expert computation.
|
412 |
+
# Block to make sure that the cross-device permutation is complete.
|
413 |
+
if self.args.mlp_impl == 'grouped':
|
414 |
+
# GroupedMLP requires counts on CPU. We can use the tensor already
|
415 |
+
# moved to CPU for the prior all_to_all, which avoids an extra
|
416 |
+
# device synchronization.
|
417 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
|
418 |
+
dim=0,
|
419 |
+
dtype=torch.int,
|
420 |
+
)
|
421 |
+
parallel_x_handle.wait()
|
422 |
+
parallel_x = self.permute_and_compute(
|
423 |
+
parallel_x,
|
424 |
+
parallel_tokens_per_expert,
|
425 |
+
parallel_indices,
|
426 |
+
parallel_bin_ids,
|
427 |
+
None, # expert_weights
|
428 |
+
parallel_bins,
|
429 |
+
expert_capacity,
|
430 |
+
top_k=1,
|
431 |
+
)
|
432 |
+
|
433 |
+
# Un-permute the tokens across the devices.
|
434 |
+
x, _ = all_to_all(
|
435 |
+
parallel_x,
|
436 |
+
send_counts,
|
437 |
+
recv_counts,
|
438 |
+
self.args.expert_parallel_group,
|
439 |
+
)
|
440 |
+
|
441 |
+
# Reduce along the hidden sharding to get the final outputs.
|
442 |
+
#
|
443 |
+
# TODO(tgale): Fuse this into the following local permutation.
|
444 |
+
shape = (
|
445 |
+
mpu.hidden_sharding_degree(self.args),
|
446 |
+
-1,
|
447 |
+
self.args.hidden_size,
|
448 |
+
)
|
449 |
+
# x = ops.sum(x.view(shape), dim=0)
|
450 |
+
x = x.view(shape).sum(dim=0)
|
451 |
+
|
452 |
+
# Un-permute locally to setup for the next series of operations.
|
453 |
+
# x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
454 |
+
x = scatter(x, indices, bin_ids, expert_weights, bins, self.top_k)
|
455 |
+
return x, tokens_per_expert.flatten()
|
456 |
+
|
457 |
+
def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor):
|
458 |
+
in_shape = x.size()
|
459 |
+
|
460 |
+
# Compute the experts.
|
461 |
+
x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts)
|
462 |
+
if self.training and self.args.moe_loss_weight > 0:
|
463 |
+
save_load_balancing_loss((tokens_per_expert, scores))
|
464 |
+
x = x.view(in_shape)
|
465 |
+
if self.bias is not None:
|
466 |
+
if self.args.return_bias:
|
467 |
+
return x, self.bias
|
468 |
+
return x + self.bias
|
469 |
+
return x
|
470 |
+
|
471 |
+
|
472 |
+
class MoE(torch.nn.Module):
|
473 |
+
|
474 |
+
def __init__(self, args: Arguments):
|
475 |
+
super(MoE, self).__init__()
|
476 |
+
|
477 |
+
# Token router.
|
478 |
+
self.router = router.LearnedRouter(args)
|
479 |
+
|
480 |
+
# Expert computation helper.
|
481 |
+
self.experts = self._init_experts_mlp(args)
|
482 |
+
|
483 |
+
self.shared_expert = None
|
484 |
+
if args.shared_expert:
|
485 |
+
# SharedExpert computation helper.
|
486 |
+
self.shared_expert = sharedexpert_registry.get(args)
|
487 |
+
|
488 |
+
def _init_experts_mlp(self, args: Arguments):
|
489 |
+
return ParallelMLP(args)
|
490 |
+
|
491 |
+
def forward(self, x: torch.Tensor):
|
492 |
+
# NOTE: If we're going to cast the activations to lower precision
|
493 |
+
# do it before we permute the tokens to save bandwidth.
|
494 |
+
x = common.cast_if_autocast_enabled(x)
|
495 |
+
|
496 |
+
# Compute the expert scores and assignments.
|
497 |
+
scores, expert_weights, top_experts = self.router(x)
|
498 |
+
|
499 |
+
# Compute the experts.
|
500 |
+
out = self.experts(x, scores, expert_weights, top_experts)
|
501 |
+
if self.shared_expert is not None:
|
502 |
+
shared_expert_out = self.shared_expert(x)
|
503 |
+
out = self.shared_expert.add_experts_sharedexpert(
|
504 |
+
shared_expert_out,
|
505 |
+
out,
|
506 |
+
)
|
507 |
+
return out
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/mpu.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
# from megablocks.layers.arguments import Arguments
|
10 |
+
from .arguments import Arguments
|
11 |
+
|
12 |
+
|
13 |
+
class MoeParam(torch.Tensor):
|
14 |
+
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__(self)
|
17 |
+
self.expert_model_parallel: bool
|
18 |
+
|
19 |
+
|
20 |
+
def is_moe_param(tensor: torch.Tensor) -> bool:
|
21 |
+
return hasattr(tensor, 'expert_model_parallel')
|
22 |
+
|
23 |
+
|
24 |
+
def get_expert_parallel_world_size(args: Arguments) -> int:
|
25 |
+
return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1)
|
26 |
+
|
27 |
+
|
28 |
+
def get_expert_parallel_rank(args: Arguments) -> int:
|
29 |
+
return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0)
|
30 |
+
|
31 |
+
|
32 |
+
def set_expert_model_parallel_attributes(
|
33 |
+
tensor: torch.Tensor,
|
34 |
+
is_parallel: bool,
|
35 |
+
):
|
36 |
+
assert not hasattr(tensor, 'expert_model_parallel')
|
37 |
+
setattr(tensor, 'expert_model_parallel', is_parallel)
|
38 |
+
|
39 |
+
|
40 |
+
def param_is_expert_model_parallel(param: MoeParam) -> bool:
|
41 |
+
return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel)
|
42 |
+
|
43 |
+
|
44 |
+
def copy_expert_model_parallel_attributes(
|
45 |
+
destination_tensor: torch.Tensor,
|
46 |
+
source_tensor: torch.Tensor,
|
47 |
+
):
|
48 |
+
if hasattr(source_tensor, 'expert_model_parallel'):
|
49 |
+
setattr(
|
50 |
+
destination_tensor,
|
51 |
+
'expert_model_parallel',
|
52 |
+
getattr(source_tensor, 'expert_model_parallel'),
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor):
|
57 |
+
world_size = dist.get_world_size(group)
|
58 |
+
rank = dist.get_rank(group)
|
59 |
+
for i in range(world_size):
|
60 |
+
dist.barrier(group)
|
61 |
+
if i == rank:
|
62 |
+
print(f'rank = {rank}', *x)
|
63 |
+
|
64 |
+
|
65 |
+
# Helpers for expert/tensor sharding.
|
66 |
+
def expert_sharding_degree(args: Arguments) -> int:
|
67 |
+
world_size = get_expert_parallel_world_size(args)
|
68 |
+
esd = min(world_size, args.moe_num_experts)
|
69 |
+
|
70 |
+
if (args.moe_num_experts % esd) != 0:
|
71 |
+
raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',)
|
72 |
+
return esd
|
73 |
+
|
74 |
+
|
75 |
+
def hidden_sharding_degree(args: Arguments) -> int:
|
76 |
+
world_size = get_expert_parallel_world_size(args)
|
77 |
+
esd = expert_sharding_degree(args)
|
78 |
+
hsd = world_size // esd
|
79 |
+
|
80 |
+
if (args.ffn_hidden_size % hsd) != 0:
|
81 |
+
raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',)
|
82 |
+
if (esd * hsd) != world_size:
|
83 |
+
raise ValueError(
|
84 |
+
f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).",
|
85 |
+
)
|
86 |
+
return hsd
|
87 |
+
|
88 |
+
|
89 |
+
def experts_per_rank(args: Arguments) -> int:
|
90 |
+
return args.moe_num_experts // expert_sharding_degree(args)
|
91 |
+
|
92 |
+
|
93 |
+
def features_per_rank(args: Arguments) -> int:
|
94 |
+
return args.ffn_hidden_size // hidden_sharding_degree(args)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/router.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# from megablocks.layers import common
|
8 |
+
# from megablocks.layers.arguments import Arguments
|
9 |
+
from . import common
|
10 |
+
from .arguments import Arguments
|
11 |
+
|
12 |
+
_ROUTER_LOGITS = []
|
13 |
+
|
14 |
+
|
15 |
+
def _save_router_logits(logits: torch.Tensor, args: Arguments):
|
16 |
+
if args.moe_zloss_weight == 0:
|
17 |
+
return
|
18 |
+
global _ROUTER_LOGITS
|
19 |
+
_ROUTER_LOGITS.append(logits)
|
20 |
+
|
21 |
+
|
22 |
+
def clear_router_zloss():
|
23 |
+
global _ROUTER_LOGITS
|
24 |
+
_ROUTER_LOGITS.clear()
|
25 |
+
|
26 |
+
|
27 |
+
def batched_router_zloss(args: Arguments):
|
28 |
+
global _ROUTER_LOGITS
|
29 |
+
|
30 |
+
if args.moe_zloss_weight == 0:
|
31 |
+
import warnings
|
32 |
+
warnings.warn('Call to batched_router_zloss, but moe_zloss_weight=0')
|
33 |
+
return 0
|
34 |
+
|
35 |
+
logits_per_router = _ROUTER_LOGITS
|
36 |
+
|
37 |
+
if args.moe_zloss_in_fp32:
|
38 |
+
logits_per_router = [logits.float() for logits in logits_per_router]
|
39 |
+
|
40 |
+
unscaled_zloss_per_router = torch.stack([
|
41 |
+
torch.logsumexp(logits, dim=1).square().mean() for logits in logits_per_router
|
42 |
+
])
|
43 |
+
|
44 |
+
return args.moe_zloss_weight * unscaled_zloss_per_router
|
45 |
+
|
46 |
+
|
47 |
+
# NOTE: To enable end-to-end benchmarking without convergence we
|
48 |
+
# support a flag to force the router to assign tokens uniformly
|
49 |
+
# across the experts. We do this with a custom autograd operation
|
50 |
+
# so that PyTorch still executes the full set of router operation.
|
51 |
+
class _UniformExpertAssignment(torch.autograd.Function):
|
52 |
+
|
53 |
+
@staticmethod
|
54 |
+
def forward(ctx: Any, x: torch.Tensor, num_experts: int):
|
55 |
+
out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
|
56 |
+
out = torch.remainder(out, num_experts)
|
57 |
+
return out.view(x.shape)
|
58 |
+
|
59 |
+
|
60 |
+
_uniform_expert_assignment = _UniformExpertAssignment.apply
|
61 |
+
|
62 |
+
|
63 |
+
class LearnedRouter(torch.nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, args: Arguments):
|
66 |
+
super().__init__()
|
67 |
+
self.args = args
|
68 |
+
|
69 |
+
# Learned router parameters.
|
70 |
+
#
|
71 |
+
# NOTE: This weight matrix is not parallelized with expert model
|
72 |
+
# parallelism. Each device needs the entire router weight matrix
|
73 |
+
# so that it can route its batch of data correctly.
|
74 |
+
self.layer = torch.nn.Linear(
|
75 |
+
args.hidden_size,
|
76 |
+
args.moe_num_experts,
|
77 |
+
bias=False,
|
78 |
+
dtype=common.dtype(args),
|
79 |
+
device=args.device,
|
80 |
+
)
|
81 |
+
args.init_method(self.layer.weight)
|
82 |
+
|
83 |
+
def jitter(self, x: torch.Tensor):
|
84 |
+
low: float = 1.0 - self.args.moe_jitter_eps
|
85 |
+
high: float = 1.0 + self.args.moe_jitter_eps
|
86 |
+
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
|
87 |
+
return low + noise * (high - low)
|
88 |
+
|
89 |
+
def _top_k(self, scores: torch.Tensor):
|
90 |
+
if self.args.moe_top_k == 1:
|
91 |
+
return scores.max(dim=-1, keepdim=True)
|
92 |
+
return torch.topk(scores, self.args.moe_top_k, dim=-1)
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor):
|
95 |
+
if self.training and self.args.moe_jitter_eps is not None:
|
96 |
+
x = x * self.jitter(x)
|
97 |
+
|
98 |
+
logits = self.layer(x.view(-1, x.shape[-1]))
|
99 |
+
_save_router_logits(logits, self.args)
|
100 |
+
scores = logits.softmax(dim=-1)
|
101 |
+
expert_weights, expert_indices = self._top_k(scores)
|
102 |
+
if self.args.moe_normalize_expert_weights:
|
103 |
+
expert_weights = expert_weights / torch.norm(
|
104 |
+
expert_weights,
|
105 |
+
p=self.args.moe_normalize_expert_weights,
|
106 |
+
dim=-1,
|
107 |
+
keepdim=True,
|
108 |
+
)
|
109 |
+
|
110 |
+
expert_indices = (
|
111 |
+
_uniform_expert_assignment(
|
112 |
+
expert_indices,
|
113 |
+
self.args.moe_num_experts,
|
114 |
+
) if self.args.uniform_expert_assignment else expert_indices
|
115 |
+
)
|
116 |
+
return scores, expert_weights, expert_indices
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_layers/sharedexpert_registry.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
# from megablocks.layers import glu, mlp
|
7 |
+
# from megablocks.layers.arguments import Arguments
|
8 |
+
from . import glu, mlp
|
9 |
+
from .arguments import Arguments
|
10 |
+
|
11 |
+
_REGISTRY = {
|
12 |
+
'mlp': mlp.SharedMLP,
|
13 |
+
'glu': glu.SharedGLU,
|
14 |
+
}
|
15 |
+
|
16 |
+
|
17 |
+
def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]:
|
18 |
+
"""Returns an SharedMLP for use in a dMoE instance.
|
19 |
+
|
20 |
+
Uses the provided arguments to instantiate the appropriate
|
21 |
+
SharedMLP instance.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
args: propagated Arguments dataclass.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
An instantiated SharedMLP constructed using the input args.
|
28 |
+
"""
|
29 |
+
if args.mlp_type not in _REGISTRY:
|
30 |
+
raise ValueError(f'Unsupported mlp type: {args.mlp_type}')
|
31 |
+
|
32 |
+
return _REGISTRY[args.mlp_type](args)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_megablocks_20250730102509.abi3.so
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a19bba459394ac0d93b6405084772af39eb92d4f280f6b5c586d1beeb1589051
|
3 |
+
size 5573536
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/_ops.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from . import _megablocks_20250730102509
|
3 |
+
ops = torch.ops._megablocks_20250730102509
|
4 |
+
|
5 |
+
def add_op_namespace_prefix(op_name: str):
|
6 |
+
"""
|
7 |
+
Prefix op by namespace.
|
8 |
+
"""
|
9 |
+
return f"_megablocks_20250730102509::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/backend/kernels.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import triton
|
6 |
+
import triton.language as tl
|
7 |
+
|
8 |
+
|
9 |
+
def assert_is_tensor(x, ndim):
|
10 |
+
if x.ndim != ndim:
|
11 |
+
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')
|
12 |
+
|
13 |
+
|
14 |
+
def assert_is_matrix(x):
|
15 |
+
assert_is_tensor(x, 2)
|
16 |
+
|
17 |
+
|
18 |
+
def assert_is_vector(x):
|
19 |
+
if x.ndim != 1:
|
20 |
+
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')
|
21 |
+
|
22 |
+
|
23 |
+
def assert_equal(a, b):
|
24 |
+
if a != b:
|
25 |
+
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)
|
26 |
+
|
27 |
+
|
28 |
+
# a: (tokens, hidden_size), real.
|
29 |
+
# indices: (tokens * top_k), integer.
|
30 |
+
# bin_ids: (tokens * top_k), integer.
|
31 |
+
# weights: (tokens * top_k), real.
|
32 |
+
# bins: (num_experts), integer.
|
33 |
+
# padded_bins: (num_experts), integer.
|
34 |
+
@triton.autotune(
|
35 |
+
configs=[
|
36 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
37 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
38 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
39 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
40 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
41 |
+
],
|
42 |
+
key=['NUM_COLUMNS'],
|
43 |
+
)
|
44 |
+
@triton.jit
|
45 |
+
def _padded_copy(
|
46 |
+
a,
|
47 |
+
b,
|
48 |
+
indices,
|
49 |
+
bin_ids,
|
50 |
+
weights,
|
51 |
+
bins,
|
52 |
+
padded_bins,
|
53 |
+
NUM_COLUMNS: tl.constexpr,
|
54 |
+
TOP_K: tl.constexpr,
|
55 |
+
BLOCK_X: tl.constexpr,
|
56 |
+
A_TO_B: tl.constexpr,
|
57 |
+
SCALE: tl.constexpr,
|
58 |
+
):
|
59 |
+
# Our index into array 'a'.
|
60 |
+
index_a = tl.load(indices + tl.program_id(0))
|
61 |
+
|
62 |
+
# One threadblock per row in 'a'. Array 'b' has greater or equal
|
63 |
+
# number of rows since they could be padded.
|
64 |
+
bin_idx = tl.load(bin_ids + tl.program_id(0))
|
65 |
+
|
66 |
+
# Now we know what bin we're assigned to, but we need to know how
|
67 |
+
# many threadblocks were assigned to earlier bins so we can offset
|
68 |
+
# in our bin properly.
|
69 |
+
offset_in_bin = tl.program_id(0)
|
70 |
+
if bin_idx > 0:
|
71 |
+
offset_in_bin -= tl.load(bins + bin_idx - 1)
|
72 |
+
|
73 |
+
# Load the starting index of our bin in array 'b'.
|
74 |
+
index_b = offset_in_bin
|
75 |
+
if bin_idx > 0:
|
76 |
+
index_b += tl.load(padded_bins + bin_idx - 1)
|
77 |
+
|
78 |
+
# Offset the input and output pointers.
|
79 |
+
#
|
80 |
+
# If we're going from A to B, divide the input index to copy
|
81 |
+
# the same input repeatedly. If we're going from B to A we
|
82 |
+
# need to reduce the result. Using atomics is slow, so we
|
83 |
+
# do the reduce step in a second kernel.
|
84 |
+
offset = index_a // TOP_K if A_TO_B else index_a
|
85 |
+
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
|
86 |
+
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
|
87 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
88 |
+
|
89 |
+
# Load the scale, if requested.
|
90 |
+
scale = tl.load(weights + index_a) if SCALE else 1
|
91 |
+
|
92 |
+
# Swap the pointers depending on the direction.
|
93 |
+
iptr = a if A_TO_B else b
|
94 |
+
optr = b if A_TO_B else a
|
95 |
+
|
96 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
97 |
+
for _ in range(iterations):
|
98 |
+
mask = offsets < NUM_COLUMNS
|
99 |
+
x = tl.load(iptr + offsets, mask=mask)
|
100 |
+
x = x.to(tl.float32) * scale.to(tl.float32)
|
101 |
+
|
102 |
+
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
|
103 |
+
|
104 |
+
offsets += BLOCK_X
|
105 |
+
|
106 |
+
|
107 |
+
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
|
108 |
+
# Validate the input shapes.
|
109 |
+
assert_is_matrix(x)
|
110 |
+
assert_is_vector(indices)
|
111 |
+
assert_is_vector(bin_ids)
|
112 |
+
assert_is_vector(bins)
|
113 |
+
assert_is_vector(padded_bins)
|
114 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
115 |
+
assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
|
116 |
+
assert_equal(bins.size(), padded_bins.size())
|
117 |
+
|
118 |
+
if weights is not None:
|
119 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
120 |
+
|
121 |
+
# NOTE: Because of the padding, the output size is dynamic.
|
122 |
+
# We load the final padded bin bound to get the output rows.
|
123 |
+
output_rows = padded_bins[-1].cpu().item()
|
124 |
+
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
|
125 |
+
_padded_copy[(indices.shape[0],)](
|
126 |
+
x,
|
127 |
+
out,
|
128 |
+
indices,
|
129 |
+
bin_ids,
|
130 |
+
weights,
|
131 |
+
bins,
|
132 |
+
padded_bins,
|
133 |
+
NUM_COLUMNS=x.shape[1],
|
134 |
+
A_TO_B=True,
|
135 |
+
TOP_K=top_k,
|
136 |
+
SCALE=weights is not None,
|
137 |
+
)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def gather(x, indices, bin_ids, weights, bins, top_k):
|
142 |
+
# Validate the input shapes.
|
143 |
+
assert_is_matrix(x)
|
144 |
+
assert_is_vector(indices)
|
145 |
+
assert_is_vector(bin_ids)
|
146 |
+
assert_is_vector(bins)
|
147 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
148 |
+
assert_equal(bin_ids.shape[0], x.shape[0] * top_k)
|
149 |
+
|
150 |
+
if weights is not None:
|
151 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
152 |
+
|
153 |
+
# NOTE: There is no padding so the output rows equals the
|
154 |
+
# input rows multiplied by top_k.
|
155 |
+
output_rows = x.shape[0] * top_k
|
156 |
+
out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
|
157 |
+
_padded_copy[(indices.shape[0],)](
|
158 |
+
x,
|
159 |
+
out,
|
160 |
+
indices,
|
161 |
+
bin_ids,
|
162 |
+
weights,
|
163 |
+
bins,
|
164 |
+
bins,
|
165 |
+
NUM_COLUMNS=x.shape[1],
|
166 |
+
A_TO_B=True,
|
167 |
+
TOP_K=top_k,
|
168 |
+
SCALE=weights is not None,
|
169 |
+
)
|
170 |
+
return out
|
171 |
+
|
172 |
+
|
173 |
+
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
|
174 |
+
# Validate the input shapes.
|
175 |
+
assert_is_matrix(x)
|
176 |
+
assert_is_vector(indices)
|
177 |
+
assert_is_vector(bin_ids)
|
178 |
+
assert_is_vector(bins)
|
179 |
+
assert_is_vector(padded_bins)
|
180 |
+
assert_equal(indices.shape[0], bin_ids.shape[0])
|
181 |
+
assert_equal(bins.size(), padded_bins.size())
|
182 |
+
|
183 |
+
if weights is not None:
|
184 |
+
assert_equal(indices.shape[0], weights.shape[0])
|
185 |
+
|
186 |
+
tokens = indices.shape[0] // top_k
|
187 |
+
out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device)
|
188 |
+
_padded_copy[(indices.shape[0],)](
|
189 |
+
out,
|
190 |
+
x,
|
191 |
+
indices,
|
192 |
+
bin_ids,
|
193 |
+
weights,
|
194 |
+
bins,
|
195 |
+
padded_bins,
|
196 |
+
NUM_COLUMNS=x.shape[1],
|
197 |
+
A_TO_B=False,
|
198 |
+
TOP_K=top_k,
|
199 |
+
SCALE=weights is not None,
|
200 |
+
)
|
201 |
+
|
202 |
+
# Reduce along the top-k dimension, if needed.
|
203 |
+
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])
|
204 |
+
|
205 |
+
|
206 |
+
def scatter(x, indices, bin_ids, weights, bins, top_k):
|
207 |
+
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)
|
208 |
+
|
209 |
+
|
210 |
+
# x: (tokens, top_k, hidden_size), real
|
211 |
+
# grad: (tokens, hidden_size), real.
|
212 |
+
# wgrad: (tokens, top_k), real.
|
213 |
+
# indices: (tokens * top_k), integer.
|
214 |
+
# bin_ids: (tokens * top_k), integer.
|
215 |
+
# bins: (num_experts), integer.
|
216 |
+
# padded_bins: (num_experts), integer.
|
217 |
+
@triton.autotune(
|
218 |
+
configs=[
|
219 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
220 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
221 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
222 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
223 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
224 |
+
],
|
225 |
+
key=['NUM_COLUMNS'],
|
226 |
+
)
|
227 |
+
@triton.jit
|
228 |
+
def _padded_copy_wgrad(
|
229 |
+
x,
|
230 |
+
grad,
|
231 |
+
wgrad,
|
232 |
+
indices,
|
233 |
+
bin_ids,
|
234 |
+
bins,
|
235 |
+
padded_bins,
|
236 |
+
NUM_COLUMNS: tl.constexpr,
|
237 |
+
TOP_K: tl.constexpr,
|
238 |
+
BLOCK_X: tl.constexpr,
|
239 |
+
):
|
240 |
+
# Our index into 'tokens * top_k'.
|
241 |
+
index_out = tl.load(indices + tl.program_id(0))
|
242 |
+
|
243 |
+
# One threadblock per row in 'a'. Array 'b' has greater or equal
|
244 |
+
# number of rows since they could be padded.
|
245 |
+
bin_idx = tl.load(bin_ids + tl.program_id(0))
|
246 |
+
|
247 |
+
# Now we know what bin we're assigned to, but we need to know how
|
248 |
+
# many threadblocks were assigned to earlier bins so we can offset
|
249 |
+
# in our bin properly.
|
250 |
+
offset_in_bin = tl.program_id(0)
|
251 |
+
if bin_idx > 0:
|
252 |
+
offset_in_bin -= tl.load(bins + bin_idx - 1)
|
253 |
+
|
254 |
+
# Load the starting index of our bin in array 'x'.
|
255 |
+
index_x = offset_in_bin
|
256 |
+
if bin_idx > 0:
|
257 |
+
index_x += tl.load(padded_bins + bin_idx - 1)
|
258 |
+
|
259 |
+
# Offset the input and output pointers.
|
260 |
+
wgrad += index_out
|
261 |
+
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
|
262 |
+
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
|
263 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
264 |
+
|
265 |
+
acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
|
266 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
267 |
+
for _ in range(iterations):
|
268 |
+
mask = offsets < NUM_COLUMNS
|
269 |
+
data = tl.load(x + offsets, mask=mask).to(tl.float32)
|
270 |
+
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
|
271 |
+
acc += data * scale
|
272 |
+
offsets += BLOCK_X
|
273 |
+
|
274 |
+
# Reduce to get the final result and store.
|
275 |
+
out = tl.sum(acc).to(wgrad.dtype.element_ty)
|
276 |
+
tl.store(wgrad, out)
|
277 |
+
|
278 |
+
|
279 |
+
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
|
280 |
+
# Validate the input shapes.
|
281 |
+
assert_is_matrix(x)
|
282 |
+
assert_is_matrix(grad)
|
283 |
+
assert_is_vector(indices)
|
284 |
+
assert_is_vector(bin_ids)
|
285 |
+
assert_is_vector(bins)
|
286 |
+
assert_is_vector(padded_bins)
|
287 |
+
assert_equal(indices.shape[0], bin_ids.shape[0])
|
288 |
+
assert_equal(bins.size(), padded_bins.size())
|
289 |
+
|
290 |
+
tokens = indices.shape[0] // top_k
|
291 |
+
out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device)
|
292 |
+
_padded_copy_wgrad[(indices.shape[0],)](
|
293 |
+
x,
|
294 |
+
grad,
|
295 |
+
out,
|
296 |
+
indices,
|
297 |
+
bin_ids,
|
298 |
+
bins,
|
299 |
+
padded_bins,
|
300 |
+
NUM_COLUMNS=x.shape[1],
|
301 |
+
TOP_K=top_k,
|
302 |
+
)
|
303 |
+
return out
|
304 |
+
|
305 |
+
|
306 |
+
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
|
307 |
+
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)
|
308 |
+
|
309 |
+
|
310 |
+
# a: (tokens, hidden_size), real.
|
311 |
+
# b: (num_experts, expert_capacity, num_columns), real.
|
312 |
+
# indices: (tokens * top_k), integer.
|
313 |
+
# weights: (tokens * top_k), real.
|
314 |
+
# bins: (num_experts), integer.
|
315 |
+
@triton.autotune(
|
316 |
+
configs=[
|
317 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
318 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
319 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
320 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
321 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
322 |
+
],
|
323 |
+
key=['NUM_COLUMNS'],
|
324 |
+
)
|
325 |
+
@triton.jit
|
326 |
+
def _binned_copy(
|
327 |
+
a,
|
328 |
+
b,
|
329 |
+
num_experts,
|
330 |
+
expert_capacity,
|
331 |
+
indices,
|
332 |
+
weights,
|
333 |
+
bins,
|
334 |
+
NUM_COLUMNS: tl.constexpr,
|
335 |
+
TOP_K: tl.constexpr,
|
336 |
+
BLOCK_X: tl.constexpr,
|
337 |
+
A_TO_B: tl.constexpr,
|
338 |
+
SCALE: tl.constexpr,
|
339 |
+
):
|
340 |
+
# Load our indices into the output.
|
341 |
+
expert_idx = tl.program_id(0)
|
342 |
+
entry_idx = tl.program_id(1)
|
343 |
+
|
344 |
+
# Calculate our offset into the output.
|
345 |
+
index_b = expert_idx * expert_capacity + entry_idx
|
346 |
+
|
347 |
+
# Load the index bounds for our bin and calculate
|
348 |
+
# the number of tokens assigned to our expert.
|
349 |
+
start = 0
|
350 |
+
if expert_idx > 0:
|
351 |
+
start = tl.load(bins + expert_idx - 1)
|
352 |
+
end = tl.load(bins + expert_idx)
|
353 |
+
num_tokens = end - start
|
354 |
+
|
355 |
+
# Calculate our offset into the input. If we don't
|
356 |
+
# have an input exit early.
|
357 |
+
if entry_idx >= num_tokens:
|
358 |
+
return
|
359 |
+
index_a = tl.load(indices + start + entry_idx)
|
360 |
+
|
361 |
+
# Offset the input and output pointers.
|
362 |
+
#
|
363 |
+
# If we're going from A to B, divide the input index to copy
|
364 |
+
# the same input repeatedly. If we're going from B to A we
|
365 |
+
# need to reduce the result. Using atomics is slow, so we
|
366 |
+
# do the reduce step in a second kernel.
|
367 |
+
offset = index_a // TOP_K if A_TO_B else index_a
|
368 |
+
a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS)
|
369 |
+
b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS)
|
370 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
371 |
+
|
372 |
+
# Load the scale, if requested.
|
373 |
+
scale = tl.load(weights + index_a) if SCALE else 1
|
374 |
+
|
375 |
+
# Swap the pointers depending on the direction.
|
376 |
+
#
|
377 |
+
# NOTE: We need to zero the output in both directions.
|
378 |
+
iptr = a if A_TO_B else b
|
379 |
+
optr = b if A_TO_B else a
|
380 |
+
|
381 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
382 |
+
for _ in range(iterations):
|
383 |
+
mask = offsets < NUM_COLUMNS
|
384 |
+
x = tl.load(iptr + offsets, mask=mask)
|
385 |
+
x = x.to(tl.float32) * scale.to(tl.float32)
|
386 |
+
|
387 |
+
tl.store(optr + offsets, x.to(optr.dtype.element_ty), mask=mask)
|
388 |
+
|
389 |
+
offsets += BLOCK_X
|
390 |
+
|
391 |
+
|
392 |
+
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
|
393 |
+
# Validate the input shapes.
|
394 |
+
assert_is_matrix(x)
|
395 |
+
assert_is_vector(indices)
|
396 |
+
assert_is_vector(bins)
|
397 |
+
assert_equal(indices.shape[0], x.shape[0] * top_k)
|
398 |
+
|
399 |
+
if weights is not None:
|
400 |
+
assert_equal(weights.shape[0], x.shape[0] * top_k)
|
401 |
+
|
402 |
+
num_experts = bins.shape[0]
|
403 |
+
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)
|
404 |
+
|
405 |
+
_binned_copy[(num_experts, expert_capacity)](
|
406 |
+
x,
|
407 |
+
out,
|
408 |
+
num_experts,
|
409 |
+
expert_capacity,
|
410 |
+
indices,
|
411 |
+
weights,
|
412 |
+
bins,
|
413 |
+
NUM_COLUMNS=x.shape[1],
|
414 |
+
A_TO_B=True,
|
415 |
+
TOP_K=top_k,
|
416 |
+
SCALE=weights is not None,
|
417 |
+
)
|
418 |
+
return out
|
419 |
+
|
420 |
+
|
421 |
+
def binned_scatter(x, indices, weights, bins, top_k):
|
422 |
+
# Validate the input shapes.
|
423 |
+
assert_is_tensor(x, 3)
|
424 |
+
assert_is_vector(indices)
|
425 |
+
assert_is_vector(bins)
|
426 |
+
assert_equal(bins.shape[0], x.shape[0])
|
427 |
+
|
428 |
+
if weights is not None:
|
429 |
+
assert_equal(indices.shape[0], weights.shape[0])
|
430 |
+
|
431 |
+
num_experts, expert_capacity, hidden_size = x.shape
|
432 |
+
tokens = indices.shape[0] // top_k
|
433 |
+
out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device)
|
434 |
+
_binned_copy[(num_experts, expert_capacity)](
|
435 |
+
out,
|
436 |
+
x,
|
437 |
+
num_experts,
|
438 |
+
expert_capacity,
|
439 |
+
indices,
|
440 |
+
weights,
|
441 |
+
bins,
|
442 |
+
NUM_COLUMNS=hidden_size,
|
443 |
+
A_TO_B=False,
|
444 |
+
TOP_K=top_k,
|
445 |
+
SCALE=weights is not None,
|
446 |
+
)
|
447 |
+
|
448 |
+
# Reduce along the top-k dimension, if needed.
|
449 |
+
return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size)
|
450 |
+
|
451 |
+
|
452 |
+
# a: (tokens, hidden_size), real.
|
453 |
+
# b: (num_experts, expert_capacity, num_columns), real.
|
454 |
+
# indices: (tokens * top_k), integer.
|
455 |
+
# weights: (tokens * top_k), real.
|
456 |
+
# bins: (num_experts), integer.
|
457 |
+
@triton.autotune(
|
458 |
+
configs=[
|
459 |
+
triton.Config({'BLOCK_X': 64}, num_warps=2),
|
460 |
+
triton.Config({'BLOCK_X': 128}, num_warps=2),
|
461 |
+
triton.Config({'BLOCK_X': 256}, num_warps=2),
|
462 |
+
triton.Config({'BLOCK_X': 128}, num_warps=4),
|
463 |
+
triton.Config({'BLOCK_X': 256}, num_warps=4),
|
464 |
+
],
|
465 |
+
key=['NUM_COLUMNS'],
|
466 |
+
)
|
467 |
+
@triton.jit
|
468 |
+
def _binned_copy_wgrad(
|
469 |
+
x,
|
470 |
+
grad,
|
471 |
+
wgrad,
|
472 |
+
num_experts,
|
473 |
+
expert_capacity,
|
474 |
+
indices,
|
475 |
+
bins,
|
476 |
+
NUM_COLUMNS: tl.constexpr,
|
477 |
+
TOP_K: tl.constexpr,
|
478 |
+
BLOCK_X: tl.constexpr,
|
479 |
+
):
|
480 |
+
# Load our indices into the output.
|
481 |
+
expert_idx = tl.program_id(0)
|
482 |
+
entry_idx = tl.program_id(1)
|
483 |
+
|
484 |
+
# Calculate our offset into the output.
|
485 |
+
index_x = expert_idx * expert_capacity + entry_idx
|
486 |
+
|
487 |
+
# Load the index bounds for our bin and calculate
|
488 |
+
# the number of tokens assigned to our expert.
|
489 |
+
start = 0
|
490 |
+
if expert_idx > 0:
|
491 |
+
start = tl.load(bins + expert_idx - 1)
|
492 |
+
end = tl.load(bins + expert_idx)
|
493 |
+
num_tokens = end - start
|
494 |
+
|
495 |
+
# Calculate our offset into the input. If we don't
|
496 |
+
# have an input exit early.
|
497 |
+
if entry_idx >= num_tokens:
|
498 |
+
return
|
499 |
+
index_out = tl.load(indices + start + entry_idx)
|
500 |
+
|
501 |
+
# Offset the input and output pointers.
|
502 |
+
wgrad += index_out
|
503 |
+
grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS)
|
504 |
+
x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS)
|
505 |
+
offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X)
|
506 |
+
|
507 |
+
acc = tl.zeros((BLOCK_X,), dtype=tl.float32)
|
508 |
+
iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X)
|
509 |
+
for _ in range(iterations):
|
510 |
+
mask = offsets < NUM_COLUMNS
|
511 |
+
data = tl.load(x + offsets, mask=mask).to(tl.float32)
|
512 |
+
scale = tl.load(grad + offsets, mask=mask).to(tl.float32)
|
513 |
+
acc += data * scale
|
514 |
+
offsets += BLOCK_X
|
515 |
+
|
516 |
+
# Reduce to get the final result and store.
|
517 |
+
out = tl.sum(acc).to(wgrad.dtype.element_ty)
|
518 |
+
tl.store(wgrad, out)
|
519 |
+
|
520 |
+
|
521 |
+
def binned_scatter_wgrad(x, grad, indices, bins, top_k):
|
522 |
+
# Validate the input shapes.
|
523 |
+
assert_is_tensor(x, 3)
|
524 |
+
assert_is_matrix(grad)
|
525 |
+
assert_is_vector(indices)
|
526 |
+
assert_is_vector(bins)
|
527 |
+
assert_equal(bins.shape[0], x.shape[0])
|
528 |
+
|
529 |
+
num_experts, expert_capacity, hidden_size = x.shape
|
530 |
+
tokens = indices.shape[0] // top_k
|
531 |
+
out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device)
|
532 |
+
_binned_copy_wgrad[(num_experts, expert_capacity)](
|
533 |
+
x,
|
534 |
+
grad,
|
535 |
+
out,
|
536 |
+
num_experts,
|
537 |
+
expert_capacity,
|
538 |
+
indices,
|
539 |
+
bins,
|
540 |
+
NUM_COLUMNS=hidden_size,
|
541 |
+
TOP_K=top_k,
|
542 |
+
)
|
543 |
+
return out
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/bak.__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from megablocks_moe.megablocks import (
|
2 |
+
MoE,
|
3 |
+
dMoE,
|
4 |
+
get_load_balancing_loss,
|
5 |
+
ParallelMLP,
|
6 |
+
ParallelDroplessMLP,
|
7 |
+
SparseMLP,
|
8 |
+
MLP,
|
9 |
+
SparseGLU,
|
10 |
+
Arguments,
|
11 |
+
)
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"MoE",
|
15 |
+
"dMoE",
|
16 |
+
"get_load_balancing_loss",
|
17 |
+
"ParallelMLP",
|
18 |
+
"ParallelDroplessMLP",
|
19 |
+
"SparseMLP",
|
20 |
+
"MLP",
|
21 |
+
"SparseGLU",
|
22 |
+
"Arguments",
|
23 |
+
]
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/benchmark_util.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def log_benchmark(name, arguments, time, std):
|
9 |
+
print('=' * 60)
|
10 |
+
print(f'{name} Benchmark')
|
11 |
+
print('Benchmark Parameters:')
|
12 |
+
for (key, value) in arguments.items():
|
13 |
+
print(f'{key} = {value}')
|
14 |
+
print('Results:')
|
15 |
+
print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std))
|
16 |
+
print('=' * 60)
|
17 |
+
|
18 |
+
|
19 |
+
def benchmark_function(fn, iterations=100, warmup=10):
|
20 |
+
# Warmup iterations.
|
21 |
+
for _ in range(warmup):
|
22 |
+
fn()
|
23 |
+
|
24 |
+
times = []
|
25 |
+
for i in range(iterations):
|
26 |
+
start = torch.cuda.Event(enable_timing=True)
|
27 |
+
end = torch.cuda.Event(enable_timing=True)
|
28 |
+
|
29 |
+
start.record()
|
30 |
+
fn()
|
31 |
+
end.record()
|
32 |
+
|
33 |
+
torch.cuda.synchronize()
|
34 |
+
times.append(start.elapsed_time(end))
|
35 |
+
return np.mean(times), np.std(times)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from . import ops
|
2 |
+
from . import backend
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/backend.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NOTE: Torch needs to be imported before the custom
|
2 |
+
# extensions. Otherwise libc10.so cannot be found.
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# # TODO(tgale): Wrap this in a try-block with better
|
6 |
+
# # error message and instructions for building the
|
7 |
+
# # c++ operations.
|
8 |
+
# import grouped_gemm_backend as backend
|
9 |
+
|
10 |
+
# We import the backend operations from the megablocks package as
|
11 |
+
# grouped_gemm is vendored in megablocks in this repository.
|
12 |
+
# from ... import _ops as backend
|
13 |
+
# from megablocks._ops import ops as backend # type: ignore
|
14 |
+
from .._ops import ops as backend # type: ignore
|
15 |
+
|
16 |
+
def _allocate_output(a, b, batch_sizes, trans_a, trans_b):
|
17 |
+
assert not (trans_a and trans_b)
|
18 |
+
assert batch_sizes.ndim == 1, "Expected 1d tensor for batch_sizes"
|
19 |
+
assert a.ndim == 2, "Expected 2d tensor for 'a'"
|
20 |
+
assert b.ndim == (2 if trans_a else 3)
|
21 |
+
|
22 |
+
shape = (
|
23 |
+
(batch_sizes.shape[0], a.shape[1], b.shape[1])
|
24 |
+
if trans_a else
|
25 |
+
(a.shape[0], (b.shape[1] if trans_b else b.shape[2]))
|
26 |
+
)
|
27 |
+
return torch.empty(*shape, device=a.device, dtype=a.dtype)
|
28 |
+
|
29 |
+
def gmm(a, b, batch_sizes, trans_a=False, trans_b=False, c=None):
|
30 |
+
if c is None:
|
31 |
+
c = _allocate_output(a, b, batch_sizes, trans_a, trans_b)
|
32 |
+
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
|
33 |
+
return c
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm/ops.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import backend
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class GroupedGemm(torch.autograd.Function):
|
6 |
+
|
7 |
+
@staticmethod
|
8 |
+
def forward(ctx, a, b, batch_sizes, trans_b):
|
9 |
+
ctx.save_for_backward(a, b, batch_sizes)
|
10 |
+
ctx.trans_b = trans_b
|
11 |
+
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def backward(ctx, grad):
|
15 |
+
grad = grad.contiguous()
|
16 |
+
a, b, batch_sizes = ctx.saved_tensors
|
17 |
+
trans_b = ctx.trans_b
|
18 |
+
|
19 |
+
agrad = None
|
20 |
+
if ctx.needs_input_grad[0]:
|
21 |
+
agrad = backend.gmm(
|
22 |
+
grad, b, batch_sizes, trans_a=False, trans_b=not trans_b)
|
23 |
+
|
24 |
+
bgrad = None
|
25 |
+
if ctx.needs_input_grad[1]:
|
26 |
+
lhs, rhs = (grad, a) if trans_b else (a, grad)
|
27 |
+
bgrad = backend.gmm(
|
28 |
+
lhs, rhs, batch_sizes, trans_a=True, trans_b=False)
|
29 |
+
return agrad, bgrad, None, None
|
30 |
+
|
31 |
+
|
32 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
33 |
+
return GroupedGemm.apply(a, b, batch_sizes, trans_b)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/grouped_gemm_util.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
_grouped_gemm_is_available: bool = False
|
6 |
+
try:
|
7 |
+
# import grouped_gemm
|
8 |
+
pass
|
9 |
+
_grouped_gemm_is_available = True
|
10 |
+
except ImportError as error:
|
11 |
+
warnings.warn('Grouped GEMM not available.')
|
12 |
+
|
13 |
+
|
14 |
+
def grouped_gemm_is_available():
|
15 |
+
return _grouped_gemm_is_available
|
16 |
+
|
17 |
+
|
18 |
+
def assert_grouped_gemm_is_available():
|
19 |
+
msg = (
|
20 |
+
'Grouped GEMM not available. Please run '
|
21 |
+
'`pip install git+https://github.com/tgale96/grouped_gemm@main`.',
|
22 |
+
)
|
23 |
+
assert _grouped_gemm_is_available, msg
|
24 |
+
|
25 |
+
|
26 |
+
# backend = grouped_gemm.backend if grouped_gemm_is_available() else None
|
27 |
+
# ops = grouped_gemm.ops if grouped_gemm_is_available() else None
|
28 |
+
|
29 |
+
|
30 |
+
#from .grouped_gemm import backend as ops
|
31 |
+
#from .grouped_gemm import ops as backend
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/layers.py
ADDED
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
|
4 |
+
from typing import Optional, Any
|
5 |
+
|
6 |
+
from . import _layers
|
7 |
+
from . import ops
|
8 |
+
|
9 |
+
|
10 |
+
# Set the expert model parallel attributes on a tensor
|
11 |
+
def set_expert_model_parallel_attributes(
|
12 |
+
tensor: torch.Tensor,
|
13 |
+
is_parallel: bool,
|
14 |
+
):
|
15 |
+
assert not hasattr(tensor, "expert_model_parallel")
|
16 |
+
setattr(tensor, "expert_model_parallel", is_parallel)
|
17 |
+
|
18 |
+
|
19 |
+
# Get the expert model parallel attributes from a tensor
|
20 |
+
def expert_sharding_degree(
|
21 |
+
world_size: int,
|
22 |
+
moe_num_experts: int,
|
23 |
+
) -> int:
|
24 |
+
esd = min(world_size, moe_num_experts)
|
25 |
+
if (moe_num_experts % esd) != 0:
|
26 |
+
raise ValueError(f"Cannot shard {moe_num_experts} experts {esd} ways.")
|
27 |
+
return esd
|
28 |
+
|
29 |
+
|
30 |
+
# Calculate the hidden sharding degree based on world size and expert sharding degree
|
31 |
+
def hidden_sharding_degree(
|
32 |
+
world_size: int,
|
33 |
+
moe_num_experts: int,
|
34 |
+
ffn_hidden_size: int,
|
35 |
+
) -> int:
|
36 |
+
esd = expert_sharding_degree(world_size, moe_num_experts)
|
37 |
+
hsd = world_size // esd
|
38 |
+
if (ffn_hidden_size % hsd) != 0:
|
39 |
+
raise ValueError(f"Cannot shard {ffn_hidden_size} features {hsd} ways.")
|
40 |
+
if (esd * hsd) != world_size:
|
41 |
+
raise ValueError(
|
42 |
+
f"Invalid sharding. expert_sharding_degree ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size})."
|
43 |
+
)
|
44 |
+
return hsd
|
45 |
+
|
46 |
+
|
47 |
+
# Calculate the number of experts per rank based on world size and expert sharding degree
|
48 |
+
def experts_per_rank(
|
49 |
+
moe_num_experts: int,
|
50 |
+
world_size: int,
|
51 |
+
) -> int:
|
52 |
+
return moe_num_experts // expert_sharding_degree(world_size, moe_num_experts)
|
53 |
+
|
54 |
+
|
55 |
+
# Calculate the number of features per rank based on ffn hidden size and hidden sharding degree
|
56 |
+
def features_per_rank(
|
57 |
+
ffn_hidden_size: int, world_size: int, moe_num_experts: int
|
58 |
+
) -> int:
|
59 |
+
return ffn_hidden_size // hidden_sharding_degree(
|
60 |
+
world_size, moe_num_experts, ffn_hidden_size
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
# Apply jitter to the input tensor
|
65 |
+
def apply_jitter(x: torch.Tensor, moe_jitter_eps: float) -> torch.Tensor:
|
66 |
+
low = 1.0 - moe_jitter_eps
|
67 |
+
high = 1.0 + moe_jitter_eps
|
68 |
+
noise = torch.rand(x.size(), dtype=x.dtype, device=x.device)
|
69 |
+
return x * (low + noise * (high - low))
|
70 |
+
|
71 |
+
|
72 |
+
# Compute the top-k scores from the logits
|
73 |
+
def compute_top_k(scores: torch.Tensor, moe_top_k: int):
|
74 |
+
if moe_top_k == 1:
|
75 |
+
return scores.max(dim=-1, keepdim=True)
|
76 |
+
return torch.topk(scores, moe_top_k, dim=-1)
|
77 |
+
|
78 |
+
|
79 |
+
# Route tokens to experts and compute expert weights and indices
|
80 |
+
def route_tokens(
|
81 |
+
x: torch.Tensor,
|
82 |
+
router_weight: torch.Tensor,
|
83 |
+
moe_top_k: int,
|
84 |
+
moe_num_experts: int,
|
85 |
+
moe_jitter_eps: float = None,
|
86 |
+
moe_normalize_expert_weights: int = None,
|
87 |
+
uniform_expert_assignment: bool = False,
|
88 |
+
training: bool = False,
|
89 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
90 |
+
if training and moe_jitter_eps is not None:
|
91 |
+
x = apply_jitter(x, moe_jitter_eps)
|
92 |
+
|
93 |
+
x_flat = x.view(-1, x.shape[-1])
|
94 |
+
logits = torch.nn.functional.linear(x_flat, router_weight)
|
95 |
+
expert_weights, expert_indices = compute_top_k(logits, moe_top_k)
|
96 |
+
expert_weights = expert_weights.softmax(dim=-1)
|
97 |
+
if moe_normalize_expert_weights is not None:
|
98 |
+
expert_weights = expert_weights / torch.norm(
|
99 |
+
expert_weights,
|
100 |
+
p=moe_normalize_expert_weights,
|
101 |
+
dim=-1,
|
102 |
+
keepdim=True,
|
103 |
+
)
|
104 |
+
if uniform_expert_assignment:
|
105 |
+
expert_indices = _layers.router._uniform_expert_assignment(
|
106 |
+
expert_indices,
|
107 |
+
moe_num_experts,
|
108 |
+
)
|
109 |
+
|
110 |
+
return logits, expert_weights, expert_indices
|
111 |
+
|
112 |
+
|
113 |
+
# Scale the gradient of the weights
|
114 |
+
def scale_grad(
|
115 |
+
w: torch.Tensor,
|
116 |
+
gradient_scale: Optional[float] = None,
|
117 |
+
) -> torch.Tensor:
|
118 |
+
if gradient_scale is None:
|
119 |
+
return w
|
120 |
+
return _layers.mlp.scale_gradient(w, gradient_scale)
|
121 |
+
|
122 |
+
|
123 |
+
# Forward pass for the MLP layer
|
124 |
+
def mlp_forward(
|
125 |
+
x: torch.Tensor,
|
126 |
+
w1: torch.Tensor,
|
127 |
+
w2: torch.Tensor,
|
128 |
+
w1_bias: torch.Tensor,
|
129 |
+
w2_bias: torch.Tensor,
|
130 |
+
gradient_scale: Optional[float] = None,
|
131 |
+
alpha: float = 1.702,
|
132 |
+
):
|
133 |
+
# Scale weights
|
134 |
+
w1 = scale_grad(w1, gradient_scale)
|
135 |
+
w2 = scale_grad(w2, gradient_scale)
|
136 |
+
w1_bias = scale_grad(w1_bias, gradient_scale)
|
137 |
+
w2_bias = scale_grad(w2_bias, gradient_scale)
|
138 |
+
|
139 |
+
# Resolve dtensors
|
140 |
+
w1 = _layers.mlp.resolve_dtensor(w1)
|
141 |
+
w2 = _layers.mlp.resolve_dtensor(w2)
|
142 |
+
w1_bias = _layers.mlp.resolve_dtensor(w1_bias)
|
143 |
+
w2_bias = _layers.mlp.resolve_dtensor(w2_bias)
|
144 |
+
|
145 |
+
# Forward pass
|
146 |
+
gate_up = torch.bmm(x, w1) + w1_bias[..., None, :]
|
147 |
+
gate, up = gate_up.chunk(2, dim=-1)
|
148 |
+
|
149 |
+
glu = gate * torch.sigmoid(gate * alpha)
|
150 |
+
x = (up + 1) * glu
|
151 |
+
|
152 |
+
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
153 |
+
|
154 |
+
|
155 |
+
# Shared expert MLP forward pass
|
156 |
+
def shared_mlp_forward(
|
157 |
+
x: torch.Tensor,
|
158 |
+
up_proj_weight: torch.Tensor,
|
159 |
+
down_proj_weight: torch.Tensor,
|
160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
162 |
+
activation_fn: Optional[Any] = None,
|
163 |
+
gradient_scale: Optional[float] = None,
|
164 |
+
) -> torch.Tensor:
|
165 |
+
# Default activation function
|
166 |
+
if activation_fn is None:
|
167 |
+
activation_fn = torch.nn.functional.gelu
|
168 |
+
|
169 |
+
# Scale weights
|
170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
172 |
+
if up_proj_bias is not None:
|
173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
174 |
+
if down_proj_bias is not None:
|
175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
176 |
+
|
177 |
+
# Resolve dtensors
|
178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
180 |
+
if up_proj_bias is not None:
|
181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
182 |
+
if down_proj_bias is not None:
|
183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
184 |
+
|
185 |
+
# Up projection
|
186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
187 |
+
|
188 |
+
# Activation
|
189 |
+
x = activation_fn(x)
|
190 |
+
|
191 |
+
# Down projection
|
192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
193 |
+
|
194 |
+
return x
|
195 |
+
|
196 |
+
|
197 |
+
# Combine outputs from shared expert and regular experts
|
198 |
+
def combine_expert_shared_outputs(
|
199 |
+
shared_expert_out: torch.Tensor,
|
200 |
+
expert_out: torch.Tensor,
|
201 |
+
shared_expert_weighted_sum: bool = False,
|
202 |
+
moe_top_k: int = 1,
|
203 |
+
) -> torch.Tensor:
|
204 |
+
if shared_expert_weighted_sum:
|
205 |
+
# Weighted sum based on number of experts used
|
206 |
+
total_experts = moe_top_k + 1
|
207 |
+
shared_weight = 1.0 / total_experts
|
208 |
+
expert_weight = moe_top_k / total_experts
|
209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
210 |
+
else:
|
211 |
+
# Simple addition
|
212 |
+
return shared_expert_out + expert_out
|
213 |
+
|
214 |
+
|
215 |
+
# Global variable to store load balancing loss
|
216 |
+
_LOAD_BALANCING_LOSS = []
|
217 |
+
|
218 |
+
|
219 |
+
def save_load_balancing_loss(loss):
|
220 |
+
global _LOAD_BALANCING_LOSS
|
221 |
+
_LOAD_BALANCING_LOSS.append(loss)
|
222 |
+
|
223 |
+
|
224 |
+
def get_load_balancing_loss():
|
225 |
+
global _LOAD_BALANCING_LOSS
|
226 |
+
return _LOAD_BALANCING_LOSS
|
227 |
+
|
228 |
+
|
229 |
+
def clear_load_balancing_loss():
|
230 |
+
global _LOAD_BALANCING_LOSS
|
231 |
+
_LOAD_BALANCING_LOSS.clear()
|
232 |
+
|
233 |
+
|
234 |
+
def batched_load_balancing_loss(args):
|
235 |
+
if args.moe_loss_weight == 0:
|
236 |
+
return 0.0
|
237 |
+
|
238 |
+
tokens_per_expert, expert_scores = zip(*get_load_balancing_loss())
|
239 |
+
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
|
240 |
+
if args.num_layers_per_virtual_pipeline_stage is not None:
|
241 |
+
num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage
|
242 |
+
|
243 |
+
if len(tokens_per_expert) != num_layers_per_pipeline_stage:
|
244 |
+
raise ValueError(
|
245 |
+
f"Expected {num_layers_per_pipeline_stage} token_per_experts "
|
246 |
+
f"but found {len(tokens_per_expert)}.\nnum_layers = "
|
247 |
+
f"{args.num_layers}\npipeline_model_parallel_size = "
|
248 |
+
f"{args.pipeline_model_parallel_size}\n"
|
249 |
+
"num_layers_per_virtual_pipeline_stage"
|
250 |
+
f" = {args.num_layers_per_virtual_pipeline_stage}",
|
251 |
+
)
|
252 |
+
if len(expert_scores) != num_layers_per_pipeline_stage:
|
253 |
+
raise ValueError(
|
254 |
+
f"Expected {num_layers_per_pipeline_stage} expert_scores "
|
255 |
+
f"but found {len(tokens_per_expert)}.\nnum_layers = "
|
256 |
+
f"{args.num_layers}\npipeline_model_parallel_size = "
|
257 |
+
f"{args.pipeline_model_parallel_size}\n"
|
258 |
+
"num_layers_per_virtual_pipeline_stage"
|
259 |
+
f" = {args.num_layers_per_virtual_pipeline_stage}",
|
260 |
+
)
|
261 |
+
|
262 |
+
# Verify the shape of the tokens_per_expert and expert_scores tensors.
|
263 |
+
assert all(
|
264 |
+
(x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)
|
265 |
+
)
|
266 |
+
|
267 |
+
tokens = expert_scores[0].shape[0]
|
268 |
+
assert all(
|
269 |
+
(
|
270 |
+
(
|
271 |
+
x.ndim == 2
|
272 |
+
and x.shape[1] == args.moe_num_experts
|
273 |
+
and x.shape[0] == tokens
|
274 |
+
)
|
275 |
+
for x in expert_scores
|
276 |
+
)
|
277 |
+
)
|
278 |
+
|
279 |
+
# Concatenate the contributions of each layer and convert to
|
280 |
+
# the correct types and formats for the dot product.
|
281 |
+
expert_scores = torch.cat(expert_scores, dim=1)
|
282 |
+
if args.moe_lbl_in_fp32:
|
283 |
+
expert_scores = expert_scores.float()
|
284 |
+
if tokens != 0:
|
285 |
+
expert_scores = expert_scores.mean(dim=0)
|
286 |
+
else:
|
287 |
+
expert_scores = expert_scores.sum(dim=0)
|
288 |
+
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
|
289 |
+
|
290 |
+
expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
|
291 |
+
assert tokens_per_expert.numel() == expected_values
|
292 |
+
assert expert_scores.numel() == expected_values
|
293 |
+
|
294 |
+
# Calculate the total scale across all factors.
|
295 |
+
#
|
296 |
+
# loss_weight * num_experts / (num_layers * tokens * top_k)
|
297 |
+
scale_numerator = args.moe_num_experts * args.moe_loss_weight
|
298 |
+
scale_denominator = args.num_layers * tokens * args.moe_top_k
|
299 |
+
scale = scale_numerator / scale_denominator
|
300 |
+
return scale * torch.dot(tokens_per_expert, expert_scores)
|
301 |
+
|
302 |
+
|
303 |
+
# Calculate the expert capacity based on tokens, top_k, number of experts,
|
304 |
+
# expert parallel group, capacity factor, and whether expert model parallelism is used.
|
305 |
+
def expert_capacity(
|
306 |
+
tokens: int,
|
307 |
+
top_k: int,
|
308 |
+
num_experts: int,
|
309 |
+
expert_parallel_group: int,
|
310 |
+
moe_capacity_factor: float,
|
311 |
+
moe_expert_model_parallelism: bool,
|
312 |
+
) -> int:
|
313 |
+
world_size = (
|
314 |
+
dist.get_world_size(expert_parallel_group)
|
315 |
+
if moe_expert_model_parallelism
|
316 |
+
else 1
|
317 |
+
)
|
318 |
+
|
319 |
+
tokens_per_expert = top_k * tokens * world_size / num_experts
|
320 |
+
return int(moe_capacity_factor * tokens_per_expert)
|
321 |
+
|
322 |
+
|
323 |
+
def load_balancing_loss(
|
324 |
+
tokens_per_expert: torch.Tensor,
|
325 |
+
expert_scores: torch.Tensor,
|
326 |
+
top_k: int,
|
327 |
+
num_experts: int,
|
328 |
+
):
|
329 |
+
assert len(expert_scores.size()) == 2
|
330 |
+
tokens, num_experts = expert_scores.size()
|
331 |
+
assert num_experts == num_experts
|
332 |
+
assert len(tokens_per_expert.size()) == 1
|
333 |
+
(num_experts,) = tokens_per_expert.size()
|
334 |
+
assert num_experts == num_experts
|
335 |
+
scale = num_experts / (tokens * top_k)
|
336 |
+
return scale * torch.dot(
|
337 |
+
tokens_per_expert.to(expert_scores.dtype),
|
338 |
+
expert_scores.mean(dim=0),
|
339 |
+
)
|
340 |
+
|
341 |
+
|
342 |
+
def indices_and_bins(
|
343 |
+
top_expert: torch.Tensor,
|
344 |
+
sort_end_bit: int,
|
345 |
+
num_experts: int,
|
346 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
347 |
+
top_expert = top_expert.int()
|
348 |
+
|
349 |
+
# Ensure contiguous memory layout
|
350 |
+
top_expert = top_expert.contiguous()
|
351 |
+
|
352 |
+
# Ensure CUB knows which device to use
|
353 |
+
with torch.cuda.device(top_expert.device):
|
354 |
+
output = ops.sort(top_expert, sort_end_bit)
|
355 |
+
bin_ids, indices = output
|
356 |
+
tokens_per_expert = ops.histogram(top_expert, num_experts)
|
357 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
358 |
+
|
359 |
+
bins = bins.view(1) if not len(bins.size()) else bins
|
360 |
+
return indices, bin_ids, bins, tokens_per_expert
|
361 |
+
|
362 |
+
|
363 |
+
def expert_capacity_fn(
|
364 |
+
tokens: int,
|
365 |
+
top_k: int,
|
366 |
+
num_experts: int,
|
367 |
+
expert_parallel_group: torch.distributed.ProcessGroup,
|
368 |
+
moe_capacity_factor: float = 1.0,
|
369 |
+
moe_expert_model_parallelism: bool = False,
|
370 |
+
) -> int:
|
371 |
+
world_size = (
|
372 |
+
dist.get_world_size(expert_parallel_group)
|
373 |
+
if moe_expert_model_parallelism
|
374 |
+
else 1
|
375 |
+
)
|
376 |
+
tokens_per_expert = top_k * tokens * world_size / num_experts
|
377 |
+
return int(moe_capacity_factor * tokens_per_expert)
|
378 |
+
|
379 |
+
|
380 |
+
def permute_and_compute(
|
381 |
+
x,
|
382 |
+
tokens_per_expert,
|
383 |
+
indices,
|
384 |
+
bin_ids,
|
385 |
+
expert_weights,
|
386 |
+
bins,
|
387 |
+
expert_capacity,
|
388 |
+
top_k,
|
389 |
+
w1,
|
390 |
+
w2,
|
391 |
+
w1_bias,
|
392 |
+
w2_bias,
|
393 |
+
gradient_scale,
|
394 |
+
alpha,
|
395 |
+
):
|
396 |
+
# Route tokens to experts
|
397 |
+
x = x.view(-1, x.shape[-1])
|
398 |
+
|
399 |
+
# Ensure CUB knows which device to use
|
400 |
+
with torch.cuda.device(x.device):
|
401 |
+
x = ops.binned_gather(x, indices, bins, expert_capacity, top_k)
|
402 |
+
|
403 |
+
# Expert computation
|
404 |
+
x = mlp_forward(x, w1, w2, w1_bias, w2_bias, gradient_scale, alpha)
|
405 |
+
|
406 |
+
# Ensure CUB knows which device to use
|
407 |
+
with torch.cuda.device(x.device):
|
408 |
+
# Route tokens back
|
409 |
+
out = ops.binned_scatter(x, indices, expert_weights, bins, top_k)
|
410 |
+
return out
|
411 |
+
|
412 |
+
|
413 |
+
def forward_once(
|
414 |
+
x: torch.Tensor,
|
415 |
+
expert_weights: torch.Tensor,
|
416 |
+
top_experts: torch.Tensor,
|
417 |
+
w1: torch.Tensor,
|
418 |
+
w2: torch.Tensor,
|
419 |
+
w1_bias: torch.Tensor,
|
420 |
+
w2_bias: torch.Tensor,
|
421 |
+
gradient_scale: Optional[float] = None,
|
422 |
+
alpha: float = 1.702,
|
423 |
+
sort_end_bit: int = 0,
|
424 |
+
top_k: int = 4,
|
425 |
+
num_experts: int = 128,
|
426 |
+
expert_parallel_group: int = None,
|
427 |
+
moe_capacity_factor: float = 1.0,
|
428 |
+
moe_expert_model_parallelism: bool = False,
|
429 |
+
mlp_impl: Optional[str] = None,
|
430 |
+
):
|
431 |
+
# x: [sl, bs, hs]
|
432 |
+
# expert_weights: [sl * bs, top-k]
|
433 |
+
# top_experts: [sl * bs, top-k]
|
434 |
+
expert_weights = expert_weights.flatten()
|
435 |
+
top_experts = top_experts.flatten()
|
436 |
+
|
437 |
+
with torch.no_grad():
|
438 |
+
indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
|
439 |
+
top_experts, sort_end_bit, num_experts
|
440 |
+
)
|
441 |
+
|
442 |
+
# Calculate expert capacity
|
443 |
+
sl, bs, _ = x.size()
|
444 |
+
|
445 |
+
expert_capacity = expert_capacity_fn(
|
446 |
+
sl * bs,
|
447 |
+
top_k,
|
448 |
+
num_experts,
|
449 |
+
expert_parallel_group,
|
450 |
+
moe_capacity_factor,
|
451 |
+
moe_expert_model_parallelism,
|
452 |
+
)
|
453 |
+
|
454 |
+
if expert_capacity == 0:
|
455 |
+
expert_capacity = torch.max(tokens_per_expert).item()
|
456 |
+
|
457 |
+
x = permute_and_compute(
|
458 |
+
x,
|
459 |
+
tokens_per_expert,
|
460 |
+
indices,
|
461 |
+
bin_ids,
|
462 |
+
expert_weights,
|
463 |
+
bins,
|
464 |
+
expert_capacity,
|
465 |
+
top_k,
|
466 |
+
w1,
|
467 |
+
w2,
|
468 |
+
w1_bias,
|
469 |
+
w2_bias,
|
470 |
+
gradient_scale,
|
471 |
+
alpha,
|
472 |
+
)
|
473 |
+
return x, tokens_per_expert
|
474 |
+
|
475 |
+
|
476 |
+
def parallel_forward_once(
|
477 |
+
x: torch.Tensor,
|
478 |
+
expert_weights: torch.Tensor,
|
479 |
+
top_experts: torch.Tensor,
|
480 |
+
w1: torch.Tensor,
|
481 |
+
w2: torch.Tensor,
|
482 |
+
w1_bias: torch.Tensor,
|
483 |
+
w2_bias: torch.Tensor,
|
484 |
+
gradient_scale: Optional[float] = None,
|
485 |
+
alpha: float = 1.702,
|
486 |
+
sort_end_bit: int = 0,
|
487 |
+
top_k: int = 4,
|
488 |
+
num_experts: int = 128,
|
489 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
490 |
+
moe_capacity_factor: float = 1.0,
|
491 |
+
moe_expert_model_parallelism: bool = True,
|
492 |
+
hidden_size: int = 1152,
|
493 |
+
mlp_impl: Optional[str] = "sparse",
|
494 |
+
):
|
495 |
+
# Flatten inputs
|
496 |
+
expert_weights = expert_weights.flatten()
|
497 |
+
top_experts = top_experts.flatten()
|
498 |
+
|
499 |
+
# TODO: remove debugging var
|
500 |
+
# my_rank = dist.get_rank(expert_parallel_group) if expert_parallel_group else 0
|
501 |
+
|
502 |
+
with torch.no_grad():
|
503 |
+
# Step 1: Local permutation setup
|
504 |
+
indices, bin_ids, bins, tokens_per_expert = indices_and_bins(
|
505 |
+
top_experts, sort_end_bit, num_experts
|
506 |
+
)
|
507 |
+
|
508 |
+
# Calculate sharding parameters
|
509 |
+
world_size = dist.get_world_size(expert_parallel_group)
|
510 |
+
hidden_sharding_deg = hidden_sharding_degree(
|
511 |
+
world_size, num_experts, hidden_size
|
512 |
+
)
|
513 |
+
experts_per_rank_val = experts_per_rank(num_experts, world_size)
|
514 |
+
|
515 |
+
# Replicate token counts for hidden sharding
|
516 |
+
repeated_tokens_per_expert = ops.repeat(
|
517 |
+
tokens_per_expert, (hidden_sharding_deg,)
|
518 |
+
)
|
519 |
+
|
520 |
+
# Exchange token counts across devices
|
521 |
+
parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert)
|
522 |
+
|
523 |
+
# Ensure CUB knows which device to use
|
524 |
+
tpe_handle = dist.all_to_all_single(
|
525 |
+
parallel_tokens_per_expert,
|
526 |
+
repeated_tokens_per_expert,
|
527 |
+
group=expert_parallel_group,
|
528 |
+
async_op=True,
|
529 |
+
)
|
530 |
+
|
531 |
+
# Step 2: Local permutation - group tokens by target device
|
532 |
+
x = x.view(-1, x.shape[-1]) # [sl * bs, hs]
|
533 |
+
x = ops.gather(x, indices, bin_ids, bins, top_k)
|
534 |
+
|
535 |
+
# Step 3: Compute communication counts and exchange tokens
|
536 |
+
with torch.no_grad():
|
537 |
+
tpe_handle.wait()
|
538 |
+
|
539 |
+
# Reshape for per-device calculations
|
540 |
+
repeated_tokens_per_expert = repeated_tokens_per_expert.view(
|
541 |
+
world_size, experts_per_rank_val
|
542 |
+
)
|
543 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.view(
|
544 |
+
world_size, experts_per_rank_val
|
545 |
+
)
|
546 |
+
|
547 |
+
# Calculate send/recv counts
|
548 |
+
send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1).tolist()
|
549 |
+
# recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1).tolist()
|
550 |
+
parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu()
|
551 |
+
recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1).tolist()
|
552 |
+
tokens_received = sum(recv_counts)
|
553 |
+
|
554 |
+
# Replicate for hidden sharding
|
555 |
+
x = ops.repeat(x, (hidden_sharding_deg, 1))
|
556 |
+
|
557 |
+
# Cross-device token exchange
|
558 |
+
parallel_x, parallel_x_handle = _layers.all_to_all.all_to_all(
|
559 |
+
x, recv_counts, send_counts, expert_parallel_group, async_op=True
|
560 |
+
)
|
561 |
+
|
562 |
+
with torch.no_grad():
|
563 |
+
# Step 4: Setup for local expert computation
|
564 |
+
replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0)
|
565 |
+
replicate_bins = (
|
566 |
+
replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins
|
567 |
+
)
|
568 |
+
|
569 |
+
# Create expert indices for received tokens
|
570 |
+
parallel_top_expert = torch.remainder(
|
571 |
+
torch.arange(
|
572 |
+
num_experts * hidden_sharding_deg,
|
573 |
+
dtype=torch.int32,
|
574 |
+
device=indices.device,
|
575 |
+
),
|
576 |
+
experts_per_rank_val,
|
577 |
+
)
|
578 |
+
parallel_top_expert = ops.replicate(
|
579 |
+
parallel_top_expert.unsqueeze(dim=0),
|
580 |
+
replicate_bins,
|
581 |
+
tokens_received,
|
582 |
+
).flatten()
|
583 |
+
|
584 |
+
# Sort tokens by expert assignment
|
585 |
+
parallel_bin_ids, parallel_indices = ops.sort(
|
586 |
+
parallel_top_expert,
|
587 |
+
sort_end_bit,
|
588 |
+
)
|
589 |
+
|
590 |
+
# Calculate bins for local experts
|
591 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert.sum(
|
592 |
+
dim=0, dtype=torch.int
|
593 |
+
)
|
594 |
+
parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0)
|
595 |
+
parallel_bins = (
|
596 |
+
parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins
|
597 |
+
)
|
598 |
+
|
599 |
+
# Calculate expert capacity
|
600 |
+
expert_capacity = expert_capacity_fn(
|
601 |
+
tokens_received,
|
602 |
+
top_k,
|
603 |
+
experts_per_rank_val,
|
604 |
+
expert_parallel_group,
|
605 |
+
moe_capacity_factor,
|
606 |
+
moe_expert_model_parallelism,
|
607 |
+
)
|
608 |
+
if expert_capacity == 0:
|
609 |
+
expert_capacity = torch.max(parallel_tokens_per_expert).item()
|
610 |
+
|
611 |
+
# Locally permute the tokens and perform the expert computation.
|
612 |
+
# Block to make sure that the cross-device permutation is complete.
|
613 |
+
if mlp_impl == "grouped":
|
614 |
+
# GroupedMLP requires counts on CPU. We can use the tensor already
|
615 |
+
# moved to CPU for the prior all_to_all, which avoids an extra
|
616 |
+
# device synchronization.
|
617 |
+
parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum(
|
618 |
+
dim=0,
|
619 |
+
dtype=torch.int,
|
620 |
+
)
|
621 |
+
|
622 |
+
# Step 5: Expert computation
|
623 |
+
parallel_x_handle.wait()
|
624 |
+
|
625 |
+
parallel_x = permute_and_compute(
|
626 |
+
parallel_x,
|
627 |
+
parallel_tokens_per_expert,
|
628 |
+
parallel_indices,
|
629 |
+
parallel_bin_ids,
|
630 |
+
None, # expert_weights
|
631 |
+
parallel_bins,
|
632 |
+
expert_capacity,
|
633 |
+
top_k=1,
|
634 |
+
w1=w1,
|
635 |
+
w2=w2,
|
636 |
+
w1_bias=w1_bias,
|
637 |
+
w2_bias=w2_bias,
|
638 |
+
gradient_scale=gradient_scale,
|
639 |
+
alpha=alpha,
|
640 |
+
)
|
641 |
+
|
642 |
+
# Step 6: Reverse communication - send results back
|
643 |
+
x, _ = _layers.all_to_all.all_to_all(
|
644 |
+
parallel_x, send_counts, recv_counts, expert_parallel_group
|
645 |
+
)
|
646 |
+
|
647 |
+
# Step 7: Reduce across hidden sharding dimension
|
648 |
+
shape = (hidden_sharding_deg, -1, hidden_size)
|
649 |
+
x = x.view(shape).sum(dim=0)
|
650 |
+
|
651 |
+
# Step 8: Final local unpermutation
|
652 |
+
x = ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k)
|
653 |
+
|
654 |
+
return x, tokens_per_expert.flatten()
|
655 |
+
|
656 |
+
|
657 |
+
def moe_forward(
|
658 |
+
x: torch.Tensor,
|
659 |
+
router_weight: torch.Tensor,
|
660 |
+
moe_top_k: int,
|
661 |
+
moe_num_experts: int,
|
662 |
+
moe_jitter_eps: float = None,
|
663 |
+
moe_normalize_expert_weights: int = None,
|
664 |
+
uniform_expert_assignment: bool = False,
|
665 |
+
training: bool = False,
|
666 |
+
w1: torch.Tensor = None,
|
667 |
+
w2: torch.Tensor = None,
|
668 |
+
w1_bias: torch.Tensor = None,
|
669 |
+
w2_bias: torch.Tensor = None,
|
670 |
+
gradient_scale: Optional[float] = None,
|
671 |
+
alpha: float = 1.702,
|
672 |
+
sort_end_bit: int = 0,
|
673 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
674 |
+
moe_capacity_factor: float = 1.0,
|
675 |
+
moe_expert_model_parallelism: bool = False,
|
676 |
+
forward_fn: Any = None,
|
677 |
+
hidden_size: int = None,
|
678 |
+
mlp_impl: str = "grouped",
|
679 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
680 |
+
|
681 |
+
# Route tokens to experts
|
682 |
+
logits, expert_weights, expert_indices = route_tokens(
|
683 |
+
x,
|
684 |
+
router_weight,
|
685 |
+
moe_top_k,
|
686 |
+
moe_num_experts,
|
687 |
+
moe_jitter_eps,
|
688 |
+
moe_normalize_expert_weights,
|
689 |
+
uniform_expert_assignment,
|
690 |
+
training,
|
691 |
+
)
|
692 |
+
|
693 |
+
# Create router scores for output
|
694 |
+
router_scores = (
|
695 |
+
torch.zeros_like(logits)
|
696 |
+
.scatter_(1, expert_indices, expert_weights)
|
697 |
+
.transpose(0, 1)
|
698 |
+
)
|
699 |
+
|
700 |
+
in_shape = x.size()
|
701 |
+
|
702 |
+
# Prepare forward function arguments
|
703 |
+
forward_args = {
|
704 |
+
"x": x,
|
705 |
+
"expert_weights": expert_weights,
|
706 |
+
"top_experts": expert_indices,
|
707 |
+
"w1": w1,
|
708 |
+
"w2": w2,
|
709 |
+
"w1_bias": w1_bias,
|
710 |
+
"w2_bias": w2_bias,
|
711 |
+
"gradient_scale": gradient_scale,
|
712 |
+
"alpha": alpha,
|
713 |
+
"sort_end_bit": sort_end_bit,
|
714 |
+
"top_k": moe_top_k,
|
715 |
+
"num_experts": moe_num_experts,
|
716 |
+
"expert_parallel_group": expert_parallel_group,
|
717 |
+
"moe_capacity_factor": moe_capacity_factor,
|
718 |
+
"moe_expert_model_parallelism": moe_expert_model_parallelism,
|
719 |
+
"mlp_impl": mlp_impl,
|
720 |
+
}
|
721 |
+
|
722 |
+
# Add hidden_size for parallel forward
|
723 |
+
if moe_expert_model_parallelism and hidden_size is not None:
|
724 |
+
forward_args["hidden_size"] = hidden_size
|
725 |
+
elif moe_expert_model_parallelism and hidden_size is None:
|
726 |
+
# Infer hidden_size from input shape
|
727 |
+
forward_args["hidden_size"] = x.shape[-1]
|
728 |
+
|
729 |
+
# Compute expert outputs
|
730 |
+
x, tokens_per_expert = forward_fn(**forward_args)
|
731 |
+
|
732 |
+
# Save load balancing loss if needed
|
733 |
+
moe_loss_weight = 0.0 # Can be made configurable
|
734 |
+
if training and moe_loss_weight > 0:
|
735 |
+
save_load_balancing_loss((tokens_per_expert, logits))
|
736 |
+
|
737 |
+
# Restore original shape
|
738 |
+
x = x.view(in_shape)
|
739 |
+
|
740 |
+
return x, expert_weights, router_scores
|
741 |
+
|
742 |
+
|
743 |
+
def moe_forward_with_shared_expert(
|
744 |
+
x: torch.Tensor,
|
745 |
+
router_weight: torch.Tensor,
|
746 |
+
moe_top_k: int,
|
747 |
+
moe_num_experts: int,
|
748 |
+
moe_jitter_eps: float = None,
|
749 |
+
moe_normalize_expert_weights: int = None,
|
750 |
+
uniform_expert_assignment: bool = False,
|
751 |
+
training: bool = False,
|
752 |
+
w1: torch.Tensor = None,
|
753 |
+
w2: torch.Tensor = None,
|
754 |
+
w1_bias: torch.Tensor = None,
|
755 |
+
w2_bias: torch.Tensor = None,
|
756 |
+
gradient_scale: Optional[float] = None,
|
757 |
+
alpha: float = 1.702,
|
758 |
+
sort_end_bit: int = 0,
|
759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
760 |
+
moe_capacity_factor: float = 1.0,
|
761 |
+
moe_expert_model_parallelism: bool = False,
|
762 |
+
forward_fn: Any = None,
|
763 |
+
hidden_size: int = None,
|
764 |
+
mlp_impl: str = "grouped",
|
765 |
+
# Shared expert parameters
|
766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
770 |
+
shared_expert_weighted_sum: bool = False,
|
771 |
+
shared_activation_fn: Optional[Any] = None,
|
772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
773 |
+
|
774 |
+
# First, compute regular MoE forward pass
|
775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
776 |
+
x=x,
|
777 |
+
router_weight=router_weight,
|
778 |
+
moe_top_k=moe_top_k,
|
779 |
+
moe_num_experts=moe_num_experts,
|
780 |
+
moe_jitter_eps=moe_jitter_eps,
|
781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
783 |
+
training=training,
|
784 |
+
w1=w1,
|
785 |
+
w2=w2,
|
786 |
+
w1_bias=w1_bias,
|
787 |
+
w2_bias=w2_bias,
|
788 |
+
gradient_scale=gradient_scale,
|
789 |
+
alpha=alpha,
|
790 |
+
sort_end_bit=sort_end_bit,
|
791 |
+
expert_parallel_group=expert_parallel_group,
|
792 |
+
moe_capacity_factor=moe_capacity_factor,
|
793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
794 |
+
forward_fn=forward_fn,
|
795 |
+
hidden_size=hidden_size,
|
796 |
+
mlp_impl=mlp_impl,
|
797 |
+
)
|
798 |
+
|
799 |
+
# If shared expert weights provided, compute shared expert output
|
800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
801 |
+
shared_expert_out = shared_mlp_forward(
|
802 |
+
x=x,
|
803 |
+
up_proj_weight=shared_up_proj_weight,
|
804 |
+
down_proj_weight=shared_down_proj_weight,
|
805 |
+
up_proj_bias=shared_up_proj_bias,
|
806 |
+
down_proj_bias=shared_down_proj_bias,
|
807 |
+
activation_fn=shared_activation_fn,
|
808 |
+
gradient_scale=gradient_scale,
|
809 |
+
)
|
810 |
+
|
811 |
+
# Combine expert outputs
|
812 |
+
combined_out = combine_expert_shared_outputs(
|
813 |
+
shared_expert_out=shared_expert_out,
|
814 |
+
expert_out=expert_out,
|
815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
816 |
+
moe_top_k=moe_top_k,
|
817 |
+
)
|
818 |
+
|
819 |
+
return combined_out, expert_weights, router_scores
|
820 |
+
|
821 |
+
# Return regular MoE output if no shared expert
|
822 |
+
return expert_out, expert_weights, router_scores
|
823 |
+
|
824 |
+
|
825 |
+
def create_shared_expert_weights(
|
826 |
+
hidden_size: int,
|
827 |
+
shared_expert_hidden_size: int,
|
828 |
+
device: torch.device,
|
829 |
+
dtype: torch.dtype,
|
830 |
+
init_method: Any,
|
831 |
+
output_layer_init_method: Any = None,
|
832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
833 |
+
|
834 |
+
if output_layer_init_method is None:
|
835 |
+
output_layer_init_method = init_method
|
836 |
+
|
837 |
+
# Create weight tensors
|
838 |
+
up_proj_weight = torch.empty(
|
839 |
+
shared_expert_hidden_size,
|
840 |
+
hidden_size,
|
841 |
+
device=device,
|
842 |
+
dtype=dtype,
|
843 |
+
)
|
844 |
+
down_proj_weight = torch.empty(
|
845 |
+
hidden_size,
|
846 |
+
shared_expert_hidden_size,
|
847 |
+
device=device,
|
848 |
+
dtype=dtype,
|
849 |
+
)
|
850 |
+
|
851 |
+
# Initialize weights
|
852 |
+
init_method(up_proj_weight)
|
853 |
+
output_layer_init_method(down_proj_weight)
|
854 |
+
|
855 |
+
# No bias by default
|
856 |
+
return up_proj_weight, down_proj_weight, None, None
|
857 |
+
|
858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
861 |
+
# TODO: Replace with a more robust solution when available
|
862 |
+
def get_device_mesh(model):
|
863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
864 |
+
try:
|
865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
867 |
+
# Extract the device_mesh from the closure
|
868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
869 |
+
except Exception:
|
870 |
+
return None
|
871 |
+
|
872 |
+
|
873 |
+
class MegaBlocksMoeMLP(torch.nn.Module):
|
874 |
+
|
875 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
876 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
877 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
878 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
879 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
880 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
881 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
882 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
883 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
884 |
+
|
885 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
886 |
+
if expert_parallel_group is None:
|
887 |
+
device_mesh = get_device_mesh(self)
|
888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
889 |
+
|
890 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
891 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
892 |
+
|
893 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
894 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
895 |
+
|
896 |
+
output, expert_weights_out, *_ = moe_forward(
|
897 |
+
x=x,
|
898 |
+
router_weight=self.router.weight,
|
899 |
+
moe_top_k=moe_top_k,
|
900 |
+
moe_num_experts=moe_num_experts,
|
901 |
+
moe_jitter_eps=moe_jitter_eps,
|
902 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
903 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
904 |
+
training=self.training,
|
905 |
+
w1=self.experts.gate_up_proj,
|
906 |
+
w2=self.experts.down_proj,
|
907 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
908 |
+
w2_bias=self.experts.down_proj_bias,
|
909 |
+
gradient_scale=gradient_scale,
|
910 |
+
alpha=alpha,
|
911 |
+
sort_end_bit=sort_end_bit,
|
912 |
+
expert_parallel_group=expert_parallel_group,
|
913 |
+
moe_capacity_factor=moe_capacity_factor,
|
914 |
+
moe_expert_model_parallelism=has_parallel,
|
915 |
+
forward_fn=forward_fn,
|
916 |
+
hidden_size=self.experts.hidden_size,
|
917 |
+
mlp_impl=mlp_impl,
|
918 |
+
)
|
919 |
+
return output, expert_weights_out
|
920 |
+
|
921 |
+
|
922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
923 |
+
|
924 |
+
def __init__(self):
|
925 |
+
super().__init__()
|
926 |
+
# Shared expert weights will be set by the user
|
927 |
+
self.shared_up_proj_weight = None
|
928 |
+
self.shared_down_proj_weight = None
|
929 |
+
self.shared_up_proj_bias = None
|
930 |
+
self.shared_down_proj_bias = None
|
931 |
+
self.shared_expert_weighted_sum = False
|
932 |
+
self.shared_activation_fn = None
|
933 |
+
|
934 |
+
def set_shared_expert_weights(
|
935 |
+
self,
|
936 |
+
up_proj_weight: torch.Tensor,
|
937 |
+
down_proj_weight: torch.Tensor,
|
938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
940 |
+
weighted_sum: bool = False,
|
941 |
+
activation_fn: Optional[Any] = None,
|
942 |
+
):
|
943 |
+
self.shared_up_proj_weight = up_proj_weight
|
944 |
+
self.shared_down_proj_weight = down_proj_weight
|
945 |
+
self.shared_up_proj_bias = up_proj_bias
|
946 |
+
self.shared_down_proj_bias = down_proj_bias
|
947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
948 |
+
self.shared_activation_fn = activation_fn
|
949 |
+
|
950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
959 |
+
|
960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
961 |
+
if expert_parallel_group is None:
|
962 |
+
device_mesh = get_device_mesh(self)
|
963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
964 |
+
|
965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
967 |
+
|
968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
970 |
+
|
971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
972 |
+
x=x,
|
973 |
+
router_weight=self.router.weight,
|
974 |
+
moe_top_k=moe_top_k,
|
975 |
+
moe_num_experts=moe_num_experts,
|
976 |
+
moe_jitter_eps=moe_jitter_eps,
|
977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
979 |
+
training=self.training,
|
980 |
+
w1=self.experts.gate_up_proj,
|
981 |
+
w2=self.experts.down_proj,
|
982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
983 |
+
w2_bias=self.experts.down_proj_bias,
|
984 |
+
gradient_scale=gradient_scale,
|
985 |
+
alpha=alpha,
|
986 |
+
sort_end_bit=sort_end_bit,
|
987 |
+
expert_parallel_group=expert_parallel_group,
|
988 |
+
moe_capacity_factor=moe_capacity_factor,
|
989 |
+
moe_expert_model_parallelism=has_parallel,
|
990 |
+
forward_fn=forward_fn,
|
991 |
+
hidden_size=self.experts.hidden_size,
|
992 |
+
mlp_impl=mlp_impl,
|
993 |
+
# Shared expert parameters
|
994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
999 |
+
shared_activation_fn=self.shared_activation_fn,
|
1000 |
+
)
|
1001 |
+
return output, expert_weights_out
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from .binned_gather import binned_gather
|
5 |
+
from .binned_scatter import binned_scatter
|
6 |
+
from .cumsum import exclusive_cumsum, inclusive_cumsum
|
7 |
+
from .gather import gather
|
8 |
+
from .histogram import histogram
|
9 |
+
from .padded_gather import padded_gather
|
10 |
+
from .padded_scatter import padded_scatter
|
11 |
+
from .repeat import repeat
|
12 |
+
from .replicate import replicate
|
13 |
+
from .round_up import round_up
|
14 |
+
from .scatter import scatter
|
15 |
+
from .sort import sort
|
16 |
+
from .sum import sum
|
17 |
+
from .topology import topology
|
18 |
+
|
19 |
+
__all__ = [
|
20 |
+
'binned_gather',
|
21 |
+
'binned_scatter',
|
22 |
+
'exclusive_cumsum',
|
23 |
+
'inclusive_cumsum',
|
24 |
+
'gather',
|
25 |
+
'histogram',
|
26 |
+
'padded_gather',
|
27 |
+
'padded_scatter',
|
28 |
+
'repeat',
|
29 |
+
'replicate',
|
30 |
+
'round_up',
|
31 |
+
'scatter',
|
32 |
+
'sort',
|
33 |
+
'sum',
|
34 |
+
'topology',
|
35 |
+
]
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/all_to_all_benchmark.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
# from megablocks import benchmark_util
|
8 |
+
# from megablocks.layers.all_to_all import all_to_all
|
9 |
+
|
10 |
+
from .. import benchmark_util
|
11 |
+
from .._layers.all_to_all import all_to_all
|
12 |
+
|
13 |
+
_ALL_TO_ALL_BENCHMARK = (
|
14 |
+
(8, 1024),
|
15 |
+
(16, 1024),
|
16 |
+
(32, 1024),
|
17 |
+
(64, 1024),
|
18 |
+
(128, 1024),
|
19 |
+
(256, 1024),
|
20 |
+
(512, 1024),
|
21 |
+
(1024, 1024),
|
22 |
+
(2 * 1024, 1024),
|
23 |
+
(4 * 1024, 1024),
|
24 |
+
(8 * 1024, 1024),
|
25 |
+
(16 * 1024, 1024),
|
26 |
+
(32 * 1024, 1024),
|
27 |
+
(64 * 1024, 1024),
|
28 |
+
(128 * 1024, 1024),
|
29 |
+
(256 * 1024, 1024),
|
30 |
+
(512 * 1024, 1024),
|
31 |
+
(1024 * 1024, 1024),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def benchmark_all_to_all(group, sl, hs):
|
36 |
+
world_size = dist.get_world_size(group)
|
37 |
+
assert (sl % world_size) == 0
|
38 |
+
send_recv_sizes = [sl // world_size] * world_size
|
39 |
+
|
40 |
+
x = torch.randn((sl, hs)).cuda().half()
|
41 |
+
|
42 |
+
details = {
|
43 |
+
'world_size': world_size,
|
44 |
+
'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements.
|
45 |
+
}
|
46 |
+
|
47 |
+
def benchmark():
|
48 |
+
return all_to_all(x, send_recv_sizes, send_recv_sizes, group)
|
49 |
+
|
50 |
+
time, std = benchmark_util.benchmark_function(benchmark)
|
51 |
+
|
52 |
+
if dist.get_rank(group) == 0:
|
53 |
+
benchmark_util.log_benchmark('All-To-All', details, time, std)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == '__main__':
|
57 |
+
assert dist.is_available()
|
58 |
+
group = dist.init_process_group(backend='nccl')
|
59 |
+
local_rank = dist.get_rank(group)
|
60 |
+
torch.cuda.set_device(local_rank)
|
61 |
+
|
62 |
+
for args in _ALL_TO_ALL_BENCHMARK:
|
63 |
+
benchmark_all_to_all(group, *args)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_gather.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from .stk_autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from ..backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for binned_gather kernel.
|
12 |
+
class BinnedGatherOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bins: torch.Tensor,
|
21 |
+
bin_size: int,
|
22 |
+
top_k: int,
|
23 |
+
):
|
24 |
+
ctx.save_for_backward(indices, bins)
|
25 |
+
ctx.top_k = top_k
|
26 |
+
return kernels.binned_gather(x, indices, None, bins, bin_size, top_k)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
@custom_bwd
|
30 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
31 |
+
grad = grad.contiguous()
|
32 |
+
indices, bins = ctx.saved_tensors
|
33 |
+
out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k)
|
34 |
+
return out, None, None, None, None
|
35 |
+
|
36 |
+
|
37 |
+
binned_gather = BinnedGatherOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/binned_scatter.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from .stk_autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from ..backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for binned_scatter kernel.
|
12 |
+
class BinnedScatterOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
weights: torch.Tensor,
|
21 |
+
bins: torch.Tensor,
|
22 |
+
top_k: int,
|
23 |
+
):
|
24 |
+
assert len(x.size()) == 3
|
25 |
+
ctx.bin_size = x.size(1)
|
26 |
+
ctx.top_k = top_k
|
27 |
+
|
28 |
+
# TODO(tgale): Don't save 'x' for backwards if we don't need to
|
29 |
+
# calculate the gradient w.r.t. 'weights'.
|
30 |
+
ctx.save_for_backward(x, indices, weights, bins)
|
31 |
+
return kernels.binned_scatter(x, indices, weights, bins, top_k)
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
@custom_bwd
|
35 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
36 |
+
grad = grad.contiguous()
|
37 |
+
x, indices, weights, bins = ctx.saved_tensors
|
38 |
+
out = kernels.binned_gather(
|
39 |
+
grad,
|
40 |
+
indices,
|
41 |
+
weights,
|
42 |
+
bins,
|
43 |
+
ctx.bin_size,
|
44 |
+
ctx.top_k,
|
45 |
+
)
|
46 |
+
|
47 |
+
wgrad = None
|
48 |
+
if ctx.needs_input_grad[2]:
|
49 |
+
wgrad = kernels.binned_scatter_wgrad(
|
50 |
+
x,
|
51 |
+
grad,
|
52 |
+
indices,
|
53 |
+
bins,
|
54 |
+
ctx.top_k,
|
55 |
+
)
|
56 |
+
return out, None, wgrad, None, None
|
57 |
+
|
58 |
+
|
59 |
+
binned_scatter = BinnedScatterOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/cumsum.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
# import megablocks_ops as ops # type: ignore
|
14 |
+
from .._ops import ops # type: ignore
|
15 |
+
except ModuleNotFoundError as e:
|
16 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
17 |
+
|
18 |
+
|
19 |
+
# Autograd wrappers for cumsum kernels.
|
20 |
+
# NOTE: Does not support gradients.
|
21 |
+
class ExclusiveCumsumOp(torch.autograd.Function):
|
22 |
+
|
23 |
+
@staticmethod
|
24 |
+
def forward(ctx: Any, x: torch.Tensor, dim: int):
|
25 |
+
if len(x.size()) == 1:
|
26 |
+
x = x.view([1, -1])
|
27 |
+
out = torch.empty_like(x)
|
28 |
+
ops.exclusive_cumsum(x, 1, out)
|
29 |
+
return out.squeeze()
|
30 |
+
out = torch.empty_like(x)
|
31 |
+
ops.exclusive_cumsum(x, dim, out)
|
32 |
+
return out
|
33 |
+
|
34 |
+
|
35 |
+
exclusive_cumsum = ExclusiveCumsumOp.apply
|
36 |
+
|
37 |
+
|
38 |
+
class InclusiveCumsumOp(torch.autograd.Function):
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor:
|
42 |
+
if len(x.size()) == 1:
|
43 |
+
x = x.view([1, -1])
|
44 |
+
out = torch.empty_like(x)
|
45 |
+
ops.inclusive_cumsum(x, 1, out)
|
46 |
+
return out.squeeze()
|
47 |
+
out = torch.empty_like(x)
|
48 |
+
ops.inclusive_cumsum(x, dim, out)
|
49 |
+
return out
|
50 |
+
|
51 |
+
|
52 |
+
inclusive_cumsum = InclusiveCumsumOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/gather.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from .stk_autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from ..backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for gather kernel.
|
12 |
+
class GatherOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bin_ids: torch.Tensor,
|
21 |
+
bins: torch.Tensor,
|
22 |
+
top_k: int,
|
23 |
+
):
|
24 |
+
ctx.save_for_backward(indices, bin_ids, bins)
|
25 |
+
ctx.top_k = top_k
|
26 |
+
return kernels.gather(x, indices, bin_ids, None, bins, top_k)
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
@custom_bwd
|
30 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
31 |
+
grad = grad.contiguous()
|
32 |
+
|
33 |
+
indices, bin_ids, bins = ctx.saved_tensors
|
34 |
+
out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k)
|
35 |
+
return out, None, None, None, None, None
|
36 |
+
|
37 |
+
|
38 |
+
gather = GatherOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
|
18 |
+
# Autograd wrapper for histogram kernel.
|
19 |
+
# NOTE: Does not support gradients.
|
20 |
+
class HistogramOp(torch.autograd.Function):
|
21 |
+
|
22 |
+
@staticmethod
|
23 |
+
def forward(ctx: Any, x: torch.Tensor, max_val: float):
|
24 |
+
return ops.histogram(x, max_val)
|
25 |
+
|
26 |
+
|
27 |
+
histogram = HistogramOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/histogram_benchmark.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from absl.testing import parameterized
|
9 |
+
|
10 |
+
from .. import ops
|
11 |
+
|
12 |
+
_HISTOGRAM_TESTS = (
|
13 |
+
(16384, torch.int32, 2),
|
14 |
+
(16384, torch.int32, 4),
|
15 |
+
(16384, torch.int32, 8),
|
16 |
+
(16384, torch.int32, 16),
|
17 |
+
(16384, torch.int32, 32),
|
18 |
+
(16384, torch.int32, 64),
|
19 |
+
(16384, torch.int32, 128),
|
20 |
+
(16384, torch.int32, 256),
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def benchmark_function(fn, iterations=10):
|
25 |
+
# Run once to get rid of startup overhead.
|
26 |
+
fn()
|
27 |
+
times = []
|
28 |
+
for _ in range(iterations):
|
29 |
+
start = torch.cuda.Event(enable_timing=True)
|
30 |
+
end = torch.cuda.Event(enable_timing=True)
|
31 |
+
start.record()
|
32 |
+
fn()
|
33 |
+
end.record()
|
34 |
+
torch.cuda.synchronize()
|
35 |
+
times.append(start.elapsed_time(end))
|
36 |
+
times = np.array(times)
|
37 |
+
return times.mean(), times.std(), times.max(), times.min()
|
38 |
+
|
39 |
+
|
40 |
+
def log_benchmark(arguments, mean_t, std_t):
|
41 |
+
print('=' * 60)
|
42 |
+
print('Benchmark Parameters:')
|
43 |
+
for (key, value) in arguments.items():
|
44 |
+
print(f'{key} = {value}')
|
45 |
+
print('Results:')
|
46 |
+
print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
|
47 |
+
print('=' * 60)
|
48 |
+
|
49 |
+
|
50 |
+
class HistogramBenchmark(parameterized.TestCase):
|
51 |
+
|
52 |
+
@parameterized.parameters(*_HISTOGRAM_TESTS)
|
53 |
+
def testHistogram(self, n, dtype, max_val):
|
54 |
+
x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
|
55 |
+
|
56 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),)
|
57 |
+
arguments = {
|
58 |
+
'n': n,
|
59 |
+
'dtype': dtype,
|
60 |
+
'max_val': max_val,
|
61 |
+
}
|
62 |
+
log_benchmark(arguments, mean_t, std_t)
|
63 |
+
|
64 |
+
@parameterized.parameters(*_HISTOGRAM_TESTS)
|
65 |
+
def testTorchHistogram(self, n, dtype, max_val):
|
66 |
+
x = torch.randint(0, 128, (n,)).cuda().to(dtype)
|
67 |
+
|
68 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),)
|
69 |
+
arguments = {
|
70 |
+
'n': n,
|
71 |
+
'dtype': dtype,
|
72 |
+
'max_val': max_val,
|
73 |
+
}
|
74 |
+
log_benchmark(arguments, mean_t, std_t)
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
unittest.main()
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/matmul_benchmark.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
|
7 |
+
# import stk
|
8 |
+
|
9 |
+
# try:
|
10 |
+
# import stk
|
11 |
+
# except ImportError:
|
12 |
+
# import warnings
|
13 |
+
# warnings.warn(
|
14 |
+
# 'Please add `stanford-stk` if megablocks/ops/matmul_benchmark.py is needed.',
|
15 |
+
# )
|
16 |
+
|
17 |
+
from .. import stk
|
18 |
+
|
19 |
+
import torch
|
20 |
+
from absl.testing import parameterized
|
21 |
+
|
22 |
+
from .. import benchmark_util, ops
|
23 |
+
|
24 |
+
|
25 |
+
# Calling tensor.t() calls tensor.transpose(0, 1) which calls
|
26 |
+
# torch.as_strided(...). Circumvent this chain to avoid an overhead
|
27 |
+
# this adds.
|
28 |
+
def transpose_view(x):
|
29 |
+
return torch.as_strided(
|
30 |
+
x,
|
31 |
+
(x.shape[1], x.shape[0]),
|
32 |
+
(x.stride()[1], x.stride()[0]),
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
_MATMUL_TESTS = (
|
37 |
+
(64 * 1024, 512, 2048, 64),
|
38 |
+
(32 * 1024, 768, 3072, 64),
|
39 |
+
(8 * 1024, 1024, 4096, 64),
|
40 |
+
(4 * 2048, 4096, 4 * 4096, 4),
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
def log_benchmark(name, arguments, time, std, flops):
|
45 |
+
benchmark_util.log_benchmark(name, arguments, time, std)
|
46 |
+
print('flops = {:.2f}B'.format(flops / 1e9))
|
47 |
+
print('throughput = {:.2f}T'.format(flops / 1e9 / time))
|
48 |
+
print('=' * 60)
|
49 |
+
|
50 |
+
|
51 |
+
class MatmulBenchmark(parameterized.TestCase):
|
52 |
+
|
53 |
+
def build_sparse_matrix(self, x, padded_bins, fhs, ne):
|
54 |
+
blocking = 128
|
55 |
+
padded_tokens, _ = x.size()
|
56 |
+
assert padded_tokens % blocking == 0
|
57 |
+
assert fhs % blocking == 0
|
58 |
+
|
59 |
+
# Offsets for the sparse matrix. All rows have the
|
60 |
+
# same number of nonzero blocks dictated by the
|
61 |
+
# dimensionality of a single expert.
|
62 |
+
block_rows = padded_tokens // blocking
|
63 |
+
blocks_per_row = fhs // blocking
|
64 |
+
offsets = torch.arange(
|
65 |
+
0,
|
66 |
+
block_rows * blocks_per_row + 1,
|
67 |
+
blocks_per_row,
|
68 |
+
dtype=torch.int32,
|
69 |
+
device=x.device,
|
70 |
+
)
|
71 |
+
|
72 |
+
# Indices for the sparse matrix. The indices for
|
73 |
+
# the intermediate matrix are dynamic depending
|
74 |
+
# on the mapping of tokens to experts.
|
75 |
+
column_indices = ops.topology(
|
76 |
+
padded_bins,
|
77 |
+
blocking,
|
78 |
+
block_rows,
|
79 |
+
blocks_per_row,
|
80 |
+
)
|
81 |
+
data = torch.empty(
|
82 |
+
column_indices.numel(),
|
83 |
+
blocking,
|
84 |
+
blocking,
|
85 |
+
dtype=torch.float16,
|
86 |
+
device=x.device,
|
87 |
+
)
|
88 |
+
shape = (padded_tokens, fhs * ne)
|
89 |
+
row_indices = stk.ops.row_indices(shape, data, offsets, column_indices)
|
90 |
+
return stk.Matrix(shape, data, row_indices, column_indices, offsets)
|
91 |
+
|
92 |
+
def build_input_matrix(self, sl, hs, ne):
|
93 |
+
x = torch.randn((sl, hs)).cuda().half()
|
94 |
+
|
95 |
+
# Assign tokens to experts uniformly.
|
96 |
+
top_expert = torch.arange(0, sl).cuda().int() % ne
|
97 |
+
|
98 |
+
bin_ids, indices = ops.sort(top_expert)
|
99 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
100 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
101 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
102 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
103 |
+
out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, 1)
|
104 |
+
return out, padded_bins
|
105 |
+
|
106 |
+
def build_weight_matrix(self, ne, hs, fhs):
|
107 |
+
return torch.randn((hs, ne * fhs)).cuda().half()
|
108 |
+
|
109 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
110 |
+
def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne):
|
111 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
112 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
113 |
+
topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
114 |
+
w = transpose_view(w)
|
115 |
+
|
116 |
+
def benchmark():
|
117 |
+
return stk.ops.sdd(x, w, topo)
|
118 |
+
|
119 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
120 |
+
arguments = {
|
121 |
+
'sequence_length': sl,
|
122 |
+
'hidden_size': hs,
|
123 |
+
'ffn_hidden_size': fhs,
|
124 |
+
'num_experts': ne,
|
125 |
+
}
|
126 |
+
log_benchmark(
|
127 |
+
'0::Fwd::SDD::NT',
|
128 |
+
arguments,
|
129 |
+
mean_t,
|
130 |
+
std_t,
|
131 |
+
x.numel() * fhs * 2,
|
132 |
+
)
|
133 |
+
|
134 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
135 |
+
def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne):
|
136 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
137 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
138 |
+
topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
139 |
+
|
140 |
+
def benchmark():
|
141 |
+
return stk.ops.dsd(topo, w)
|
142 |
+
|
143 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
144 |
+
arguments = {
|
145 |
+
'sequence_length': sl,
|
146 |
+
'hidden_size': hs,
|
147 |
+
'ffn_hidden_size': fhs,
|
148 |
+
'num_experts': ne,
|
149 |
+
}
|
150 |
+
log_benchmark(
|
151 |
+
'0::GradX::DSD::NN',
|
152 |
+
arguments,
|
153 |
+
mean_t,
|
154 |
+
std_t,
|
155 |
+
x.numel() * fhs * 2,
|
156 |
+
)
|
157 |
+
|
158 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
159 |
+
def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne):
|
160 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
161 |
+
topo = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
162 |
+
topo = topo.t()
|
163 |
+
|
164 |
+
def benchmark():
|
165 |
+
return stk.ops.dsd(topo, x)
|
166 |
+
|
167 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
168 |
+
arguments = {
|
169 |
+
'sequence_length': sl,
|
170 |
+
'hidden_size': hs,
|
171 |
+
'ffn_hidden_size': fhs,
|
172 |
+
'num_experts': ne,
|
173 |
+
}
|
174 |
+
log_benchmark(
|
175 |
+
'0::GradW::DSD::TN',
|
176 |
+
arguments,
|
177 |
+
mean_t,
|
178 |
+
std_t,
|
179 |
+
x.numel() * fhs * 2,
|
180 |
+
)
|
181 |
+
|
182 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
183 |
+
def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne):
|
184 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
185 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
186 |
+
x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
187 |
+
|
188 |
+
def benchmark():
|
189 |
+
return stk.ops.dsd(x, w)
|
190 |
+
|
191 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
192 |
+
arguments = {
|
193 |
+
'sequence_length': sl,
|
194 |
+
'hidden_size': hs,
|
195 |
+
'ffn_hidden_size': fhs,
|
196 |
+
'num_experts': ne,
|
197 |
+
}
|
198 |
+
log_benchmark(
|
199 |
+
'1::Fwd::DSD::NN',
|
200 |
+
arguments,
|
201 |
+
mean_t,
|
202 |
+
std_t,
|
203 |
+
x.nnz * hs * 2,
|
204 |
+
)
|
205 |
+
|
206 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
207 |
+
def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne):
|
208 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
209 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
210 |
+
x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
211 |
+
out = stk.ops.dsd(x, w)
|
212 |
+
w = transpose_view(w)
|
213 |
+
|
214 |
+
def benchmark():
|
215 |
+
return stk.ops.sdd(out, w, x)
|
216 |
+
|
217 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
218 |
+
arguments = {
|
219 |
+
'sequence_length': sl,
|
220 |
+
'hidden_size': hs,
|
221 |
+
'ffn_hidden_size': fhs,
|
222 |
+
'num_experts': ne,
|
223 |
+
}
|
224 |
+
log_benchmark(
|
225 |
+
'1::GradX::SDD::NT',
|
226 |
+
arguments,
|
227 |
+
mean_t,
|
228 |
+
std_t,
|
229 |
+
x.nnz * hs * 2,
|
230 |
+
)
|
231 |
+
|
232 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
233 |
+
def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne):
|
234 |
+
x, padded_bins = self.build_input_matrix(sl, hs, ne)
|
235 |
+
w = self.build_weight_matrix(ne, hs, fhs).t().contiguous()
|
236 |
+
x = self.build_sparse_matrix(x, padded_bins, fhs, ne)
|
237 |
+
out = stk.ops.dsd(x, w)
|
238 |
+
x = x.t()
|
239 |
+
|
240 |
+
def benchmark():
|
241 |
+
return stk.ops.dsd(x, out)
|
242 |
+
|
243 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
244 |
+
arguments = {
|
245 |
+
'sequence_length': sl,
|
246 |
+
'hidden_size': hs,
|
247 |
+
'ffn_hidden_size': fhs,
|
248 |
+
'num_experts': ne,
|
249 |
+
}
|
250 |
+
log_benchmark(
|
251 |
+
'1::GradW::DSD::TN',
|
252 |
+
arguments,
|
253 |
+
mean_t,
|
254 |
+
std_t,
|
255 |
+
x.nnz * hs * 2,
|
256 |
+
)
|
257 |
+
|
258 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
259 |
+
def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne):
|
260 |
+
assert (sl % ne) == 0
|
261 |
+
x = torch.randn((ne, sl // ne, hs)).cuda().half()
|
262 |
+
w = torch.randn((ne, hs, fhs)).cuda().half()
|
263 |
+
|
264 |
+
w = w.transpose(1, 2).contiguous()
|
265 |
+
w = w.transpose(1, 2)
|
266 |
+
|
267 |
+
def benchmark():
|
268 |
+
return torch.bmm(x, w)
|
269 |
+
|
270 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
271 |
+
arguments = {
|
272 |
+
'sequence_length': sl,
|
273 |
+
'hidden_size': hs,
|
274 |
+
'ffn_hidden_size': fhs,
|
275 |
+
'num_experts': ne,
|
276 |
+
}
|
277 |
+
log_benchmark(
|
278 |
+
'0::Fwd:DDD::NT',
|
279 |
+
arguments,
|
280 |
+
mean_t,
|
281 |
+
std_t,
|
282 |
+
x.numel() * fhs * 2,
|
283 |
+
)
|
284 |
+
|
285 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
286 |
+
def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne):
|
287 |
+
assert (sl % ne) == 0
|
288 |
+
x = torch.randn((ne, sl // ne, hs)).cuda().half()
|
289 |
+
w = torch.randn((ne, hs, fhs)).cuda().half()
|
290 |
+
out = torch.bmm(x, w)
|
291 |
+
w = w.transpose(1, 2).contiguous()
|
292 |
+
|
293 |
+
def benchmark():
|
294 |
+
return torch.bmm(out, w)
|
295 |
+
|
296 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
297 |
+
arguments = {
|
298 |
+
'sequence_length': sl,
|
299 |
+
'hidden_size': hs,
|
300 |
+
'ffn_hidden_size': fhs,
|
301 |
+
'num_experts': ne,
|
302 |
+
}
|
303 |
+
log_benchmark(
|
304 |
+
'0:GradX:DDD::NN',
|
305 |
+
arguments,
|
306 |
+
mean_t,
|
307 |
+
std_t,
|
308 |
+
x.numel() * fhs * 2,
|
309 |
+
)
|
310 |
+
|
311 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
312 |
+
def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne):
|
313 |
+
assert (sl % ne) == 0
|
314 |
+
x = torch.randn((ne, sl // ne, hs)).cuda().half()
|
315 |
+
w = torch.randn((ne, hs, fhs)).cuda().half()
|
316 |
+
out = torch.bmm(x, w)
|
317 |
+
out = out.transpose(1, 2)
|
318 |
+
|
319 |
+
def benchmark():
|
320 |
+
return torch.bmm(out, x)
|
321 |
+
|
322 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
323 |
+
arguments = {
|
324 |
+
'sequence_length': sl,
|
325 |
+
'hidden_size': hs,
|
326 |
+
'ffn_hidden_size': fhs,
|
327 |
+
'num_experts': ne,
|
328 |
+
}
|
329 |
+
log_benchmark(
|
330 |
+
'0:GradW:DDD::TN',
|
331 |
+
arguments,
|
332 |
+
mean_t,
|
333 |
+
std_t,
|
334 |
+
x.numel() * fhs * 2,
|
335 |
+
)
|
336 |
+
|
337 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
338 |
+
def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne):
|
339 |
+
assert (sl % ne) == 0
|
340 |
+
x = torch.randn((ne, sl // ne, fhs)).cuda().half()
|
341 |
+
w = torch.randn((ne, fhs, hs)).cuda().half()
|
342 |
+
|
343 |
+
def benchmark():
|
344 |
+
return torch.bmm(x, w)
|
345 |
+
|
346 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
347 |
+
arguments = {
|
348 |
+
'sequence_length': sl,
|
349 |
+
'hidden_size': hs,
|
350 |
+
'ffn_hidden_size': fhs,
|
351 |
+
'num_experts': ne,
|
352 |
+
}
|
353 |
+
log_benchmark(
|
354 |
+
'1::Fwd::DDD::NN',
|
355 |
+
arguments,
|
356 |
+
mean_t,
|
357 |
+
std_t,
|
358 |
+
x.numel() * hs * 2,
|
359 |
+
)
|
360 |
+
|
361 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
362 |
+
def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne):
|
363 |
+
assert (sl % ne) == 0
|
364 |
+
x = torch.randn((ne, sl // ne, fhs)).cuda().half()
|
365 |
+
w = torch.randn((ne, fhs, hs)).cuda().half()
|
366 |
+
out = torch.bmm(x, w)
|
367 |
+
w = torch.transpose(w, 1, 2)
|
368 |
+
|
369 |
+
def benchmark():
|
370 |
+
return torch.bmm(out, w)
|
371 |
+
|
372 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
373 |
+
arguments = {
|
374 |
+
'sequence_length': sl,
|
375 |
+
'hidden_size': hs,
|
376 |
+
'ffn_hidden_size': fhs,
|
377 |
+
'num_experts': ne,
|
378 |
+
}
|
379 |
+
log_benchmark(
|
380 |
+
'1::GradX::DDD::NT',
|
381 |
+
arguments,
|
382 |
+
mean_t,
|
383 |
+
std_t,
|
384 |
+
x.numel() * hs * 2,
|
385 |
+
)
|
386 |
+
|
387 |
+
@parameterized.parameters(*_MATMUL_TESTS)
|
388 |
+
def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne):
|
389 |
+
assert (sl % ne) == 0
|
390 |
+
x = torch.randn((ne, sl // ne, fhs)).cuda().half()
|
391 |
+
w = torch.randn((ne, fhs, hs)).cuda().half()
|
392 |
+
out = torch.bmm(x, w)
|
393 |
+
x = torch.transpose(x, 1, 2)
|
394 |
+
|
395 |
+
def benchmark():
|
396 |
+
return torch.bmm(x, out)
|
397 |
+
|
398 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
399 |
+
arguments = {
|
400 |
+
'sequence_length': sl,
|
401 |
+
'hidden_size': hs,
|
402 |
+
'ffn_hidden_size': fhs,
|
403 |
+
'num_experts': ne,
|
404 |
+
}
|
405 |
+
log_benchmark(
|
406 |
+
'1::GradW::DDD::TN',
|
407 |
+
arguments,
|
408 |
+
mean_t,
|
409 |
+
std_t,
|
410 |
+
x.numel() * hs * 2,
|
411 |
+
)
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == '__main__':
|
415 |
+
unittest.main()
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_gather.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from .stk_autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from ..backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for padded_gather kernel.
|
12 |
+
class PaddedGatherOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bin_ids: torch.Tensor,
|
21 |
+
bins: torch.Tensor,
|
22 |
+
padded_bins: torch.Tensor,
|
23 |
+
top_k: int,
|
24 |
+
):
|
25 |
+
ctx.save_for_backward(indices, bin_ids, bins, padded_bins)
|
26 |
+
ctx.top_k = top_k
|
27 |
+
return kernels.padded_gather(
|
28 |
+
x,
|
29 |
+
indices,
|
30 |
+
bin_ids,
|
31 |
+
None,
|
32 |
+
bins,
|
33 |
+
padded_bins,
|
34 |
+
top_k,
|
35 |
+
)
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
@custom_bwd
|
39 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
40 |
+
grad = grad.contiguous()
|
41 |
+
|
42 |
+
indices, bin_ids, bins, padded_bins = ctx.saved_tensors
|
43 |
+
out = kernels.padded_scatter(
|
44 |
+
grad,
|
45 |
+
indices,
|
46 |
+
bin_ids,
|
47 |
+
None,
|
48 |
+
bins,
|
49 |
+
padded_bins,
|
50 |
+
ctx.top_k,
|
51 |
+
)
|
52 |
+
return out, None, None, None, None, None
|
53 |
+
|
54 |
+
|
55 |
+
padded_gather = PaddedGatherOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from .stk_autocast import custom_bwd, custom_fwd
|
7 |
+
|
8 |
+
from ..backend import kernels
|
9 |
+
|
10 |
+
|
11 |
+
# Autograd wrapper for padded_scatter kernel.
|
12 |
+
class PaddedScatterOp(torch.autograd.Function):
|
13 |
+
|
14 |
+
@staticmethod
|
15 |
+
@custom_fwd
|
16 |
+
def forward(
|
17 |
+
ctx: Any,
|
18 |
+
x: torch.Tensor,
|
19 |
+
indices: torch.Tensor,
|
20 |
+
bin_ids: torch.Tensor,
|
21 |
+
weights: torch.Tensor,
|
22 |
+
bins: torch.Tensor,
|
23 |
+
padded_bins: torch.Tensor,
|
24 |
+
top_k: int,
|
25 |
+
):
|
26 |
+
maybe_x = [x] if ctx.needs_input_grad[3] else []
|
27 |
+
ctx.save_for_backward(
|
28 |
+
indices,
|
29 |
+
bin_ids,
|
30 |
+
weights,
|
31 |
+
bins,
|
32 |
+
padded_bins,
|
33 |
+
*maybe_x,
|
34 |
+
)
|
35 |
+
ctx.top_k = top_k
|
36 |
+
ctx.x_shape = x.shape
|
37 |
+
return kernels.padded_scatter(
|
38 |
+
x,
|
39 |
+
indices,
|
40 |
+
bin_ids,
|
41 |
+
weights,
|
42 |
+
bins,
|
43 |
+
padded_bins,
|
44 |
+
top_k,
|
45 |
+
)
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
@custom_bwd
|
49 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
50 |
+
grad = grad.contiguous()
|
51 |
+
saved_tensors = ctx.saved_tensors
|
52 |
+
|
53 |
+
indices, bin_ids, weights, bins, padded_bins = saved_tensors[:5]
|
54 |
+
dgrad = None
|
55 |
+
if ctx.needs_input_grad[0]:
|
56 |
+
dgrad = kernels.padded_gather(
|
57 |
+
grad,
|
58 |
+
indices,
|
59 |
+
bin_ids,
|
60 |
+
weights,
|
61 |
+
bins,
|
62 |
+
padded_bins,
|
63 |
+
ctx.top_k,
|
64 |
+
)
|
65 |
+
|
66 |
+
wgrad = None
|
67 |
+
if ctx.needs_input_grad[3]: # need wgrad
|
68 |
+
x = saved_tensors[-1]
|
69 |
+
wgrad = kernels.padded_scatter_wgrad(
|
70 |
+
x,
|
71 |
+
grad,
|
72 |
+
indices,
|
73 |
+
bin_ids,
|
74 |
+
bins,
|
75 |
+
padded_bins,
|
76 |
+
ctx.top_k,
|
77 |
+
)
|
78 |
+
return dgrad, None, None, wgrad, None, None, None, None
|
79 |
+
|
80 |
+
|
81 |
+
def padded_scatter(
|
82 |
+
x: torch.Tensor,
|
83 |
+
indices: torch.Tensor,
|
84 |
+
bin_ids: torch.Tensor,
|
85 |
+
weights: torch.Tensor,
|
86 |
+
bins: torch.Tensor,
|
87 |
+
padded_bins: torch.Tensor,
|
88 |
+
top_k: int,
|
89 |
+
):
|
90 |
+
return PaddedScatterOp.apply(
|
91 |
+
x,
|
92 |
+
indices,
|
93 |
+
bin_ids,
|
94 |
+
weights,
|
95 |
+
bins,
|
96 |
+
padded_bins,
|
97 |
+
top_k,
|
98 |
+
)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/padded_scatter_benchmark.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from absl.testing import parameterized
|
8 |
+
|
9 |
+
from .. import benchmark_util, ops
|
10 |
+
|
11 |
+
_PADDED_SCATTER_BENCHMARK = (
|
12 |
+
# dMoE-Medium, 8-way EMP.
|
13 |
+
(1024 * 16, 1024, 8, 4),
|
14 |
+
# dMoE-Medium, post-all-to-all.
|
15 |
+
(1024 * 16 * 4, 1024, 8, 1),
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
class PaddedScatterTest(parameterized.TestCase):
|
20 |
+
|
21 |
+
@parameterized.parameters(*_PADDED_SCATTER_BENCHMARK)
|
22 |
+
def testPaddedScatter(self, sl, hs, ne, top_k):
|
23 |
+
# Create the data and indices.
|
24 |
+
x = torch.randn((sl, hs)).cuda().half()
|
25 |
+
|
26 |
+
# Randomly assign tokens to experts.
|
27 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
28 |
+
bin_ids, indices = ops.sort(top_expert)
|
29 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
30 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
31 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
32 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
33 |
+
|
34 |
+
# Sample weights for the scatter reduce.
|
35 |
+
weights = torch.rand((sl * top_k,)).cuda().half()
|
36 |
+
|
37 |
+
# Gather the data to prepare for backwards.
|
38 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
39 |
+
|
40 |
+
def benchmark():
|
41 |
+
return ops.padded_scatter(
|
42 |
+
x,
|
43 |
+
indices,
|
44 |
+
bin_ids,
|
45 |
+
weights,
|
46 |
+
bins,
|
47 |
+
padded_bins,
|
48 |
+
top_k,
|
49 |
+
)
|
50 |
+
|
51 |
+
time, std = benchmark_util.benchmark_function(benchmark)
|
52 |
+
benchmark_util.log_benchmark(
|
53 |
+
'Padded Scatter',
|
54 |
+
{
|
55 |
+
'sequence_length': sl,
|
56 |
+
'hidden_size': hs,
|
57 |
+
'num_experts': ne,
|
58 |
+
'top_k': top_k,
|
59 |
+
},
|
60 |
+
time,
|
61 |
+
std,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == '__main__':
|
66 |
+
unittest.main()
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/permute_benchmark.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from absl.testing import parameterized
|
8 |
+
|
9 |
+
from .. import benchmark_util, ops
|
10 |
+
|
11 |
+
_PERMUTE_TESTS = (
|
12 |
+
(16384, 768, 2),
|
13 |
+
(16384, 768, 4),
|
14 |
+
(16384, 768, 8),
|
15 |
+
(16384, 768, 16),
|
16 |
+
(16384, 768, 32),
|
17 |
+
(16384, 768, 64),
|
18 |
+
(16384, 768, 128),
|
19 |
+
(16384 * 8, 768, 2),
|
20 |
+
(16384 * 8, 768, 4),
|
21 |
+
(16384 * 8, 768, 8),
|
22 |
+
(16384 * 8, 768, 16),
|
23 |
+
(16384 * 8, 768, 32),
|
24 |
+
(16384 * 8, 768, 64),
|
25 |
+
(16384 * 8, 768, 128),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class PermuteBenchmark(parameterized.TestCase):
|
30 |
+
|
31 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
32 |
+
def testBinnedGather(self, sl, hs, ne):
|
33 |
+
# NOTE: Capacity factor == 1.
|
34 |
+
ec = sl // ne
|
35 |
+
|
36 |
+
# Create the data and indices.
|
37 |
+
x = torch.randn((sl, hs)).cuda().half()
|
38 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
39 |
+
bin_ids, indices = ops.sort(top_expert)
|
40 |
+
tokens_per_expert = ops.histogram(indices, ne)
|
41 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
42 |
+
|
43 |
+
def benchmark():
|
44 |
+
return ops.binned_gather(x, indices, bins, ec)
|
45 |
+
|
46 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
47 |
+
arguments = {
|
48 |
+
'sequence_length': sl,
|
49 |
+
'hidden_size': hs,
|
50 |
+
'num_experts': ne,
|
51 |
+
}
|
52 |
+
benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t)
|
53 |
+
|
54 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
55 |
+
def testBinnedScatter(self, sl, hs, ne):
|
56 |
+
# NOTE: Capacity factor == 1.
|
57 |
+
ec = sl // ne
|
58 |
+
|
59 |
+
# Create the data and indices.
|
60 |
+
x = torch.randn((sl, hs)).cuda().half()
|
61 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
62 |
+
bin_ids, indices = ops.sort(top_expert)
|
63 |
+
tokens_per_expert = ops.histogram(indices, ne)
|
64 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
65 |
+
x = ops.binned_gather(x, indices, bins, ec)
|
66 |
+
|
67 |
+
def benchmark():
|
68 |
+
return ops.binned_scatter(x, indices, bins)
|
69 |
+
|
70 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
71 |
+
arguments = {
|
72 |
+
'sequence_length': sl,
|
73 |
+
'hidden_size': hs,
|
74 |
+
'num_experts': ne,
|
75 |
+
}
|
76 |
+
benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t)
|
77 |
+
|
78 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
79 |
+
def testPaddedGather(self, sl, hs, ne):
|
80 |
+
# Create the data and indices.
|
81 |
+
x = torch.randn((sl, hs)).cuda().half()
|
82 |
+
|
83 |
+
# Randomly assign tokens to experts.
|
84 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
85 |
+
bin_ids, indices = ops.sort(top_expert)
|
86 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
87 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
88 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
89 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
90 |
+
|
91 |
+
def benchmark():
|
92 |
+
return ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
|
93 |
+
|
94 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
95 |
+
arguments = {
|
96 |
+
'sequence_length': sl,
|
97 |
+
'hidden_size': hs,
|
98 |
+
'num_experts': ne,
|
99 |
+
}
|
100 |
+
benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t)
|
101 |
+
|
102 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
103 |
+
def testPaddedScatter(self, sl, hs, ne):
|
104 |
+
# Create the data and indices.
|
105 |
+
x = torch.randn((sl, hs)).cuda().half()
|
106 |
+
|
107 |
+
# Randomly assign tokens to experts.
|
108 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
109 |
+
bin_ids, indices = ops.sort(top_expert)
|
110 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
111 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
112 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
113 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
114 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins)
|
115 |
+
|
116 |
+
def benchmark():
|
117 |
+
return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins)
|
118 |
+
|
119 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
120 |
+
arguments = {
|
121 |
+
'sequence_length': sl,
|
122 |
+
'hidden_size': hs,
|
123 |
+
'num_experts': ne,
|
124 |
+
}
|
125 |
+
benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t)
|
126 |
+
|
127 |
+
@parameterized.parameters(*_PERMUTE_TESTS)
|
128 |
+
def testCopy(self, sl, hs, ne):
|
129 |
+
# NOTE: Capacity factor == 1.
|
130 |
+
# ec = sl // ne
|
131 |
+
|
132 |
+
# Create the data and indices.
|
133 |
+
x = torch.randn((sl, hs)).cuda().half()
|
134 |
+
y = x.clone()
|
135 |
+
|
136 |
+
def benchmark():
|
137 |
+
return y.copy_(x)
|
138 |
+
|
139 |
+
mean_t, std_t = benchmark_util.benchmark_function(benchmark)
|
140 |
+
arguments = {
|
141 |
+
'sequence_length': sl,
|
142 |
+
'hidden_size': hs,
|
143 |
+
'num_experts': ne,
|
144 |
+
}
|
145 |
+
benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t)
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
unittest.main()
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/repeat.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def repeat(x: torch.Tensor, tiling: torch.Size):
|
8 |
+
if all((t == 1 for t in tiling)):
|
9 |
+
return x
|
10 |
+
return x.repeat(*tiling)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/replicate.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
|
18 |
+
# Autograd wrapper for replicate kernel.
|
19 |
+
class ReplicateOp(torch.autograd.Function):
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
|
23 |
+
ctx.save_for_backward(bins)
|
24 |
+
out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device)
|
25 |
+
ops.replicate_forward(x, bins, out)
|
26 |
+
return out
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
30 |
+
bins, = ctx.saved_tensors
|
31 |
+
out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device)
|
32 |
+
ops.replicate_backward(grad, bins, out)
|
33 |
+
return out, None, None
|
34 |
+
|
35 |
+
|
36 |
+
replicate = ReplicateOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/round_up.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def round_up(x: torch.Tensor, value: int):
|
8 |
+
assert isinstance(value, int)
|
9 |
+
assert x.dtype == torch.int32
|
10 |
+
|
11 |
+
# TODO(tgale): If this becomes and issue
|
12 |
+
# do this in a custom kernel. We only expect
|
13 |
+
# to use this on arrays of less than 1k elements.
|
14 |
+
return torch.div(x + (value - 1), value, rounding_mode='trunc') * value
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/scatter.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from .stk_autocast import custom_bwd, custom_fwd
|
8 |
+
|
9 |
+
from ..backend import kernels
|
10 |
+
|
11 |
+
|
12 |
+
# Autograd wrapper for scatter kernel.
|
13 |
+
class ScatterOp(torch.autograd.Function):
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
@custom_fwd
|
17 |
+
def forward(
|
18 |
+
ctx: Any,
|
19 |
+
x: torch.Tensor,
|
20 |
+
indices: torch.Tensor,
|
21 |
+
bin_ids: torch.Tensor,
|
22 |
+
weights: torch.Tensor,
|
23 |
+
bins: torch.Tensor,
|
24 |
+
top_k: int,
|
25 |
+
) -> torch.Tensor:
|
26 |
+
maybe_x = [x] if ctx.needs_input_grad[3] else []
|
27 |
+
ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x)
|
28 |
+
ctx.top_k = top_k
|
29 |
+
ctx.x_shape = x.shape
|
30 |
+
return kernels.scatter(x, indices, bin_ids, weights, bins, top_k)
|
31 |
+
|
32 |
+
@staticmethod
|
33 |
+
@custom_bwd
|
34 |
+
def backward(ctx: Any, grad: torch.Tensor):
|
35 |
+
grad = grad.contiguous()
|
36 |
+
saved_tensors = ctx.saved_tensors
|
37 |
+
|
38 |
+
indices, bin_ids, weights, bins = saved_tensors[:4]
|
39 |
+
dgrad = None
|
40 |
+
if ctx.needs_input_grad[0]:
|
41 |
+
dgrad = kernels.gather(
|
42 |
+
grad,
|
43 |
+
indices,
|
44 |
+
bin_ids,
|
45 |
+
weights,
|
46 |
+
bins,
|
47 |
+
ctx.top_k,
|
48 |
+
)
|
49 |
+
|
50 |
+
wgrad = None
|
51 |
+
if ctx.needs_input_grad[3]: # need wgrad
|
52 |
+
x = saved_tensors[-1]
|
53 |
+
wgrad = kernels.scatter_wgrad(
|
54 |
+
x,
|
55 |
+
grad,
|
56 |
+
indices,
|
57 |
+
bin_ids,
|
58 |
+
bins,
|
59 |
+
ctx.top_k,
|
60 |
+
)
|
61 |
+
return dgrad, None, None, wgrad, None, None, None
|
62 |
+
|
63 |
+
|
64 |
+
def scatter(
|
65 |
+
x: torch.Tensor,
|
66 |
+
indices: torch.Tensor,
|
67 |
+
bin_ids: torch.Tensor,
|
68 |
+
weights: torch.Tensor,
|
69 |
+
bins: torch.Tensor,
|
70 |
+
top_k: int,
|
71 |
+
) -> Optional[torch.Tensor]:
|
72 |
+
return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k)
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Optional, Tuple
|
5 |
+
|
6 |
+
# NOTE: Torch needs to be imported before the custom
|
7 |
+
# extensions. Otherwise libc10.so cannot be found.
|
8 |
+
import torch
|
9 |
+
|
10 |
+
# Wrap this in a try-block with better error message and
|
11 |
+
# instructions for building the c++ operations.
|
12 |
+
try:
|
13 |
+
from .._ops import ops # type: ignore
|
14 |
+
except ModuleNotFoundError as e:
|
15 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
16 |
+
|
17 |
+
_BITS_FOR_DTYPE = {
|
18 |
+
torch.int16: 16,
|
19 |
+
torch.int32: 32,
|
20 |
+
torch.int64: 64,
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
# Autograd wrapper for sort kernel.
|
25 |
+
# NOTE: Does not support gradients.
|
26 |
+
class SortOp(torch.autograd.Function):
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
30 |
+
if end_bit is None:
|
31 |
+
end_bit = _BITS_FOR_DTYPE[x.dtype]
|
32 |
+
x_out = torch.empty_like(x)
|
33 |
+
iota_out = torch.empty_like(x)
|
34 |
+
ops.sort(x, end_bit, x_out, iota_out)
|
35 |
+
return (x_out, iota_out)
|
36 |
+
|
37 |
+
|
38 |
+
sort = SortOp.apply
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sort_benchmark.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from absl.testing import parameterized
|
9 |
+
|
10 |
+
from .. import ops
|
11 |
+
|
12 |
+
_SORT_TESTS = (
|
13 |
+
(16384, torch.int32, None),
|
14 |
+
(16384, torch.int32, 2),
|
15 |
+
(16384, torch.int32, 128),
|
16 |
+
)
|
17 |
+
|
18 |
+
_BASELINE_SORT_TESTS = ((16384,),)
|
19 |
+
|
20 |
+
|
21 |
+
def numpy_dtype(dtype):
|
22 |
+
types = {
|
23 |
+
torch.int16: np.int16,
|
24 |
+
torch.int32: np.int32,
|
25 |
+
torch.int64: np.int64,
|
26 |
+
}
|
27 |
+
return types[dtype]
|
28 |
+
|
29 |
+
|
30 |
+
def benchmark_function(fn, iterations=10):
|
31 |
+
# Run once to get rid of startup overhead.
|
32 |
+
fn()
|
33 |
+
times = []
|
34 |
+
for _ in range(iterations):
|
35 |
+
start = torch.cuda.Event(enable_timing=True)
|
36 |
+
end = torch.cuda.Event(enable_timing=True)
|
37 |
+
start.record()
|
38 |
+
fn()
|
39 |
+
end.record()
|
40 |
+
torch.cuda.synchronize()
|
41 |
+
times.append(start.elapsed_time(end))
|
42 |
+
times = np.array(times)
|
43 |
+
return times.mean(), times.std(), times.max(), times.min()
|
44 |
+
|
45 |
+
|
46 |
+
def log_benchmark(arguments, mean_t, std_t):
|
47 |
+
print('=' * 60)
|
48 |
+
print('Benchmark Parameters:')
|
49 |
+
for (key, value) in arguments.items():
|
50 |
+
print(f'{key} = {value}')
|
51 |
+
print('Results:')
|
52 |
+
print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t))
|
53 |
+
print('=' * 60)
|
54 |
+
|
55 |
+
|
56 |
+
class SortBenchmark(parameterized.TestCase):
|
57 |
+
|
58 |
+
@parameterized.parameters(*_SORT_TESTS)
|
59 |
+
def testSort(self, n, dtype, max_val):
|
60 |
+
if max_val is None:
|
61 |
+
max_val = np.iinfo(numpy_dtype(dtype)).max
|
62 |
+
end_bit = int(np.ceil(np.log2(max_val)))
|
63 |
+
x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
|
64 |
+
|
65 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),)
|
66 |
+
arguments = {
|
67 |
+
'n': n,
|
68 |
+
'dtype': dtype,
|
69 |
+
'max_val': max_val,
|
70 |
+
}
|
71 |
+
log_benchmark(arguments, mean_t, std_t)
|
72 |
+
|
73 |
+
@parameterized.parameters(*_BASELINE_SORT_TESTS)
|
74 |
+
def testTorchSort(self, n):
|
75 |
+
x = torch.randint(0, 128, (n,)).cuda().to(torch.int32)
|
76 |
+
|
77 |
+
mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x))
|
78 |
+
arguments = {
|
79 |
+
'n': n,
|
80 |
+
}
|
81 |
+
log_benchmark(arguments, mean_t, std_t)
|
82 |
+
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
unittest.main()
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/stk_autocast.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# vendored from
|
2 |
+
# https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/autocast.py
|
3 |
+
import functools
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def _is_eligible(x):
|
8 |
+
return x.is_floating_point() and x.is_cuda and (x.dtype is not torch.float64)
|
9 |
+
|
10 |
+
|
11 |
+
def _cast(x, dtype):
|
12 |
+
if isinstance(x, torch.Tensor) and _is_eligible(x):
|
13 |
+
return x.to(dtype)
|
14 |
+
elif isinstance(x, map):
|
15 |
+
return {_cast(k, dtype): _cast(v, dtype) for k, v in x.items()}
|
16 |
+
elif isinstance(x, list) or isinstance(x, tuple):
|
17 |
+
return type(x)(map(lambda y: _cast(y, dtype), x))
|
18 |
+
return x
|
19 |
+
|
20 |
+
|
21 |
+
def custom_fwd(fwd):
|
22 |
+
"""Wrap a custom autograd function that always uses autocast dtype."""
|
23 |
+
|
24 |
+
@functools.wraps(fwd)
|
25 |
+
def decorate_fwd(*args, **kwargs):
|
26 |
+
if torch.is_autocast_enabled():
|
27 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
28 |
+
dtype = torch.get_autocast_gpu_dtype()
|
29 |
+
return fwd(*_cast(args, dtype), **_cast(kwargs, dtype))
|
30 |
+
return fwd(*args, **kwargs)
|
31 |
+
return decorate_fwd
|
32 |
+
|
33 |
+
|
34 |
+
def custom_bwd(bwd):
|
35 |
+
@functools.wraps(bwd)
|
36 |
+
def decorate_bwd(*args, **kwargs):
|
37 |
+
with torch.autocast(device_type="cuda", enabled=False):
|
38 |
+
return bwd(*args, **kwargs)
|
39 |
+
return decorate_bwd
|
build/torch27-cxx11-rocm63-x86_64-linux/megablocks/ops/sum.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def sum(x: torch.Tensor, dim: int = 0):
|
7 |
+
if x.shape[dim] == 1:
|
8 |
+
return x.squeeze(dim=dim)
|
9 |
+
return x.sum(dim=dim)
|