Akos Hadnagy
commited on
Commit
·
1e1ffe8
1
Parent(s):
ff615fc
WIP
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- build.toml +35 -0
- csrc/bak.ops.cu +21 -0
- csrc/cuda_util.h +62 -0
- csrc/cumsum.h +163 -0
- csrc/grouped_gemm/fill_arguments.cuh +141 -0
- csrc/grouped_gemm/grouped_gemm.cu +567 -0
- csrc/grouped_gemm/grouped_gemm.h +20 -0
- csrc/grouped_gemm/ops.cu +11 -0
- csrc/histogram.h +86 -0
- csrc/indices.h +95 -0
- csrc/new_cumsum.cu +161 -0
- csrc/new_cumsum.h +11 -0
- csrc/new_histogram.cu +85 -0
- csrc/new_histogram.h +10 -0
- csrc/new_indices.cu +97 -0
- csrc/new_indices.h +14 -0
- csrc/new_replicate.cu +220 -0
- csrc/new_replicate.h +17 -0
- csrc/new_sort.cu +90 -0
- csrc/new_sort.h +13 -0
- csrc/replicate.h +211 -0
- csrc/sort.h +91 -0
- flake.lock +168 -0
- flake.nix +24 -0
- tests/__init__.py +0 -0
- tests/conftest.py +110 -0
- tests/fixtures/autouse.py +107 -0
- tests/fixtures/fixtures.py +13 -0
- tests/layer_test.py +53 -0
- tests/layers/architectures.py +53 -0
- tests/layers/moe_test.py +199 -0
- tests/ops/binned_gather_test.py +71 -0
- tests/ops/binned_scatter_test.py +87 -0
- tests/ops/cumsum_test.py +44 -0
- tests/ops/histogram_test.py +82 -0
- tests/ops/padded_gather_test.py +94 -0
- tests/ops/padded_scatter_test.py +155 -0
- tests/ops/replicate_test.py +108 -0
- tests/ops/sort_test.py +65 -0
- tests/ops/topology_test.py +81 -0
- tests/ops_test.py +171 -0
- tests/parallel_layer_test.py +94 -0
- tests/test_gg.py +57 -0
- tests/test_mb_moe.py +48 -0
- tests/test_mb_moe_shared_expert.py +139 -0
- tests/test_mb_moe_shared_expert_multi.py +200 -0
- torch-ext/megablocks/__init__.py +202 -0
- torch-ext/megablocks/_layers/__init__.py +10 -0
- torch-ext/megablocks/_layers/activation_fn.py +33 -0
- torch-ext/megablocks/_layers/all_to_all.py +54 -0
build.toml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "megablocks"
|
3 |
+
universal = false
|
4 |
+
|
5 |
+
[torch]
|
6 |
+
src = [
|
7 |
+
"torch-ext/torch_binding.cpp",
|
8 |
+
"torch-ext/torch_binding.h"
|
9 |
+
]
|
10 |
+
|
11 |
+
[kernel.megablocks]
|
12 |
+
backend = "rocm"
|
13 |
+
rocm-archs = [
|
14 |
+
"gfx942",
|
15 |
+
"gfx1030",
|
16 |
+
"gfx1100",
|
17 |
+
"gfx1101",
|
18 |
+
]
|
19 |
+
depends = ["torch"]
|
20 |
+
src = [
|
21 |
+
"csrc/new_cumsum.h",
|
22 |
+
"csrc/new_cumsum.cu",
|
23 |
+
"csrc/new_histogram.h",
|
24 |
+
"csrc/new_histogram.cu",
|
25 |
+
"csrc/new_indices.h",
|
26 |
+
"csrc/new_indices.cu",
|
27 |
+
"csrc/new_replicate.cu",
|
28 |
+
"csrc/new_replicate.h",
|
29 |
+
"csrc/new_sort.h",
|
30 |
+
"csrc/new_sort.cu",
|
31 |
+
# vendored grouped gemm
|
32 |
+
#"csrc/grouped_gemm/fill_arguments.cuh",
|
33 |
+
#"csrc/grouped_gemm/grouped_gemm.cu",
|
34 |
+
#"csrc/grouped_gemm/grouped_gemm.h",
|
35 |
+
]
|
csrc/bak.ops.cu
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cumsum.h"
|
2 |
+
#include "histogram.h"
|
3 |
+
#include "indices.h"
|
4 |
+
#include "replicate.h"
|
5 |
+
#include "sort.h"
|
6 |
+
|
7 |
+
#include <torch/extension.h>
|
8 |
+
|
9 |
+
namespace megablocks {
|
10 |
+
|
11 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
12 |
+
m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
|
13 |
+
m.def("histogram", &histogram, "even width histogram.");
|
14 |
+
m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum");
|
15 |
+
m.def("indices", &indices, "indices construction for sparse matrix.");
|
16 |
+
m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically.");
|
17 |
+
m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically.");
|
18 |
+
m.def("sort", &sort, "key/value sort.");
|
19 |
+
}
|
20 |
+
|
21 |
+
} // namespace megablocks
|
csrc/cuda_util.h
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef BLOCKPARTY_CSRC_CUDA_UTIL_H_
|
2 |
+
#define BLOCKPARTY_CSRC_CUDA_UTIL_H_
|
3 |
+
|
4 |
+
#include <cuda_fp16.h>
|
5 |
+
#include <cuda_runtime.h>
|
6 |
+
// #include <torch/extension.h>
|
7 |
+
|
8 |
+
namespace megablocks {
|
9 |
+
|
10 |
+
typedef __half2 half2;
|
11 |
+
|
12 |
+
struct __align__(8) half4 {
|
13 |
+
half2 x, y;
|
14 |
+
};
|
15 |
+
|
16 |
+
struct __align__(16) half8 {
|
17 |
+
half2 x, y, z, w;
|
18 |
+
};
|
19 |
+
|
20 |
+
template <class To, class From>
|
21 |
+
__device__ __forceinline__ To BitCast(const From& src) noexcept {
|
22 |
+
To dst;
|
23 |
+
std::memcpy(&dst, &src, sizeof(To));
|
24 |
+
return dst;
|
25 |
+
}
|
26 |
+
|
27 |
+
template <typename T>
|
28 |
+
__device__ __forceinline__ void Store(const T& value, T* ptr) {
|
29 |
+
*ptr = value;
|
30 |
+
}
|
31 |
+
|
32 |
+
template <typename T>
|
33 |
+
__device__ __forceinline__ T Load(const T* address) {
|
34 |
+
return __ldg(address);
|
35 |
+
}
|
36 |
+
|
37 |
+
__device__ __forceinline__ half4 Load(const half4* address) {
|
38 |
+
float2 x = __ldg(reinterpret_cast<const float2*>(address));
|
39 |
+
return BitCast<half4>(x);
|
40 |
+
}
|
41 |
+
|
42 |
+
__device__ __forceinline__ half8 Load(const half8* address) {
|
43 |
+
float4 x = __ldg(reinterpret_cast<const float4*>(address));
|
44 |
+
return BitCast<half8>(x);
|
45 |
+
}
|
46 |
+
|
47 |
+
template <typename T>
|
48 |
+
__device__ __forceinline__ T Zero() { return 0; };
|
49 |
+
|
50 |
+
template <>
|
51 |
+
__device__ __forceinline__ half2 Zero<half2>() {
|
52 |
+
return {(c10::Half)0., (c10::Half)0.};
|
53 |
+
};
|
54 |
+
|
55 |
+
template <>
|
56 |
+
__device__ __forceinline__ half4 Zero<half4>() {
|
57 |
+
return {Zero<half2>(), Zero<half2>()};
|
58 |
+
};
|
59 |
+
|
60 |
+
} // namespace megablocks
|
61 |
+
|
62 |
+
#endif // BLOCKPARTY_CSRC_CUDA_UTIL_H_
|
csrc/cumsum.h
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#define CUB_IGNORE_DEPRECATED_API
|
2 |
+
|
3 |
+
#undef CUB_WRAPPED_NAMESPACE
|
4 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
5 |
+
|
6 |
+
#include <cstdint>
|
7 |
+
|
8 |
+
#include <cub/cub.cuh>
|
9 |
+
#include <c10/cuda/CUDAStream.h>
|
10 |
+
#include <torch/all.h>
|
11 |
+
// #include <torch/extension.h>
|
12 |
+
|
13 |
+
#define CUDA_CALL(code) \
|
14 |
+
do { \
|
15 |
+
cudaError_t status = code; \
|
16 |
+
std::string err = cudaGetErrorString(status); \
|
17 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
18 |
+
} while (0)
|
19 |
+
|
20 |
+
namespace megablocks {
|
21 |
+
|
22 |
+
struct Inclusive {};
|
23 |
+
struct Exclusive {};
|
24 |
+
|
25 |
+
template <typename Type> struct Cumsum {
|
26 |
+
|
27 |
+
template<
|
28 |
+
typename InputIteratorT,
|
29 |
+
typename OutputIteratorT>
|
30 |
+
static void Run(void * d_temp_storage,
|
31 |
+
size_t & temp_storage_bytes,
|
32 |
+
InputIteratorT d_in,
|
33 |
+
OutputIteratorT d_out,
|
34 |
+
int num_items,
|
35 |
+
cudaStream_t stream = 0,
|
36 |
+
bool debug_synchronous = false) {
|
37 |
+
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
|
38 |
+
temp_storage_bytes,
|
39 |
+
d_in,
|
40 |
+
d_out,
|
41 |
+
num_items,
|
42 |
+
stream));//,
|
43 |
+
//debug_synchronous));
|
44 |
+
}
|
45 |
+
};
|
46 |
+
|
47 |
+
template <> struct Cumsum<Inclusive> {
|
48 |
+
template<
|
49 |
+
typename InputIteratorT,
|
50 |
+
typename OutputIteratorT>
|
51 |
+
static void Run(void * d_temp_storage,
|
52 |
+
size_t & temp_storage_bytes,
|
53 |
+
InputIteratorT d_in,
|
54 |
+
OutputIteratorT d_out,
|
55 |
+
int num_items,
|
56 |
+
cudaStream_t stream = 0,
|
57 |
+
bool debug_synchronous = false) {
|
58 |
+
CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
|
59 |
+
temp_storage_bytes,
|
60 |
+
d_in,
|
61 |
+
d_out,
|
62 |
+
num_items,
|
63 |
+
stream));//,
|
64 |
+
//debug_synchronous));
|
65 |
+
}
|
66 |
+
};
|
67 |
+
|
68 |
+
template <typename SumType, typename T>
|
69 |
+
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
70 |
+
// Get temporary storage size.
|
71 |
+
size_t scratchpad_bytes = 0;
|
72 |
+
Cumsum<SumType>::Run(nullptr,
|
73 |
+
scratchpad_bytes,
|
74 |
+
x.data_ptr<T>(),
|
75 |
+
out.data_ptr<T>(),
|
76 |
+
x.size(1),
|
77 |
+
c10::cuda::getCurrentCUDAStream());
|
78 |
+
|
79 |
+
// Allocate scratchpad.
|
80 |
+
//
|
81 |
+
// NOTE: We scale for the batch dimension so we can run in parallel.
|
82 |
+
auto options = torch::TensorOptions()
|
83 |
+
.dtype(torch::kInt8)
|
84 |
+
.device(x.device());
|
85 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
|
86 |
+
options);
|
87 |
+
|
88 |
+
// Run the kernel.
|
89 |
+
//
|
90 |
+
// NOTE: Using different streams for each issue does not appear to
|
91 |
+
// yield performance gains for our problem set. The overhead of
|
92 |
+
// event/stream synchronization appears to outweigh the benfits.
|
93 |
+
// We could write a true batched cumsum, but this would require
|
94 |
+
// significant code duplication from cub and we might move away
|
95 |
+
// from this formulation anyways.
|
96 |
+
for (int i = 0; i < x.size(0); ++i) {
|
97 |
+
void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
|
98 |
+
Cumsum<SumType>::Run(scratchpad_ptr,
|
99 |
+
scratchpad_bytes,
|
100 |
+
x.data_ptr<T>() + x.size(1) * i,
|
101 |
+
out.data_ptr<T>() + x.size(1) * i,
|
102 |
+
x.size(1),
|
103 |
+
c10::cuda::getCurrentCUDAStream());
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
108 |
+
// Validate the input matrix.
|
109 |
+
TORCH_CHECK(x.is_cuda());
|
110 |
+
TORCH_CHECK(x.ndimension() == 2);
|
111 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
112 |
+
x.scalar_type() == torch::kInt32 ||
|
113 |
+
x.scalar_type() == torch::kInt64);
|
114 |
+
TORCH_CHECK(out.is_cuda());
|
115 |
+
TORCH_CHECK(out.ndimension() == 2);
|
116 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
117 |
+
|
118 |
+
// NOTE: We currently only support contraction across the contiguous
|
119 |
+
// dimension in the matrix.
|
120 |
+
TORCH_CHECK(dim == 1);
|
121 |
+
|
122 |
+
switch (x.scalar_type()) {
|
123 |
+
case torch::kInt16:
|
124 |
+
cub_cumsum<Exclusive, short>(x, dim, out);
|
125 |
+
return;
|
126 |
+
case torch::kInt32:
|
127 |
+
cub_cumsum<Exclusive, int>(x, dim, out);
|
128 |
+
return;
|
129 |
+
}
|
130 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
131 |
+
cub_cumsum<Exclusive, long>(x, dim, out);
|
132 |
+
}
|
133 |
+
|
134 |
+
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
135 |
+
// Validate the input matrix.
|
136 |
+
TORCH_CHECK(x.is_cuda());
|
137 |
+
TORCH_CHECK(x.ndimension() == 2);
|
138 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
139 |
+
x.scalar_type() == torch::kInt32 ||
|
140 |
+
x.scalar_type() == torch::kInt64);
|
141 |
+
TORCH_CHECK(out.is_cuda());
|
142 |
+
TORCH_CHECK(out.ndimension() == 2);
|
143 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
144 |
+
|
145 |
+
// NOTE: We currently only support contraction across the contiguous
|
146 |
+
// dimension in the matrix.
|
147 |
+
TORCH_CHECK(dim == 1);
|
148 |
+
|
149 |
+
switch (x.scalar_type()) {
|
150 |
+
case torch::kInt16:
|
151 |
+
cub_cumsum<Inclusive, short>(x, dim, out);
|
152 |
+
return;
|
153 |
+
case torch::kInt32:
|
154 |
+
cub_cumsum<Inclusive, int>(x, dim, out);
|
155 |
+
return;
|
156 |
+
}
|
157 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
158 |
+
cub_cumsum<Inclusive, long>(x, dim, out);
|
159 |
+
}
|
160 |
+
|
161 |
+
} // namespace megablocks
|
162 |
+
|
163 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/grouped_gemm/fill_arguments.cuh
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/cuda/detail/KernelUtils.h>
|
4 |
+
#include <cub/cub.cuh>
|
5 |
+
#include <cutlass/bfloat16.h>
|
6 |
+
#include <cutlass/gemm_coord.h>
|
7 |
+
|
8 |
+
namespace grouped_gemm {
|
9 |
+
|
10 |
+
constexpr int kDynamicDim = -1;
|
11 |
+
constexpr int kMaxExperts = 512;
|
12 |
+
|
13 |
+
struct GemmProblem {
|
14 |
+
::cutlass::gemm::GemmCoord dims;
|
15 |
+
int64_t lda, ldb, ldc;
|
16 |
+
// All offsets are in elements.
|
17 |
+
int64_t a_offset, b_offset, c_offset;
|
18 |
+
};
|
19 |
+
|
20 |
+
// TODO: revisit `ExtractGemmProblemK` struct
|
21 |
+
// struct ExtractGemmProblemK {
|
22 |
+
// __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
|
23 |
+
// return {problem.dims.k()};
|
24 |
+
// }
|
25 |
+
// };
|
26 |
+
|
27 |
+
template <
|
28 |
+
// If `k` is dynamic, we sort the problems by `k` in descending order.
|
29 |
+
// Otherwise, `m` is dynamic, and no sorting happens.
|
30 |
+
bool kDynamicK,
|
31 |
+
typename ElementA, typename ElementB, typename ElementC,
|
32 |
+
typename LayoutA, typename LayoutB, typename LayoutC,
|
33 |
+
typename Args
|
34 |
+
>
|
35 |
+
__global__ void FillArguments(
|
36 |
+
int num_experts, const int64_t* batch_sizes,
|
37 |
+
ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
|
38 |
+
Args args, ::cutlass::gemm::GemmCoord dims
|
39 |
+
) {
|
40 |
+
const int expert_idx = threadIdx.x;
|
41 |
+
const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
|
42 |
+
|
43 |
+
if (kDynamicK) {
|
44 |
+
assert(dims.k() == kDynamicDim);
|
45 |
+
dims.k() = batch_size;
|
46 |
+
} else {
|
47 |
+
assert(dims.m() == kDynamicDim);
|
48 |
+
dims.m() = batch_size;
|
49 |
+
}
|
50 |
+
|
51 |
+
using BlockScan = cub::BlockScan<int, kMaxExperts>;
|
52 |
+
using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
|
53 |
+
|
54 |
+
union SharedMemory {
|
55 |
+
typename BlockScan::TempStorage scan_storage;
|
56 |
+
typename BlockSort::TempStorage sort_storage;
|
57 |
+
};
|
58 |
+
__shared__ SharedMemory shared_memory;
|
59 |
+
|
60 |
+
int dynamic_dim = kDynamicK ? dims.k() : dims.m();
|
61 |
+
int dynamic_dim_cumsum;
|
62 |
+
BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
|
63 |
+
__syncthreads();
|
64 |
+
|
65 |
+
// We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
|
66 |
+
// `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
|
67 |
+
GemmProblem problem[1] = {
|
68 |
+
GemmProblem {
|
69 |
+
.dims = dims,
|
70 |
+
.lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
|
71 |
+
.ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
|
72 |
+
.ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
|
73 |
+
.a_offset = kDynamicK
|
74 |
+
? (dims.m() * dynamic_dim_cumsum)
|
75 |
+
: (dynamic_dim_cumsum * dims.k()),
|
76 |
+
.b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
|
77 |
+
.c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
|
78 |
+
},
|
79 |
+
};
|
80 |
+
|
81 |
+
if constexpr (kDynamicK) {
|
82 |
+
// Sort by k dimension in descending order
|
83 |
+
// We need to extract the key (k value) for sorting
|
84 |
+
int k_keys[1] = { problem[0].dims.k() };
|
85 |
+
|
86 |
+
BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
|
87 |
+
|
88 |
+
// TODO: revisit original impl without `__syncthreads()`
|
89 |
+
// BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
|
90 |
+
// Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
|
91 |
+
// > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
|
92 |
+
// > is **to be reused or repurposed**.
|
93 |
+
// We don't need `__syncthreads()` here, since we don't do either of these things.
|
94 |
+
}
|
95 |
+
|
96 |
+
if (expert_idx < num_experts) {
|
97 |
+
args.problem_sizes[expert_idx] = problem[0].dims;
|
98 |
+
args.lda[expert_idx] = problem[0].lda;
|
99 |
+
args.ldb[expert_idx] = problem[0].ldb;
|
100 |
+
args.ldc[expert_idx] = problem[0].ldc;
|
101 |
+
|
102 |
+
args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
|
103 |
+
args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
|
104 |
+
args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
|
108 |
+
template <typename Args>
|
109 |
+
__global__ void ZeroOutK0Outputs(int num_experts, Args args) {
|
110 |
+
const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
|
111 |
+
const int64_t delta = (int64_t)gridDim.x * blockDim.x;
|
112 |
+
for (int ei = 0; ei < num_experts; ++ei) {
|
113 |
+
auto& dims = args.problem_sizes[ei];
|
114 |
+
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
|
115 |
+
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
|
116 |
+
// * (here) set the output to zero
|
117 |
+
// * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
|
118 |
+
if (dims.k() == 0) {
|
119 |
+
// Assume packed layout, run a grid-strided loop over the output.
|
120 |
+
int64_t total_elems = (int64_t)dims.m() * dims.n();
|
121 |
+
auto* out = args.ptr_C[ei];
|
122 |
+
for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
|
123 |
+
out[idx] = {};
|
124 |
+
}
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
template <typename Args>
|
130 |
+
__global__ void IgnoreK0Problems(int num_experts, Args args) {
|
131 |
+
const int expert_idx = threadIdx.x;
|
132 |
+
if (expert_idx < num_experts) {
|
133 |
+
auto& dims = args.problem_sizes[expert_idx];
|
134 |
+
if (dims.k() == 0) {
|
135 |
+
dims.m() = 0;
|
136 |
+
dims.n() = 0;
|
137 |
+
}
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
} // namespace grouped_gemm
|
csrc/grouped_gemm/grouped_gemm.cu
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "grouped_gemm.h"
|
2 |
+
#include "fill_arguments.cuh"
|
3 |
+
|
4 |
+
#include <ATen/cuda/CUDAContext.h>
|
5 |
+
#include <ATen/cuda/detail/KernelUtils.h>
|
6 |
+
#include <c10/util/BFloat16.h>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
#include <cub/cub.cuh>
|
9 |
+
#include <torch/torch.h>
|
10 |
+
|
11 |
+
#include "cutlass/bfloat16.h"
|
12 |
+
#include "cutlass/complex.h"
|
13 |
+
#include "cutlass/gemm/kernel/gemm_grouped.h"
|
14 |
+
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
15 |
+
#include "cutlass/gemm/device/gemm_grouped.h"
|
16 |
+
|
17 |
+
#include <type_traits>
|
18 |
+
|
19 |
+
namespace grouped_gemm {
|
20 |
+
|
21 |
+
#define CUDA_CALL(code) \
|
22 |
+
do { \
|
23 |
+
cudaError_t status = code; \
|
24 |
+
std::string err = cudaGetErrorString(status); \
|
25 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
26 |
+
} while (0)
|
27 |
+
|
28 |
+
#define CUBLAS_CALL(code) \
|
29 |
+
do { \
|
30 |
+
cublasStatus_t status = code; \
|
31 |
+
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
|
32 |
+
} while (0)
|
33 |
+
|
34 |
+
#define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
|
35 |
+
#define GROUPED_GEMM_STRINGIFY(x) \
|
36 |
+
GROUPED_GEMM_STRINGIFY_HELPER(x)
|
37 |
+
|
38 |
+
template <bool trans>
|
39 |
+
using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
|
40 |
+
|
41 |
+
using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
|
42 |
+
::cutlass::arch::OpClassTensorOp,
|
43 |
+
::cutlass::arch::Sm80,
|
44 |
+
::cutlass::bfloat16_t,
|
45 |
+
::cutlass::bfloat16_t,
|
46 |
+
::cutlass::bfloat16_t,
|
47 |
+
float
|
48 |
+
>;
|
49 |
+
|
50 |
+
// TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
|
51 |
+
template <bool trans_a, bool trans_b>
|
52 |
+
using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
53 |
+
// A operand.
|
54 |
+
::cutlass::bfloat16_t,
|
55 |
+
GroupedGemmInputLayout<trans_a>,
|
56 |
+
::cutlass::ComplexTransform::kNone,
|
57 |
+
GroupedGemmConfig::kAlignmentA,
|
58 |
+
// B operand.
|
59 |
+
::cutlass::bfloat16_t,
|
60 |
+
GroupedGemmInputLayout<trans_b>,
|
61 |
+
::cutlass::ComplexTransform::kNone,
|
62 |
+
GroupedGemmConfig::kAlignmentB,
|
63 |
+
// C operand.
|
64 |
+
::cutlass::bfloat16_t,
|
65 |
+
::cutlass::layout::RowMajor,
|
66 |
+
float,
|
67 |
+
::cutlass::arch::OpClassTensorOp,
|
68 |
+
::cutlass::arch::Sm80,
|
69 |
+
GroupedGemmConfig::ThreadblockShape,
|
70 |
+
GroupedGemmConfig::WarpShape,
|
71 |
+
GroupedGemmConfig::InstructionShape,
|
72 |
+
GroupedGemmConfig::EpilogueOutputOp,
|
73 |
+
// NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
|
74 |
+
// This parameter is passed in at present to match the APIs of other kernels. The parameter
|
75 |
+
// is unused within the kernel.
|
76 |
+
::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
|
77 |
+
// TODO(tgale): Tune this for SM90.
|
78 |
+
GroupedGemmConfig::kStages>::GemmKernel;
|
79 |
+
|
80 |
+
template <bool trans_a, bool trans_b>
|
81 |
+
using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
|
82 |
+
|
83 |
+
template <typename T>
|
84 |
+
torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
|
85 |
+
size_t bytes = x.size() * sizeof(T);
|
86 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
|
87 |
+
torch::Tensor out = torch::empty(bytes, options);
|
88 |
+
|
89 |
+
CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
|
90 |
+
x.data(), bytes,
|
91 |
+
cudaMemcpyHostToDevice,
|
92 |
+
c10::cuda::getCurrentCUDAStream()));
|
93 |
+
return out;
|
94 |
+
}
|
95 |
+
|
96 |
+
template <typename T>
|
97 |
+
static void ReorderArray(T* data, const std::vector<size_t>& indices) {
|
98 |
+
// For now, simply create a copy of the data and then copy over to the original.
|
99 |
+
std::vector<T> copy(data, data + indices.size());
|
100 |
+
for (size_t i = 0; i < indices.size(); ++i) {
|
101 |
+
data[i] = copy.at(indices[i]);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
template <typename T>
|
106 |
+
torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
|
107 |
+
return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
|
108 |
+
}
|
109 |
+
|
110 |
+
struct RawGemmArguments {
|
111 |
+
torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
|
112 |
+
int threadblock_count{};
|
113 |
+
};
|
114 |
+
|
115 |
+
template <
|
116 |
+
typename Gemm,
|
117 |
+
typename ElementA, typename ElementB, typename ElementC
|
118 |
+
>
|
119 |
+
RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
|
120 |
+
TORCH_CHECK(
|
121 |
+
num_experts <= kMaxExperts,
|
122 |
+
"At most ", kMaxExperts,
|
123 |
+
" experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
|
124 |
+
);
|
125 |
+
|
126 |
+
return RawGemmArguments {
|
127 |
+
.lda = TypedEmpty<int64_t>(num_experts, device),
|
128 |
+
.ldb = TypedEmpty<int64_t>(num_experts, device),
|
129 |
+
.ldc = TypedEmpty<int64_t>(num_experts, device),
|
130 |
+
.ptr_a = TypedEmpty<ElementA*>(num_experts, device),
|
131 |
+
.ptr_b = TypedEmpty<ElementB*>(num_experts, device),
|
132 |
+
.ptr_c = TypedEmpty<ElementC*>(num_experts, device),
|
133 |
+
.problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
|
134 |
+
|
135 |
+
// We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
|
136 |
+
.threadblock_count = Gemm::sufficient(),
|
137 |
+
};
|
138 |
+
}
|
139 |
+
|
140 |
+
template <
|
141 |
+
bool kDynamicK,
|
142 |
+
typename Gemm,
|
143 |
+
typename ElementA, typename ElementB, typename ElementC,
|
144 |
+
typename LayoutA, typename LayoutB, typename LayoutC
|
145 |
+
>
|
146 |
+
RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
|
147 |
+
torch::Tensor b,
|
148 |
+
torch::Tensor c,
|
149 |
+
torch::Tensor batch_sizes,
|
150 |
+
::cutlass::gemm::GemmCoord coord_template,
|
151 |
+
int64_t num_experts) {
|
152 |
+
std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
|
153 |
+
|
154 |
+
// Create the host arrays of leading dimension data and pointer data.
|
155 |
+
std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
|
156 |
+
int64_t elements_a = 0, elements_b = 0, elements_c = 0;
|
157 |
+
|
158 |
+
std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
|
159 |
+
|
160 |
+
for (int i = 0; i < num_experts; ++i) {
|
161 |
+
auto& problem = problem_sizes_host[i];
|
162 |
+
problem = coord_template;
|
163 |
+
(kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
|
164 |
+
|
165 |
+
lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
|
166 |
+
ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
|
167 |
+
ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
|
168 |
+
|
169 |
+
ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
|
170 |
+
ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
|
171 |
+
ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
|
172 |
+
|
173 |
+
elements_a += problem.m() * problem.k();
|
174 |
+
elements_b += problem.k() * problem.n();
|
175 |
+
elements_c += problem.m() * problem.n();
|
176 |
+
|
177 |
+
if (problem.k() == 0) {
|
178 |
+
// CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
|
179 |
+
// Until a fix is available on the CUTLASS side, handle these problems by ourselves:
|
180 |
+
// * set the output to zero with `cudaMemsetAsync()`
|
181 |
+
// * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
|
182 |
+
CUDA_CALL(cudaMemsetAsync(ptr_c_host[i],
|
183 |
+
0,
|
184 |
+
problem.m() * problem.n() * sizeof(ElementC),
|
185 |
+
c10::cuda::getCurrentCUDAStream()));
|
186 |
+
|
187 |
+
problem.m() = 0;
|
188 |
+
problem.n() = 0;
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
// Only sort problems when K are different
|
193 |
+
if (kDynamicK) {
|
194 |
+
std::vector<size_t> indices(num_experts);
|
195 |
+
std::iota(indices.begin(), indices.end(), 0);
|
196 |
+
std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
|
197 |
+
return problem_sizes_host[i].k() > problem_sizes_host[j].k();
|
198 |
+
});
|
199 |
+
|
200 |
+
ReorderArray(problem_sizes_host.data(), indices);
|
201 |
+
ReorderArray(lda_host.data(), indices);
|
202 |
+
ReorderArray(ldb_host.data(), indices);
|
203 |
+
ReorderArray(ldc_host.data(), indices);
|
204 |
+
ReorderArray(ptr_a_host.data(), indices);
|
205 |
+
ReorderArray(ptr_b_host.data(), indices);
|
206 |
+
ReorderArray(ptr_c_host.data(), indices);
|
207 |
+
}
|
208 |
+
|
209 |
+
// Copy the problem sizes, pointers and leading dimension data to the device.
|
210 |
+
return RawGemmArguments {
|
211 |
+
.lda = CopyToDevice(lda_host, a.device()),
|
212 |
+
.ldb = CopyToDevice(ldb_host, a.device()),
|
213 |
+
.ldc = CopyToDevice(ldc_host, a.device()),
|
214 |
+
.ptr_a = CopyToDevice(ptr_a_host, a.device()),
|
215 |
+
.ptr_b = CopyToDevice(ptr_b_host, a.device()),
|
216 |
+
.ptr_c = CopyToDevice(ptr_c_host, a.device()),
|
217 |
+
.problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
|
218 |
+
|
219 |
+
// We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
|
220 |
+
.threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
|
221 |
+
};
|
222 |
+
}
|
223 |
+
|
224 |
+
template <
|
225 |
+
bool kDynamicK,
|
226 |
+
typename Gemm,
|
227 |
+
typename ElementA, typename ElementB, typename ElementC,
|
228 |
+
typename LayoutA, typename LayoutB, typename LayoutC
|
229 |
+
>
|
230 |
+
typename Gemm::Arguments MakeArguments(torch::Tensor a,
|
231 |
+
torch::Tensor b,
|
232 |
+
torch::Tensor c,
|
233 |
+
torch::Tensor batch_sizes,
|
234 |
+
::cutlass::gemm::GemmCoord coord_template,
|
235 |
+
int64_t num_experts) {
|
236 |
+
RawGemmArguments raw_args;
|
237 |
+
if (batch_sizes.is_cuda()) {
|
238 |
+
raw_args = MakeArgumentsOnDevice<
|
239 |
+
Gemm, ElementA, ElementB, ElementC
|
240 |
+
>(num_experts, a.device());
|
241 |
+
} else {
|
242 |
+
raw_args = MakeArgumentsOnHost<
|
243 |
+
kDynamicK,
|
244 |
+
Gemm,
|
245 |
+
ElementA, ElementB, ElementC,
|
246 |
+
LayoutA, LayoutB, LayoutC
|
247 |
+
>(a, b, c, batch_sizes, coord_template, num_experts);
|
248 |
+
}
|
249 |
+
|
250 |
+
printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
|
251 |
+
// Validate the result.
|
252 |
+
if (!raw_args.threadblock_count) {
|
253 |
+
TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
|
254 |
+
}
|
255 |
+
|
256 |
+
typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
|
257 |
+
// We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
|
258 |
+
// so we can safely pass `nullptr` for `host_problem_sizes`.
|
259 |
+
// TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
|
260 |
+
// know the problem dimensions on the host.
|
261 |
+
typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
|
262 |
+
(int)num_experts,
|
263 |
+
(int)raw_args.threadblock_count,
|
264 |
+
epilogue_op,
|
265 |
+
(ElementA**)raw_args.ptr_a.data_ptr(),
|
266 |
+
(ElementB**)raw_args.ptr_b.data_ptr(),
|
267 |
+
(ElementC**)raw_args.ptr_c.data_ptr(),
|
268 |
+
(ElementC**)raw_args.ptr_c.data_ptr(),
|
269 |
+
/*lda=*/(int64_t*)raw_args.lda.data_ptr(),
|
270 |
+
/*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
|
271 |
+
/*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
|
272 |
+
/*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
|
273 |
+
/*host_problem_sizes=*/nullptr);
|
274 |
+
return arguments;
|
275 |
+
}
|
276 |
+
|
277 |
+
template <
|
278 |
+
bool trans_a,
|
279 |
+
typename ElementA, typename ElementB, typename ElementC,
|
280 |
+
typename LayoutA, typename LayoutB, typename LayoutC,
|
281 |
+
typename Arguments
|
282 |
+
>
|
283 |
+
void FillCutlassArguments(int num_experts,
|
284 |
+
torch::Tensor batch_sizes,
|
285 |
+
torch::Tensor a,
|
286 |
+
torch::Tensor b,
|
287 |
+
torch::Tensor c,
|
288 |
+
const Arguments& arguments,
|
289 |
+
::cutlass::gemm::GemmCoord coord_template) {
|
290 |
+
// Convert the batch sizes to the format CUTLASS understands on the device.
|
291 |
+
// Use a single block here because:
|
292 |
+
// * the number of elements to process is microscopically small
|
293 |
+
// * we don't need any additional global memory
|
294 |
+
FillArguments<
|
295 |
+
/*kDynamicK*/trans_a,
|
296 |
+
ElementA, ElementB, ElementC,
|
297 |
+
LayoutA, LayoutB, LayoutC
|
298 |
+
><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>(
|
299 |
+
num_experts, batch_sizes.data_ptr<int64_t>(),
|
300 |
+
(ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
|
301 |
+
arguments, coord_template
|
302 |
+
);
|
303 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
304 |
+
}
|
305 |
+
|
306 |
+
template <typename Args>
|
307 |
+
void RemoveK0Problems(int num_experts, const Args& arguments) {
|
308 |
+
// For zeroing out the outputs (which might be arbitrarily large), we want to use
|
309 |
+
// as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
|
310 |
+
// `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
|
311 |
+
// should be a good approximation for this.
|
312 |
+
// When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
|
313 |
+
ZeroOutK0Outputs<><<<
|
314 |
+
arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()
|
315 |
+
>>>(
|
316 |
+
num_experts, arguments
|
317 |
+
);
|
318 |
+
IgnoreK0Problems<><<<
|
319 |
+
1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()
|
320 |
+
>>>(
|
321 |
+
num_experts, arguments
|
322 |
+
);
|
323 |
+
}
|
324 |
+
|
325 |
+
template <bool trans_a, bool trans_b>
|
326 |
+
torch::Tensor CutlassGroupedGemm(torch::Tensor a,
|
327 |
+
torch::Tensor b,
|
328 |
+
torch::Tensor c,
|
329 |
+
torch::Tensor batch_sizes,
|
330 |
+
::cutlass::gemm::GemmCoord coord_template) {
|
331 |
+
using Gemm = GemmGrouped<trans_a, trans_b>;
|
332 |
+
using LayoutA = typename Gemm::LayoutA;
|
333 |
+
using LayoutB = typename Gemm::LayoutB;
|
334 |
+
using LayoutC = typename Gemm::LayoutC;
|
335 |
+
|
336 |
+
using ElementA = typename Gemm::ElementA;
|
337 |
+
using ElementB = typename Gemm::ElementB;
|
338 |
+
using ElementC = typename Gemm::ElementC;
|
339 |
+
|
340 |
+
Gemm gemm;
|
341 |
+
int64_t num_experts = batch_sizes.size(0);
|
342 |
+
auto arguments = MakeArguments<
|
343 |
+
/*kDynamicK*/trans_a,
|
344 |
+
Gemm,
|
345 |
+
ElementA, ElementB, ElementC,
|
346 |
+
LayoutA, LayoutB, LayoutC
|
347 |
+
>(a, b, c, batch_sizes, coord_template, num_experts);
|
348 |
+
int64_t workspace_size = gemm.get_workspace_size(arguments);
|
349 |
+
auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
|
350 |
+
torch::Tensor workspace = torch::empty(workspace_size, options);
|
351 |
+
|
352 |
+
if (batch_sizes.is_cuda()) {
|
353 |
+
FillCutlassArguments<
|
354 |
+
trans_a,
|
355 |
+
ElementA, ElementB, ElementC,
|
356 |
+
LayoutA, LayoutB, LayoutC
|
357 |
+
>(num_experts, batch_sizes, a, b, c, arguments, coord_template);
|
358 |
+
|
359 |
+
RemoveK0Problems<>(num_experts, arguments);
|
360 |
+
}
|
361 |
+
|
362 |
+
// Initialize the kernel.
|
363 |
+
if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
|
364 |
+
TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
|
365 |
+
}
|
366 |
+
|
367 |
+
// Execute the kernel in the current stream.
|
368 |
+
if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
|
369 |
+
TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
|
370 |
+
}
|
371 |
+
return c;
|
372 |
+
}
|
373 |
+
|
374 |
+
void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
|
375 |
+
c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
|
376 |
+
c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
|
377 |
+
int m = trans_b ? b_rows : b_cols;
|
378 |
+
int k = trans_b ? b_cols : b_rows;
|
379 |
+
int n = trans_a ? a_cols : a_rows;
|
380 |
+
|
381 |
+
int lda = trans_a ? n : k;
|
382 |
+
int ldb = trans_b ? k : m;
|
383 |
+
cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
|
384 |
+
cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
|
385 |
+
|
386 |
+
float alpha = 1.0, beta = 0.0;
|
387 |
+
CUBLAS_CALL(cublasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
|
388 |
+
transpose_b, transpose_a,
|
389 |
+
m, n, k, &alpha,
|
390 |
+
b, CUDA_R_16BF, ldb,
|
391 |
+
a, CUDA_R_16BF, lda,
|
392 |
+
&beta,
|
393 |
+
c, CUDA_R_16BF, c_cols, CUDA_R_32F,
|
394 |
+
CUBLAS_GEMM_DEFAULT));
|
395 |
+
}
|
396 |
+
|
397 |
+
void CublasGroupedGemm(torch::Tensor a,
|
398 |
+
torch::Tensor b,
|
399 |
+
torch::Tensor c,
|
400 |
+
torch::Tensor batch_sizes,
|
401 |
+
bool trans_b) {
|
402 |
+
int64_t bs = batch_sizes.size(0), k = a.size(1);
|
403 |
+
int64_t n = trans_b ? b.size(1) : b.size(2);
|
404 |
+
int64_t b_rows = b.size(1), b_cols = b.size(2);
|
405 |
+
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
|
406 |
+
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
|
407 |
+
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
|
408 |
+
for (int i = 0; i < bs; ++i) {
|
409 |
+
int64_t m = batch_sizes.data_ptr<int64_t>()[i];
|
410 |
+
CublasGemm(a_ptr, m, k, /*trans_a=*/false,
|
411 |
+
b_ptr, b_rows, b_cols, trans_b,
|
412 |
+
c_ptr, m, n);
|
413 |
+
a_ptr += m * k;
|
414 |
+
b_ptr += b_rows * b_cols;
|
415 |
+
c_ptr += m * n;
|
416 |
+
}
|
417 |
+
}
|
418 |
+
|
419 |
+
void CublasGroupedGemmVariableK(torch::Tensor a,
|
420 |
+
torch::Tensor b,
|
421 |
+
torch::Tensor c,
|
422 |
+
torch::Tensor batch_sizes) {
|
423 |
+
int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
|
424 |
+
c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
|
425 |
+
c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
|
426 |
+
c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
|
427 |
+
for (int i = 0; i < bs; ++i) {
|
428 |
+
int64_t k = batch_sizes.data_ptr<int64_t>()[i];
|
429 |
+
CublasGemm(a_ptr, k, m, /*trans_a=*/true,
|
430 |
+
b_ptr, k, n, /*trans_b=*/false,
|
431 |
+
c_ptr, m, n);
|
432 |
+
a_ptr += k * m;
|
433 |
+
b_ptr += k * n;
|
434 |
+
c_ptr += m * n;
|
435 |
+
}
|
436 |
+
}
|
437 |
+
|
438 |
+
void GroupedGemmVariableK(torch::Tensor a,
|
439 |
+
torch::Tensor b,
|
440 |
+
torch::Tensor c,
|
441 |
+
torch::Tensor batch_sizes) {
|
442 |
+
// We expected a CUDA tensor with two dimensions and shape
|
443 |
+
// (tokens, hidden_out) for 'b'.
|
444 |
+
TORCH_CHECK(b.is_cuda());
|
445 |
+
TORCH_CHECK(b.ndimension() == 2);
|
446 |
+
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
|
447 |
+
|
448 |
+
// Validate the dimensions.
|
449 |
+
int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
|
450 |
+
int64_t m = a.size(1), n = b.size(1);
|
451 |
+
|
452 |
+
// Validate that we have the same contraction dimension.
|
453 |
+
TORCH_CHECK(tokens == b.size(0));
|
454 |
+
|
455 |
+
// Validate the output shape.
|
456 |
+
TORCH_CHECK(c.is_cuda());
|
457 |
+
TORCH_CHECK(c.ndimension() == 3);
|
458 |
+
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
|
459 |
+
TORCH_CHECK(c.size(0) == num_experts);
|
460 |
+
TORCH_CHECK(c.size(1) == m);
|
461 |
+
TORCH_CHECK(c.size(2) == n);
|
462 |
+
|
463 |
+
// Run the computation.
|
464 |
+
CublasGroupedGemmVariableK(a, b, c, batch_sizes);
|
465 |
+
}
|
466 |
+
|
467 |
+
// NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
|
468 |
+
// assumed to be batched with fixed sized batches.
|
469 |
+
//
|
470 |
+
// TODO(tgale): Validate alignment is true for every batch element.
|
471 |
+
void GroupedGemm(torch::Tensor a,
|
472 |
+
torch::Tensor b,
|
473 |
+
torch::Tensor c,
|
474 |
+
torch::Tensor batch_sizes,
|
475 |
+
bool trans_a, bool trans_b) {
|
476 |
+
// NOTE: We only support 'trans_a' or 'trans_b', not both.
|
477 |
+
TORCH_CHECK(!(trans_a && trans_b));
|
478 |
+
|
479 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
480 |
+
// No way to run cuBLAS kernels if the problem dimensions are not known on the host.
|
481 |
+
TORCH_CHECK(batch_sizes.is_cpu());
|
482 |
+
#else
|
483 |
+
// CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
|
484 |
+
TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
|
485 |
+
#endif
|
486 |
+
TORCH_CHECK(batch_sizes.ndimension() == 1);
|
487 |
+
TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
|
488 |
+
|
489 |
+
// We expected a CUDA tensor with two dimensions and shape
|
490 |
+
// (tokens, hidden_in) for 'a'.
|
491 |
+
TORCH_CHECK(a.is_cuda());
|
492 |
+
TORCH_CHECK(a.ndimension() == 2);
|
493 |
+
TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
|
494 |
+
|
495 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
496 |
+
if (trans_a) {
|
497 |
+
// If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
|
498 |
+
// for the rest of the op.
|
499 |
+
GroupedGemmVariableK(a, b, c, batch_sizes);
|
500 |
+
return;
|
501 |
+
}
|
502 |
+
#endif
|
503 |
+
|
504 |
+
TORCH_CHECK(b.is_cuda());
|
505 |
+
TORCH_CHECK(c.is_cuda());
|
506 |
+
TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
|
507 |
+
TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
|
508 |
+
|
509 |
+
// The expected shapes of 'b' and 'c' are:
|
510 |
+
// * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
|
511 |
+
// * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
|
512 |
+
// * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
|
513 |
+
size_t hidden_in{}, hidden_out{};
|
514 |
+
if (trans_a) {
|
515 |
+
hidden_in = a.size(1);
|
516 |
+
hidden_out = b.size(1);
|
517 |
+
|
518 |
+
TORCH_CHECK(b.ndimension() == 2);
|
519 |
+
TORCH_CHECK(c.ndimension() == 3);
|
520 |
+
TORCH_CHECK(b.size(0) == a.size(0));
|
521 |
+
TORCH_CHECK(c.size(0) == batch_sizes.size(0));
|
522 |
+
TORCH_CHECK(c.size(1) == hidden_in);
|
523 |
+
TORCH_CHECK(c.size(2) == hidden_out);
|
524 |
+
} else {
|
525 |
+
TORCH_CHECK(b.ndimension() == 3);
|
526 |
+
TORCH_CHECK(c.ndimension() == 2);
|
527 |
+
|
528 |
+
// Validate the contraction dimensions match.
|
529 |
+
int64_t tokens = a.size(0), num_experts = b.size(0);
|
530 |
+
hidden_in = trans_b ? b.size(2) : b.size(1);
|
531 |
+
hidden_out = trans_b ? b.size(1) : b.size(2);
|
532 |
+
TORCH_CHECK(hidden_in == a.size(1));
|
533 |
+
|
534 |
+
// Validate that we have one size per expert.
|
535 |
+
TORCH_CHECK(batch_sizes.size(0) == num_experts);
|
536 |
+
}
|
537 |
+
|
538 |
+
// NOTE: We support transposition through the 'trans_b' flag.
|
539 |
+
TORCH_CHECK(a.is_contiguous());
|
540 |
+
TORCH_CHECK(b.is_contiguous());
|
541 |
+
TORCH_CHECK(c.is_contiguous());
|
542 |
+
|
543 |
+
#if !defined(GROUPED_GEMM_CUTLASS)
|
544 |
+
CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
|
545 |
+
return;
|
546 |
+
#else
|
547 |
+
// The `coord_template` argument contains `kDynamicDim` as one of its dimensions
|
548 |
+
// as a placeholder. This placeholder is later expanded into the actual dimension
|
549 |
+
// for every element of the batch, either on the host or on the device
|
550 |
+
// (if we can't do in on the host).
|
551 |
+
const auto coord_template = trans_a
|
552 |
+
? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
|
553 |
+
: cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
|
554 |
+
if (trans_a) {
|
555 |
+
CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
|
556 |
+
return;
|
557 |
+
}
|
558 |
+
if (trans_b) {
|
559 |
+
CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
|
560 |
+
return;
|
561 |
+
}
|
562 |
+
CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
|
563 |
+
return;
|
564 |
+
#endif
|
565 |
+
}
|
566 |
+
|
567 |
+
} // namespace grouped_gemm
|
csrc/grouped_gemm/grouped_gemm.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
// // Set default if not already defined
|
4 |
+
// #ifndef GROUPED_GEMM_CUTLASS
|
5 |
+
// #define GROUPED_GEMM_CUTLASS 0
|
6 |
+
// #endif
|
7 |
+
|
8 |
+
// #include <torch/extension.h>
|
9 |
+
#include <torch/torch.h>
|
10 |
+
|
11 |
+
namespace grouped_gemm {
|
12 |
+
|
13 |
+
void GroupedGemm(torch::Tensor a,
|
14 |
+
torch::Tensor b,
|
15 |
+
torch::Tensor c,
|
16 |
+
torch::Tensor batch_sizes,
|
17 |
+
bool trans_a, bool trans_b);
|
18 |
+
|
19 |
+
} // namespace grouped_gemm
|
20 |
+
|
csrc/grouped_gemm/ops.cu
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "grouped_gemm.h"
|
2 |
+
|
3 |
+
#include <torch/extension.h>
|
4 |
+
|
5 |
+
namespace grouped_gemm {
|
6 |
+
|
7 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
8 |
+
m.def("gmm", &GroupedGemm, "Grouped GEMM.");
|
9 |
+
}
|
10 |
+
|
11 |
+
} // namespace grouped_gemm
|
csrc/histogram.h
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
// #include <torch/extension.h>
|
9 |
+
|
10 |
+
#define CUDA_CALL(code) \
|
11 |
+
do { \
|
12 |
+
cudaError_t status = code; \
|
13 |
+
std::string err = cudaGetErrorString(status); \
|
14 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
15 |
+
} while (0)
|
16 |
+
|
17 |
+
namespace megablocks {
|
18 |
+
|
19 |
+
template <typename T>
|
20 |
+
torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
|
21 |
+
// Allocate the count buffer.
|
22 |
+
auto options = torch::TensorOptions()
|
23 |
+
.dtype(torch::kInt32)
|
24 |
+
.device(x.device());
|
25 |
+
torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
|
26 |
+
|
27 |
+
// Exit early if there is not work to do.
|
28 |
+
if (out.numel() == 0) return out;
|
29 |
+
|
30 |
+
// Get scratchpad size.
|
31 |
+
size_t scratchpad_bytes = 0;
|
32 |
+
CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
|
33 |
+
scratchpad_bytes,
|
34 |
+
x.data_ptr<T>(),
|
35 |
+
out.data_ptr<int>(),
|
36 |
+
/*num_levels=*/num_bins + 1,
|
37 |
+
/*lower_level=*/0,
|
38 |
+
/*upper_level=*/num_bins,
|
39 |
+
/*num_samples=*/int(x.size(1)),
|
40 |
+
c10::cuda::getCurrentCUDAStream()));
|
41 |
+
|
42 |
+
// Allocate scratchpad.
|
43 |
+
options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
|
44 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
45 |
+
|
46 |
+
// Run the kernel.
|
47 |
+
for (int i = 0; i < x.size(0); ++i) {
|
48 |
+
CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
|
49 |
+
scratchpad_bytes,
|
50 |
+
x.data_ptr<T>() + x.size(1) * i,
|
51 |
+
out.data_ptr<int>() + out.size(1) * i,
|
52 |
+
/*num_levels=*/num_bins + 1,
|
53 |
+
/*lower_level=*/0,
|
54 |
+
/*upper_level=*/num_bins,
|
55 |
+
/*num_samples=*/int(x.size(1)),
|
56 |
+
c10::cuda::getCurrentCUDAStream()));
|
57 |
+
}
|
58 |
+
return out;
|
59 |
+
}
|
60 |
+
|
61 |
+
torch::Tensor histogram(torch::Tensor x, int num_bins) {
|
62 |
+
TORCH_CHECK(x.is_cuda());
|
63 |
+
TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
|
64 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
65 |
+
x.scalar_type() == torch::kInt32 ||
|
66 |
+
x.scalar_type() == torch::kInt64);
|
67 |
+
bool no_batch = x.ndimension() == 1;
|
68 |
+
if (no_batch) x = x.view({1, x.numel()});
|
69 |
+
|
70 |
+
if (x.scalar_type() == torch::kInt16) {
|
71 |
+
auto out = cub_histogram<short>(x, num_bins);
|
72 |
+
return no_batch ? out.flatten() : out;
|
73 |
+
} else if (x.scalar_type() == torch::kInt32) {
|
74 |
+
auto out = cub_histogram<int>(x, num_bins);
|
75 |
+
return no_batch ? out.flatten() : out;
|
76 |
+
} else {
|
77 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
78 |
+
auto out = cub_histogram<long>(x, num_bins);
|
79 |
+
return no_batch ? out.flatten() : out;
|
80 |
+
}
|
81 |
+
}
|
82 |
+
|
83 |
+
} // namespace megablocks
|
84 |
+
|
85 |
+
#undef CUDA_CALL
|
86 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/indices.h
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cstdint>
|
2 |
+
#include <c10/util/Half.h>
|
3 |
+
// #include <torch/extension.h>
|
4 |
+
#include <c10/cuda/CUDAStream.h>
|
5 |
+
|
6 |
+
#define CUDA_CALL(code) \
|
7 |
+
do { \
|
8 |
+
cudaError_t status = code; \
|
9 |
+
std::string err = cudaGetErrorString(status); \
|
10 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
11 |
+
} while (0)
|
12 |
+
|
13 |
+
namespace megablocks {
|
14 |
+
namespace construct_indices {
|
15 |
+
|
16 |
+
// We expect the number of outputs per block to be small. For
|
17 |
+
// example, with ffn_hidden_size=4096, we only need to write
|
18 |
+
// 32 elements per block per iteration.
|
19 |
+
const int kThreadsPerBlock = 32;
|
20 |
+
|
21 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
22 |
+
ConstructIndicesKernel(short * __restrict__ indices,
|
23 |
+
int num_columns,
|
24 |
+
int block_size,
|
25 |
+
const int * __restrict__ padded_bins) {
|
26 |
+
// Load the offset for this bins indices.
|
27 |
+
int start = 0;
|
28 |
+
if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
|
29 |
+
int end = __ldg(padded_bins + blockIdx.x);
|
30 |
+
|
31 |
+
// Divide the start and end into blocks.
|
32 |
+
start /= block_size;
|
33 |
+
end /= block_size;
|
34 |
+
|
35 |
+
// Offset the output buffer to the start of the bin.
|
36 |
+
indices += (start + blockIdx.y) * num_columns + threadIdx.x;
|
37 |
+
|
38 |
+
// Write the indices to the output.
|
39 |
+
int bin_offset = blockIdx.y;
|
40 |
+
int num_rows = end - start;
|
41 |
+
for (; bin_offset < num_rows; num_rows -= gridDim.y) {
|
42 |
+
short *out = indices;
|
43 |
+
for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
|
44 |
+
*out = bid + (blockIdx.x * num_columns);
|
45 |
+
out += kThreadsPerBlock;
|
46 |
+
}
|
47 |
+
indices += gridDim.y * num_columns;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
cudaError_t ConstructIndices(short * __restrict__ indices,
|
52 |
+
int output_block_rows,
|
53 |
+
int output_block_columns,
|
54 |
+
int block_size,
|
55 |
+
const int * __restrict__ padded_bins,
|
56 |
+
int num_bins,
|
57 |
+
cudaStream_t stream) {
|
58 |
+
dim3 block_dim(kThreadsPerBlock);
|
59 |
+
dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
|
60 |
+
ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
|
61 |
+
output_block_columns,
|
62 |
+
block_size,
|
63 |
+
padded_bins);
|
64 |
+
return cudaGetLastError();
|
65 |
+
}
|
66 |
+
|
67 |
+
} // namespace construct_indices
|
68 |
+
|
69 |
+
void indices(torch::Tensor padded_bins,
|
70 |
+
int block_size,
|
71 |
+
int output_block_rows,
|
72 |
+
int output_block_columns,
|
73 |
+
torch::Tensor out) {
|
74 |
+
TORCH_CHECK(padded_bins.is_cuda());
|
75 |
+
TORCH_CHECK(padded_bins.ndimension() == 1);
|
76 |
+
TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
|
77 |
+
|
78 |
+
TORCH_CHECK(out.is_cuda());
|
79 |
+
TORCH_CHECK(out.ndimension() == 1);
|
80 |
+
TORCH_CHECK(out.scalar_type() == torch::kInt16);
|
81 |
+
TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
|
82 |
+
|
83 |
+
// Exit early if there is no work to do.
|
84 |
+
if (out.numel() == 0) return;
|
85 |
+
|
86 |
+
CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
|
87 |
+
output_block_rows,
|
88 |
+
output_block_columns,
|
89 |
+
block_size,
|
90 |
+
padded_bins.data_ptr<int>(),
|
91 |
+
padded_bins.numel(),
|
92 |
+
c10::cuda::getCurrentCUDAStream()));
|
93 |
+
}
|
94 |
+
|
95 |
+
} // namespace megablocks
|
csrc/new_cumsum.cu
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#define CUB_IGNORE_DEPRECATED_API
|
2 |
+
|
3 |
+
#undef CUB_WRAPPED_NAMESPACE
|
4 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
5 |
+
|
6 |
+
#include "new_cumsum.h"
|
7 |
+
#include <cstdint>
|
8 |
+
#include <hipcub/hipcub.hpp>
|
9 |
+
#include <c10/cuda/CUDAStream.h>
|
10 |
+
|
11 |
+
#define CUDA_CALL(code) \
|
12 |
+
do { \
|
13 |
+
cudaError_t status = code; \
|
14 |
+
std::string err = cudaGetErrorString(status); \
|
15 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
16 |
+
} while (0)
|
17 |
+
|
18 |
+
namespace megablocks {
|
19 |
+
|
20 |
+
struct Inclusive {};
|
21 |
+
struct Exclusive {};
|
22 |
+
|
23 |
+
template <typename Type> struct Cumsum {
|
24 |
+
|
25 |
+
template<
|
26 |
+
typename InputIteratorT,
|
27 |
+
typename OutputIteratorT>
|
28 |
+
static void Run(void * d_temp_storage,
|
29 |
+
size_t & temp_storage_bytes,
|
30 |
+
InputIteratorT d_in,
|
31 |
+
OutputIteratorT d_out,
|
32 |
+
int num_items,
|
33 |
+
cudaStream_t stream = 0,
|
34 |
+
bool debug_synchronous = false) {
|
35 |
+
CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(d_temp_storage,
|
36 |
+
temp_storage_bytes,
|
37 |
+
d_in,
|
38 |
+
d_out,
|
39 |
+
num_items,
|
40 |
+
stream));//,
|
41 |
+
//debug_synchronous));
|
42 |
+
}
|
43 |
+
};
|
44 |
+
|
45 |
+
template <> struct Cumsum<Inclusive> {
|
46 |
+
template<
|
47 |
+
typename InputIteratorT,
|
48 |
+
typename OutputIteratorT>
|
49 |
+
static void Run(void * d_temp_storage,
|
50 |
+
size_t & temp_storage_bytes,
|
51 |
+
InputIteratorT d_in,
|
52 |
+
OutputIteratorT d_out,
|
53 |
+
int num_items,
|
54 |
+
cudaStream_t stream = 0,
|
55 |
+
bool debug_synchronous = false) {
|
56 |
+
CUDA_CALL(hipcub::DeviceScan::InclusiveSum(d_temp_storage,
|
57 |
+
temp_storage_bytes,
|
58 |
+
d_in,
|
59 |
+
d_out,
|
60 |
+
num_items,
|
61 |
+
stream));//,
|
62 |
+
//debug_synchronous));
|
63 |
+
}
|
64 |
+
};
|
65 |
+
|
66 |
+
template <typename SumType, typename T>
|
67 |
+
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
68 |
+
// Get temporary storage size.
|
69 |
+
size_t scratchpad_bytes = 0;
|
70 |
+
Cumsum<SumType>::Run(nullptr,
|
71 |
+
scratchpad_bytes,
|
72 |
+
x.data_ptr<T>(),
|
73 |
+
out.data_ptr<T>(),
|
74 |
+
x.size(1),
|
75 |
+
c10::cuda::getCurrentCUDAStream());
|
76 |
+
|
77 |
+
// Allocate scratchpad.
|
78 |
+
//
|
79 |
+
// NOTE: We scale for the batch dimension so we can run in parallel.
|
80 |
+
auto options = torch::TensorOptions()
|
81 |
+
.dtype(torch::kInt8)
|
82 |
+
.device(x.device());
|
83 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
|
84 |
+
options);
|
85 |
+
|
86 |
+
// Run the kernel.
|
87 |
+
//
|
88 |
+
// NOTE: Using different streams for each issue does not appear to
|
89 |
+
// yield performance gains for our problem set. The overhead of
|
90 |
+
// event/stream synchronization appears to outweigh the benfits.
|
91 |
+
// We could write a true batched cumsum, but this would require
|
92 |
+
// significant code duplication from cub and we might move away
|
93 |
+
// from this formulation anyways.
|
94 |
+
for (int i = 0; i < x.size(0); ++i) {
|
95 |
+
void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
|
96 |
+
Cumsum<SumType>::Run(scratchpad_ptr,
|
97 |
+
scratchpad_bytes,
|
98 |
+
x.data_ptr<T>() + x.size(1) * i,
|
99 |
+
out.data_ptr<T>() + x.size(1) * i,
|
100 |
+
x.size(1),
|
101 |
+
c10::cuda::getCurrentCUDAStream());
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
106 |
+
// Validate the input matrix.
|
107 |
+
TORCH_CHECK(x.is_cuda());
|
108 |
+
TORCH_CHECK(x.ndimension() == 2);
|
109 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
110 |
+
x.scalar_type() == torch::kInt32 ||
|
111 |
+
x.scalar_type() == torch::kInt64);
|
112 |
+
TORCH_CHECK(out.is_cuda());
|
113 |
+
TORCH_CHECK(out.ndimension() == 2);
|
114 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
115 |
+
|
116 |
+
// NOTE: We currently only support contraction across the contiguous
|
117 |
+
// dimension in the matrix.
|
118 |
+
TORCH_CHECK(dim == 1);
|
119 |
+
|
120 |
+
switch (x.scalar_type()) {
|
121 |
+
case torch::kInt16:
|
122 |
+
cub_cumsum<Exclusive, short>(x, dim, out);
|
123 |
+
return;
|
124 |
+
case torch::kInt32:
|
125 |
+
cub_cumsum<Exclusive, int>(x, dim, out);
|
126 |
+
return;
|
127 |
+
}
|
128 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
129 |
+
cub_cumsum<Exclusive, long>(x, dim, out);
|
130 |
+
}
|
131 |
+
|
132 |
+
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
|
133 |
+
// Validate the input matrix.
|
134 |
+
TORCH_CHECK(x.is_cuda());
|
135 |
+
TORCH_CHECK(x.ndimension() == 2);
|
136 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
137 |
+
x.scalar_type() == torch::kInt32 ||
|
138 |
+
x.scalar_type() == torch::kInt64);
|
139 |
+
TORCH_CHECK(out.is_cuda());
|
140 |
+
TORCH_CHECK(out.ndimension() == 2);
|
141 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
142 |
+
|
143 |
+
// NOTE: We currently only support contraction across the contiguous
|
144 |
+
// dimension in the matrix.
|
145 |
+
TORCH_CHECK(dim == 1);
|
146 |
+
|
147 |
+
switch (x.scalar_type()) {
|
148 |
+
case torch::kInt16:
|
149 |
+
cub_cumsum<Inclusive, short>(x, dim, out);
|
150 |
+
return;
|
151 |
+
case torch::kInt32:
|
152 |
+
cub_cumsum<Inclusive, int>(x, dim, out);
|
153 |
+
return;
|
154 |
+
}
|
155 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
156 |
+
cub_cumsum<Inclusive, long>(x, dim, out);
|
157 |
+
}
|
158 |
+
|
159 |
+
} // namespace megablocks
|
160 |
+
|
161 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_cumsum.h
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Forward declarations for the public interface functions
|
8 |
+
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
|
9 |
+
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
|
10 |
+
|
11 |
+
} // namespace megablocks
|
csrc/new_histogram.cu
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include "new_histogram.h"
|
5 |
+
#include <cstdint>
|
6 |
+
#include <hipcub/hipcub.hpp>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
|
9 |
+
#define CUDA_CALL(code) \
|
10 |
+
do { \
|
11 |
+
cudaError_t status = code; \
|
12 |
+
std::string err = cudaGetErrorString(status); \
|
13 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
14 |
+
} while (0)
|
15 |
+
|
16 |
+
namespace megablocks {
|
17 |
+
|
18 |
+
template <typename T>
|
19 |
+
torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
|
20 |
+
// Allocate the count buffer.
|
21 |
+
auto options = torch::TensorOptions()
|
22 |
+
.dtype(torch::kInt32)
|
23 |
+
.device(x.device());
|
24 |
+
torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
|
25 |
+
|
26 |
+
// Exit early if there is not work to do.
|
27 |
+
if (out.numel() == 0) return out;
|
28 |
+
|
29 |
+
// Get scratchpad size.
|
30 |
+
size_t scratchpad_bytes = 0;
|
31 |
+
CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(nullptr,
|
32 |
+
scratchpad_bytes,
|
33 |
+
x.data_ptr<T>(),
|
34 |
+
out.data_ptr<int>(),
|
35 |
+
/*num_levels=*/num_bins + 1,
|
36 |
+
/*lower_level=*/0,
|
37 |
+
/*upper_level=*/num_bins,
|
38 |
+
/*num_samples=*/int(x.size(1)),
|
39 |
+
c10::cuda::getCurrentCUDAStream()));
|
40 |
+
|
41 |
+
// Allocate scratchpad.
|
42 |
+
options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
|
43 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
44 |
+
|
45 |
+
// Run the kernel.
|
46 |
+
for (int i = 0; i < x.size(0); ++i) {
|
47 |
+
CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
|
48 |
+
scratchpad_bytes,
|
49 |
+
x.data_ptr<T>() + x.size(1) * i,
|
50 |
+
out.data_ptr<int>() + out.size(1) * i,
|
51 |
+
/*num_levels=*/num_bins + 1,
|
52 |
+
/*lower_level=*/0,
|
53 |
+
/*upper_level=*/num_bins,
|
54 |
+
/*num_samples=*/int(x.size(1)),
|
55 |
+
c10::cuda::getCurrentCUDAStream()));
|
56 |
+
}
|
57 |
+
return out;
|
58 |
+
}
|
59 |
+
|
60 |
+
torch::Tensor histogram(torch::Tensor x, int num_bins) {
|
61 |
+
TORCH_CHECK(x.is_cuda());
|
62 |
+
TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
|
63 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
64 |
+
x.scalar_type() == torch::kInt32 ||
|
65 |
+
x.scalar_type() == torch::kInt64);
|
66 |
+
bool no_batch = x.ndimension() == 1;
|
67 |
+
if (no_batch) x = x.view({1, x.numel()});
|
68 |
+
|
69 |
+
if (x.scalar_type() == torch::kInt16) {
|
70 |
+
auto out = cub_histogram<short>(x, num_bins);
|
71 |
+
return no_batch ? out.flatten() : out;
|
72 |
+
} else if (x.scalar_type() == torch::kInt32) {
|
73 |
+
auto out = cub_histogram<int>(x, num_bins);
|
74 |
+
return no_batch ? out.flatten() : out;
|
75 |
+
} else {
|
76 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
77 |
+
auto out = cub_histogram<long>(x, num_bins);
|
78 |
+
return no_batch ? out.flatten() : out;
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
} // namespace megablocks
|
83 |
+
|
84 |
+
#undef CUDA_CALL
|
85 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_histogram.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Public interface function for computing histograms
|
8 |
+
torch::Tensor histogram(torch::Tensor x, int num_bins);
|
9 |
+
|
10 |
+
} // namespace megablocks
|
csrc/new_indices.cu
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "new_indices.h"
|
2 |
+
#include <cstdint>
|
3 |
+
#include <c10/util/Half.h>
|
4 |
+
#include <c10/cuda/CUDAStream.h>
|
5 |
+
|
6 |
+
#define CUDA_CALL(code) \
|
7 |
+
do { \
|
8 |
+
cudaError_t status = code; \
|
9 |
+
std::string err = cudaGetErrorString(status); \
|
10 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
11 |
+
} while (0)
|
12 |
+
|
13 |
+
namespace megablocks {
|
14 |
+
namespace construct_indices {
|
15 |
+
|
16 |
+
// We expect the number of outputs per block to be small. For
|
17 |
+
// example, with ffn_hidden_size=4096, we only need to write
|
18 |
+
// 32 elements per block per iteration.
|
19 |
+
const int kThreadsPerBlock = 32;
|
20 |
+
|
21 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
22 |
+
ConstructIndicesKernel(short * __restrict__ indices,
|
23 |
+
int num_columns,
|
24 |
+
int block_size,
|
25 |
+
const int * __restrict__ padded_bins) {
|
26 |
+
// Load the offset for this bins indices.
|
27 |
+
int start = 0;
|
28 |
+
if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
|
29 |
+
int end = __ldg(padded_bins + blockIdx.x);
|
30 |
+
|
31 |
+
// Divide the start and end into blocks.
|
32 |
+
start /= block_size;
|
33 |
+
end /= block_size;
|
34 |
+
|
35 |
+
// Offset the output buffer to the start of the bin.
|
36 |
+
indices += (start + blockIdx.y) * num_columns + threadIdx.x;
|
37 |
+
|
38 |
+
// Write the indices to the output.
|
39 |
+
int bin_offset = blockIdx.y;
|
40 |
+
int num_rows = end - start;
|
41 |
+
for (; bin_offset < num_rows; num_rows -= gridDim.y) {
|
42 |
+
short *out = indices;
|
43 |
+
for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
|
44 |
+
*out = bid + (blockIdx.x * num_columns);
|
45 |
+
out += kThreadsPerBlock;
|
46 |
+
}
|
47 |
+
indices += gridDim.y * num_columns;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
cudaError_t ConstructIndices(short * __restrict__ indices,
|
52 |
+
int output_block_rows,
|
53 |
+
int output_block_columns,
|
54 |
+
int block_size,
|
55 |
+
const int * __restrict__ padded_bins,
|
56 |
+
int num_bins,
|
57 |
+
cudaStream_t stream) {
|
58 |
+
dim3 block_dim(kThreadsPerBlock);
|
59 |
+
dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
|
60 |
+
ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
|
61 |
+
output_block_columns,
|
62 |
+
block_size,
|
63 |
+
padded_bins);
|
64 |
+
return cudaGetLastError();
|
65 |
+
}
|
66 |
+
|
67 |
+
} // namespace construct_indices
|
68 |
+
|
69 |
+
void indices(torch::Tensor padded_bins,
|
70 |
+
int block_size,
|
71 |
+
int output_block_rows,
|
72 |
+
int output_block_columns,
|
73 |
+
torch::Tensor out) {
|
74 |
+
TORCH_CHECK(padded_bins.is_cuda());
|
75 |
+
TORCH_CHECK(padded_bins.ndimension() == 1);
|
76 |
+
TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
|
77 |
+
|
78 |
+
TORCH_CHECK(out.is_cuda());
|
79 |
+
TORCH_CHECK(out.ndimension() == 1);
|
80 |
+
TORCH_CHECK(out.scalar_type() == torch::kInt16);
|
81 |
+
TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
|
82 |
+
|
83 |
+
// Exit early if there is no work to do.
|
84 |
+
if (out.numel() == 0) return;
|
85 |
+
|
86 |
+
CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
|
87 |
+
output_block_rows,
|
88 |
+
output_block_columns,
|
89 |
+
block_size,
|
90 |
+
padded_bins.data_ptr<int>(),
|
91 |
+
padded_bins.numel(),
|
92 |
+
c10::cuda::getCurrentCUDAStream()));
|
93 |
+
}
|
94 |
+
|
95 |
+
} // namespace megablocks
|
96 |
+
|
97 |
+
#undef CUDA_CALL
|
csrc/new_indices.h
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Public interface function for constructing indices from padded bins
|
8 |
+
void indices(torch::Tensor padded_bins,
|
9 |
+
int block_size,
|
10 |
+
int output_block_rows,
|
11 |
+
int output_block_columns,
|
12 |
+
torch::Tensor out);
|
13 |
+
|
14 |
+
} // namespace megablocks
|
csrc/new_replicate.cu
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Modifications: Copyright Advanced Micro Devices, Inc. SPDX License: MIT.
|
2 |
+
|
3 |
+
#undef CUB_WRAPPED_NAMESPACE
|
4 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
5 |
+
|
6 |
+
#include "new_replicate.h"
|
7 |
+
|
8 |
+
#include <cstdint>
|
9 |
+
|
10 |
+
#include <cub/cub.cuh>
|
11 |
+
#include <c10/util/Half.h>
|
12 |
+
#include <c10/cuda/CUDAStream.h>
|
13 |
+
|
14 |
+
#ifndef USE_ROCM
|
15 |
+
#define _LDG(arg) __ldg(arg)
|
16 |
+
#else
|
17 |
+
#define _LDG(arg) *(arg)
|
18 |
+
#endif
|
19 |
+
|
20 |
+
#define CUDA_CALL(code) \
|
21 |
+
do { \
|
22 |
+
cudaError_t status = code; \
|
23 |
+
std::string err = cudaGetErrorString(status); \
|
24 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
25 |
+
} while (0)
|
26 |
+
|
27 |
+
namespace megablocks {
|
28 |
+
namespace replicate {
|
29 |
+
|
30 |
+
template <typename T, int kThreadsPerBlock>
|
31 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
32 |
+
ReplicateForwardKernel(T * __restrict__ x,
|
33 |
+
int * __restrict__ bins,
|
34 |
+
T * __restrict__ out,
|
35 |
+
int columns) {
|
36 |
+
// Offset to this threadblocks batch.
|
37 |
+
//
|
38 |
+
// x is [batch_size, num_bins]
|
39 |
+
// out is [batch_size, columns]
|
40 |
+
// bins is [num_bins]
|
41 |
+
int batch_idx = blockIdx.y;
|
42 |
+
int num_bins = gridDim.x;
|
43 |
+
x += batch_idx * num_bins;
|
44 |
+
out += batch_idx * columns;
|
45 |
+
|
46 |
+
// Load the start/end for this bin.
|
47 |
+
int bin_idx = blockIdx.x;
|
48 |
+
int start = 0;
|
49 |
+
if (bin_idx > 0) start = _LDG(bins + bin_idx - 1);
|
50 |
+
int end = _LDG(bins + bin_idx);
|
51 |
+
|
52 |
+
// Load the value to replicate.
|
53 |
+
T value = _LDG((T*)x + bin_idx);
|
54 |
+
|
55 |
+
// Offset to this threadblocks bin and this threads
|
56 |
+
// offset within the bin.
|
57 |
+
int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
|
58 |
+
out += start + bin_offset;
|
59 |
+
|
60 |
+
// Replicate the value to the output.
|
61 |
+
//
|
62 |
+
// TODO(tgale): Vectorize these stores.
|
63 |
+
int num_elements = end - start;
|
64 |
+
const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
|
65 |
+
T *out_ptr = (T*)out;
|
66 |
+
for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
|
67 |
+
*out_ptr = value;
|
68 |
+
out_ptr += kElementsPerLoop;
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
template <typename T>
|
73 |
+
cudaError_t ReplicateForward(T *x,
|
74 |
+
int batch_size,
|
75 |
+
int num_bins,
|
76 |
+
int *bins,
|
77 |
+
T *out,
|
78 |
+
int columns,
|
79 |
+
cudaStream_t stream) {
|
80 |
+
const int kThreadsPerBlock = 64;
|
81 |
+
dim3 block_dim(kThreadsPerBlock, 1, 1);
|
82 |
+
int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
|
83 |
+
dim3 grid_dim(num_bins, batch_size, group_size);
|
84 |
+
ReplicateForwardKernel<T, kThreadsPerBlock><<<
|
85 |
+
grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
|
86 |
+
return cudaGetLastError();
|
87 |
+
}
|
88 |
+
|
89 |
+
void cub_segmented_reduce(torch::Tensor grad,
|
90 |
+
torch::Tensor bins,
|
91 |
+
torch::Tensor out,
|
92 |
+
cudaStream_t stream) {
|
93 |
+
// Append a zero to the bin boundaries for CUB.
|
94 |
+
torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
|
95 |
+
CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
|
96 |
+
0,
|
97 |
+
offsets.numel() * sizeof(int),
|
98 |
+
stream));
|
99 |
+
CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
|
100 |
+
bins.data_ptr<int>(),
|
101 |
+
bins.numel() * sizeof(int),
|
102 |
+
cudaMemcpyDeviceToDevice,
|
103 |
+
stream));
|
104 |
+
|
105 |
+
// Get temporary buffer size.
|
106 |
+
size_t scratchpad_bytes = 0;
|
107 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
|
108 |
+
scratchpad_bytes,
|
109 |
+
grad.data_ptr<c10::Half>(),
|
110 |
+
out.data_ptr<c10::Half>(),
|
111 |
+
bins.numel(),
|
112 |
+
offsets.data_ptr<int>(),
|
113 |
+
offsets.data_ptr<int>() + 1,
|
114 |
+
stream));
|
115 |
+
|
116 |
+
// Allocate scratchpad.
|
117 |
+
auto options = torch::TensorOptions()
|
118 |
+
.dtype(torch::kInt8)
|
119 |
+
.device(grad.device());
|
120 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
121 |
+
|
122 |
+
// Run the kernel for each batch item.
|
123 |
+
for (int i = 0; i < grad.size(0); ++i) {
|
124 |
+
int num_bins = out.size(1);
|
125 |
+
int num_values = grad.size(1);
|
126 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
|
127 |
+
scratchpad_bytes,
|
128 |
+
grad.data_ptr<c10::Half>() + i * num_values,
|
129 |
+
out.data_ptr<c10::Half>() + i * num_bins,
|
130 |
+
bins.numel(),
|
131 |
+
offsets.data_ptr<int>(),
|
132 |
+
offsets.data_ptr<int>() + 1,
|
133 |
+
stream));
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
} // namespace replicate
|
138 |
+
|
139 |
+
void replicate_forward(torch::Tensor x,
|
140 |
+
torch::Tensor bins,
|
141 |
+
torch::Tensor out) {
|
142 |
+
// Validate the inputs.
|
143 |
+
TORCH_CHECK(x.is_cuda());
|
144 |
+
TORCH_CHECK(x.ndimension() == 2);
|
145 |
+
TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
|
146 |
+
x.scalar_type() == torch::kInt16 ||
|
147 |
+
x.scalar_type() == torch::kInt32);
|
148 |
+
TORCH_CHECK(bins.is_cuda());
|
149 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
150 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
151 |
+
TORCH_CHECK(out.is_cuda());
|
152 |
+
TORCH_CHECK(out.ndimension() == 2);
|
153 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
154 |
+
|
155 |
+
// Batch dimensions should match for input/output.
|
156 |
+
TORCH_CHECK(x.size(0) == out.size(0));
|
157 |
+
|
158 |
+
// One input for each bin (in each batch).
|
159 |
+
TORCH_CHECK(x.size(1) == bins.size(0));
|
160 |
+
|
161 |
+
// Exit early if there is no work to do.
|
162 |
+
if (out.numel() == 0) return;
|
163 |
+
|
164 |
+
switch (x.scalar_type()) {
|
165 |
+
case torch::kFloat16:
|
166 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
|
167 |
+
x.size(0),
|
168 |
+
x.size(1),
|
169 |
+
bins.data_ptr<int>(),
|
170 |
+
out.data_ptr<c10::Half>(),
|
171 |
+
out.size(1),
|
172 |
+
c10::cuda::getCurrentCUDAStream()));
|
173 |
+
return;
|
174 |
+
case torch::kInt32:
|
175 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
|
176 |
+
x.size(0),
|
177 |
+
x.size(1),
|
178 |
+
bins.data_ptr<int>(),
|
179 |
+
out.data_ptr<int>(),
|
180 |
+
out.size(1),
|
181 |
+
c10::cuda::getCurrentCUDAStream()));
|
182 |
+
return;
|
183 |
+
}
|
184 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16);
|
185 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
|
186 |
+
x.size(0),
|
187 |
+
x.size(1),
|
188 |
+
bins.data_ptr<int>(),
|
189 |
+
out.data_ptr<short>(),
|
190 |
+
out.size(1),
|
191 |
+
c10::cuda::getCurrentCUDAStream()));
|
192 |
+
}
|
193 |
+
|
194 |
+
void replicate_backward(torch::Tensor grad,
|
195 |
+
torch::Tensor bins,
|
196 |
+
torch::Tensor out) {
|
197 |
+
// Validate the inputs.
|
198 |
+
TORCH_CHECK(grad.is_cuda());
|
199 |
+
TORCH_CHECK(grad.ndimension() == 2);
|
200 |
+
TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
|
201 |
+
TORCH_CHECK(bins.is_cuda());
|
202 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
203 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
204 |
+
TORCH_CHECK(out.is_cuda());
|
205 |
+
TORCH_CHECK(out.ndimension() == 2);
|
206 |
+
TORCH_CHECK(out.scalar_type() == torch::kFloat16);
|
207 |
+
|
208 |
+
// Batch dimensions should match for input/output.
|
209 |
+
TORCH_CHECK(grad.size(0) == out.size(0));
|
210 |
+
|
211 |
+
// One output for each bin (in each batch).
|
212 |
+
TORCH_CHECK(out.size(1) == bins.size(0));
|
213 |
+
|
214 |
+
replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
|
215 |
+
}
|
216 |
+
|
217 |
+
} // namespace megablocks
|
218 |
+
|
219 |
+
#undef CUDA_CALL
|
220 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_replicate.h
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Forward pass: replicate values from x according to bin sizes
|
8 |
+
void replicate_forward(torch::Tensor x,
|
9 |
+
torch::Tensor bins,
|
10 |
+
torch::Tensor out);
|
11 |
+
|
12 |
+
// Backward pass: reduce gradients back to bins using segmented reduction
|
13 |
+
void replicate_backward(torch::Tensor grad,
|
14 |
+
torch::Tensor bins,
|
15 |
+
torch::Tensor out);
|
16 |
+
|
17 |
+
} // namespace megablocks
|
csrc/new_sort.cu
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include "new_sort.h"
|
5 |
+
#include <cstdint>
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
|
9 |
+
#define CUDA_CALL(code) \
|
10 |
+
do { \
|
11 |
+
cudaError_t status = code; \
|
12 |
+
std::string err = cudaGetErrorString(status); \
|
13 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
14 |
+
} while (0)
|
15 |
+
|
16 |
+
namespace megablocks {
|
17 |
+
|
18 |
+
template <typename T>
|
19 |
+
void cub_radix_sort(torch::Tensor x,
|
20 |
+
int end_bit,
|
21 |
+
torch::Tensor x_out,
|
22 |
+
torch::Tensor iota_out) {
|
23 |
+
// Get iota for values in sort.
|
24 |
+
torch::Tensor iota = torch::arange(0, x.numel(), x.options());
|
25 |
+
|
26 |
+
// Get temporary buffer size.
|
27 |
+
size_t scratchpad_bytes = 0;
|
28 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
|
29 |
+
scratchpad_bytes,
|
30 |
+
x.data_ptr<T>(),
|
31 |
+
x_out.data_ptr<T>(),
|
32 |
+
iota.data_ptr<T>(),
|
33 |
+
iota_out.data_ptr<T>(),
|
34 |
+
x.numel(),
|
35 |
+
/*begin_bit*/0,
|
36 |
+
/*end_bit=*/end_bit,
|
37 |
+
c10::cuda::getCurrentCUDAStream()));
|
38 |
+
|
39 |
+
// Allocate scratchpad.
|
40 |
+
auto options = torch::TensorOptions()
|
41 |
+
.dtype(torch::kInt8)
|
42 |
+
.device(x.device());
|
43 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
44 |
+
|
45 |
+
// Run the kernel.
|
46 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
|
47 |
+
scratchpad_bytes,
|
48 |
+
x.data_ptr<T>(),
|
49 |
+
x_out.data_ptr<T>(),
|
50 |
+
iota.data_ptr<T>(),
|
51 |
+
iota_out.data_ptr<T>(),
|
52 |
+
x.numel(),
|
53 |
+
/*begin_bit=*/0,
|
54 |
+
/*end_bit=*/end_bit,
|
55 |
+
c10::cuda::getCurrentCUDAStream()));
|
56 |
+
}
|
57 |
+
|
58 |
+
void sort(torch::Tensor x,
|
59 |
+
int end_bit,
|
60 |
+
torch::Tensor x_out,
|
61 |
+
torch::Tensor iota_out) {
|
62 |
+
TORCH_CHECK(x.is_cuda());
|
63 |
+
TORCH_CHECK(x.ndimension() == 1);
|
64 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
65 |
+
x.scalar_type() == torch::kInt32 ||
|
66 |
+
x.scalar_type() == torch::kInt64);
|
67 |
+
TORCH_CHECK(x_out.is_cuda());
|
68 |
+
TORCH_CHECK(x_out.ndimension() == 1);
|
69 |
+
TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
|
70 |
+
TORCH_CHECK(iota_out.is_cuda());
|
71 |
+
TORCH_CHECK(iota_out.ndimension() == 1);
|
72 |
+
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
|
73 |
+
|
74 |
+
// Exit early if there is not work to do.
|
75 |
+
if (x_out.numel() == 0) return;
|
76 |
+
|
77 |
+
switch (x.scalar_type()) {
|
78 |
+
case torch::kInt16:
|
79 |
+
return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
|
80 |
+
case torch::kInt32:
|
81 |
+
return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
|
82 |
+
}
|
83 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
84 |
+
return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
|
85 |
+
}
|
86 |
+
|
87 |
+
} // namespace megablocks
|
88 |
+
|
89 |
+
#undef CUDA_CALL
|
90 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/new_sort.h
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
namespace megablocks {
|
6 |
+
|
7 |
+
// Public interface function for radix sorting with indices
|
8 |
+
void sort(torch::Tensor x,
|
9 |
+
int end_bit,
|
10 |
+
torch::Tensor x_out,
|
11 |
+
torch::Tensor iota_out);
|
12 |
+
|
13 |
+
} // namespace megablocks
|
csrc/replicate.h
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/util/Half.h>
|
8 |
+
#include <c10/cuda/CUDAStream.h>
|
9 |
+
// #include <torch/extension.h>
|
10 |
+
|
11 |
+
#define CUDA_CALL(code) \
|
12 |
+
do { \
|
13 |
+
cudaError_t status = code; \
|
14 |
+
std::string err = cudaGetErrorString(status); \
|
15 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
16 |
+
} while (0)
|
17 |
+
|
18 |
+
namespace megablocks {
|
19 |
+
namespace replicate {
|
20 |
+
|
21 |
+
template <typename T, int kThreadsPerBlock>
|
22 |
+
__global__ void __launch_bounds__(kThreadsPerBlock)
|
23 |
+
ReplicateForwardKernel(T * __restrict__ x,
|
24 |
+
int * __restrict__ bins,
|
25 |
+
T * __restrict__ out,
|
26 |
+
int columns) {
|
27 |
+
// Offset to this threadblocks batch.
|
28 |
+
//
|
29 |
+
// x is [batch_size, num_bins]
|
30 |
+
// out is [batch_size, columns]
|
31 |
+
// bins is [num_bins]
|
32 |
+
int batch_idx = blockIdx.y;
|
33 |
+
int num_bins = gridDim.x;
|
34 |
+
x += batch_idx * num_bins;
|
35 |
+
out += batch_idx * columns;
|
36 |
+
|
37 |
+
// Load the start/end for this bin.
|
38 |
+
int bin_idx = blockIdx.x;
|
39 |
+
int start = 0;
|
40 |
+
if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
|
41 |
+
int end = __ldg(bins + bin_idx);
|
42 |
+
|
43 |
+
// Load the value to replicate.
|
44 |
+
T value = __ldg((T*)x + bin_idx);
|
45 |
+
|
46 |
+
// Offset to this threadblocks bin and this threads
|
47 |
+
// offset within the bin.
|
48 |
+
int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
|
49 |
+
out += start + bin_offset;
|
50 |
+
|
51 |
+
// Replicate the value to the output.
|
52 |
+
//
|
53 |
+
// TODO(tgale): Vectorize these stores.
|
54 |
+
int num_elements = end - start;
|
55 |
+
const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
|
56 |
+
T *out_ptr = (T*)out;
|
57 |
+
for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
|
58 |
+
*out_ptr = value;
|
59 |
+
out_ptr += kElementsPerLoop;
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
template <typename T>
|
64 |
+
cudaError_t ReplicateForward(T *x,
|
65 |
+
int batch_size,
|
66 |
+
int num_bins,
|
67 |
+
int *bins,
|
68 |
+
T *out,
|
69 |
+
int columns,
|
70 |
+
cudaStream_t stream) {
|
71 |
+
const int kThreadsPerBlock = 64;
|
72 |
+
dim3 block_dim(kThreadsPerBlock, 1, 1);
|
73 |
+
int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
|
74 |
+
dim3 grid_dim(num_bins, batch_size, group_size);
|
75 |
+
ReplicateForwardKernel<T, kThreadsPerBlock><<<
|
76 |
+
grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
|
77 |
+
return cudaGetLastError();
|
78 |
+
}
|
79 |
+
|
80 |
+
void cub_segmented_reduce(torch::Tensor grad,
|
81 |
+
torch::Tensor bins,
|
82 |
+
torch::Tensor out,
|
83 |
+
cudaStream_t stream) {
|
84 |
+
// Append a zero to the bin boundaries for CUB.
|
85 |
+
torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
|
86 |
+
CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
|
87 |
+
0,
|
88 |
+
offsets.numel() * sizeof(int),
|
89 |
+
stream));
|
90 |
+
CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
|
91 |
+
bins.data_ptr<int>(),
|
92 |
+
bins.numel() * sizeof(int),
|
93 |
+
cudaMemcpyDeviceToDevice,
|
94 |
+
stream));
|
95 |
+
|
96 |
+
// Get temporary buffer size.
|
97 |
+
size_t scratchpad_bytes = 0;
|
98 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
|
99 |
+
scratchpad_bytes,
|
100 |
+
grad.data_ptr<c10::Half>(),
|
101 |
+
out.data_ptr<c10::Half>(),
|
102 |
+
bins.numel(),
|
103 |
+
offsets.data_ptr<int>(),
|
104 |
+
offsets.data_ptr<int>() + 1,
|
105 |
+
stream));
|
106 |
+
|
107 |
+
// Allocate scratchpad.
|
108 |
+
auto options = torch::TensorOptions()
|
109 |
+
.dtype(torch::kInt8)
|
110 |
+
.device(grad.device());
|
111 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
112 |
+
|
113 |
+
// Run the kernel for each batch item.
|
114 |
+
for (int i = 0; i < grad.size(0); ++i) {
|
115 |
+
int num_bins = out.size(1);
|
116 |
+
int num_values = grad.size(1);
|
117 |
+
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
|
118 |
+
scratchpad_bytes,
|
119 |
+
grad.data_ptr<c10::Half>() + i * num_values,
|
120 |
+
out.data_ptr<c10::Half>() + i * num_bins,
|
121 |
+
bins.numel(),
|
122 |
+
offsets.data_ptr<int>(),
|
123 |
+
offsets.data_ptr<int>() + 1,
|
124 |
+
stream));
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
} // namespace replicate
|
129 |
+
|
130 |
+
void replicate_forward(torch::Tensor x,
|
131 |
+
torch::Tensor bins,
|
132 |
+
torch::Tensor out) {
|
133 |
+
// Validate the inputs.
|
134 |
+
TORCH_CHECK(x.is_cuda());
|
135 |
+
TORCH_CHECK(x.ndimension() == 2);
|
136 |
+
TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
|
137 |
+
x.scalar_type() == torch::kInt16 ||
|
138 |
+
x.scalar_type() == torch::kInt32);
|
139 |
+
TORCH_CHECK(bins.is_cuda());
|
140 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
141 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
142 |
+
TORCH_CHECK(out.is_cuda());
|
143 |
+
TORCH_CHECK(out.ndimension() == 2);
|
144 |
+
TORCH_CHECK(out.scalar_type() == x.scalar_type());
|
145 |
+
|
146 |
+
// Batch dimensions should match for input/output.
|
147 |
+
TORCH_CHECK(x.size(0) == out.size(0));
|
148 |
+
|
149 |
+
// One input for each bin (in each batch).
|
150 |
+
TORCH_CHECK(x.size(1) == bins.size(0));
|
151 |
+
|
152 |
+
// Exit early if there is no work to do.
|
153 |
+
if (out.numel() == 0) return;
|
154 |
+
|
155 |
+
switch (x.scalar_type()) {
|
156 |
+
case torch::kFloat16:
|
157 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
|
158 |
+
x.size(0),
|
159 |
+
x.size(1),
|
160 |
+
bins.data_ptr<int>(),
|
161 |
+
out.data_ptr<c10::Half>(),
|
162 |
+
out.size(1),
|
163 |
+
c10::cuda::getCurrentCUDAStream()));
|
164 |
+
return;
|
165 |
+
case torch::kInt32:
|
166 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
|
167 |
+
x.size(0),
|
168 |
+
x.size(1),
|
169 |
+
bins.data_ptr<int>(),
|
170 |
+
out.data_ptr<int>(),
|
171 |
+
out.size(1),
|
172 |
+
c10::cuda::getCurrentCUDAStream()));
|
173 |
+
return;
|
174 |
+
}
|
175 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16);
|
176 |
+
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
|
177 |
+
x.size(0),
|
178 |
+
x.size(1),
|
179 |
+
bins.data_ptr<int>(),
|
180 |
+
out.data_ptr<short>(),
|
181 |
+
out.size(1),
|
182 |
+
c10::cuda::getCurrentCUDAStream()));
|
183 |
+
}
|
184 |
+
|
185 |
+
void replicate_backward(torch::Tensor grad,
|
186 |
+
torch::Tensor bins,
|
187 |
+
torch::Tensor out) {
|
188 |
+
// Validate the inputs.
|
189 |
+
TORCH_CHECK(grad.is_cuda());
|
190 |
+
TORCH_CHECK(grad.ndimension() == 2);
|
191 |
+
TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
|
192 |
+
TORCH_CHECK(bins.is_cuda());
|
193 |
+
TORCH_CHECK(bins.ndimension() == 1);
|
194 |
+
TORCH_CHECK(bins.scalar_type() == torch::kInt);
|
195 |
+
TORCH_CHECK(out.is_cuda());
|
196 |
+
TORCH_CHECK(out.ndimension() == 2);
|
197 |
+
TORCH_CHECK(out.scalar_type() == torch::kFloat16);
|
198 |
+
|
199 |
+
// Batch dimensions should match for input/output.
|
200 |
+
TORCH_CHECK(grad.size(0) == out.size(0));
|
201 |
+
|
202 |
+
// One output for each bin (in each batch).
|
203 |
+
TORCH_CHECK(out.size(1) == bins.size(0));
|
204 |
+
|
205 |
+
replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
|
206 |
+
}
|
207 |
+
|
208 |
+
} // namespace megablocks
|
209 |
+
|
210 |
+
#undef CUDA_CALL
|
211 |
+
#undef CUB_WRAPPED_NAMESPACE
|
csrc/sort.h
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#undef CUB_WRAPPED_NAMESPACE
|
2 |
+
#define CUB_WRAPPED_NAMESPACE megablocks
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
#include <cub/cub.cuh>
|
7 |
+
#include <c10/cuda/CUDAStream.h>
|
8 |
+
// #include <torch/extension.h>
|
9 |
+
|
10 |
+
#define CUDA_CALL(code) \
|
11 |
+
do { \
|
12 |
+
cudaError_t status = code; \
|
13 |
+
std::string err = cudaGetErrorString(status); \
|
14 |
+
TORCH_CHECK(status == cudaSuccess, err); \
|
15 |
+
} while (0)
|
16 |
+
|
17 |
+
namespace megablocks {
|
18 |
+
|
19 |
+
template <typename T>
|
20 |
+
void cub_radix_sort(torch::Tensor x,
|
21 |
+
int end_bit,
|
22 |
+
torch::Tensor x_out,
|
23 |
+
torch::Tensor iota_out) {
|
24 |
+
// Get iota for values in sort.
|
25 |
+
torch::Tensor iota = torch::arange(0, x.numel(), x.options());
|
26 |
+
|
27 |
+
// Get temporary buffer size.
|
28 |
+
size_t scratchpad_bytes = 0;
|
29 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
|
30 |
+
scratchpad_bytes,
|
31 |
+
x.data_ptr<T>(),
|
32 |
+
x_out.data_ptr<T>(),
|
33 |
+
iota.data_ptr<T>(),
|
34 |
+
iota_out.data_ptr<T>(),
|
35 |
+
x.numel(),
|
36 |
+
/*begin_bit*/0,
|
37 |
+
/*end_bit=*/end_bit,
|
38 |
+
c10::cuda::getCurrentCUDAStream()));
|
39 |
+
|
40 |
+
// Allocate scratchpad.
|
41 |
+
auto options = torch::TensorOptions()
|
42 |
+
.dtype(torch::kInt8)
|
43 |
+
.device(x.device());
|
44 |
+
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
|
45 |
+
|
46 |
+
// Run the kernel.
|
47 |
+
CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
|
48 |
+
scratchpad_bytes,
|
49 |
+
x.data_ptr<T>(),
|
50 |
+
x_out.data_ptr<T>(),
|
51 |
+
iota.data_ptr<T>(),
|
52 |
+
iota_out.data_ptr<T>(),
|
53 |
+
x.numel(),
|
54 |
+
/*begin_bit=*/0,
|
55 |
+
/*end_bit=*/end_bit,
|
56 |
+
c10::cuda::getCurrentCUDAStream()));
|
57 |
+
}
|
58 |
+
|
59 |
+
void sort(torch::Tensor x,
|
60 |
+
int end_bit,
|
61 |
+
torch::Tensor x_out,
|
62 |
+
torch::Tensor iota_out) {
|
63 |
+
TORCH_CHECK(x.is_cuda());
|
64 |
+
TORCH_CHECK(x.ndimension() == 1);
|
65 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
|
66 |
+
x.scalar_type() == torch::kInt32 ||
|
67 |
+
x.scalar_type() == torch::kInt64);
|
68 |
+
TORCH_CHECK(x_out.is_cuda());
|
69 |
+
TORCH_CHECK(x_out.ndimension() == 1);
|
70 |
+
TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
|
71 |
+
TORCH_CHECK(iota_out.is_cuda());
|
72 |
+
TORCH_CHECK(iota_out.ndimension() == 1);
|
73 |
+
TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
|
74 |
+
|
75 |
+
// Exit early if there is not work to do.
|
76 |
+
if (x_out.numel() == 0) return;
|
77 |
+
|
78 |
+
switch (x.scalar_type()) {
|
79 |
+
case torch::kInt16:
|
80 |
+
return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
|
81 |
+
case torch::kInt32:
|
82 |
+
return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
|
83 |
+
}
|
84 |
+
TORCH_CHECK(x.scalar_type() == torch::kInt64);
|
85 |
+
return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
|
86 |
+
}
|
87 |
+
|
88 |
+
} // namespace megablocks
|
89 |
+
|
90 |
+
#undef CUDA_CALL
|
91 |
+
#undef CUB_WRAPPED_NAMESPACE
|
flake.lock
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1747046372,
|
6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-compat_2": {
|
19 |
+
"locked": {
|
20 |
+
"lastModified": 1733328505,
|
21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
22 |
+
"owner": "edolstra",
|
23 |
+
"repo": "flake-compat",
|
24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
25 |
+
"type": "github"
|
26 |
+
},
|
27 |
+
"original": {
|
28 |
+
"owner": "edolstra",
|
29 |
+
"repo": "flake-compat",
|
30 |
+
"type": "github"
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"flake-utils": {
|
34 |
+
"inputs": {
|
35 |
+
"systems": "systems"
|
36 |
+
},
|
37 |
+
"locked": {
|
38 |
+
"lastModified": 1731533236,
|
39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
40 |
+
"owner": "numtide",
|
41 |
+
"repo": "flake-utils",
|
42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
43 |
+
"type": "github"
|
44 |
+
},
|
45 |
+
"original": {
|
46 |
+
"owner": "numtide",
|
47 |
+
"repo": "flake-utils",
|
48 |
+
"type": "github"
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"flake-utils_2": {
|
52 |
+
"inputs": {
|
53 |
+
"systems": "systems_2"
|
54 |
+
},
|
55 |
+
"locked": {
|
56 |
+
"lastModified": 1731533236,
|
57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
+
"owner": "numtide",
|
59 |
+
"repo": "flake-utils",
|
60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
+
"type": "github"
|
62 |
+
},
|
63 |
+
"original": {
|
64 |
+
"owner": "numtide",
|
65 |
+
"repo": "flake-utils",
|
66 |
+
"type": "github"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"hf-nix": {
|
70 |
+
"inputs": {
|
71 |
+
"flake-compat": "flake-compat_2",
|
72 |
+
"flake-utils": "flake-utils_2",
|
73 |
+
"nixpkgs": "nixpkgs"
|
74 |
+
},
|
75 |
+
"locked": {
|
76 |
+
"lastModified": 1751968576,
|
77 |
+
"narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
|
78 |
+
"owner": "huggingface",
|
79 |
+
"repo": "hf-nix",
|
80 |
+
"rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
|
81 |
+
"type": "github"
|
82 |
+
},
|
83 |
+
"original": {
|
84 |
+
"owner": "huggingface",
|
85 |
+
"repo": "hf-nix",
|
86 |
+
"type": "github"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"kernel-builder": {
|
90 |
+
"inputs": {
|
91 |
+
"flake-compat": "flake-compat",
|
92 |
+
"flake-utils": "flake-utils",
|
93 |
+
"hf-nix": "hf-nix",
|
94 |
+
"nixpkgs": [
|
95 |
+
"kernel-builder",
|
96 |
+
"hf-nix",
|
97 |
+
"nixpkgs"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1753256281,
|
102 |
+
"narHash": "sha256-CfL3Fyf2ih7OtyL7ScZUCwOeCj+gjlRyPykhR6Zbt3I=",
|
103 |
+
"owner": "huggingface",
|
104 |
+
"repo": "kernel-builder",
|
105 |
+
"rev": "dcbbdf2d3c8e78b27321b205b2c9d67ffce6a706",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "huggingface",
|
110 |
+
"repo": "kernel-builder",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"nixpkgs": {
|
115 |
+
"locked": {
|
116 |
+
"lastModified": 1747820358,
|
117 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
118 |
+
"owner": "danieldk",
|
119 |
+
"repo": "nixpkgs",
|
120 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
121 |
+
"type": "github"
|
122 |
+
},
|
123 |
+
"original": {
|
124 |
+
"owner": "danieldk",
|
125 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
+
"repo": "nixpkgs",
|
127 |
+
"type": "github"
|
128 |
+
}
|
129 |
+
},
|
130 |
+
"root": {
|
131 |
+
"inputs": {
|
132 |
+
"kernel-builder": "kernel-builder"
|
133 |
+
}
|
134 |
+
},
|
135 |
+
"systems": {
|
136 |
+
"locked": {
|
137 |
+
"lastModified": 1681028828,
|
138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
139 |
+
"owner": "nix-systems",
|
140 |
+
"repo": "default",
|
141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
142 |
+
"type": "github"
|
143 |
+
},
|
144 |
+
"original": {
|
145 |
+
"owner": "nix-systems",
|
146 |
+
"repo": "default",
|
147 |
+
"type": "github"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"systems_2": {
|
151 |
+
"locked": {
|
152 |
+
"lastModified": 1681028828,
|
153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
154 |
+
"owner": "nix-systems",
|
155 |
+
"repo": "default",
|
156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
157 |
+
"type": "github"
|
158 |
+
},
|
159 |
+
"original": {
|
160 |
+
"owner": "nix-systems",
|
161 |
+
"repo": "default",
|
162 |
+
"type": "github"
|
163 |
+
}
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"root": "root",
|
167 |
+
"version": 7
|
168 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for megablocks_moe kernel";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs {
|
14 |
+
path = ./.;
|
15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
+
|
17 |
+
pythonCheckInputs = pkgs: with pkgs; [
|
18 |
+
tqdm
|
19 |
+
py-cpuinfo
|
20 |
+
importlib-metadata
|
21 |
+
torchmetrics
|
22 |
+
];
|
23 |
+
};
|
24 |
+
}
|
tests/__init__.py
ADDED
File without changes
|
tests/conftest.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import os
|
5 |
+
from typing import List, Optional
|
6 |
+
|
7 |
+
import pytest
|
8 |
+
# from composer.utils import reproducibility
|
9 |
+
|
10 |
+
# Allowed options for pytest.mark.world_size()
|
11 |
+
WORLD_SIZE_OPTIONS = (1, 2)
|
12 |
+
|
13 |
+
# Enforce deterministic mode before any tests start.
|
14 |
+
# reproducibility.configure_deterministic_mode()
|
15 |
+
|
16 |
+
# TODO: allow plugind when deps resolved
|
17 |
+
|
18 |
+
# Add the path of any pytest fixture files you want to make global
|
19 |
+
pytest_plugins = [
|
20 |
+
# 'tests.fixtures.autouse',
|
21 |
+
'tests.fixtures.fixtures',
|
22 |
+
]
|
23 |
+
|
24 |
+
|
25 |
+
def _get_world_size(item: pytest.Item):
|
26 |
+
"""Returns the world_size of a test, defaults to 1."""
|
27 |
+
_default = pytest.mark.world_size(1).mark
|
28 |
+
return item.get_closest_marker('world_size', default=_default).args[0]
|
29 |
+
|
30 |
+
|
31 |
+
def _get_option(
|
32 |
+
config: pytest.Config,
|
33 |
+
name: str,
|
34 |
+
default: Optional[str] = None,
|
35 |
+
) -> str: # type: ignore
|
36 |
+
val = config.getoption(name)
|
37 |
+
if val is not None:
|
38 |
+
assert isinstance(val, str)
|
39 |
+
return val
|
40 |
+
val = config.getini(name)
|
41 |
+
if val == []:
|
42 |
+
val = None
|
43 |
+
if val is None:
|
44 |
+
if default is None:
|
45 |
+
pytest.fail(f'Config option {name} is not specified but is required',)
|
46 |
+
val = default
|
47 |
+
assert isinstance(val, str)
|
48 |
+
return val
|
49 |
+
|
50 |
+
|
51 |
+
def _add_option(
|
52 |
+
parser: pytest.Parser,
|
53 |
+
name: str,
|
54 |
+
help: str,
|
55 |
+
choices: Optional[list[str]] = None,
|
56 |
+
):
|
57 |
+
parser.addoption(
|
58 |
+
f'--{name}',
|
59 |
+
default=None,
|
60 |
+
type=str,
|
61 |
+
choices=choices,
|
62 |
+
help=help,
|
63 |
+
)
|
64 |
+
parser.addini(
|
65 |
+
name=name,
|
66 |
+
help=help,
|
67 |
+
type='string',
|
68 |
+
default=None,
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
def pytest_collection_modifyitems(
|
73 |
+
config: pytest.Config,
|
74 |
+
items: List[pytest.Item],
|
75 |
+
) -> None:
|
76 |
+
"""Filter tests by world_size (for multi-GPU tests)"""
|
77 |
+
world_size = int(os.environ.get('WORLD_SIZE', '1'))
|
78 |
+
print(f'world_size={world_size}')
|
79 |
+
|
80 |
+
conditions = [
|
81 |
+
lambda item: _get_world_size(item) == world_size,
|
82 |
+
]
|
83 |
+
|
84 |
+
# keep items that satisfy all conditions
|
85 |
+
remaining = []
|
86 |
+
deselected = []
|
87 |
+
for item in items:
|
88 |
+
if all(condition(item) for condition in conditions):
|
89 |
+
remaining.append(item)
|
90 |
+
else:
|
91 |
+
deselected.append(item)
|
92 |
+
|
93 |
+
if deselected:
|
94 |
+
config.hook.pytest_deselected(items=deselected)
|
95 |
+
items[:] = remaining
|
96 |
+
|
97 |
+
|
98 |
+
def pytest_addoption(parser: pytest.Parser) -> None:
|
99 |
+
_add_option(
|
100 |
+
parser,
|
101 |
+
'seed',
|
102 |
+
help="""\
|
103 |
+
Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
|
104 |
+
before each test.""",
|
105 |
+
)
|
106 |
+
|
107 |
+
|
108 |
+
def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
|
109 |
+
if exitstatus == 5:
|
110 |
+
session.exitstatus = 0 # Ignore no-test-ran errors
|
tests/fixtures/autouse.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import gc
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
|
8 |
+
import composer
|
9 |
+
import pytest
|
10 |
+
import torch
|
11 |
+
from composer.devices import DeviceCPU, DeviceGPU
|
12 |
+
from composer.utils import dist, reproducibility
|
13 |
+
|
14 |
+
|
15 |
+
@pytest.fixture(autouse=True)
|
16 |
+
def clear_cuda_cache(request: pytest.FixtureRequest):
|
17 |
+
"""Clear memory between GPU tests."""
|
18 |
+
marker = request.node.get_closest_marker('gpu')
|
19 |
+
if marker is not None and torch.cuda.is_available():
|
20 |
+
torch.cuda.empty_cache()
|
21 |
+
gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
|
22 |
+
|
23 |
+
|
24 |
+
@pytest.fixture(autouse=True)
|
25 |
+
def reset_mlflow_tracking_dir():
|
26 |
+
"""Reset MLFlow tracking dir so it doesn't persist across tests."""
|
27 |
+
try:
|
28 |
+
import mlflow
|
29 |
+
mlflow.set_tracking_uri(None) # type: ignore
|
30 |
+
except ModuleNotFoundError:
|
31 |
+
# MLFlow not installed
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.fixture(scope='session')
|
36 |
+
def cleanup_dist():
|
37 |
+
"""Ensure all dist tests clean up resources properly."""
|
38 |
+
yield
|
39 |
+
# Avoid race condition where a test is still writing to a file on one rank
|
40 |
+
# while the file system is being torn down on another rank.
|
41 |
+
dist.barrier()
|
42 |
+
|
43 |
+
|
44 |
+
@pytest.fixture(autouse=True, scope='session')
|
45 |
+
def configure_dist(request: pytest.FixtureRequest):
|
46 |
+
# Configure dist globally when the world size is greater than 1,
|
47 |
+
# so individual tests that do not use the trainer
|
48 |
+
# do not need to worry about manually configuring dist.
|
49 |
+
|
50 |
+
if dist.get_world_size() == 1:
|
51 |
+
return
|
52 |
+
|
53 |
+
device = None
|
54 |
+
|
55 |
+
for item in request.session.items:
|
56 |
+
device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU()
|
57 |
+
break
|
58 |
+
|
59 |
+
assert device is not None
|
60 |
+
|
61 |
+
if not dist.is_initialized():
|
62 |
+
dist.initialize_dist(device, timeout=300.0)
|
63 |
+
# Hold PyTest until all ranks have reached this barrier. Ensure that no rank starts
|
64 |
+
# any test before other ranks are ready to start it, which could be a cause of random timeouts
|
65 |
+
# (e.g. rank 1 starts the next test while rank 0 is finishing up the previous test).
|
66 |
+
dist.barrier()
|
67 |
+
|
68 |
+
|
69 |
+
@pytest.fixture(autouse=True)
|
70 |
+
def set_log_levels():
|
71 |
+
"""Ensures all log levels are set to DEBUG."""
|
72 |
+
logging.basicConfig()
|
73 |
+
logging.getLogger(composer.__name__).setLevel(logging.DEBUG)
|
74 |
+
|
75 |
+
|
76 |
+
@pytest.fixture(autouse=True)
|
77 |
+
def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
|
78 |
+
"""Monkeypatch reproducibility.
|
79 |
+
|
80 |
+
Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local
|
81 |
+
seed.
|
82 |
+
"""
|
83 |
+
monkeypatch.setattr(
|
84 |
+
reproducibility,
|
85 |
+
'get_random_seed',
|
86 |
+
lambda: rank_zero_seed,
|
87 |
+
)
|
88 |
+
reproducibility.seed_all(rank_zero_seed + dist.get_global_rank())
|
89 |
+
|
90 |
+
|
91 |
+
@pytest.fixture(autouse=True)
|
92 |
+
def remove_run_name_env_var():
|
93 |
+
# Remove environment variables for run names in unit tests
|
94 |
+
composer_run_name = os.environ.get('COMPOSER_RUN_NAME')
|
95 |
+
run_name = os.environ.get('RUN_NAME')
|
96 |
+
|
97 |
+
if 'COMPOSER_RUN_NAME' in os.environ:
|
98 |
+
del os.environ['COMPOSER_RUN_NAME']
|
99 |
+
if 'RUN_NAME' in os.environ:
|
100 |
+
del os.environ['RUN_NAME']
|
101 |
+
|
102 |
+
yield
|
103 |
+
|
104 |
+
if composer_run_name is not None:
|
105 |
+
os.environ['COMPOSER_RUN_NAME'] = composer_run_name
|
106 |
+
if run_name is not None:
|
107 |
+
os.environ['RUN_NAME'] = run_name
|
tests/fixtures/fixtures.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
|
6 |
+
from tests.conftest import _get_option
|
7 |
+
|
8 |
+
|
9 |
+
@pytest.fixture
|
10 |
+
def rank_zero_seed(pytestconfig: pytest.Config) -> int:
|
11 |
+
"""Read the rank_zero_seed from the CLI option."""
|
12 |
+
seed = _get_option(pytestconfig, 'seed', default='0')
|
13 |
+
return int(seed)
|
tests/layer_test.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
|
6 |
+
def test_megablocks_moe_mlp_import():
|
7 |
+
"""Test if MegaBlocksMoeMLP can be imported."""
|
8 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
9 |
+
|
10 |
+
assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."
|
11 |
+
|
12 |
+
|
13 |
+
def test_megablocks_moe_mlp_functionality():
|
14 |
+
"""Test the functionality of MegaBlocksMoeMLP."""
|
15 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
16 |
+
|
17 |
+
# Create a simple instance of MegaBlocksMoeMLP
|
18 |
+
model = MegaBlocksMoeMLP()
|
19 |
+
|
20 |
+
# add experts attribute to the model
|
21 |
+
model.experts = namedtuple(
|
22 |
+
"Experts",
|
23 |
+
[
|
24 |
+
"gate_up_proj",
|
25 |
+
"gate_down_proj",
|
26 |
+
"down_proj",
|
27 |
+
"hidden_size",
|
28 |
+
],
|
29 |
+
)
|
30 |
+
|
31 |
+
num_experts = 128
|
32 |
+
hidden_size = 1152
|
33 |
+
intermediate_size = 3072
|
34 |
+
|
35 |
+
# Shorter names for reading convenience
|
36 |
+
ne, hs, isz = num_experts, hidden_size, intermediate_size
|
37 |
+
|
38 |
+
model.router = torch.nn.Linear(hs, ne).cuda()
|
39 |
+
model.router.weight.data.fill_(1)
|
40 |
+
|
41 |
+
e = model.experts
|
42 |
+
e.gate_up_proj = torch.nn.Parameter(torch.ones(ne, hs, isz, device="cuda"))
|
43 |
+
e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
|
44 |
+
e.down_proj = torch.nn.Parameter(torch.ones(ne, 1536, hs, device="cuda"))
|
45 |
+
e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
|
46 |
+
e.hidden_size = hs
|
47 |
+
|
48 |
+
# Create dummy input data
|
49 |
+
x = torch.randn(1, 1, 1152).to(torch.device("cuda"))
|
50 |
+
output, expert_weights_out = model(x)
|
51 |
+
|
52 |
+
# print("Output shape:", output.shape)
|
53 |
+
assert output.shape == (1, 1, 1152), "Output shape mismatch."
|
tests/layers/architectures.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from megablocks._layers.arguments import Arguments
|
8 |
+
|
9 |
+
|
10 |
+
class FFN(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, args: Arguments):
|
13 |
+
super().__init__()
|
14 |
+
self.w1 = torch.nn.Parameter(
|
15 |
+
torch.empty(
|
16 |
+
args.hidden_size,
|
17 |
+
args.ffn_hidden_size,
|
18 |
+
device=args.device,
|
19 |
+
dtype=torch.float16 if args.fp16 else torch.float32,
|
20 |
+
),
|
21 |
+
)
|
22 |
+
self.w2 = torch.nn.Parameter(
|
23 |
+
torch.empty(
|
24 |
+
args.ffn_hidden_size,
|
25 |
+
args.hidden_size,
|
26 |
+
device=args.device,
|
27 |
+
dtype=torch.float16 if args.fp16 else torch.float32,
|
28 |
+
),
|
29 |
+
)
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
return torch.matmul(
|
33 |
+
F.gelu(torch.matmul(x, self.w1), approximate='tanh'),
|
34 |
+
self.w2,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class GLU(FFN):
|
39 |
+
|
40 |
+
def __init__(self, args: Arguments):
|
41 |
+
super().__init__(args)
|
42 |
+
self.v1 = torch.nn.Parameter(
|
43 |
+
torch.empty(
|
44 |
+
args.hidden_size,
|
45 |
+
args.ffn_hidden_size,
|
46 |
+
device=args.device,
|
47 |
+
dtype=torch.float16 if args.fp16 else torch.float32,
|
48 |
+
),
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1)
|
53 |
+
return torch.matmul(x1, self.w2)
|
tests/layers/moe_test.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import pytest
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from megablocks._layers.arguments import Arguments
|
10 |
+
from megablocks._layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
|
11 |
+
from megablocks._layers.router import batched_router_zloss, clear_router_zloss
|
12 |
+
from tests.layers.architectures import FFN
|
13 |
+
|
14 |
+
_FORWARD_TESTS = (
|
15 |
+
(16, 1024, 512, 1, 1),
|
16 |
+
(16, 1024, 512, 2, 1),
|
17 |
+
(16, 1024, 512, 4, 1),
|
18 |
+
(16, 1024, 512, 8, 1),
|
19 |
+
(8, 2048, 512, 1, 1),
|
20 |
+
(8, 2048, 512, 2, 1),
|
21 |
+
(8, 2048, 512, 4, 1),
|
22 |
+
(16, 1024, 512, 2, 2),
|
23 |
+
(16, 1024, 512, 4, 2),
|
24 |
+
(16, 1024, 512, 4, 4),
|
25 |
+
(16, 1024, 512, 8, 2),
|
26 |
+
(16, 1024, 512, 8, 4),
|
27 |
+
(16, 1024, 512, 8, 8),
|
28 |
+
)
|
29 |
+
|
30 |
+
_DENSE_TESTS = (
|
31 |
+
(16, 1024, 512),
|
32 |
+
(8, 2048, 512),
|
33 |
+
)
|
34 |
+
|
35 |
+
|
36 |
+
def construct_moe(
|
37 |
+
hidden_size: int,
|
38 |
+
ffn_hidden_size: int,
|
39 |
+
moe_num_experts: int = 1,
|
40 |
+
moe_capacity_factor: int = 1,
|
41 |
+
moe_top_k: int = 1,
|
42 |
+
moe_zloss_weight: float = 0,
|
43 |
+
):
|
44 |
+
# All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
|
45 |
+
# TODO: Remove this once sparse is supported with triton >=3.2.0
|
46 |
+
try:
|
47 |
+
import triton
|
48 |
+
if triton.__version__ >= '3.2.0':
|
49 |
+
pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
|
50 |
+
except ImportError:
|
51 |
+
pass
|
52 |
+
|
53 |
+
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
|
54 |
+
args = Arguments(
|
55 |
+
hidden_size=hidden_size,
|
56 |
+
ffn_hidden_size=ffn_hidden_size,
|
57 |
+
moe_num_experts=moe_num_experts,
|
58 |
+
moe_capacity_factor=moe_capacity_factor,
|
59 |
+
moe_top_k=moe_top_k,
|
60 |
+
init_method=init_method,
|
61 |
+
moe_zloss_weight=moe_zloss_weight,
|
62 |
+
)
|
63 |
+
|
64 |
+
mlp = FFN(args)
|
65 |
+
moe_mlp = MoE(args)
|
66 |
+
|
67 |
+
mlp.cuda(torch.cuda.current_device()).half()
|
68 |
+
moe_mlp.cuda(torch.cuda.current_device()).half()
|
69 |
+
|
70 |
+
# Set the baseline parameters to match exactly.
|
71 |
+
if moe_num_experts == 1:
|
72 |
+
with torch.no_grad():
|
73 |
+
mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
|
74 |
+
mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
|
75 |
+
return args, mlp, moe_mlp
|
76 |
+
|
77 |
+
|
78 |
+
@pytest.mark.gpu
|
79 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
|
80 |
+
def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int):
|
81 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
82 |
+
|
83 |
+
_, _, layer = construct_moe(
|
84 |
+
hidden_size=hs,
|
85 |
+
ffn_hidden_size=hs * 2,
|
86 |
+
moe_num_experts=num_experts,
|
87 |
+
moe_top_k=top_k,
|
88 |
+
)
|
89 |
+
|
90 |
+
out, _ = layer(x)
|
91 |
+
assert out.shape == x.shape
|
92 |
+
clear_load_balancing_loss()
|
93 |
+
|
94 |
+
|
95 |
+
@pytest.mark.gpu
|
96 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
|
97 |
+
def test_moe_forward_backward(
|
98 |
+
bs: int,
|
99 |
+
sl: int,
|
100 |
+
hs: int,
|
101 |
+
num_experts: int,
|
102 |
+
top_k: int,
|
103 |
+
):
|
104 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
105 |
+
x.requires_grad_(True)
|
106 |
+
|
107 |
+
args, _, layer = construct_moe(
|
108 |
+
hidden_size=hs,
|
109 |
+
ffn_hidden_size=hs * 2,
|
110 |
+
moe_num_experts=num_experts,
|
111 |
+
moe_top_k=top_k,
|
112 |
+
)
|
113 |
+
|
114 |
+
out, _ = layer(x)
|
115 |
+
assert out.shape == x.shape
|
116 |
+
|
117 |
+
loss = out.sum() + batched_load_balancing_loss(args)
|
118 |
+
loss.backward()
|
119 |
+
layer.zero_grad(set_to_none=True)
|
120 |
+
x.grad = None
|
121 |
+
clear_load_balancing_loss()
|
122 |
+
|
123 |
+
|
124 |
+
@pytest.mark.gpu
|
125 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
|
126 |
+
def test_moe_forward_backward_with_zloss(
|
127 |
+
bs: int,
|
128 |
+
sl: int,
|
129 |
+
hs: int,
|
130 |
+
num_experts: int,
|
131 |
+
top_k: int,
|
132 |
+
):
|
133 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
134 |
+
x.requires_grad_(True)
|
135 |
+
|
136 |
+
args, _, layer = construct_moe(
|
137 |
+
hidden_size=hs,
|
138 |
+
ffn_hidden_size=hs * 2,
|
139 |
+
moe_num_experts=num_experts,
|
140 |
+
moe_top_k=top_k,
|
141 |
+
moe_zloss_weight=1e-3,
|
142 |
+
)
|
143 |
+
|
144 |
+
out, _ = layer(x)
|
145 |
+
assert out.shape == x.shape
|
146 |
+
|
147 |
+
loss = out.sum() + batched_load_balancing_loss(args)
|
148 |
+
loss.backward()
|
149 |
+
layer.zero_grad(set_to_none=True)
|
150 |
+
x.grad = None
|
151 |
+
clear_load_balancing_loss()
|
152 |
+
clear_router_zloss()
|
153 |
+
|
154 |
+
|
155 |
+
@pytest.mark.gpu
|
156 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
|
157 |
+
def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
|
158 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
159 |
+
|
160 |
+
_, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
|
161 |
+
|
162 |
+
expected_out = mlp(x)
|
163 |
+
out, _ = moe_mlp(x)
|
164 |
+
assert out.shape == x.shape == expected_out.shape
|
165 |
+
assert torch.allclose(out, expected_out)
|
166 |
+
clear_load_balancing_loss()
|
167 |
+
|
168 |
+
|
169 |
+
@pytest.mark.gpu
|
170 |
+
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
|
171 |
+
def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
|
172 |
+
x = torch.randn(sl, bs, hs).half().cuda()
|
173 |
+
x.requires_grad_(True)
|
174 |
+
|
175 |
+
_, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
|
176 |
+
|
177 |
+
out, _ = moe_mlp(x)
|
178 |
+
loss = out.sum()
|
179 |
+
loss.backward()
|
180 |
+
w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze()
|
181 |
+
w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
|
182 |
+
moe_mlp.zero_grad(set_to_none=True)
|
183 |
+
x.grad = None
|
184 |
+
clear_load_balancing_loss()
|
185 |
+
|
186 |
+
expected_out = mlp(x)
|
187 |
+
expected_loss = expected_out.sum()
|
188 |
+
expected_loss.backward()
|
189 |
+
expected_w1_grad = mlp.w1.grad.detach()
|
190 |
+
expected_w2_grad = mlp.w2.grad.detach()
|
191 |
+
mlp.zero_grad(set_to_none=True)
|
192 |
+
x.grad = None
|
193 |
+
|
194 |
+
# Verify the gradients match.
|
195 |
+
assert w1_grad.shape == expected_w1_grad.shape
|
196 |
+
assert w2_grad.shape == expected_w2_grad.shape
|
197 |
+
assert torch.allclose(w1_grad, expected_w1_grad)
|
198 |
+
assert torch.allclose(w2_grad, expected_w2_grad)
|
199 |
+
clear_load_balancing_loss()
|
tests/ops/binned_gather_test.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
BINNED_GATHER_TESTS = (
|
11 |
+
(4, 2, 2, 1),
|
12 |
+
(4, 2, 2, 2),
|
13 |
+
(4, 2, 2, 4),
|
14 |
+
(1024, 1536, 4, 1),
|
15 |
+
(1024, 1536, 4, 2),
|
16 |
+
(1024, 1536, 4, 4),
|
17 |
+
(1024, 1536, 64, 1),
|
18 |
+
(1024, 1536, 64, 2),
|
19 |
+
(1024, 1536, 64, 4),
|
20 |
+
(1024, 1536, 128, 1),
|
21 |
+
(1024, 1536, 128, 2),
|
22 |
+
(1024, 1536, 128, 4),
|
23 |
+
(16384, 768, 4, 1),
|
24 |
+
(16384, 768, 4, 2),
|
25 |
+
(16384, 768, 4, 4),
|
26 |
+
(16384, 768, 64, 1),
|
27 |
+
(16384, 768, 64, 2),
|
28 |
+
(16384, 768, 64, 4),
|
29 |
+
(16384, 768, 128, 1),
|
30 |
+
(16384, 768, 128, 2),
|
31 |
+
(16384, 768, 128, 4),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.mark.gpu
|
36 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), BINNED_GATHER_TESTS)
|
37 |
+
def test_binned_gather(sl: int, hs: int, ne: int, top_k: int):
|
38 |
+
# NOTE: Capacity factor == 1.
|
39 |
+
ec = (sl * top_k) // ne
|
40 |
+
|
41 |
+
# Create the data and indices.
|
42 |
+
x = torch.randn((sl, hs)).cuda().half()
|
43 |
+
|
44 |
+
# Randomly assign tokens to experts.
|
45 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
46 |
+
_, indices = ops.sort(top_expert)
|
47 |
+
bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
|
48 |
+
|
49 |
+
def binned_gather(
|
50 |
+
x: torch.Tensor,
|
51 |
+
indices: torch.Tensor,
|
52 |
+
bins: torch.Tensor,
|
53 |
+
ec: int,
|
54 |
+
top_k: int,
|
55 |
+
):
|
56 |
+
x = x.cpu().numpy()
|
57 |
+
indices = indices.cpu().numpy()
|
58 |
+
bins = bins.cpu().numpy()
|
59 |
+
start = 0
|
60 |
+
out = np.zeros((ne, ec, hs))
|
61 |
+
for i in range(ne):
|
62 |
+
end = bins[i]
|
63 |
+
for j in range(min(ec, end - start)):
|
64 |
+
index = indices[start + j] // top_k
|
65 |
+
out[i, j, :] = x[index, :]
|
66 |
+
start = end
|
67 |
+
return torch.from_numpy(out).cuda().half()
|
68 |
+
|
69 |
+
out = ops.binned_gather(x, indices, bins, ec, top_k)
|
70 |
+
expected_out = binned_gather(x, indices, bins, ec, top_k)
|
71 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/binned_scatter_test.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
_BINNED_SCATTER_TESTS = (
|
11 |
+
(4, 2, 2, 1),
|
12 |
+
(4, 2, 2, 2),
|
13 |
+
(4, 2, 2, 4),
|
14 |
+
(1024, 1536, 4, 1),
|
15 |
+
(1024, 1536, 4, 2),
|
16 |
+
(1024, 1536, 4, 4),
|
17 |
+
(1024, 1536, 64, 1),
|
18 |
+
(1024, 1536, 64, 2),
|
19 |
+
(1024, 1536, 64, 4),
|
20 |
+
(1024, 1536, 128, 1),
|
21 |
+
(1024, 1536, 128, 2),
|
22 |
+
(1024, 1536, 128, 4),
|
23 |
+
(16384, 768, 4, 1),
|
24 |
+
(16384, 768, 4, 2),
|
25 |
+
(16384, 768, 4, 4),
|
26 |
+
(16384, 768, 64, 1),
|
27 |
+
(16384, 768, 64, 2),
|
28 |
+
(16384, 768, 64, 4),
|
29 |
+
(16384, 768, 128, 1),
|
30 |
+
(16384, 768, 128, 2),
|
31 |
+
(16384, 768, 128, 4),
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
@pytest.mark.gpu
|
36 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS)
|
37 |
+
def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int):
|
38 |
+
# NOTE: Capacity factor == 1.
|
39 |
+
ec = (sl * top_k) // ne
|
40 |
+
|
41 |
+
# Create the data and indices.
|
42 |
+
x = torch.randn((sl, hs)).cuda().half()
|
43 |
+
|
44 |
+
# Randomly assign tokens to experts.
|
45 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
46 |
+
_, indices = ops.sort(top_expert)
|
47 |
+
bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
|
48 |
+
|
49 |
+
# Sample weights for the scatter reduce.
|
50 |
+
weights = torch.rand((sl * top_k,)).cuda().half()
|
51 |
+
|
52 |
+
x = ops.binned_gather(x, indices, bins, ec, top_k)
|
53 |
+
|
54 |
+
def binned_scatter(
|
55 |
+
x: torch.Tensor,
|
56 |
+
indices: torch.Tensor,
|
57 |
+
weights: torch.Tensor,
|
58 |
+
bins: torch.Tensor,
|
59 |
+
top_k: int,
|
60 |
+
):
|
61 |
+
x = x.cpu().numpy()
|
62 |
+
indices = indices.cpu().numpy()
|
63 |
+
weights = weights.cpu().numpy()
|
64 |
+
bins = bins.cpu().numpy()
|
65 |
+
start = 0
|
66 |
+
out = np.zeros((sl, hs))
|
67 |
+
for i in range(ne):
|
68 |
+
end = bins[i]
|
69 |
+
for j in range(min(ec, end - start)):
|
70 |
+
index = indices[start + j]
|
71 |
+
scale = weights[index]
|
72 |
+
index //= top_k
|
73 |
+
|
74 |
+
out[index, :] += scale * x[i, j, :]
|
75 |
+
start = end
|
76 |
+
return torch.from_numpy(out).cuda().half()
|
77 |
+
|
78 |
+
out = ops.binned_scatter(x, indices, weights, bins, top_k)
|
79 |
+
expected_out = binned_scatter(x, indices, weights, bins, top_k)
|
80 |
+
|
81 |
+
# NOTE: We need to check approximate equality because the
|
82 |
+
# scatter reduce uses atomics.
|
83 |
+
assert np.testing.assert_allclose(
|
84 |
+
out.cpu(),
|
85 |
+
expected_out.cpu(),
|
86 |
+
rtol=5e-3,
|
87 |
+
) is None
|
tests/ops/cumsum_test.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks import ops
|
8 |
+
|
9 |
+
CUMSUM_TESTS = (
|
10 |
+
(1, 32),
|
11 |
+
(2, 32),
|
12 |
+
(2, 1024),
|
13 |
+
(4, 1024),
|
14 |
+
(8, 1024),
|
15 |
+
(16, 1024),
|
16 |
+
(32, 1024),
|
17 |
+
(64, 1024),
|
18 |
+
(128, 1024),
|
19 |
+
(2, 16384),
|
20 |
+
(4, 16384),
|
21 |
+
(8, 16384),
|
22 |
+
(16, 16384),
|
23 |
+
(32, 16384),
|
24 |
+
(64, 16384),
|
25 |
+
(128, 16384),
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
@pytest.mark.gpu
|
30 |
+
@pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
|
31 |
+
def test_exclusive_cumsum(n: int, m: int):
|
32 |
+
x = torch.randint(0, 2, (n, m)).long().cuda()
|
33 |
+
out = ops.exclusive_cumsum(x, 1) * x
|
34 |
+
expected_out = (torch.cumsum(x, dim=1) - 1) * x
|
35 |
+
assert torch.all(torch.eq(out, expected_out))
|
36 |
+
|
37 |
+
|
38 |
+
@pytest.mark.gpu
|
39 |
+
@pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
|
40 |
+
def test_inclusive_cumsum(n: int, m: int):
|
41 |
+
x = torch.randint(0, 2, (n, m)).long().cuda()
|
42 |
+
out = ops.inclusive_cumsum(x, 1)
|
43 |
+
expected_out = torch.cumsum(x, dim=1)
|
44 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/histogram_test.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import pytest
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from megablocks import ops
|
8 |
+
|
9 |
+
_HISTOGRAM_TESTS = (
|
10 |
+
(1, 32, torch.int16, 128),
|
11 |
+
(1, 1024, torch.int16, 128),
|
12 |
+
(1, 16384, torch.int16, 128),
|
13 |
+
(1, 32, torch.int32, 128),
|
14 |
+
(1, 1024, torch.int32, 128),
|
15 |
+
(1, 16384, torch.int32, 128),
|
16 |
+
(1, 32, torch.int64, 128),
|
17 |
+
(1, 1024, torch.int64, 128),
|
18 |
+
(1, 16384, torch.int64, 128),
|
19 |
+
(1, 32, torch.int16, 1024),
|
20 |
+
(1, 1024, torch.int16, 1024),
|
21 |
+
(1, 16384, torch.int16, 1024),
|
22 |
+
(1, 32, torch.int32, 1024),
|
23 |
+
(1, 1024, torch.int32, 1024),
|
24 |
+
(1, 16384, torch.int32, 1024),
|
25 |
+
(1, 32, torch.int64, 1024),
|
26 |
+
(1, 1024, torch.int64, 1024),
|
27 |
+
(1, 16384, torch.int64, 1024),
|
28 |
+
(2, 32, torch.int16, 128),
|
29 |
+
(2, 1024, torch.int16, 128),
|
30 |
+
(2, 16384, torch.int16, 128),
|
31 |
+
(2, 32, torch.int32, 128),
|
32 |
+
(2, 1024, torch.int32, 128),
|
33 |
+
(2, 16384, torch.int32, 128),
|
34 |
+
(2, 32, torch.int64, 128),
|
35 |
+
(2, 1024, torch.int64, 128),
|
36 |
+
(2, 16384, torch.int64, 128),
|
37 |
+
(2, 32, torch.int16, 1024),
|
38 |
+
(2, 1024, torch.int16, 1024),
|
39 |
+
(2, 16384, torch.int16, 1024),
|
40 |
+
(2, 32, torch.int32, 1024),
|
41 |
+
(2, 1024, torch.int32, 1024),
|
42 |
+
(2, 16384, torch.int32, 1024),
|
43 |
+
(2, 32, torch.int64, 1024),
|
44 |
+
(2, 1024, torch.int64, 1024),
|
45 |
+
(2, 16384, torch.int64, 1024),
|
46 |
+
(8, 32, torch.int16, 128),
|
47 |
+
(8, 1024, torch.int16, 128),
|
48 |
+
(8, 16384, torch.int16, 128),
|
49 |
+
(8, 32, torch.int32, 128),
|
50 |
+
(8, 1024, torch.int32, 128),
|
51 |
+
(8, 16384, torch.int32, 128),
|
52 |
+
(8, 32, torch.int64, 128),
|
53 |
+
(8, 1024, torch.int64, 128),
|
54 |
+
(8, 16384, torch.int64, 128),
|
55 |
+
(8, 32, torch.int16, 1024),
|
56 |
+
(8, 1024, torch.int16, 1024),
|
57 |
+
(8, 16384, torch.int16, 1024),
|
58 |
+
(8, 32, torch.int32, 1024),
|
59 |
+
(8, 1024, torch.int32, 1024),
|
60 |
+
(8, 16384, torch.int32, 1024),
|
61 |
+
(8, 32, torch.int64, 1024),
|
62 |
+
(8, 1024, torch.int64, 1024),
|
63 |
+
(8, 16384, torch.int64, 1024),
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
# Override the seed_all fixture in autouse.py because
|
68 |
+
# _histc_cuda does not have a deterministic implementation
|
69 |
+
@pytest.fixture()
|
70 |
+
def seed_all():
|
71 |
+
torch.use_deterministic_algorithms(False)
|
72 |
+
return
|
73 |
+
|
74 |
+
|
75 |
+
@pytest.mark.gpu
|
76 |
+
@pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
|
77 |
+
def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
|
78 |
+
x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)
|
79 |
+
|
80 |
+
out = ops.histogram(x, max_val)
|
81 |
+
expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
|
82 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/padded_gather_test.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
PADDED_GATHER_TESTS = (
|
11 |
+
(4, 2, 2, 1),
|
12 |
+
(4, 2, 2, 2),
|
13 |
+
(1024, 1, 4, 1),
|
14 |
+
(1024, 1, 4, 2),
|
15 |
+
(1024, 1, 4, 4),
|
16 |
+
(1024, 1, 64, 1),
|
17 |
+
(1024, 1, 64, 2),
|
18 |
+
(1024, 1, 64, 4),
|
19 |
+
(1024, 1, 128, 1),
|
20 |
+
(1024, 1, 128, 2),
|
21 |
+
(1024, 1, 128, 4),
|
22 |
+
(1024, 1536, 4, 1),
|
23 |
+
(1024, 1536, 4, 2),
|
24 |
+
(1024, 1536, 4, 4),
|
25 |
+
(1024, 1536, 64, 1),
|
26 |
+
(1024, 1536, 64, 2),
|
27 |
+
(1024, 1536, 64, 4),
|
28 |
+
(1024, 1536, 128, 1),
|
29 |
+
(1024, 1536, 128, 2),
|
30 |
+
(1024, 1536, 128, 4),
|
31 |
+
(16384, 768, 4, 1),
|
32 |
+
(16384, 768, 4, 2),
|
33 |
+
(16384, 768, 4, 4),
|
34 |
+
(16384, 768, 64, 1),
|
35 |
+
(16384, 768, 64, 2),
|
36 |
+
(16384, 768, 64, 4),
|
37 |
+
(16384, 768, 128, 1),
|
38 |
+
(16384, 768, 128, 2),
|
39 |
+
(16384, 768, 128, 4),
|
40 |
+
(16384, 1, 4, 1),
|
41 |
+
(16384, 1, 4, 2),
|
42 |
+
(16384, 1, 4, 4),
|
43 |
+
(16384, 1, 64, 1),
|
44 |
+
(16384, 1, 64, 2),
|
45 |
+
(16384, 1, 64, 4),
|
46 |
+
(16384, 1, 128, 1),
|
47 |
+
(16384, 1, 128, 2),
|
48 |
+
(16384, 1, 128, 4),
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
@pytest.mark.gpu
|
53 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), PADDED_GATHER_TESTS)
|
54 |
+
def testPaddedGather(sl: int, hs: int, ne: int, top_k: int):
|
55 |
+
# Create the data and indices.
|
56 |
+
x = torch.randn((sl, hs)).cuda().half()
|
57 |
+
|
58 |
+
# Randomly assign tokens to experts.
|
59 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
60 |
+
bin_ids, indices = ops.sort(top_expert)
|
61 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
62 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
63 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
64 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
65 |
+
|
66 |
+
def padded_gather(
|
67 |
+
x: torch.Tensor,
|
68 |
+
indices: torch.Tensor,
|
69 |
+
bin_ids: torch.Tensor,
|
70 |
+
bins: torch.Tensor,
|
71 |
+
padded_bins: torch.Tensor,
|
72 |
+
top_k: int,
|
73 |
+
):
|
74 |
+
x = x.cpu().numpy()
|
75 |
+
indices = indices.cpu().numpy()
|
76 |
+
bin_ids = bin_ids.cpu().numpy()
|
77 |
+
bins = bins.cpu().numpy()
|
78 |
+
padded_bins = padded_bins.cpu().numpy()
|
79 |
+
|
80 |
+
out = np.zeros((padded_bins[-1], hs))
|
81 |
+
in_idx = 0
|
82 |
+
for i, end in enumerate(bins):
|
83 |
+
out_idx = 0 if i == 0 else padded_bins[i - 1]
|
84 |
+
end = bins[i]
|
85 |
+
while in_idx < end:
|
86 |
+
load_idx = indices[in_idx] // top_k
|
87 |
+
out[out_idx, :] = x[load_idx, :]
|
88 |
+
in_idx += 1
|
89 |
+
out_idx += 1
|
90 |
+
return torch.from_numpy(out).cuda().half()
|
91 |
+
|
92 |
+
out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
93 |
+
expected_out = padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
94 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops/padded_scatter_test.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
PADDED_SCATTER_TESTS = [
|
11 |
+
(4, 2, 2, 2),
|
12 |
+
(4, 2, 2, 1),
|
13 |
+
(4, 2, 2, 1),
|
14 |
+
(4, 2, 2, 1),
|
15 |
+
(4, 2, 2, 2),
|
16 |
+
(4, 2, 2, 2),
|
17 |
+
(1024, 1, 4, 1),
|
18 |
+
(1024, 1, 4, 2),
|
19 |
+
(1024, 1, 4, 4),
|
20 |
+
(1024, 1, 4, 1),
|
21 |
+
(1024, 1, 4, 2),
|
22 |
+
(1024, 1, 4, 4),
|
23 |
+
(1024, 1, 4, 1),
|
24 |
+
(1024, 1, 4, 2),
|
25 |
+
(1024, 1, 4, 4),
|
26 |
+
(1024, 1, 64, 1),
|
27 |
+
(1024, 1, 64, 2),
|
28 |
+
(1024, 1, 64, 4),
|
29 |
+
(1024, 1, 128, 1),
|
30 |
+
(1024, 1, 128, 2),
|
31 |
+
(1024, 1, 128, 4),
|
32 |
+
(1024, 1536, 4, 1),
|
33 |
+
(1024, 1536, 4, 2),
|
34 |
+
(1024, 1536, 4, 4),
|
35 |
+
(1024, 1536, 4, 4),
|
36 |
+
(1024, 1536, 4, 4),
|
37 |
+
(1024, 1536, 64, 1),
|
38 |
+
(1024, 1536, 64, 2),
|
39 |
+
(1024, 1536, 64, 4),
|
40 |
+
(1024, 1536, 128, 1),
|
41 |
+
(1024, 1536, 128, 2),
|
42 |
+
(1024, 1536, 128, 4),
|
43 |
+
(1024, 1536, 128, 1),
|
44 |
+
(1024, 1536, 128, 1),
|
45 |
+
(16384, 768, 4, 1),
|
46 |
+
(16384, 768, 4, 2),
|
47 |
+
(16384, 768, 4, 4),
|
48 |
+
(16384, 768, 64, 1),
|
49 |
+
(16384, 768, 64, 2),
|
50 |
+
(16384, 768, 64, 4),
|
51 |
+
(16384, 768, 128, 1),
|
52 |
+
(16384, 768, 128, 2),
|
53 |
+
(16384, 768, 128, 4),
|
54 |
+
(16384, 1, 4, 1),
|
55 |
+
(16384, 1, 4, 2),
|
56 |
+
(16384, 1, 4, 4),
|
57 |
+
(16384, 1, 64, 1),
|
58 |
+
(16384, 1, 64, 2),
|
59 |
+
(16384, 1, 64, 4),
|
60 |
+
(16384, 1, 128, 1),
|
61 |
+
(16384, 1, 128, 2),
|
62 |
+
(16384, 1, 128, 4),
|
63 |
+
(16384, 1, 128, 2),
|
64 |
+
(16384, 1, 128, 2),
|
65 |
+
]
|
66 |
+
|
67 |
+
|
68 |
+
def _to_numpy(x: torch.Tensor) -> np.ndarray:
|
69 |
+
return x.detach().cpu().numpy()
|
70 |
+
|
71 |
+
|
72 |
+
@pytest.mark.gpu
|
73 |
+
@pytest.mark.parametrize((
|
74 |
+
'sl',
|
75 |
+
'hs',
|
76 |
+
'ne',
|
77 |
+
'top_k',
|
78 |
+
), PADDED_SCATTER_TESTS)
|
79 |
+
def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
|
80 |
+
# Create the data and indices.
|
81 |
+
x = torch.randn((sl, hs), requires_grad=True).cuda().half()
|
82 |
+
|
83 |
+
# Randomly assign tokens to experts.
|
84 |
+
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
|
85 |
+
bin_ids, indices = ops.sort(top_expert)
|
86 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
87 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
|
88 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
89 |
+
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
|
90 |
+
|
91 |
+
# Sample weights for the scatter reduce.
|
92 |
+
weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
|
93 |
+
|
94 |
+
# Gather the data to prepare for backwards.
|
95 |
+
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
|
96 |
+
|
97 |
+
def padded_scatter(
|
98 |
+
x: torch.Tensor,
|
99 |
+
indices: torch.Tensor,
|
100 |
+
bin_ids: torch.Tensor,
|
101 |
+
weights: torch.Tensor,
|
102 |
+
bins: torch.Tensor,
|
103 |
+
padded_bins: torch.Tensor,
|
104 |
+
top_k: int,
|
105 |
+
):
|
106 |
+
x = x.detach().cpu().numpy()
|
107 |
+
indices: np.ndarray = _to_numpy(indices)
|
108 |
+
bin_ids: np.ndarray = _to_numpy(bin_ids)
|
109 |
+
weights: np.ndarray = _to_numpy(weights)
|
110 |
+
bins: np.ndarray = _to_numpy(bins)
|
111 |
+
padded_bins: np.ndarray = _to_numpy(padded_bins)
|
112 |
+
|
113 |
+
out = np.zeros((indices.shape[0] // top_k, hs))
|
114 |
+
out_idx = 0
|
115 |
+
for i in range(len(bins)):
|
116 |
+
in_idx = 0 if i == 0 else padded_bins[i - 1]
|
117 |
+
end = bins[i]
|
118 |
+
while out_idx < end:
|
119 |
+
store_idx = indices[out_idx]
|
120 |
+
scale = weights[store_idx]
|
121 |
+
store_idx //= top_k
|
122 |
+
|
123 |
+
out[store_idx, :] += scale * x[in_idx, :]
|
124 |
+
out_idx += 1
|
125 |
+
in_idx += 1
|
126 |
+
return torch.from_numpy(out).cuda().half()
|
127 |
+
|
128 |
+
out = ops.padded_scatter(
|
129 |
+
x,
|
130 |
+
indices,
|
131 |
+
bin_ids,
|
132 |
+
weights,
|
133 |
+
bins,
|
134 |
+
padded_bins,
|
135 |
+
top_k,
|
136 |
+
)
|
137 |
+
expected_out = padded_scatter(
|
138 |
+
x,
|
139 |
+
indices,
|
140 |
+
bin_ids,
|
141 |
+
weights,
|
142 |
+
bins,
|
143 |
+
padded_bins,
|
144 |
+
top_k,
|
145 |
+
)
|
146 |
+
|
147 |
+
out.backward(torch.randn_like(out)) # sanity check backward pass
|
148 |
+
|
149 |
+
# NOTE: We need to check approximate equality because the scatter reduce uses atomics.
|
150 |
+
# np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
|
151 |
+
assert np.testing.assert_allclose(
|
152 |
+
_to_numpy(out),
|
153 |
+
_to_numpy(expected_out),
|
154 |
+
rtol=5e-3,
|
155 |
+
) is None
|
tests/ops/replicate_test.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
try:
|
9 |
+
from megablocks._ops import ops as backend # type: ignore
|
10 |
+
except ModuleNotFoundError as e:
|
11 |
+
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
|
12 |
+
|
13 |
+
from megablocks import ops
|
14 |
+
|
15 |
+
|
16 |
+
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
|
17 |
+
return x.view(1) if not len(x.size()) else x
|
18 |
+
|
19 |
+
|
20 |
+
REPLICATE_TESTS = [
|
21 |
+
(8, 1, 1),
|
22 |
+
(8, 2, 1),
|
23 |
+
(8, 4, 1),
|
24 |
+
(8, 8, 1),
|
25 |
+
(8, 2, 2),
|
26 |
+
(8, 4, 2),
|
27 |
+
(8, 8, 2),
|
28 |
+
(8, 2, 4),
|
29 |
+
(8, 4, 4),
|
30 |
+
(8, 8, 4),
|
31 |
+
(8, 2, 8),
|
32 |
+
(8, 4, 8),
|
33 |
+
(8, 8, 8),
|
34 |
+
(16384, 2, 1),
|
35 |
+
(16384, 4, 1),
|
36 |
+
(16384, 8, 1),
|
37 |
+
(16384, 16, 1),
|
38 |
+
(16384, 32, 1),
|
39 |
+
(16384, 64, 1),
|
40 |
+
(16384, 128, 1),
|
41 |
+
(16384, 2, 2),
|
42 |
+
(16384, 4, 2),
|
43 |
+
(16384, 8, 2),
|
44 |
+
(16384, 16, 2),
|
45 |
+
(16384, 32, 2),
|
46 |
+
(16384, 64, 2),
|
47 |
+
(16384, 128, 2),
|
48 |
+
(16384, 2, 4),
|
49 |
+
(16384, 4, 4),
|
50 |
+
(16384, 8, 4),
|
51 |
+
(16384, 16, 4),
|
52 |
+
(16384, 32, 4),
|
53 |
+
(16384, 64, 4),
|
54 |
+
(16384, 128, 4),
|
55 |
+
(16384, 2, 8),
|
56 |
+
(16384, 4, 8),
|
57 |
+
(16384, 8, 8),
|
58 |
+
(16384, 16, 8),
|
59 |
+
(16384, 32, 8),
|
60 |
+
(16384, 64, 8),
|
61 |
+
(16384, 128, 8),
|
62 |
+
]
|
63 |
+
|
64 |
+
|
65 |
+
@pytest.mark.gpu
|
66 |
+
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
|
67 |
+
def test_replicate(tokens: int, num_centers: int, top_k: int):
|
68 |
+
tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
|
69 |
+
tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
|
70 |
+
bins = ops.inclusive_cumsum(tokens_per_center, 0)
|
71 |
+
bins = promote_scalar(bins)
|
72 |
+
center_weights = torch.randn(top_k, num_centers).cuda().half()
|
73 |
+
|
74 |
+
def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
|
75 |
+
x = x.cpu().numpy()
|
76 |
+
bins = bins.cpu().numpy()
|
77 |
+
out = np.zeros((x.shape[0], num_outputs))
|
78 |
+
for batch_idx in range(x.shape[0]):
|
79 |
+
start = 0
|
80 |
+
for i, end in enumerate(bins):
|
81 |
+
value = x[batch_idx, i]
|
82 |
+
while start < end:
|
83 |
+
out[batch_idx, start] = value
|
84 |
+
start += 1
|
85 |
+
return torch.from_numpy(out).cuda().half()
|
86 |
+
|
87 |
+
out = ops.replicate(center_weights, bins, tokens)
|
88 |
+
expected_out = replicate(center_weights, bins, tokens)
|
89 |
+
assert torch.all(torch.eq(out, expected_out))
|
90 |
+
|
91 |
+
|
92 |
+
@pytest.mark.gpu
|
93 |
+
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
|
94 |
+
def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
|
95 |
+
tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
|
96 |
+
tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
|
97 |
+
bins = ops.inclusive_cumsum(tokens_per_center, 0)
|
98 |
+
bins = promote_scalar(bins)
|
99 |
+
center_weights = torch.randn(top_k, num_centers).cuda().half()
|
100 |
+
|
101 |
+
grad = ops.replicate(center_weights, bins, tokens)
|
102 |
+
|
103 |
+
out = torch.empty_like(center_weights)
|
104 |
+
backend.replicate_backward(grad, bins, out)
|
105 |
+
expected_out = center_weights * tokens_per_center.view([1, num_centers])
|
106 |
+
|
107 |
+
# NOTE: This floating-point reduction could be a problem for training stability and accuracy.
|
108 |
+
assert torch.allclose(out, expected_out, rtol=1e-2)
|
tests/ops/sort_test.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Dict, Optional, Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pytest
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from megablocks import ops
|
11 |
+
|
12 |
+
SORT_TESTS = [
|
13 |
+
(32, torch.int16, None),
|
14 |
+
(1024, torch.int16, None),
|
15 |
+
(16384, torch.int16, None),
|
16 |
+
(32, torch.int32, None),
|
17 |
+
(1024, torch.int32, None),
|
18 |
+
(16384, torch.int32, None),
|
19 |
+
(32, torch.int64, None),
|
20 |
+
(1024, torch.int64, None),
|
21 |
+
(16384, torch.int64, None),
|
22 |
+
(32, torch.int16, 128),
|
23 |
+
(1024, torch.int16, 128),
|
24 |
+
(16384, torch.int16, 128),
|
25 |
+
(32, torch.int32, 128),
|
26 |
+
(1024, torch.int32, 128),
|
27 |
+
(16384, torch.int32, 128),
|
28 |
+
(32, torch.int64, 128),
|
29 |
+
(1024, torch.int64, 128),
|
30 |
+
(16384, torch.int64, 128),
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]:
|
35 |
+
types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = {
|
36 |
+
torch.int16: np.int16,
|
37 |
+
torch.int32: np.int32,
|
38 |
+
torch.int64: np.int64,
|
39 |
+
}
|
40 |
+
return types[dtype]
|
41 |
+
|
42 |
+
|
43 |
+
@pytest.mark.gpu
|
44 |
+
@pytest.mark.parametrize(
|
45 |
+
('n', 'dtype', 'max_val'),
|
46 |
+
SORT_TESTS,
|
47 |
+
)
|
48 |
+
def test_sort(n: int, dtype: torch.dtype, max_val: Optional[int]):
|
49 |
+
if max_val is None:
|
50 |
+
max_val = np.iinfo(torch_to_numpy_dtype(dtype)).max
|
51 |
+
end_bit = int(np.ceil(np.log2(max_val)))
|
52 |
+
x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
|
53 |
+
|
54 |
+
out, indices = ops.sort(x, end_bit)
|
55 |
+
expected_out, expected_indices = torch.sort(x)
|
56 |
+
assert torch.all(torch.eq(out, expected_out))
|
57 |
+
|
58 |
+
# NOTE: The indices can be in different order depending
|
59 |
+
# on sort stability if multiple values in the array are
|
60 |
+
# equal.
|
61 |
+
data = torch.empty_like(x)
|
62 |
+
data.scatter_(0, indices.long(), out)
|
63 |
+
expected_data = torch.empty_like(x)
|
64 |
+
expected_data.scatter_(0, expected_indices, expected_out)
|
65 |
+
assert torch.all(torch.eq(data, expected_data))
|
tests/ops/topology_test.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pytest
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from megablocks import ops
|
9 |
+
|
10 |
+
TOPOLOGY_TESTS = (
|
11 |
+
(1024, 1536, 2),
|
12 |
+
(1024, 1536, 4),
|
13 |
+
(1024, 1536, 8),
|
14 |
+
(1024, 1536, 16),
|
15 |
+
(1024, 1536, 32),
|
16 |
+
(1024, 1536, 64),
|
17 |
+
(1024, 1536, 128),
|
18 |
+
(1024, 1536, 256),
|
19 |
+
(1024, 1536, 512),
|
20 |
+
(16384, 768, 2),
|
21 |
+
(16384, 768, 4),
|
22 |
+
(16384, 768, 8),
|
23 |
+
(16384, 768, 16),
|
24 |
+
(16384, 768, 32),
|
25 |
+
(16384, 768, 64),
|
26 |
+
(16384, 768, 128),
|
27 |
+
(16384, 768, 256),
|
28 |
+
(16384, 768, 512),
|
29 |
+
(16384, 768, 1024),
|
30 |
+
(8, 14336, 8),
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
@pytest.mark.gpu
|
35 |
+
@pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS)
|
36 |
+
def test_topology(sl: int, hs: int, ne: int):
|
37 |
+
# Create the data and indices.
|
38 |
+
blocking = 128
|
39 |
+
assert hs % blocking == 0
|
40 |
+
|
41 |
+
# Randomly assign tokens to experts.
|
42 |
+
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
|
43 |
+
tokens_per_expert = ops.histogram(top_expert, ne)
|
44 |
+
padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking)
|
45 |
+
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
|
46 |
+
|
47 |
+
# Dimensions for the output indices.
|
48 |
+
output_block_rows = int(padded_bins[-1]) // blocking
|
49 |
+
output_block_columns = hs // blocking
|
50 |
+
|
51 |
+
def topology(
|
52 |
+
padded_bins: torch.Tensor,
|
53 |
+
blocking: torch.Tensor,
|
54 |
+
rows: int,
|
55 |
+
columns: int,
|
56 |
+
):
|
57 |
+
padded_bins = padded_bins.cpu().numpy()
|
58 |
+
|
59 |
+
out = np.zeros([rows * columns])
|
60 |
+
start = 0
|
61 |
+
for i in range(padded_bins.shape[0]):
|
62 |
+
end = padded_bins[i] // blocking
|
63 |
+
while start < end:
|
64 |
+
for j in range(columns):
|
65 |
+
out[start * columns + j] = j + i * columns
|
66 |
+
start += 1
|
67 |
+
return torch.from_numpy(out).cuda().short()
|
68 |
+
|
69 |
+
out = ops.topology(
|
70 |
+
padded_bins,
|
71 |
+
blocking,
|
72 |
+
output_block_rows,
|
73 |
+
output_block_columns,
|
74 |
+
)
|
75 |
+
expected_out = topology(
|
76 |
+
padded_bins,
|
77 |
+
blocking,
|
78 |
+
output_block_rows,
|
79 |
+
output_block_columns,
|
80 |
+
)
|
81 |
+
assert torch.all(torch.eq(out, expected_out))
|
tests/ops_test.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import megablocks
|
3 |
+
|
4 |
+
import unittest
|
5 |
+
from absl.testing import parameterized
|
6 |
+
|
7 |
+
# import itertools
|
8 |
+
# import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def allclose(x, y, pct=2.0):
|
12 |
+
mask = torch.isclose(x, y, rtol=1e-5)
|
13 |
+
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
|
14 |
+
if pct_diff > pct:
|
15 |
+
print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
|
16 |
+
print("{:.2f}% of values not close.".format(pct_diff))
|
17 |
+
return False
|
18 |
+
return True
|
19 |
+
|
20 |
+
|
21 |
+
def add_flags(x):
|
22 |
+
out = []
|
23 |
+
for y in x:
|
24 |
+
for trans_b in (False, True):
|
25 |
+
out.append(y + (trans_b, False))
|
26 |
+
|
27 |
+
# TODO: Revisit enabling batch_sizes_on_device
|
28 |
+
# for batch_sizes_on_device in (False, True):
|
29 |
+
# out.append(y + (trans_b, batch_sizes_on_device))
|
30 |
+
return out
|
31 |
+
|
32 |
+
|
33 |
+
_TEST_PROBLEMS = add_flags((
|
34 |
+
(1, 128, 128, 128),
|
35 |
+
(8, 128, 128, 128),
|
36 |
+
(16, 128, 128, 128),
|
37 |
+
(1, 128, 256, 512),
|
38 |
+
(8, 128, 256, 512),
|
39 |
+
(16, 128, 256, 512),
|
40 |
+
))
|
41 |
+
|
42 |
+
|
43 |
+
def randn(bs, x, y):
|
44 |
+
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
|
45 |
+
return out.cuda().to(torch.bfloat16)
|
46 |
+
|
47 |
+
|
48 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
49 |
+
batch_sizes = batch_sizes.cpu().numpy()
|
50 |
+
|
51 |
+
out = []
|
52 |
+
start = 0
|
53 |
+
for i, size in enumerate(batch_sizes):
|
54 |
+
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
|
55 |
+
out.append(a[start:start + size, :] @ rhs)
|
56 |
+
start += size
|
57 |
+
return torch.cat(out)
|
58 |
+
|
59 |
+
|
60 |
+
@parameterized.parameters(*_TEST_PROBLEMS)
|
61 |
+
class OpsTest(parameterized.TestCase):
|
62 |
+
|
63 |
+
def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
|
64 |
+
torch.manual_seed(0)
|
65 |
+
a = randn(z, m, k).view(-1, k)
|
66 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
67 |
+
batch_sizes = torch.tensor([m] * z)
|
68 |
+
if batch_sizes_on_device:
|
69 |
+
batch_sizes = batch_sizes.cuda()
|
70 |
+
|
71 |
+
a.requires_grad_(True)
|
72 |
+
b.requires_grad_(True)
|
73 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
74 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
75 |
+
|
76 |
+
# out = ops.gmm(a, b, batch_sizes, trans_b)
|
77 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
78 |
+
# print("out", out)
|
79 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
80 |
+
self.assertTrue(allclose(out, expected_out))
|
81 |
+
|
82 |
+
# Check gradients.
|
83 |
+
out.sum().backward()
|
84 |
+
expected_out.sum().backward()
|
85 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
86 |
+
self.assertTrue(allclose(b.grad, b_ref.grad))
|
87 |
+
|
88 |
+
def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
|
89 |
+
torch.manual_seed(0)
|
90 |
+
a = randn(z, m, k).view(-1, k)
|
91 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
92 |
+
|
93 |
+
dist = torch.rand(z, )
|
94 |
+
dist /= dist.sum()
|
95 |
+
batch_sizes = (dist * m).to(torch.long)
|
96 |
+
error = m * z - batch_sizes.sum()
|
97 |
+
batch_sizes[-1] += error
|
98 |
+
assert batch_sizes.sum() == (m * z)
|
99 |
+
if batch_sizes_on_device:
|
100 |
+
batch_sizes = batch_sizes.cuda()
|
101 |
+
|
102 |
+
a.requires_grad_(True)
|
103 |
+
b.requires_grad_(True)
|
104 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
105 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
106 |
+
|
107 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
108 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
109 |
+
self.assertTrue(allclose(out, expected_out))
|
110 |
+
|
111 |
+
# Check gradients.
|
112 |
+
out.sum().backward()
|
113 |
+
expected_out.sum().backward()
|
114 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
115 |
+
|
116 |
+
# TODO: Review to ensure that the gradients are correct.
|
117 |
+
# self.assertTrue(allclose(b.grad, b_ref.grad))
|
118 |
+
|
119 |
+
|
120 |
+
# @parameterized.parameters(False, True)
|
121 |
+
@parameterized.parameters(False, False)
|
122 |
+
class EdgeCasesTest(unittest.TestCase):
|
123 |
+
|
124 |
+
def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
|
125 |
+
torch.manual_seed(0)
|
126 |
+
m = 16384
|
127 |
+
k = 4096
|
128 |
+
n = 14336
|
129 |
+
num_experts = 8
|
130 |
+
|
131 |
+
a = randn(num_experts, m // num_experts, k).view(-1, k)
|
132 |
+
b = randn(num_experts, k, n)
|
133 |
+
batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
|
134 |
+
if batch_sizes_on_device:
|
135 |
+
batch_sizes = batch_sizes.cuda()
|
136 |
+
|
137 |
+
a.requires_grad_(True)
|
138 |
+
b.requires_grad_(True)
|
139 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
140 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
141 |
+
|
142 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes)
|
143 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes)
|
144 |
+
self.assertTrue(allclose(out, expected_out))
|
145 |
+
|
146 |
+
# Check gradients.
|
147 |
+
out.sum().backward()
|
148 |
+
expected_out.sum().backward()
|
149 |
+
self.assertTrue(allclose(a.grad, a_ref.grad))
|
150 |
+
self.assertTrue(allclose(b.grad, b_ref.grad))
|
151 |
+
|
152 |
+
def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
|
153 |
+
sz = 128
|
154 |
+
total_tokens = 192
|
155 |
+
|
156 |
+
a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
|
157 |
+
b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
|
158 |
+
c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
|
159 |
+
batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
|
160 |
+
if batch_sizes_on_device:
|
161 |
+
batch_sizes = batch_sizes.cuda()
|
162 |
+
|
163 |
+
megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
|
164 |
+
self.assertTrue((c[0] == 0).all())
|
165 |
+
self.assertTrue((c[1] == 128).all())
|
166 |
+
self.assertTrue((c[2] == 0).all())
|
167 |
+
self.assertTrue((c[3] == 64).all())
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == '__main__':
|
171 |
+
unittest.main()
|
tests/parallel_layer_test.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
import torch.multiprocessing as mp
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
def test_megablocks_moe_mlp_import():
|
8 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
9 |
+
|
10 |
+
assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."
|
11 |
+
|
12 |
+
|
13 |
+
def run_distributed_test(rank, world_size):
|
14 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
15 |
+
|
16 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
17 |
+
os.environ["MASTER_PORT"] = "12355"
|
18 |
+
os.environ["RANK"] = str(rank)
|
19 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
20 |
+
|
21 |
+
dist.init_process_group(
|
22 |
+
backend="gloo",
|
23 |
+
rank=rank,
|
24 |
+
world_size=world_size,
|
25 |
+
)
|
26 |
+
|
27 |
+
expert_parallel_group = torch.distributed.new_group(
|
28 |
+
range(torch.distributed.get_world_size())
|
29 |
+
)
|
30 |
+
|
31 |
+
model = MegaBlocksMoeMLP()
|
32 |
+
model.expert_parallel_group = expert_parallel_group
|
33 |
+
|
34 |
+
class Experts:
|
35 |
+
def __init__(self):
|
36 |
+
self.gate_up_proj = None
|
37 |
+
self.gate_up_proj_bias = None
|
38 |
+
self.down_proj = None
|
39 |
+
self.down_proj_bias = None
|
40 |
+
self.hidden_size = None
|
41 |
+
|
42 |
+
model.experts = Experts()
|
43 |
+
|
44 |
+
num_experts = 128
|
45 |
+
hidden_size = 1152
|
46 |
+
intermediate_size = 3072
|
47 |
+
|
48 |
+
ne, hs, isz = num_experts, hidden_size, intermediate_size
|
49 |
+
|
50 |
+
experts_per_rank = ne // world_size
|
51 |
+
|
52 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
53 |
+
|
54 |
+
model.router = torch.nn.Linear(hs, ne).to(device)
|
55 |
+
model.router.weight.data.fill_(1)
|
56 |
+
|
57 |
+
e = model.experts
|
58 |
+
e.gate_up_proj = torch.nn.Parameter(
|
59 |
+
torch.ones(experts_per_rank, hs, isz, device=device)
|
60 |
+
)
|
61 |
+
e.gate_up_proj_bias = torch.nn.Parameter(
|
62 |
+
torch.zeros(experts_per_rank, isz, device=device)
|
63 |
+
)
|
64 |
+
e.down_proj = torch.nn.Parameter(
|
65 |
+
torch.ones(experts_per_rank, 1536, hs, device=device)
|
66 |
+
)
|
67 |
+
e.down_proj_bias = torch.nn.Parameter(
|
68 |
+
torch.zeros(experts_per_rank, hs, device=device)
|
69 |
+
)
|
70 |
+
e.hidden_size = hs
|
71 |
+
|
72 |
+
x = torch.randn(1, 1, 1152).to(device)
|
73 |
+
output, expert_weights_out = model(x)
|
74 |
+
|
75 |
+
assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}."
|
76 |
+
|
77 |
+
print(f"Rank {rank}: Test passed! Output shape: {output.shape}")
|
78 |
+
|
79 |
+
dist.destroy_process_group()
|
80 |
+
|
81 |
+
|
82 |
+
def test_megablocks_moe_mlp_functionality():
|
83 |
+
world_size = 2
|
84 |
+
|
85 |
+
mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True)
|
86 |
+
|
87 |
+
print("Multi-process test completed successfully!")
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == "__main__":
|
91 |
+
test_megablocks_moe_mlp_import()
|
92 |
+
print("Import test passed!")
|
93 |
+
|
94 |
+
test_megablocks_moe_mlp_functionality()
|
tests/test_gg.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import megablocks
|
3 |
+
|
4 |
+
|
5 |
+
def randn(bs, x, y):
|
6 |
+
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
|
7 |
+
return out.cuda().to(torch.bfloat16)
|
8 |
+
|
9 |
+
|
10 |
+
def gmm(a, b, batch_sizes, trans_b=False):
|
11 |
+
batch_sizes = batch_sizes.cpu().numpy()
|
12 |
+
|
13 |
+
out = []
|
14 |
+
start = 0
|
15 |
+
for i, size in enumerate(batch_sizes):
|
16 |
+
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
|
17 |
+
out.append(a[start : start + size, :] @ rhs)
|
18 |
+
start += size
|
19 |
+
return torch.cat(out)
|
20 |
+
|
21 |
+
|
22 |
+
def test_gmm():
|
23 |
+
z = 1
|
24 |
+
m = 128
|
25 |
+
n = 128
|
26 |
+
k = 128
|
27 |
+
trans_b = False
|
28 |
+
batch_sizes_on_device = False
|
29 |
+
# TODO: fix to enable batch_sizes_on_device
|
30 |
+
# batch_sizes_on_device = True
|
31 |
+
|
32 |
+
torch.manual_seed(0)
|
33 |
+
a = randn(z, m, k).view(-1, k)
|
34 |
+
b = randn(z, n, k) if trans_b else randn(z, k, n)
|
35 |
+
batch_sizes = torch.tensor([m] * z)
|
36 |
+
if batch_sizes_on_device:
|
37 |
+
batch_sizes = batch_sizes.cuda()
|
38 |
+
|
39 |
+
a.requires_grad_(True)
|
40 |
+
b.requires_grad_(True)
|
41 |
+
a_ref = a.detach().clone().requires_grad_(True)
|
42 |
+
b_ref = b.detach().clone().requires_grad_(True)
|
43 |
+
|
44 |
+
# out = ops.gmm(a, b, batch_sizes, trans_b)
|
45 |
+
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
|
46 |
+
print("out", out)
|
47 |
+
|
48 |
+
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
|
49 |
+
|
50 |
+
assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
|
51 |
+
|
52 |
+
out.sum().backward()
|
53 |
+
|
54 |
+
expected_out.sum().backward()
|
55 |
+
assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
|
56 |
+
assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
|
57 |
+
print("Test passed successfully!")
|
tests/test_mb_moe.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import megablocks
|
3 |
+
|
4 |
+
def test_import():
|
5 |
+
"""Simple test to check if the module can be imported."""
|
6 |
+
print("megablocks_moe module imported successfully.")
|
7 |
+
print("Available functions:", dir(megablocks))
|
8 |
+
|
9 |
+
expected_functions = [
|
10 |
+
"Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
|
11 |
+
"SparseGLU", "SparseMLP", "argsort",
|
12 |
+
"backend", "cumsum", "dMoE", "exclusive_cumsum",
|
13 |
+
"get_load_balancing_loss", "grouped_gemm_util", "histogram",
|
14 |
+
"inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
|
15 |
+
"replicate_forward", "sort", "torch"
|
16 |
+
]
|
17 |
+
|
18 |
+
# Check if all expected functions are available
|
19 |
+
for func in expected_functions:
|
20 |
+
assert func in dir(megablocks), f"Missing function: {func}"
|
21 |
+
|
22 |
+
# exclusive_cumsum
|
23 |
+
def test_exclusive_cumsum():
|
24 |
+
"""Test exclusive cumulative sum."""
|
25 |
+
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
|
26 |
+
out = torch.empty_like(x)
|
27 |
+
megablocks.exclusive_cumsum(x, 0, out)
|
28 |
+
expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
|
29 |
+
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
|
30 |
+
print("cumsum output:", out)
|
31 |
+
|
32 |
+
# inclusive_cumsum
|
33 |
+
def test_inclusive_cumsum():
|
34 |
+
"""Test inclusive cumulative sum."""
|
35 |
+
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
|
36 |
+
out = torch.empty_like(x)
|
37 |
+
megablocks.inclusive_cumsum(x, dim=0, out=out)
|
38 |
+
expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
|
39 |
+
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
|
40 |
+
|
41 |
+
# histogram
|
42 |
+
def test_histogram():
|
43 |
+
"""Test histogram operation."""
|
44 |
+
x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
|
45 |
+
num_bins = 3
|
46 |
+
hist = megablocks.histogram(x, num_bins)
|
47 |
+
expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
|
48 |
+
assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"
|
tests/test_mb_moe_shared_expert.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import megablocks
|
3 |
+
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
|
4 |
+
|
5 |
+
|
6 |
+
def test_megablocks_moe_mlp_with_shared_expert_import():
|
7 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
8 |
+
assert hasattr(mlp, 'shared_up_proj_weight')
|
9 |
+
assert hasattr(mlp, 'shared_down_proj_weight')
|
10 |
+
assert hasattr(mlp, 'set_shared_expert_weights')
|
11 |
+
|
12 |
+
|
13 |
+
def test_set_shared_expert_weights():
|
14 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
15 |
+
|
16 |
+
hidden_size = 128
|
17 |
+
shared_expert_hidden_size = 256
|
18 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
19 |
+
dtype = torch.float32
|
20 |
+
|
21 |
+
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
|
22 |
+
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
|
23 |
+
up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
|
24 |
+
down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
|
25 |
+
|
26 |
+
mlp.set_shared_expert_weights(
|
27 |
+
up_proj_weight=up_proj_weight,
|
28 |
+
down_proj_weight=down_proj_weight,
|
29 |
+
up_proj_bias=up_proj_bias,
|
30 |
+
down_proj_bias=down_proj_bias,
|
31 |
+
weighted_sum=True,
|
32 |
+
activation_fn=torch.nn.functional.gelu
|
33 |
+
)
|
34 |
+
|
35 |
+
assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
|
36 |
+
assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
|
37 |
+
assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
|
38 |
+
assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
|
39 |
+
assert mlp.shared_expert_weighted_sum == True
|
40 |
+
assert mlp.shared_activation_fn == torch.nn.functional.gelu
|
41 |
+
|
42 |
+
|
43 |
+
def test_create_shared_expert_weights():
|
44 |
+
hidden_size = 128
|
45 |
+
shared_expert_hidden_size = 256
|
46 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
47 |
+
dtype = torch.float32
|
48 |
+
|
49 |
+
def init_method(tensor):
|
50 |
+
torch.nn.init.xavier_uniform_(tensor)
|
51 |
+
|
52 |
+
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
|
53 |
+
hidden_size=hidden_size,
|
54 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
55 |
+
device=device,
|
56 |
+
dtype=dtype,
|
57 |
+
init_method=init_method
|
58 |
+
)
|
59 |
+
|
60 |
+
assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
|
61 |
+
assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
|
62 |
+
assert up_proj_weight.device.type == device.type
|
63 |
+
assert down_proj_weight.device.type == device.type
|
64 |
+
assert up_proj_weight.dtype == dtype
|
65 |
+
assert down_proj_weight.dtype == dtype
|
66 |
+
assert up_proj_bias is None
|
67 |
+
assert down_proj_bias is None
|
68 |
+
|
69 |
+
|
70 |
+
def test_shared_expert_weights_none_by_default():
|
71 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
72 |
+
|
73 |
+
assert mlp.shared_up_proj_weight is None
|
74 |
+
assert mlp.shared_down_proj_weight is None
|
75 |
+
assert mlp.shared_up_proj_bias is None
|
76 |
+
assert mlp.shared_down_proj_bias is None
|
77 |
+
assert mlp.shared_expert_weighted_sum == False
|
78 |
+
assert mlp.shared_activation_fn is None
|
79 |
+
|
80 |
+
|
81 |
+
def test_inheritance_from_megablocks_moe_mlp():
|
82 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
83 |
+
|
84 |
+
from megablocks.layers import MegaBlocksMoeMLP
|
85 |
+
assert isinstance(mlp, MegaBlocksMoeMLP)
|
86 |
+
assert hasattr(mlp, 'forward')
|
87 |
+
|
88 |
+
|
89 |
+
def test_shared_expert_weights_custom_init():
|
90 |
+
hidden_size = 64
|
91 |
+
shared_expert_hidden_size = 128
|
92 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
93 |
+
dtype = torch.float16
|
94 |
+
|
95 |
+
def custom_init(tensor):
|
96 |
+
torch.nn.init.constant_(tensor, 0.5)
|
97 |
+
|
98 |
+
def custom_output_init(tensor):
|
99 |
+
torch.nn.init.constant_(tensor, 0.1)
|
100 |
+
|
101 |
+
up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
|
102 |
+
hidden_size=hidden_size,
|
103 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
104 |
+
device=device,
|
105 |
+
dtype=dtype,
|
106 |
+
init_method=custom_init,
|
107 |
+
output_layer_init_method=custom_output_init
|
108 |
+
)
|
109 |
+
|
110 |
+
assert torch.all(up_proj_weight == 0.5)
|
111 |
+
assert torch.all(down_proj_weight == 0.1)
|
112 |
+
assert up_proj_weight.dtype == dtype
|
113 |
+
assert down_proj_weight.dtype == dtype
|
114 |
+
|
115 |
+
|
116 |
+
def test_shared_expert_weights_dimensions():
|
117 |
+
mlp = MegaBlocksMoeMLPWithSharedExpert()
|
118 |
+
|
119 |
+
batch_size = 4
|
120 |
+
seq_len = 16
|
121 |
+
hidden_size = 128
|
122 |
+
shared_expert_hidden_size = 256
|
123 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
124 |
+
|
125 |
+
up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
|
126 |
+
down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
|
127 |
+
|
128 |
+
mlp.set_shared_expert_weights(
|
129 |
+
up_proj_weight=up_proj_weight,
|
130 |
+
down_proj_weight=down_proj_weight
|
131 |
+
)
|
132 |
+
|
133 |
+
x = torch.randn(seq_len, batch_size, hidden_size, device=device)
|
134 |
+
|
135 |
+
expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
|
136 |
+
expected_down_output_shape = (seq_len, batch_size, hidden_size)
|
137 |
+
|
138 |
+
assert up_proj_weight.shape[1] == x.shape[-1]
|
139 |
+
assert down_proj_weight.shape[0] == x.shape[-1]
|
tests/test_mb_moe_shared_expert_multi.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.distributed as dist
|
3 |
+
import torch.multiprocessing as mp
|
4 |
+
import os
|
5 |
+
import pytest
|
6 |
+
from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
|
7 |
+
|
8 |
+
|
9 |
+
def run_distributed_shared_expert_test(rank, world_size):
|
10 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
11 |
+
os.environ["MASTER_PORT"] = "12356"
|
12 |
+
os.environ["RANK"] = str(rank)
|
13 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
14 |
+
|
15 |
+
dist.init_process_group(
|
16 |
+
backend="gloo",
|
17 |
+
rank=rank,
|
18 |
+
world_size=world_size,
|
19 |
+
)
|
20 |
+
|
21 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
22 |
+
|
23 |
+
hidden_size = 128
|
24 |
+
shared_expert_hidden_size = 192
|
25 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
+
|
27 |
+
def simple_init(tensor):
|
28 |
+
torch.nn.init.xavier_uniform_(tensor)
|
29 |
+
|
30 |
+
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
|
31 |
+
hidden_size=hidden_size,
|
32 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
33 |
+
device=torch.device(device),
|
34 |
+
dtype=torch.float32,
|
35 |
+
init_method=simple_init
|
36 |
+
)
|
37 |
+
|
38 |
+
model.set_shared_expert_weights(
|
39 |
+
up_proj_weight=shared_up_proj_weight,
|
40 |
+
down_proj_weight=shared_down_proj_weight,
|
41 |
+
up_proj_bias=shared_up_proj_bias,
|
42 |
+
down_proj_bias=shared_down_proj_bias,
|
43 |
+
weighted_sum=True,
|
44 |
+
activation_fn=torch.nn.functional.gelu
|
45 |
+
)
|
46 |
+
|
47 |
+
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
|
48 |
+
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
|
49 |
+
assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
|
50 |
+
|
51 |
+
print(f"Rank {rank}: Shared expert setup test passed!")
|
52 |
+
|
53 |
+
dist.destroy_process_group()
|
54 |
+
|
55 |
+
|
56 |
+
def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
|
57 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
58 |
+
os.environ["MASTER_PORT"] = "12357"
|
59 |
+
os.environ["RANK"] = str(rank)
|
60 |
+
os.environ["WORLD_SIZE"] = str(world_size)
|
61 |
+
|
62 |
+
dist.init_process_group(
|
63 |
+
backend="gloo",
|
64 |
+
rank=rank,
|
65 |
+
world_size=world_size,
|
66 |
+
)
|
67 |
+
|
68 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
69 |
+
|
70 |
+
hidden_size = 64
|
71 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
72 |
+
|
73 |
+
def simple_init(tensor):
|
74 |
+
torch.nn.init.xavier_uniform_(tensor)
|
75 |
+
|
76 |
+
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
|
77 |
+
hidden_size=hidden_size,
|
78 |
+
shared_expert_hidden_size=96,
|
79 |
+
device=torch.device(device),
|
80 |
+
dtype=torch.float32,
|
81 |
+
init_method=simple_init
|
82 |
+
)
|
83 |
+
|
84 |
+
model.set_shared_expert_weights(
|
85 |
+
up_proj_weight=shared_up_proj_weight,
|
86 |
+
down_proj_weight=shared_down_proj_weight,
|
87 |
+
weighted_sum=False,
|
88 |
+
activation_fn=torch.nn.functional.relu
|
89 |
+
)
|
90 |
+
|
91 |
+
assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
|
92 |
+
assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
|
93 |
+
assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
|
94 |
+
assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
|
95 |
+
|
96 |
+
print(f"Rank {rank}: Weighted sum setup test passed!")
|
97 |
+
|
98 |
+
dist.destroy_process_group()
|
99 |
+
|
100 |
+
|
101 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
102 |
+
def test_shared_expert_distributed_functionality(world_size):
|
103 |
+
if world_size == 1:
|
104 |
+
# Single process test
|
105 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
106 |
+
|
107 |
+
hidden_size = 128
|
108 |
+
shared_expert_hidden_size = 192
|
109 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
+
|
111 |
+
def simple_init(tensor):
|
112 |
+
torch.nn.init.xavier_uniform_(tensor)
|
113 |
+
|
114 |
+
shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
|
115 |
+
hidden_size=hidden_size,
|
116 |
+
shared_expert_hidden_size=shared_expert_hidden_size,
|
117 |
+
device=torch.device(device),
|
118 |
+
dtype=torch.float32,
|
119 |
+
init_method=simple_init
|
120 |
+
)
|
121 |
+
|
122 |
+
model.set_shared_expert_weights(
|
123 |
+
up_proj_weight=shared_up_proj_weight,
|
124 |
+
down_proj_weight=shared_down_proj_weight,
|
125 |
+
up_proj_bias=shared_up_proj_bias,
|
126 |
+
down_proj_bias=shared_down_proj_bias,
|
127 |
+
weighted_sum=True,
|
128 |
+
activation_fn=torch.nn.functional.gelu
|
129 |
+
)
|
130 |
+
|
131 |
+
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
|
132 |
+
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
|
133 |
+
assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
|
134 |
+
|
135 |
+
print("Single process shared expert setup test passed!")
|
136 |
+
else:
|
137 |
+
# Multi-process test
|
138 |
+
mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
|
139 |
+
print("Multi-process shared expert test completed successfully!")
|
140 |
+
|
141 |
+
|
142 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
143 |
+
def test_shared_expert_distributed_weighted_sum(world_size):
|
144 |
+
if world_size == 1:
|
145 |
+
# Single process test
|
146 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
147 |
+
|
148 |
+
hidden_size = 64
|
149 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
150 |
+
|
151 |
+
def simple_init(tensor):
|
152 |
+
torch.nn.init.xavier_uniform_(tensor)
|
153 |
+
|
154 |
+
shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
|
155 |
+
hidden_size=hidden_size,
|
156 |
+
shared_expert_hidden_size=96,
|
157 |
+
device=torch.device(device),
|
158 |
+
dtype=torch.float32,
|
159 |
+
init_method=simple_init
|
160 |
+
)
|
161 |
+
|
162 |
+
model.set_shared_expert_weights(
|
163 |
+
up_proj_weight=shared_up_proj_weight,
|
164 |
+
down_proj_weight=shared_down_proj_weight,
|
165 |
+
weighted_sum=False,
|
166 |
+
activation_fn=torch.nn.functional.relu
|
167 |
+
)
|
168 |
+
|
169 |
+
assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
|
170 |
+
assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
|
171 |
+
assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
|
172 |
+
assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
|
173 |
+
|
174 |
+
print("Single process weighted sum setup test passed!")
|
175 |
+
else:
|
176 |
+
# Multi-process test
|
177 |
+
mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
|
178 |
+
print("Multi-process shared expert weighted sum test completed successfully!")
|
179 |
+
|
180 |
+
|
181 |
+
def test_shared_expert_single_process():
|
182 |
+
model = MegaBlocksMoeMLPWithSharedExpert()
|
183 |
+
|
184 |
+
assert model.shared_up_proj_weight is None
|
185 |
+
assert model.shared_down_proj_weight is None
|
186 |
+
assert hasattr(model, 'set_shared_expert_weights')
|
187 |
+
|
188 |
+
print("Single process shared expert basic test passed!")
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
test_shared_expert_single_process()
|
193 |
+
print("Single process test passed!")
|
194 |
+
|
195 |
+
os.environ['WORLD_SIZE'] = '2'
|
196 |
+
test_shared_expert_distributed_functionality()
|
197 |
+
print("Distributed functionality test passed!")
|
198 |
+
|
199 |
+
test_shared_expert_distributed_weighted_sum()
|
200 |
+
print("Distributed weighted sum test passed!")
|
torch-ext/megablocks/__init__.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from ._ops import ops
|
7 |
+
|
8 |
+
#from .grouped_gemm import backend as gg_backend
|
9 |
+
#from .grouped_gemm import ops as gg_ops
|
10 |
+
|
11 |
+
|
12 |
+
from ._layers.arguments import Arguments
|
13 |
+
from ._layers.dmoe import ParallelDroplessMLP, dMoE
|
14 |
+
from ._layers.glu import SparseGLU
|
15 |
+
from ._layers.mlp import MLP, SparseMLP
|
16 |
+
from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
|
17 |
+
|
18 |
+
from . import layers
|
19 |
+
|
20 |
+
# This section contains the direct kernel exports (not inlcuded in the original code)
|
21 |
+
def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
Compute exclusive cumulative sum along the specified dimension.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
x: Input tensor
|
27 |
+
dim: Dimension along which to compute cumsum
|
28 |
+
out: Output tensor (modified in-place)
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
The output tensor
|
32 |
+
"""
|
33 |
+
result = ops.exclusive_cumsum(x, dim)
|
34 |
+
out.copy_(result)
|
35 |
+
return out
|
36 |
+
|
37 |
+
|
38 |
+
def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
Compute inclusive cumulative sum along the specified dimension.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
x: Input tensor
|
44 |
+
dim: Dimension along which to compute cumsum
|
45 |
+
out: Output tensor (modified in-place)
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
The output tensor
|
49 |
+
"""
|
50 |
+
result = ops.inclusive_cumsum(x, dim)
|
51 |
+
out.copy_(result)
|
52 |
+
return out
|
53 |
+
|
54 |
+
|
55 |
+
def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
|
56 |
+
"""
|
57 |
+
Compute histogram of input tensor values.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
x: Input tensor
|
61 |
+
num_bins: Number of histogram bins
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Histogram tensor with counts for each bin
|
65 |
+
"""
|
66 |
+
return ops.histogram(x, num_bins)
|
67 |
+
|
68 |
+
|
69 |
+
def indices(
|
70 |
+
padded_bins: torch.Tensor,
|
71 |
+
block_size: int,
|
72 |
+
output_block_rows: int,
|
73 |
+
output_block_columns: int,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Construct indices from padded bins for sparse operations.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
padded_bins: Tensor containing bin boundaries
|
80 |
+
block_size: Size of each block
|
81 |
+
output_block_rows: Number of rows in output blocks
|
82 |
+
output_block_columns: Number of columns in output blocks
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
Tensor containing constructed indices
|
86 |
+
"""
|
87 |
+
return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
|
88 |
+
|
89 |
+
|
90 |
+
def replicate_forward(
|
91 |
+
x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
92 |
+
) -> torch.Tensor:
|
93 |
+
"""
|
94 |
+
Forward pass of replicate operation - replicate values according to bin sizes.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
x: Input tensor with values to replicate
|
98 |
+
bins: Tensor containing bin sizes
|
99 |
+
out: Output tensor (modified in-place)
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
The output tensor
|
103 |
+
"""
|
104 |
+
return ops.replicate_forward(x, bins, out)
|
105 |
+
|
106 |
+
|
107 |
+
def replicate_backward(
|
108 |
+
grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
|
109 |
+
) -> torch.Tensor:
|
110 |
+
"""
|
111 |
+
Backward pass of replicate operation - reduce gradients back to bins.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
grad: Gradient tensor to reduce
|
115 |
+
bins: Tensor containing bin sizes
|
116 |
+
out: Output tensor (modified in-place)
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
The output tensor
|
120 |
+
"""
|
121 |
+
return ops.replicate_backward(grad, bins, out)
|
122 |
+
|
123 |
+
|
124 |
+
def sort(
|
125 |
+
x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
|
126 |
+
) -> torch.Tensor:
|
127 |
+
"""
|
128 |
+
Radix sort with index tracking.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
x: Input tensor to sort
|
132 |
+
end_bit: Number of bits to consider in sorting
|
133 |
+
x_out: Output tensor for sorted values
|
134 |
+
iota_out: Output tensor for sorted indices
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
The sorted values tensor
|
138 |
+
"""
|
139 |
+
return ops.sort(x, end_bit, x_out, iota_out)
|
140 |
+
|
141 |
+
|
142 |
+
# Convenience functions for common use cases
|
143 |
+
def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
|
144 |
+
"""
|
145 |
+
Compute cumulative sum with automatic output allocation.
|
146 |
+
|
147 |
+
Args:
|
148 |
+
x: Input tensor
|
149 |
+
dim: Dimension along which to compute cumsum (default: last dimension)
|
150 |
+
exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
New tensor containing the cumulative sum
|
154 |
+
"""
|
155 |
+
out = torch.empty_like(x)
|
156 |
+
if exclusive:
|
157 |
+
return exclusive_cumsum(x, dim, out)
|
158 |
+
else:
|
159 |
+
return inclusive_cumsum(x, dim, out)
|
160 |
+
|
161 |
+
|
162 |
+
def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
|
163 |
+
"""
|
164 |
+
Sort tensor and return both sorted values and indices.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
x: Input tensor to sort
|
168 |
+
end_bit: Number of bits to consider in sorting
|
169 |
+
|
170 |
+
Returns:
|
171 |
+
Tuple of (sorted_values, sorted_indices)
|
172 |
+
"""
|
173 |
+
x_out = torch.empty_like(x)
|
174 |
+
iota_out = torch.empty_like(x)
|
175 |
+
sort(x, end_bit, x_out, iota_out)
|
176 |
+
return x_out, iota_out
|
177 |
+
|
178 |
+
|
179 |
+
# Export public API
|
180 |
+
__all__ = [
|
181 |
+
"MyReplacementLayer",
|
182 |
+
# Direct kernel exports
|
183 |
+
"exclusive_cumsum",
|
184 |
+
"inclusive_cumsum",
|
185 |
+
"histogram",
|
186 |
+
"indices",
|
187 |
+
"replicate_forward",
|
188 |
+
"replicate_backward",
|
189 |
+
"sort",
|
190 |
+
"cumsum",
|
191 |
+
"argsort",
|
192 |
+
# Original exports
|
193 |
+
"Arguments",
|
194 |
+
"ParallelDroplessMLP",
|
195 |
+
"dMoE",
|
196 |
+
"SparseGLU",
|
197 |
+
"MLP",
|
198 |
+
"SparseMLP",
|
199 |
+
"MoE",
|
200 |
+
"ParallelMLP",
|
201 |
+
"get_load_balancing_loss",
|
202 |
+
]
|
torch-ext/megablocks/_layers/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
# from megablocks.layers.dmoe import dMoE
|
5 |
+
from .moe import MoE
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'MoE',
|
9 |
+
# 'dMoE',
|
10 |
+
]
|
torch-ext/megablocks/_layers/activation_fn.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
from typing import Any, Callable, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from ..stk import Matrix
|
8 |
+
|
9 |
+
|
10 |
+
def act_fn(
|
11 |
+
x: Matrix,
|
12 |
+
function: Callable,
|
13 |
+
return_grad_fn: bool = False,
|
14 |
+
**kwargs,
|
15 |
+
) -> Union[tuple[Matrix, Any] | Matrix]:
|
16 |
+
assert isinstance(x, Matrix)
|
17 |
+
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
|
18 |
+
if return_grad_fn:
|
19 |
+
x.data.requires_grad = True
|
20 |
+
out = function(x.data, **kwargs)
|
21 |
+
y = Matrix(
|
22 |
+
x.size(),
|
23 |
+
out,
|
24 |
+
x.row_indices,
|
25 |
+
x.column_indices,
|
26 |
+
x.offsets,
|
27 |
+
x.column_indices_t,
|
28 |
+
x.offsets_t,
|
29 |
+
x.block_offsets_t,
|
30 |
+
)
|
31 |
+
if return_grad_fn:
|
32 |
+
return y, out.backward
|
33 |
+
return y
|
torch-ext/megablocks/_layers/all_to_all.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Databricks
|
2 |
+
# SPDX-License-Identifier: Apache-2.0
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class AllToAllOp(torch.autograd.Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
|
12 |
+
out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
|
13 |
+
|
14 |
+
ctx.input_shape = x.shape
|
15 |
+
ctx.output_split_sizes = output_split_sizes
|
16 |
+
ctx.input_split_sizes = input_split_sizes
|
17 |
+
ctx.group = group
|
18 |
+
handle = dist.all_to_all_single(
|
19 |
+
out,
|
20 |
+
x,
|
21 |
+
output_split_sizes=output_split_sizes,
|
22 |
+
input_split_sizes=input_split_sizes,
|
23 |
+
group=group,
|
24 |
+
async_op=async_op,
|
25 |
+
)
|
26 |
+
return out, handle
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, grad, _):
|
30 |
+
if ctx.needs_input_grad[0]:
|
31 |
+
out = torch.empty(
|
32 |
+
ctx.input_shape,
|
33 |
+
device=grad.device,
|
34 |
+
dtype=grad.dtype,
|
35 |
+
)
|
36 |
+
dist.all_to_all_single(
|
37 |
+
out,
|
38 |
+
grad,
|
39 |
+
output_split_sizes=ctx.input_split_sizes,
|
40 |
+
input_split_sizes=ctx.output_split_sizes,
|
41 |
+
group=ctx.group,
|
42 |
+
)
|
43 |
+
return out, None, None, None, None
|
44 |
+
return None, None, None, None, None
|
45 |
+
|
46 |
+
|
47 |
+
def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
|
48 |
+
return AllToAllOp.apply(
|
49 |
+
x,
|
50 |
+
output_split_sizes,
|
51 |
+
input_split_sizes,
|
52 |
+
group,
|
53 |
+
async_op,
|
54 |
+
)
|