Akos Hadnagy commited on
Commit
1e1ffe8
·
1 Parent(s): ff615fc
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. build.toml +35 -0
  2. csrc/bak.ops.cu +21 -0
  3. csrc/cuda_util.h +62 -0
  4. csrc/cumsum.h +163 -0
  5. csrc/grouped_gemm/fill_arguments.cuh +141 -0
  6. csrc/grouped_gemm/grouped_gemm.cu +567 -0
  7. csrc/grouped_gemm/grouped_gemm.h +20 -0
  8. csrc/grouped_gemm/ops.cu +11 -0
  9. csrc/histogram.h +86 -0
  10. csrc/indices.h +95 -0
  11. csrc/new_cumsum.cu +161 -0
  12. csrc/new_cumsum.h +11 -0
  13. csrc/new_histogram.cu +85 -0
  14. csrc/new_histogram.h +10 -0
  15. csrc/new_indices.cu +97 -0
  16. csrc/new_indices.h +14 -0
  17. csrc/new_replicate.cu +220 -0
  18. csrc/new_replicate.h +17 -0
  19. csrc/new_sort.cu +90 -0
  20. csrc/new_sort.h +13 -0
  21. csrc/replicate.h +211 -0
  22. csrc/sort.h +91 -0
  23. flake.lock +168 -0
  24. flake.nix +24 -0
  25. tests/__init__.py +0 -0
  26. tests/conftest.py +110 -0
  27. tests/fixtures/autouse.py +107 -0
  28. tests/fixtures/fixtures.py +13 -0
  29. tests/layer_test.py +53 -0
  30. tests/layers/architectures.py +53 -0
  31. tests/layers/moe_test.py +199 -0
  32. tests/ops/binned_gather_test.py +71 -0
  33. tests/ops/binned_scatter_test.py +87 -0
  34. tests/ops/cumsum_test.py +44 -0
  35. tests/ops/histogram_test.py +82 -0
  36. tests/ops/padded_gather_test.py +94 -0
  37. tests/ops/padded_scatter_test.py +155 -0
  38. tests/ops/replicate_test.py +108 -0
  39. tests/ops/sort_test.py +65 -0
  40. tests/ops/topology_test.py +81 -0
  41. tests/ops_test.py +171 -0
  42. tests/parallel_layer_test.py +94 -0
  43. tests/test_gg.py +57 -0
  44. tests/test_mb_moe.py +48 -0
  45. tests/test_mb_moe_shared_expert.py +139 -0
  46. tests/test_mb_moe_shared_expert_multi.py +200 -0
  47. torch-ext/megablocks/__init__.py +202 -0
  48. torch-ext/megablocks/_layers/__init__.py +10 -0
  49. torch-ext/megablocks/_layers/activation_fn.py +33 -0
  50. torch-ext/megablocks/_layers/all_to_all.py +54 -0
build.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "megablocks"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/torch_binding.cpp",
8
+ "torch-ext/torch_binding.h"
9
+ ]
10
+
11
+ [kernel.megablocks]
12
+ backend = "rocm"
13
+ rocm-archs = [
14
+ "gfx942",
15
+ "gfx1030",
16
+ "gfx1100",
17
+ "gfx1101",
18
+ ]
19
+ depends = ["torch"]
20
+ src = [
21
+ "csrc/new_cumsum.h",
22
+ "csrc/new_cumsum.cu",
23
+ "csrc/new_histogram.h",
24
+ "csrc/new_histogram.cu",
25
+ "csrc/new_indices.h",
26
+ "csrc/new_indices.cu",
27
+ "csrc/new_replicate.cu",
28
+ "csrc/new_replicate.h",
29
+ "csrc/new_sort.h",
30
+ "csrc/new_sort.cu",
31
+ # vendored grouped gemm
32
+ #"csrc/grouped_gemm/fill_arguments.cuh",
33
+ #"csrc/grouped_gemm/grouped_gemm.cu",
34
+ #"csrc/grouped_gemm/grouped_gemm.h",
35
+ ]
csrc/bak.ops.cu ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cumsum.h"
2
+ #include "histogram.h"
3
+ #include "indices.h"
4
+ #include "replicate.h"
5
+ #include "sort.h"
6
+
7
+ #include <torch/extension.h>
8
+
9
+ namespace megablocks {
10
+
11
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
12
+ m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
13
+ m.def("histogram", &histogram, "even width histogram.");
14
+ m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum");
15
+ m.def("indices", &indices, "indices construction for sparse matrix.");
16
+ m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically.");
17
+ m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically.");
18
+ m.def("sort", &sort, "key/value sort.");
19
+ }
20
+
21
+ } // namespace megablocks
csrc/cuda_util.h ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef BLOCKPARTY_CSRC_CUDA_UTIL_H_
2
+ #define BLOCKPARTY_CSRC_CUDA_UTIL_H_
3
+
4
+ #include <cuda_fp16.h>
5
+ #include <cuda_runtime.h>
6
+ // #include <torch/extension.h>
7
+
8
+ namespace megablocks {
9
+
10
+ typedef __half2 half2;
11
+
12
+ struct __align__(8) half4 {
13
+ half2 x, y;
14
+ };
15
+
16
+ struct __align__(16) half8 {
17
+ half2 x, y, z, w;
18
+ };
19
+
20
+ template <class To, class From>
21
+ __device__ __forceinline__ To BitCast(const From& src) noexcept {
22
+ To dst;
23
+ std::memcpy(&dst, &src, sizeof(To));
24
+ return dst;
25
+ }
26
+
27
+ template <typename T>
28
+ __device__ __forceinline__ void Store(const T& value, T* ptr) {
29
+ *ptr = value;
30
+ }
31
+
32
+ template <typename T>
33
+ __device__ __forceinline__ T Load(const T* address) {
34
+ return __ldg(address);
35
+ }
36
+
37
+ __device__ __forceinline__ half4 Load(const half4* address) {
38
+ float2 x = __ldg(reinterpret_cast<const float2*>(address));
39
+ return BitCast<half4>(x);
40
+ }
41
+
42
+ __device__ __forceinline__ half8 Load(const half8* address) {
43
+ float4 x = __ldg(reinterpret_cast<const float4*>(address));
44
+ return BitCast<half8>(x);
45
+ }
46
+
47
+ template <typename T>
48
+ __device__ __forceinline__ T Zero() { return 0; };
49
+
50
+ template <>
51
+ __device__ __forceinline__ half2 Zero<half2>() {
52
+ return {(c10::Half)0., (c10::Half)0.};
53
+ };
54
+
55
+ template <>
56
+ __device__ __forceinline__ half4 Zero<half4>() {
57
+ return {Zero<half2>(), Zero<half2>()};
58
+ };
59
+
60
+ } // namespace megablocks
61
+
62
+ #endif // BLOCKPARTY_CSRC_CUDA_UTIL_H_
csrc/cumsum.h ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define CUB_IGNORE_DEPRECATED_API
2
+
3
+ #undef CUB_WRAPPED_NAMESPACE
4
+ #define CUB_WRAPPED_NAMESPACE megablocks
5
+
6
+ #include <cstdint>
7
+
8
+ #include <cub/cub.cuh>
9
+ #include <c10/cuda/CUDAStream.h>
10
+ #include <torch/all.h>
11
+ // #include <torch/extension.h>
12
+
13
+ #define CUDA_CALL(code) \
14
+ do { \
15
+ cudaError_t status = code; \
16
+ std::string err = cudaGetErrorString(status); \
17
+ TORCH_CHECK(status == cudaSuccess, err); \
18
+ } while (0)
19
+
20
+ namespace megablocks {
21
+
22
+ struct Inclusive {};
23
+ struct Exclusive {};
24
+
25
+ template <typename Type> struct Cumsum {
26
+
27
+ template<
28
+ typename InputIteratorT,
29
+ typename OutputIteratorT>
30
+ static void Run(void * d_temp_storage,
31
+ size_t & temp_storage_bytes,
32
+ InputIteratorT d_in,
33
+ OutputIteratorT d_out,
34
+ int num_items,
35
+ cudaStream_t stream = 0,
36
+ bool debug_synchronous = false) {
37
+ CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
38
+ temp_storage_bytes,
39
+ d_in,
40
+ d_out,
41
+ num_items,
42
+ stream));//,
43
+ //debug_synchronous));
44
+ }
45
+ };
46
+
47
+ template <> struct Cumsum<Inclusive> {
48
+ template<
49
+ typename InputIteratorT,
50
+ typename OutputIteratorT>
51
+ static void Run(void * d_temp_storage,
52
+ size_t & temp_storage_bytes,
53
+ InputIteratorT d_in,
54
+ OutputIteratorT d_out,
55
+ int num_items,
56
+ cudaStream_t stream = 0,
57
+ bool debug_synchronous = false) {
58
+ CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
59
+ temp_storage_bytes,
60
+ d_in,
61
+ d_out,
62
+ num_items,
63
+ stream));//,
64
+ //debug_synchronous));
65
+ }
66
+ };
67
+
68
+ template <typename SumType, typename T>
69
+ void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
70
+ // Get temporary storage size.
71
+ size_t scratchpad_bytes = 0;
72
+ Cumsum<SumType>::Run(nullptr,
73
+ scratchpad_bytes,
74
+ x.data_ptr<T>(),
75
+ out.data_ptr<T>(),
76
+ x.size(1),
77
+ c10::cuda::getCurrentCUDAStream());
78
+
79
+ // Allocate scratchpad.
80
+ //
81
+ // NOTE: We scale for the batch dimension so we can run in parallel.
82
+ auto options = torch::TensorOptions()
83
+ .dtype(torch::kInt8)
84
+ .device(x.device());
85
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
86
+ options);
87
+
88
+ // Run the kernel.
89
+ //
90
+ // NOTE: Using different streams for each issue does not appear to
91
+ // yield performance gains for our problem set. The overhead of
92
+ // event/stream synchronization appears to outweigh the benfits.
93
+ // We could write a true batched cumsum, but this would require
94
+ // significant code duplication from cub and we might move away
95
+ // from this formulation anyways.
96
+ for (int i = 0; i < x.size(0); ++i) {
97
+ void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
98
+ Cumsum<SumType>::Run(scratchpad_ptr,
99
+ scratchpad_bytes,
100
+ x.data_ptr<T>() + x.size(1) * i,
101
+ out.data_ptr<T>() + x.size(1) * i,
102
+ x.size(1),
103
+ c10::cuda::getCurrentCUDAStream());
104
+ }
105
+ }
106
+
107
+ void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
108
+ // Validate the input matrix.
109
+ TORCH_CHECK(x.is_cuda());
110
+ TORCH_CHECK(x.ndimension() == 2);
111
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
112
+ x.scalar_type() == torch::kInt32 ||
113
+ x.scalar_type() == torch::kInt64);
114
+ TORCH_CHECK(out.is_cuda());
115
+ TORCH_CHECK(out.ndimension() == 2);
116
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
117
+
118
+ // NOTE: We currently only support contraction across the contiguous
119
+ // dimension in the matrix.
120
+ TORCH_CHECK(dim == 1);
121
+
122
+ switch (x.scalar_type()) {
123
+ case torch::kInt16:
124
+ cub_cumsum<Exclusive, short>(x, dim, out);
125
+ return;
126
+ case torch::kInt32:
127
+ cub_cumsum<Exclusive, int>(x, dim, out);
128
+ return;
129
+ }
130
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
131
+ cub_cumsum<Exclusive, long>(x, dim, out);
132
+ }
133
+
134
+ void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
135
+ // Validate the input matrix.
136
+ TORCH_CHECK(x.is_cuda());
137
+ TORCH_CHECK(x.ndimension() == 2);
138
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
139
+ x.scalar_type() == torch::kInt32 ||
140
+ x.scalar_type() == torch::kInt64);
141
+ TORCH_CHECK(out.is_cuda());
142
+ TORCH_CHECK(out.ndimension() == 2);
143
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
144
+
145
+ // NOTE: We currently only support contraction across the contiguous
146
+ // dimension in the matrix.
147
+ TORCH_CHECK(dim == 1);
148
+
149
+ switch (x.scalar_type()) {
150
+ case torch::kInt16:
151
+ cub_cumsum<Inclusive, short>(x, dim, out);
152
+ return;
153
+ case torch::kInt32:
154
+ cub_cumsum<Inclusive, int>(x, dim, out);
155
+ return;
156
+ }
157
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
158
+ cub_cumsum<Inclusive, long>(x, dim, out);
159
+ }
160
+
161
+ } // namespace megablocks
162
+
163
+ #undef CUB_WRAPPED_NAMESPACE
csrc/grouped_gemm/fill_arguments.cuh ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/detail/KernelUtils.h>
4
+ #include <cub/cub.cuh>
5
+ #include <cutlass/bfloat16.h>
6
+ #include <cutlass/gemm_coord.h>
7
+
8
+ namespace grouped_gemm {
9
+
10
+ constexpr int kDynamicDim = -1;
11
+ constexpr int kMaxExperts = 512;
12
+
13
+ struct GemmProblem {
14
+ ::cutlass::gemm::GemmCoord dims;
15
+ int64_t lda, ldb, ldc;
16
+ // All offsets are in elements.
17
+ int64_t a_offset, b_offset, c_offset;
18
+ };
19
+
20
+ // TODO: revisit `ExtractGemmProblemK` struct
21
+ // struct ExtractGemmProblemK {
22
+ // __device__ ::cuda::std::tuple<int&> operator()(GemmProblem& problem) const {
23
+ // return {problem.dims.k()};
24
+ // }
25
+ // };
26
+
27
+ template <
28
+ // If `k` is dynamic, we sort the problems by `k` in descending order.
29
+ // Otherwise, `m` is dynamic, and no sorting happens.
30
+ bool kDynamicK,
31
+ typename ElementA, typename ElementB, typename ElementC,
32
+ typename LayoutA, typename LayoutB, typename LayoutC,
33
+ typename Args
34
+ >
35
+ __global__ void FillArguments(
36
+ int num_experts, const int64_t* batch_sizes,
37
+ ElementA* ptr_a, ElementB* ptr_b, ElementC* ptr_c,
38
+ Args args, ::cutlass::gemm::GemmCoord dims
39
+ ) {
40
+ const int expert_idx = threadIdx.x;
41
+ const int batch_size = expert_idx < num_experts ? batch_sizes[expert_idx] : -1;
42
+
43
+ if (kDynamicK) {
44
+ assert(dims.k() == kDynamicDim);
45
+ dims.k() = batch_size;
46
+ } else {
47
+ assert(dims.m() == kDynamicDim);
48
+ dims.m() = batch_size;
49
+ }
50
+
51
+ using BlockScan = cub::BlockScan<int, kMaxExperts>;
52
+ using BlockSort = cub::BlockRadixSort<int, kMaxExperts, 1, GemmProblem>;
53
+
54
+ union SharedMemory {
55
+ typename BlockScan::TempStorage scan_storage;
56
+ typename BlockSort::TempStorage sort_storage;
57
+ };
58
+ __shared__ SharedMemory shared_memory;
59
+
60
+ int dynamic_dim = kDynamicK ? dims.k() : dims.m();
61
+ int dynamic_dim_cumsum;
62
+ BlockScan(shared_memory.scan_storage).ExclusiveSum(dynamic_dim, dynamic_dim_cumsum);
63
+ __syncthreads();
64
+
65
+ // We have to use `GemmProblem[1]` here instead of just `GemmProblem` because `SortDescending()` expects
66
+ // `KeyT (&)[ITEMS_PER_THREAD]` for the `keys` argument (i.e., `GemmProblem (&keys)[1]` in our case).
67
+ GemmProblem problem[1] = {
68
+ GemmProblem {
69
+ .dims = dims,
70
+ .lda = LayoutA::packed({dims.m(), dims.k()}).stride(0),
71
+ .ldb = LayoutB::packed({dims.k(), dims.n()}).stride(0),
72
+ .ldc = LayoutC::packed({dims.m(), dims.n()}).stride(0),
73
+ .a_offset = kDynamicK
74
+ ? (dims.m() * dynamic_dim_cumsum)
75
+ : (dynamic_dim_cumsum * dims.k()),
76
+ .b_offset = (kDynamicK ? dynamic_dim_cumsum : expert_idx * dims.k()) * dims.n(),
77
+ .c_offset = (kDynamicK ? expert_idx * dims.m() : dynamic_dim_cumsum) * dims.n(),
78
+ },
79
+ };
80
+
81
+ if constexpr (kDynamicK) {
82
+ // Sort by k dimension in descending order
83
+ // We need to extract the key (k value) for sorting
84
+ int k_keys[1] = { problem[0].dims.k() };
85
+
86
+ BlockSort(shared_memory.sort_storage).SortDescending(k_keys, problem);
87
+
88
+ // TODO: revisit original impl without `__syncthreads()`
89
+ // BlockSort(shared_memory.sort_storage).SortDescending(problem, ExtractGemmProblemK{});
90
+ // Quoting the CUB documentation (https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockRadixSort.html):
91
+ // > A subsequent __syncthreads() threadblock barrier should be invoked after calling this method if the collective’s temporary storage [...]
92
+ // > is **to be reused or repurposed**.
93
+ // We don't need `__syncthreads()` here, since we don't do either of these things.
94
+ }
95
+
96
+ if (expert_idx < num_experts) {
97
+ args.problem_sizes[expert_idx] = problem[0].dims;
98
+ args.lda[expert_idx] = problem[0].lda;
99
+ args.ldb[expert_idx] = problem[0].ldb;
100
+ args.ldc[expert_idx] = problem[0].ldc;
101
+
102
+ args.ptr_A[expert_idx] = ptr_a + problem[0].a_offset;
103
+ args.ptr_B[expert_idx] = ptr_b + problem[0].b_offset;
104
+ args.ptr_C[expert_idx] = ptr_c + problem[0].c_offset;
105
+ }
106
+ }
107
+
108
+ template <typename Args>
109
+ __global__ void ZeroOutK0Outputs(int num_experts, Args args) {
110
+ const int64_t start_idx = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
111
+ const int64_t delta = (int64_t)gridDim.x * blockDim.x;
112
+ for (int ei = 0; ei < num_experts; ++ei) {
113
+ auto& dims = args.problem_sizes[ei];
114
+ // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
115
+ // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
116
+ // * (here) set the output to zero
117
+ // * (in `IgnoreK0Problems`) make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
118
+ if (dims.k() == 0) {
119
+ // Assume packed layout, run a grid-strided loop over the output.
120
+ int64_t total_elems = (int64_t)dims.m() * dims.n();
121
+ auto* out = args.ptr_C[ei];
122
+ for (int64_t idx = start_idx; idx < total_elems; idx += delta) {
123
+ out[idx] = {};
124
+ }
125
+ }
126
+ }
127
+ }
128
+
129
+ template <typename Args>
130
+ __global__ void IgnoreK0Problems(int num_experts, Args args) {
131
+ const int expert_idx = threadIdx.x;
132
+ if (expert_idx < num_experts) {
133
+ auto& dims = args.problem_sizes[expert_idx];
134
+ if (dims.k() == 0) {
135
+ dims.m() = 0;
136
+ dims.n() = 0;
137
+ }
138
+ }
139
+ }
140
+
141
+ } // namespace grouped_gemm
csrc/grouped_gemm/grouped_gemm.cu ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "grouped_gemm.h"
2
+ #include "fill_arguments.cuh"
3
+
4
+ #include <ATen/cuda/CUDAContext.h>
5
+ #include <ATen/cuda/detail/KernelUtils.h>
6
+ #include <c10/util/BFloat16.h>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ #include <cub/cub.cuh>
9
+ #include <torch/torch.h>
10
+
11
+ #include "cutlass/bfloat16.h"
12
+ #include "cutlass/complex.h"
13
+ #include "cutlass/gemm/kernel/gemm_grouped.h"
14
+ #include "cutlass/gemm/kernel/default_gemm_grouped.h"
15
+ #include "cutlass/gemm/device/gemm_grouped.h"
16
+
17
+ #include <type_traits>
18
+
19
+ namespace grouped_gemm {
20
+
21
+ #define CUDA_CALL(code) \
22
+ do { \
23
+ cudaError_t status = code; \
24
+ std::string err = cudaGetErrorString(status); \
25
+ TORCH_CHECK(status == cudaSuccess, err); \
26
+ } while (0)
27
+
28
+ #define CUBLAS_CALL(code) \
29
+ do { \
30
+ cublasStatus_t status = code; \
31
+ TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "CuBLAS Error"); \
32
+ } while (0)
33
+
34
+ #define GROUPED_GEMM_STRINGIFY_HELPER(x) #x
35
+ #define GROUPED_GEMM_STRINGIFY(x) \
36
+ GROUPED_GEMM_STRINGIFY_HELPER(x)
37
+
38
+ template <bool trans>
39
+ using GroupedGemmInputLayout = std::conditional_t<trans, ::cutlass::layout::ColumnMajor, ::cutlass::layout::RowMajor>;
40
+
41
+ using GroupedGemmConfig = ::cutlass::gemm::device::DefaultGemmConfiguration<
42
+ ::cutlass::arch::OpClassTensorOp,
43
+ ::cutlass::arch::Sm80,
44
+ ::cutlass::bfloat16_t,
45
+ ::cutlass::bfloat16_t,
46
+ ::cutlass::bfloat16_t,
47
+ float
48
+ >;
49
+
50
+ // TODO(tgale): Update this for SM90 when it's supported by CUTLASS.
51
+ template <bool trans_a, bool trans_b>
52
+ using GroupedGemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
53
+ // A operand.
54
+ ::cutlass::bfloat16_t,
55
+ GroupedGemmInputLayout<trans_a>,
56
+ ::cutlass::ComplexTransform::kNone,
57
+ GroupedGemmConfig::kAlignmentA,
58
+ // B operand.
59
+ ::cutlass::bfloat16_t,
60
+ GroupedGemmInputLayout<trans_b>,
61
+ ::cutlass::ComplexTransform::kNone,
62
+ GroupedGemmConfig::kAlignmentB,
63
+ // C operand.
64
+ ::cutlass::bfloat16_t,
65
+ ::cutlass::layout::RowMajor,
66
+ float,
67
+ ::cutlass::arch::OpClassTensorOp,
68
+ ::cutlass::arch::Sm80,
69
+ GroupedGemmConfig::ThreadblockShape,
70
+ GroupedGemmConfig::WarpShape,
71
+ GroupedGemmConfig::InstructionShape,
72
+ GroupedGemmConfig::EpilogueOutputOp,
73
+ // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels.
74
+ // This parameter is passed in at present to match the APIs of other kernels. The parameter
75
+ // is unused within the kernel.
76
+ ::cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,
77
+ // TODO(tgale): Tune this for SM90.
78
+ GroupedGemmConfig::kStages>::GemmKernel;
79
+
80
+ template <bool trans_a, bool trans_b>
81
+ using GemmGrouped = ::cutlass::gemm::device::GemmGrouped<GroupedGemmKernel<trans_a, trans_b>>;
82
+
83
+ template <typename T>
84
+ torch::Tensor CopyToDevice(const std::vector<T> &x, const torch::Device &device) {
85
+ size_t bytes = x.size() * sizeof(T);
86
+ auto options = torch::TensorOptions().dtype(torch::kInt8).device(device);
87
+ torch::Tensor out = torch::empty(bytes, options);
88
+
89
+ CUDA_CALL(cudaMemcpyAsync(out.data_ptr(),
90
+ x.data(), bytes,
91
+ cudaMemcpyHostToDevice,
92
+ c10::cuda::getCurrentCUDAStream()));
93
+ return out;
94
+ }
95
+
96
+ template <typename T>
97
+ static void ReorderArray(T* data, const std::vector<size_t>& indices) {
98
+ // For now, simply create a copy of the data and then copy over to the original.
99
+ std::vector<T> copy(data, data + indices.size());
100
+ for (size_t i = 0; i < indices.size(); ++i) {
101
+ data[i] = copy.at(indices[i]);
102
+ }
103
+ }
104
+
105
+ template <typename T>
106
+ torch::Tensor TypedEmpty(size_t numel, const torch::Device& device) {
107
+ return torch::empty(numel * sizeof(T), torch::dtype(torch::kInt8).device(device));
108
+ }
109
+
110
+ struct RawGemmArguments {
111
+ torch::Tensor lda, ldb, ldc, ptr_a, ptr_b, ptr_c, problem_sizes;
112
+ int threadblock_count{};
113
+ };
114
+
115
+ template <
116
+ typename Gemm,
117
+ typename ElementA, typename ElementB, typename ElementC
118
+ >
119
+ RawGemmArguments MakeArgumentsOnDevice(int num_experts, const torch::Device& device) {
120
+ TORCH_CHECK(
121
+ num_experts <= kMaxExperts,
122
+ "At most ", kMaxExperts,
123
+ " experts are supported when batch_sizes is a CUDA tensor, but got ", num_experts
124
+ );
125
+
126
+ return RawGemmArguments {
127
+ .lda = TypedEmpty<int64_t>(num_experts, device),
128
+ .ldb = TypedEmpty<int64_t>(num_experts, device),
129
+ .ldc = TypedEmpty<int64_t>(num_experts, device),
130
+ .ptr_a = TypedEmpty<ElementA*>(num_experts, device),
131
+ .ptr_b = TypedEmpty<ElementB*>(num_experts, device),
132
+ .ptr_c = TypedEmpty<ElementC*>(num_experts, device),
133
+ .problem_sizes = TypedEmpty<cutlass::gemm::GemmCoord>(num_experts, device),
134
+
135
+ // We don't know the problem dimensions on the host, so we just base the number of threadblocks on occupancy here.
136
+ .threadblock_count = Gemm::sufficient(),
137
+ };
138
+ }
139
+
140
+ template <
141
+ bool kDynamicK,
142
+ typename Gemm,
143
+ typename ElementA, typename ElementB, typename ElementC,
144
+ typename LayoutA, typename LayoutB, typename LayoutC
145
+ >
146
+ RawGemmArguments MakeArgumentsOnHost(torch::Tensor a,
147
+ torch::Tensor b,
148
+ torch::Tensor c,
149
+ torch::Tensor batch_sizes,
150
+ ::cutlass::gemm::GemmCoord coord_template,
151
+ int64_t num_experts) {
152
+ std::vector<::cutlass::gemm::GemmCoord> problem_sizes_host(num_experts);
153
+
154
+ // Create the host arrays of leading dimension data and pointer data.
155
+ std::vector<int64_t> lda_host(num_experts), ldb_host(num_experts), ldc_host(num_experts);
156
+ int64_t elements_a = 0, elements_b = 0, elements_c = 0;
157
+
158
+ std::vector<ElementA *> ptr_a_host(num_experts), ptr_b_host(num_experts), ptr_c_host(num_experts);
159
+
160
+ for (int i = 0; i < num_experts; ++i) {
161
+ auto& problem = problem_sizes_host[i];
162
+ problem = coord_template;
163
+ (kDynamicK ? problem.k() : problem.m()) = batch_sizes.data_ptr<int64_t>()[i];
164
+
165
+ lda_host[i] = LayoutA::packed({problem.m(), problem.k()}).stride(0);
166
+ ldb_host[i] = LayoutB::packed({problem.k(), problem.n()}).stride(0);
167
+ ldc_host[i] = LayoutC::packed({problem.m(), problem.n()}).stride(0);
168
+
169
+ ptr_a_host[i] = (ElementA*)a.data_ptr() + elements_a;
170
+ ptr_b_host[i] = (ElementB*)b.data_ptr() + elements_b;
171
+ ptr_c_host[i] = (ElementC*)c.data_ptr() + elements_c;
172
+
173
+ elements_a += problem.m() * problem.k();
174
+ elements_b += problem.k() * problem.n();
175
+ elements_c += problem.m() * problem.n();
176
+
177
+ if (problem.k() == 0) {
178
+ // CUTLASS doesn't handle problems with `k=0` correctly, see https://github.com/NVIDIA/cutlass/pull/1593.
179
+ // Until a fix is available on the CUTLASS side, handle these problems by ourselves:
180
+ // * set the output to zero with `cudaMemsetAsync()`
181
+ // * make this problem a no-op by setting `m=0` and `n=0` (CUTLASS can handle the outer dimensions being zero)
182
+ CUDA_CALL(cudaMemsetAsync(ptr_c_host[i],
183
+ 0,
184
+ problem.m() * problem.n() * sizeof(ElementC),
185
+ c10::cuda::getCurrentCUDAStream()));
186
+
187
+ problem.m() = 0;
188
+ problem.n() = 0;
189
+ }
190
+ }
191
+
192
+ // Only sort problems when K are different
193
+ if (kDynamicK) {
194
+ std::vector<size_t> indices(num_experts);
195
+ std::iota(indices.begin(), indices.end(), 0);
196
+ std::stable_sort(indices.begin(), indices.end(), [&problem_sizes_host](size_t i, size_t j) {
197
+ return problem_sizes_host[i].k() > problem_sizes_host[j].k();
198
+ });
199
+
200
+ ReorderArray(problem_sizes_host.data(), indices);
201
+ ReorderArray(lda_host.data(), indices);
202
+ ReorderArray(ldb_host.data(), indices);
203
+ ReorderArray(ldc_host.data(), indices);
204
+ ReorderArray(ptr_a_host.data(), indices);
205
+ ReorderArray(ptr_b_host.data(), indices);
206
+ ReorderArray(ptr_c_host.data(), indices);
207
+ }
208
+
209
+ // Copy the problem sizes, pointers and leading dimension data to the device.
210
+ return RawGemmArguments {
211
+ .lda = CopyToDevice(lda_host, a.device()),
212
+ .ldb = CopyToDevice(ldb_host, a.device()),
213
+ .ldc = CopyToDevice(ldc_host, a.device()),
214
+ .ptr_a = CopyToDevice(ptr_a_host, a.device()),
215
+ .ptr_b = CopyToDevice(ptr_b_host, a.device()),
216
+ .ptr_c = CopyToDevice(ptr_c_host, a.device()),
217
+ .problem_sizes = CopyToDevice(problem_sizes_host, a.device()),
218
+
219
+ // We know the problem dimensions on the host, so we can calculate the number of threadblocks based on that.
220
+ .threadblock_count = Gemm::sufficient(problem_sizes_host.data(), num_experts),
221
+ };
222
+ }
223
+
224
+ template <
225
+ bool kDynamicK,
226
+ typename Gemm,
227
+ typename ElementA, typename ElementB, typename ElementC,
228
+ typename LayoutA, typename LayoutB, typename LayoutC
229
+ >
230
+ typename Gemm::Arguments MakeArguments(torch::Tensor a,
231
+ torch::Tensor b,
232
+ torch::Tensor c,
233
+ torch::Tensor batch_sizes,
234
+ ::cutlass::gemm::GemmCoord coord_template,
235
+ int64_t num_experts) {
236
+ RawGemmArguments raw_args;
237
+ if (batch_sizes.is_cuda()) {
238
+ raw_args = MakeArgumentsOnDevice<
239
+ Gemm, ElementA, ElementB, ElementC
240
+ >(num_experts, a.device());
241
+ } else {
242
+ raw_args = MakeArgumentsOnHost<
243
+ kDynamicK,
244
+ Gemm,
245
+ ElementA, ElementB, ElementC,
246
+ LayoutA, LayoutB, LayoutC
247
+ >(a, b, c, batch_sizes, coord_template, num_experts);
248
+ }
249
+
250
+ printf("Using %d threadblocks for grouped GEMM.\n", raw_args.threadblock_count);
251
+ // Validate the result.
252
+ if (!raw_args.threadblock_count) {
253
+ TORCH_CHECK(false, "Grouped GEMM execution not possible with HW");
254
+ }
255
+
256
+ typename Gemm::EpilogueOutputOp::Params epilogue_op(/*alpha=*/1.0f, /*beta=*/0.0f);
257
+ // We currently always use `GroupScheduleMode::kDeviceOnly`, which doesn't use `host_problem_sizes` at all,
258
+ // so we can safely pass `nullptr` for `host_problem_sizes`.
259
+ // TODO(tgale): Experiment with `GroupScheduleMode::kHostPrecompute` for `batch_sizes.is_cpu()`, where we
260
+ // know the problem dimensions on the host.
261
+ typename Gemm::Arguments arguments((cutlass::gemm::GemmCoord*)raw_args.problem_sizes.data_ptr(),
262
+ (int)num_experts,
263
+ (int)raw_args.threadblock_count,
264
+ epilogue_op,
265
+ (ElementA**)raw_args.ptr_a.data_ptr(),
266
+ (ElementB**)raw_args.ptr_b.data_ptr(),
267
+ (ElementC**)raw_args.ptr_c.data_ptr(),
268
+ (ElementC**)raw_args.ptr_c.data_ptr(),
269
+ /*lda=*/(int64_t*)raw_args.lda.data_ptr(),
270
+ /*ldb=*/(int64_t*)raw_args.ldb.data_ptr(),
271
+ /*ldc=*/(int64_t*)raw_args.ldc.data_ptr(),
272
+ /*ldd=*/(int64_t*)raw_args.ldc.data_ptr(),
273
+ /*host_problem_sizes=*/nullptr);
274
+ return arguments;
275
+ }
276
+
277
+ template <
278
+ bool trans_a,
279
+ typename ElementA, typename ElementB, typename ElementC,
280
+ typename LayoutA, typename LayoutB, typename LayoutC,
281
+ typename Arguments
282
+ >
283
+ void FillCutlassArguments(int num_experts,
284
+ torch::Tensor batch_sizes,
285
+ torch::Tensor a,
286
+ torch::Tensor b,
287
+ torch::Tensor c,
288
+ const Arguments& arguments,
289
+ ::cutlass::gemm::GemmCoord coord_template) {
290
+ // Convert the batch sizes to the format CUTLASS understands on the device.
291
+ // Use a single block here because:
292
+ // * the number of elements to process is microscopically small
293
+ // * we don't need any additional global memory
294
+ FillArguments<
295
+ /*kDynamicK*/trans_a,
296
+ ElementA, ElementB, ElementC,
297
+ LayoutA, LayoutB, LayoutC
298
+ ><<<1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()>>>(
299
+ num_experts, batch_sizes.data_ptr<int64_t>(),
300
+ (ElementA*)a.data_ptr(), (ElementB*)b.data_ptr(), (ElementC*)c.data_ptr(),
301
+ arguments, coord_template
302
+ );
303
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
304
+ }
305
+
306
+ template <typename Args>
307
+ void RemoveK0Problems(int num_experts, const Args& arguments) {
308
+ // For zeroing out the outputs (which might be arbitrarily large), we want to use
309
+ // as many threadblocks as possible in order to hit the maximum possible global memory bandwidth.
310
+ // `arguments.threadblock_count`, which we will use for the grouped GEMM proper,
311
+ // should be a good approximation for this.
312
+ // When the `k=0` case is fixed in CUTLASS, we can completely remove this function.
313
+ ZeroOutK0Outputs<><<<
314
+ arguments.threadblock_count, at::cuda::detail::CUDA_NUM_THREADS, 0, c10::cuda::getCurrentCUDAStream()
315
+ >>>(
316
+ num_experts, arguments
317
+ );
318
+ IgnoreK0Problems<><<<
319
+ 1, kMaxExperts, 0, c10::cuda::getCurrentCUDAStream()
320
+ >>>(
321
+ num_experts, arguments
322
+ );
323
+ }
324
+
325
+ template <bool trans_a, bool trans_b>
326
+ torch::Tensor CutlassGroupedGemm(torch::Tensor a,
327
+ torch::Tensor b,
328
+ torch::Tensor c,
329
+ torch::Tensor batch_sizes,
330
+ ::cutlass::gemm::GemmCoord coord_template) {
331
+ using Gemm = GemmGrouped<trans_a, trans_b>;
332
+ using LayoutA = typename Gemm::LayoutA;
333
+ using LayoutB = typename Gemm::LayoutB;
334
+ using LayoutC = typename Gemm::LayoutC;
335
+
336
+ using ElementA = typename Gemm::ElementA;
337
+ using ElementB = typename Gemm::ElementB;
338
+ using ElementC = typename Gemm::ElementC;
339
+
340
+ Gemm gemm;
341
+ int64_t num_experts = batch_sizes.size(0);
342
+ auto arguments = MakeArguments<
343
+ /*kDynamicK*/trans_a,
344
+ Gemm,
345
+ ElementA, ElementB, ElementC,
346
+ LayoutA, LayoutB, LayoutC
347
+ >(a, b, c, batch_sizes, coord_template, num_experts);
348
+ int64_t workspace_size = gemm.get_workspace_size(arguments);
349
+ auto options = torch::TensorOptions().dtype(torch::kInt8).device(a.device());
350
+ torch::Tensor workspace = torch::empty(workspace_size, options);
351
+
352
+ if (batch_sizes.is_cuda()) {
353
+ FillCutlassArguments<
354
+ trans_a,
355
+ ElementA, ElementB, ElementC,
356
+ LayoutA, LayoutB, LayoutC
357
+ >(num_experts, batch_sizes, a, b, c, arguments, coord_template);
358
+
359
+ RemoveK0Problems<>(num_experts, arguments);
360
+ }
361
+
362
+ // Initialize the kernel.
363
+ if(gemm.initialize(arguments, workspace.data_ptr()) != cutlass::Status::kSuccess) {
364
+ TORCH_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM");
365
+ }
366
+
367
+ // Execute the kernel in the current stream.
368
+ if(gemm.run(c10::cuda::getCurrentCUDAStream()) != cutlass::Status::kSuccess) {
369
+ TORCH_CHECK(false, "Failed to run CUTLASS Grouped GEMM");
370
+ }
371
+ return c;
372
+ }
373
+
374
+ void CublasGemm(c10::BFloat16 *a, int64_t a_rows, int64_t a_cols, bool trans_a,
375
+ c10::BFloat16 *b, int64_t b_rows, int64_t b_cols, bool trans_b,
376
+ c10::BFloat16 *c, int64_t c_rows, int64_t c_cols) {
377
+ int m = trans_b ? b_rows : b_cols;
378
+ int k = trans_b ? b_cols : b_rows;
379
+ int n = trans_a ? a_cols : a_rows;
380
+
381
+ int lda = trans_a ? n : k;
382
+ int ldb = trans_b ? k : m;
383
+ cublasOperation_t transpose_a = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
384
+ cublasOperation_t transpose_b = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
385
+
386
+ float alpha = 1.0, beta = 0.0;
387
+ CUBLAS_CALL(cublasGemmEx(at::cuda::getCurrentCUDABlasHandle(),
388
+ transpose_b, transpose_a,
389
+ m, n, k, &alpha,
390
+ b, CUDA_R_16BF, ldb,
391
+ a, CUDA_R_16BF, lda,
392
+ &beta,
393
+ c, CUDA_R_16BF, c_cols, CUDA_R_32F,
394
+ CUBLAS_GEMM_DEFAULT));
395
+ }
396
+
397
+ void CublasGroupedGemm(torch::Tensor a,
398
+ torch::Tensor b,
399
+ torch::Tensor c,
400
+ torch::Tensor batch_sizes,
401
+ bool trans_b) {
402
+ int64_t bs = batch_sizes.size(0), k = a.size(1);
403
+ int64_t n = trans_b ? b.size(1) : b.size(2);
404
+ int64_t b_rows = b.size(1), b_cols = b.size(2);
405
+ c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
406
+ c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
407
+ c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
408
+ for (int i = 0; i < bs; ++i) {
409
+ int64_t m = batch_sizes.data_ptr<int64_t>()[i];
410
+ CublasGemm(a_ptr, m, k, /*trans_a=*/false,
411
+ b_ptr, b_rows, b_cols, trans_b,
412
+ c_ptr, m, n);
413
+ a_ptr += m * k;
414
+ b_ptr += b_rows * b_cols;
415
+ c_ptr += m * n;
416
+ }
417
+ }
418
+
419
+ void CublasGroupedGemmVariableK(torch::Tensor a,
420
+ torch::Tensor b,
421
+ torch::Tensor c,
422
+ torch::Tensor batch_sizes) {
423
+ int64_t bs = batch_sizes.size(0), m = a.size(1), n = b.size(1);
424
+ c10::BFloat16* a_ptr = a.data_ptr<c10::BFloat16>();
425
+ c10::BFloat16* b_ptr = b.data_ptr<c10::BFloat16>();
426
+ c10::BFloat16* c_ptr = c.data_ptr<c10::BFloat16>();
427
+ for (int i = 0; i < bs; ++i) {
428
+ int64_t k = batch_sizes.data_ptr<int64_t>()[i];
429
+ CublasGemm(a_ptr, k, m, /*trans_a=*/true,
430
+ b_ptr, k, n, /*trans_b=*/false,
431
+ c_ptr, m, n);
432
+ a_ptr += k * m;
433
+ b_ptr += k * n;
434
+ c_ptr += m * n;
435
+ }
436
+ }
437
+
438
+ void GroupedGemmVariableK(torch::Tensor a,
439
+ torch::Tensor b,
440
+ torch::Tensor c,
441
+ torch::Tensor batch_sizes) {
442
+ // We expected a CUDA tensor with two dimensions and shape
443
+ // (tokens, hidden_out) for 'b'.
444
+ TORCH_CHECK(b.is_cuda());
445
+ TORCH_CHECK(b.ndimension() == 2);
446
+ TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
447
+
448
+ // Validate the dimensions.
449
+ int64_t tokens = a.size(0), num_experts = batch_sizes.size(0);
450
+ int64_t m = a.size(1), n = b.size(1);
451
+
452
+ // Validate that we have the same contraction dimension.
453
+ TORCH_CHECK(tokens == b.size(0));
454
+
455
+ // Validate the output shape.
456
+ TORCH_CHECK(c.is_cuda());
457
+ TORCH_CHECK(c.ndimension() == 3);
458
+ TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
459
+ TORCH_CHECK(c.size(0) == num_experts);
460
+ TORCH_CHECK(c.size(1) == m);
461
+ TORCH_CHECK(c.size(2) == n);
462
+
463
+ // Run the computation.
464
+ CublasGroupedGemmVariableK(a, b, c, batch_sizes);
465
+ }
466
+
467
+ // NOTE: We only support dynamic group sizes for the 'a' tensor. Tensor 'b' is
468
+ // assumed to be batched with fixed sized batches.
469
+ //
470
+ // TODO(tgale): Validate alignment is true for every batch element.
471
+ void GroupedGemm(torch::Tensor a,
472
+ torch::Tensor b,
473
+ torch::Tensor c,
474
+ torch::Tensor batch_sizes,
475
+ bool trans_a, bool trans_b) {
476
+ // NOTE: We only support 'trans_a' or 'trans_b', not both.
477
+ TORCH_CHECK(!(trans_a && trans_b));
478
+
479
+ #if !defined(GROUPED_GEMM_CUTLASS)
480
+ // No way to run cuBLAS kernels if the problem dimensions are not known on the host.
481
+ TORCH_CHECK(batch_sizes.is_cpu());
482
+ #else
483
+ // CUTLASS can handle both CPU- and CUDA-resident problem dimensions.
484
+ TORCH_CHECK(batch_sizes.is_cuda() || batch_sizes.is_cpu());
485
+ #endif
486
+ TORCH_CHECK(batch_sizes.ndimension() == 1);
487
+ TORCH_CHECK(batch_sizes.scalar_type() == torch::kInt64);
488
+
489
+ // We expected a CUDA tensor with two dimensions and shape
490
+ // (tokens, hidden_in) for 'a'.
491
+ TORCH_CHECK(a.is_cuda());
492
+ TORCH_CHECK(a.ndimension() == 2);
493
+ TORCH_CHECK(a.scalar_type() == torch::kBFloat16);
494
+
495
+ #if !defined(GROUPED_GEMM_CUTLASS)
496
+ if (trans_a) {
497
+ // If we can't use CUTLASS for the transposed cases, defer to the variable 'k' helper using cuBLAS
498
+ // for the rest of the op.
499
+ GroupedGemmVariableK(a, b, c, batch_sizes);
500
+ return;
501
+ }
502
+ #endif
503
+
504
+ TORCH_CHECK(b.is_cuda());
505
+ TORCH_CHECK(c.is_cuda());
506
+ TORCH_CHECK(b.scalar_type() == torch::kBFloat16);
507
+ TORCH_CHECK(c.scalar_type() == torch::kBFloat16);
508
+
509
+ // The expected shapes of 'b' and 'c' are:
510
+ // * when 'trans_a' is set: b=(tokens, hidden_out), c=(num_experts, hidden_in, hidden_out)
511
+ // * when 'trans_b' is set: b=(num_experts, hidden_out, hidden_in), c=(tokens, hidden_out)
512
+ // * otherwise: b=(num_experts, hidden_in, hidden_out), c=(tokens, hidden
513
+ size_t hidden_in{}, hidden_out{};
514
+ if (trans_a) {
515
+ hidden_in = a.size(1);
516
+ hidden_out = b.size(1);
517
+
518
+ TORCH_CHECK(b.ndimension() == 2);
519
+ TORCH_CHECK(c.ndimension() == 3);
520
+ TORCH_CHECK(b.size(0) == a.size(0));
521
+ TORCH_CHECK(c.size(0) == batch_sizes.size(0));
522
+ TORCH_CHECK(c.size(1) == hidden_in);
523
+ TORCH_CHECK(c.size(2) == hidden_out);
524
+ } else {
525
+ TORCH_CHECK(b.ndimension() == 3);
526
+ TORCH_CHECK(c.ndimension() == 2);
527
+
528
+ // Validate the contraction dimensions match.
529
+ int64_t tokens = a.size(0), num_experts = b.size(0);
530
+ hidden_in = trans_b ? b.size(2) : b.size(1);
531
+ hidden_out = trans_b ? b.size(1) : b.size(2);
532
+ TORCH_CHECK(hidden_in == a.size(1));
533
+
534
+ // Validate that we have one size per expert.
535
+ TORCH_CHECK(batch_sizes.size(0) == num_experts);
536
+ }
537
+
538
+ // NOTE: We support transposition through the 'trans_b' flag.
539
+ TORCH_CHECK(a.is_contiguous());
540
+ TORCH_CHECK(b.is_contiguous());
541
+ TORCH_CHECK(c.is_contiguous());
542
+
543
+ #if !defined(GROUPED_GEMM_CUTLASS)
544
+ CublasGroupedGemm(a, b, c, batch_sizes, trans_b);
545
+ return;
546
+ #else
547
+ // The `coord_template` argument contains `kDynamicDim` as one of its dimensions
548
+ // as a placeholder. This placeholder is later expanded into the actual dimension
549
+ // for every element of the batch, either on the host or on the device
550
+ // (if we can't do in on the host).
551
+ const auto coord_template = trans_a
552
+ ? cutlass::gemm::GemmCoord(hidden_in, hidden_out, kDynamicDim)
553
+ : cutlass::gemm::GemmCoord(kDynamicDim, hidden_out, hidden_in);
554
+ if (trans_a) {
555
+ CutlassGroupedGemm<true, false>(a, b, c, batch_sizes, coord_template);
556
+ return;
557
+ }
558
+ if (trans_b) {
559
+ CutlassGroupedGemm<false, true>(a, b, c, batch_sizes, coord_template);
560
+ return;
561
+ }
562
+ CutlassGroupedGemm<false, false>(a, b, c, batch_sizes, coord_template);
563
+ return;
564
+ #endif
565
+ }
566
+
567
+ } // namespace grouped_gemm
csrc/grouped_gemm/grouped_gemm.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // // Set default if not already defined
4
+ // #ifndef GROUPED_GEMM_CUTLASS
5
+ // #define GROUPED_GEMM_CUTLASS 0
6
+ // #endif
7
+
8
+ // #include <torch/extension.h>
9
+ #include <torch/torch.h>
10
+
11
+ namespace grouped_gemm {
12
+
13
+ void GroupedGemm(torch::Tensor a,
14
+ torch::Tensor b,
15
+ torch::Tensor c,
16
+ torch::Tensor batch_sizes,
17
+ bool trans_a, bool trans_b);
18
+
19
+ } // namespace grouped_gemm
20
+
csrc/grouped_gemm/ops.cu ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "grouped_gemm.h"
2
+
3
+ #include <torch/extension.h>
4
+
5
+ namespace grouped_gemm {
6
+
7
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
8
+ m.def("gmm", &GroupedGemm, "Grouped GEMM.");
9
+ }
10
+
11
+ } // namespace grouped_gemm
csrc/histogram.h ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ // #include <torch/extension.h>
9
+
10
+ #define CUDA_CALL(code) \
11
+ do { \
12
+ cudaError_t status = code; \
13
+ std::string err = cudaGetErrorString(status); \
14
+ TORCH_CHECK(status == cudaSuccess, err); \
15
+ } while (0)
16
+
17
+ namespace megablocks {
18
+
19
+ template <typename T>
20
+ torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
21
+ // Allocate the count buffer.
22
+ auto options = torch::TensorOptions()
23
+ .dtype(torch::kInt32)
24
+ .device(x.device());
25
+ torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
26
+
27
+ // Exit early if there is not work to do.
28
+ if (out.numel() == 0) return out;
29
+
30
+ // Get scratchpad size.
31
+ size_t scratchpad_bytes = 0;
32
+ CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
33
+ scratchpad_bytes,
34
+ x.data_ptr<T>(),
35
+ out.data_ptr<int>(),
36
+ /*num_levels=*/num_bins + 1,
37
+ /*lower_level=*/0,
38
+ /*upper_level=*/num_bins,
39
+ /*num_samples=*/int(x.size(1)),
40
+ c10::cuda::getCurrentCUDAStream()));
41
+
42
+ // Allocate scratchpad.
43
+ options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
44
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
45
+
46
+ // Run the kernel.
47
+ for (int i = 0; i < x.size(0); ++i) {
48
+ CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
49
+ scratchpad_bytes,
50
+ x.data_ptr<T>() + x.size(1) * i,
51
+ out.data_ptr<int>() + out.size(1) * i,
52
+ /*num_levels=*/num_bins + 1,
53
+ /*lower_level=*/0,
54
+ /*upper_level=*/num_bins,
55
+ /*num_samples=*/int(x.size(1)),
56
+ c10::cuda::getCurrentCUDAStream()));
57
+ }
58
+ return out;
59
+ }
60
+
61
+ torch::Tensor histogram(torch::Tensor x, int num_bins) {
62
+ TORCH_CHECK(x.is_cuda());
63
+ TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
64
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
65
+ x.scalar_type() == torch::kInt32 ||
66
+ x.scalar_type() == torch::kInt64);
67
+ bool no_batch = x.ndimension() == 1;
68
+ if (no_batch) x = x.view({1, x.numel()});
69
+
70
+ if (x.scalar_type() == torch::kInt16) {
71
+ auto out = cub_histogram<short>(x, num_bins);
72
+ return no_batch ? out.flatten() : out;
73
+ } else if (x.scalar_type() == torch::kInt32) {
74
+ auto out = cub_histogram<int>(x, num_bins);
75
+ return no_batch ? out.flatten() : out;
76
+ } else {
77
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
78
+ auto out = cub_histogram<long>(x, num_bins);
79
+ return no_batch ? out.flatten() : out;
80
+ }
81
+ }
82
+
83
+ } // namespace megablocks
84
+
85
+ #undef CUDA_CALL
86
+ #undef CUB_WRAPPED_NAMESPACE
csrc/indices.h ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cstdint>
2
+ #include <c10/util/Half.h>
3
+ // #include <torch/extension.h>
4
+ #include <c10/cuda/CUDAStream.h>
5
+
6
+ #define CUDA_CALL(code) \
7
+ do { \
8
+ cudaError_t status = code; \
9
+ std::string err = cudaGetErrorString(status); \
10
+ TORCH_CHECK(status == cudaSuccess, err); \
11
+ } while (0)
12
+
13
+ namespace megablocks {
14
+ namespace construct_indices {
15
+
16
+ // We expect the number of outputs per block to be small. For
17
+ // example, with ffn_hidden_size=4096, we only need to write
18
+ // 32 elements per block per iteration.
19
+ const int kThreadsPerBlock = 32;
20
+
21
+ __global__ void __launch_bounds__(kThreadsPerBlock)
22
+ ConstructIndicesKernel(short * __restrict__ indices,
23
+ int num_columns,
24
+ int block_size,
25
+ const int * __restrict__ padded_bins) {
26
+ // Load the offset for this bins indices.
27
+ int start = 0;
28
+ if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
29
+ int end = __ldg(padded_bins + blockIdx.x);
30
+
31
+ // Divide the start and end into blocks.
32
+ start /= block_size;
33
+ end /= block_size;
34
+
35
+ // Offset the output buffer to the start of the bin.
36
+ indices += (start + blockIdx.y) * num_columns + threadIdx.x;
37
+
38
+ // Write the indices to the output.
39
+ int bin_offset = blockIdx.y;
40
+ int num_rows = end - start;
41
+ for (; bin_offset < num_rows; num_rows -= gridDim.y) {
42
+ short *out = indices;
43
+ for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
44
+ *out = bid + (blockIdx.x * num_columns);
45
+ out += kThreadsPerBlock;
46
+ }
47
+ indices += gridDim.y * num_columns;
48
+ }
49
+ }
50
+
51
+ cudaError_t ConstructIndices(short * __restrict__ indices,
52
+ int output_block_rows,
53
+ int output_block_columns,
54
+ int block_size,
55
+ const int * __restrict__ padded_bins,
56
+ int num_bins,
57
+ cudaStream_t stream) {
58
+ dim3 block_dim(kThreadsPerBlock);
59
+ dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
60
+ ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
61
+ output_block_columns,
62
+ block_size,
63
+ padded_bins);
64
+ return cudaGetLastError();
65
+ }
66
+
67
+ } // namespace construct_indices
68
+
69
+ void indices(torch::Tensor padded_bins,
70
+ int block_size,
71
+ int output_block_rows,
72
+ int output_block_columns,
73
+ torch::Tensor out) {
74
+ TORCH_CHECK(padded_bins.is_cuda());
75
+ TORCH_CHECK(padded_bins.ndimension() == 1);
76
+ TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
77
+
78
+ TORCH_CHECK(out.is_cuda());
79
+ TORCH_CHECK(out.ndimension() == 1);
80
+ TORCH_CHECK(out.scalar_type() == torch::kInt16);
81
+ TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
82
+
83
+ // Exit early if there is no work to do.
84
+ if (out.numel() == 0) return;
85
+
86
+ CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
87
+ output_block_rows,
88
+ output_block_columns,
89
+ block_size,
90
+ padded_bins.data_ptr<int>(),
91
+ padded_bins.numel(),
92
+ c10::cuda::getCurrentCUDAStream()));
93
+ }
94
+
95
+ } // namespace megablocks
csrc/new_cumsum.cu ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define CUB_IGNORE_DEPRECATED_API
2
+
3
+ #undef CUB_WRAPPED_NAMESPACE
4
+ #define CUB_WRAPPED_NAMESPACE megablocks
5
+
6
+ #include "new_cumsum.h"
7
+ #include <cstdint>
8
+ #include <hipcub/hipcub.hpp>
9
+ #include <c10/cuda/CUDAStream.h>
10
+
11
+ #define CUDA_CALL(code) \
12
+ do { \
13
+ cudaError_t status = code; \
14
+ std::string err = cudaGetErrorString(status); \
15
+ TORCH_CHECK(status == cudaSuccess, err); \
16
+ } while (0)
17
+
18
+ namespace megablocks {
19
+
20
+ struct Inclusive {};
21
+ struct Exclusive {};
22
+
23
+ template <typename Type> struct Cumsum {
24
+
25
+ template<
26
+ typename InputIteratorT,
27
+ typename OutputIteratorT>
28
+ static void Run(void * d_temp_storage,
29
+ size_t & temp_storage_bytes,
30
+ InputIteratorT d_in,
31
+ OutputIteratorT d_out,
32
+ int num_items,
33
+ cudaStream_t stream = 0,
34
+ bool debug_synchronous = false) {
35
+ CUDA_CALL(hipcub::DeviceScan::ExclusiveSum(d_temp_storage,
36
+ temp_storage_bytes,
37
+ d_in,
38
+ d_out,
39
+ num_items,
40
+ stream));//,
41
+ //debug_synchronous));
42
+ }
43
+ };
44
+
45
+ template <> struct Cumsum<Inclusive> {
46
+ template<
47
+ typename InputIteratorT,
48
+ typename OutputIteratorT>
49
+ static void Run(void * d_temp_storage,
50
+ size_t & temp_storage_bytes,
51
+ InputIteratorT d_in,
52
+ OutputIteratorT d_out,
53
+ int num_items,
54
+ cudaStream_t stream = 0,
55
+ bool debug_synchronous = false) {
56
+ CUDA_CALL(hipcub::DeviceScan::InclusiveSum(d_temp_storage,
57
+ temp_storage_bytes,
58
+ d_in,
59
+ d_out,
60
+ num_items,
61
+ stream));//,
62
+ //debug_synchronous));
63
+ }
64
+ };
65
+
66
+ template <typename SumType, typename T>
67
+ void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
68
+ // Get temporary storage size.
69
+ size_t scratchpad_bytes = 0;
70
+ Cumsum<SumType>::Run(nullptr,
71
+ scratchpad_bytes,
72
+ x.data_ptr<T>(),
73
+ out.data_ptr<T>(),
74
+ x.size(1),
75
+ c10::cuda::getCurrentCUDAStream());
76
+
77
+ // Allocate scratchpad.
78
+ //
79
+ // NOTE: We scale for the batch dimension so we can run in parallel.
80
+ auto options = torch::TensorOptions()
81
+ .dtype(torch::kInt8)
82
+ .device(x.device());
83
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
84
+ options);
85
+
86
+ // Run the kernel.
87
+ //
88
+ // NOTE: Using different streams for each issue does not appear to
89
+ // yield performance gains for our problem set. The overhead of
90
+ // event/stream synchronization appears to outweigh the benfits.
91
+ // We could write a true batched cumsum, but this would require
92
+ // significant code duplication from cub and we might move away
93
+ // from this formulation anyways.
94
+ for (int i = 0; i < x.size(0); ++i) {
95
+ void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
96
+ Cumsum<SumType>::Run(scratchpad_ptr,
97
+ scratchpad_bytes,
98
+ x.data_ptr<T>() + x.size(1) * i,
99
+ out.data_ptr<T>() + x.size(1) * i,
100
+ x.size(1),
101
+ c10::cuda::getCurrentCUDAStream());
102
+ }
103
+ }
104
+
105
+ void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
106
+ // Validate the input matrix.
107
+ TORCH_CHECK(x.is_cuda());
108
+ TORCH_CHECK(x.ndimension() == 2);
109
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
110
+ x.scalar_type() == torch::kInt32 ||
111
+ x.scalar_type() == torch::kInt64);
112
+ TORCH_CHECK(out.is_cuda());
113
+ TORCH_CHECK(out.ndimension() == 2);
114
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
115
+
116
+ // NOTE: We currently only support contraction across the contiguous
117
+ // dimension in the matrix.
118
+ TORCH_CHECK(dim == 1);
119
+
120
+ switch (x.scalar_type()) {
121
+ case torch::kInt16:
122
+ cub_cumsum<Exclusive, short>(x, dim, out);
123
+ return;
124
+ case torch::kInt32:
125
+ cub_cumsum<Exclusive, int>(x, dim, out);
126
+ return;
127
+ }
128
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
129
+ cub_cumsum<Exclusive, long>(x, dim, out);
130
+ }
131
+
132
+ void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
133
+ // Validate the input matrix.
134
+ TORCH_CHECK(x.is_cuda());
135
+ TORCH_CHECK(x.ndimension() == 2);
136
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
137
+ x.scalar_type() == torch::kInt32 ||
138
+ x.scalar_type() == torch::kInt64);
139
+ TORCH_CHECK(out.is_cuda());
140
+ TORCH_CHECK(out.ndimension() == 2);
141
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
142
+
143
+ // NOTE: We currently only support contraction across the contiguous
144
+ // dimension in the matrix.
145
+ TORCH_CHECK(dim == 1);
146
+
147
+ switch (x.scalar_type()) {
148
+ case torch::kInt16:
149
+ cub_cumsum<Inclusive, short>(x, dim, out);
150
+ return;
151
+ case torch::kInt32:
152
+ cub_cumsum<Inclusive, int>(x, dim, out);
153
+ return;
154
+ }
155
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
156
+ cub_cumsum<Inclusive, long>(x, dim, out);
157
+ }
158
+
159
+ } // namespace megablocks
160
+
161
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_cumsum.h ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Forward declarations for the public interface functions
8
+ void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
9
+ void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
10
+
11
+ } // namespace megablocks
csrc/new_histogram.cu ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include "new_histogram.h"
5
+ #include <cstdint>
6
+ #include <hipcub/hipcub.hpp>
7
+ #include <c10/cuda/CUDAStream.h>
8
+
9
+ #define CUDA_CALL(code) \
10
+ do { \
11
+ cudaError_t status = code; \
12
+ std::string err = cudaGetErrorString(status); \
13
+ TORCH_CHECK(status == cudaSuccess, err); \
14
+ } while (0)
15
+
16
+ namespace megablocks {
17
+
18
+ template <typename T>
19
+ torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
20
+ // Allocate the count buffer.
21
+ auto options = torch::TensorOptions()
22
+ .dtype(torch::kInt32)
23
+ .device(x.device());
24
+ torch::Tensor out = torch::empty({x.size(0), num_bins}, options);
25
+
26
+ // Exit early if there is not work to do.
27
+ if (out.numel() == 0) return out;
28
+
29
+ // Get scratchpad size.
30
+ size_t scratchpad_bytes = 0;
31
+ CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(nullptr,
32
+ scratchpad_bytes,
33
+ x.data_ptr<T>(),
34
+ out.data_ptr<int>(),
35
+ /*num_levels=*/num_bins + 1,
36
+ /*lower_level=*/0,
37
+ /*upper_level=*/num_bins,
38
+ /*num_samples=*/int(x.size(1)),
39
+ c10::cuda::getCurrentCUDAStream()));
40
+
41
+ // Allocate scratchpad.
42
+ options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
43
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
44
+
45
+ // Run the kernel.
46
+ for (int i = 0; i < x.size(0); ++i) {
47
+ CUDA_CALL(hipcub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
48
+ scratchpad_bytes,
49
+ x.data_ptr<T>() + x.size(1) * i,
50
+ out.data_ptr<int>() + out.size(1) * i,
51
+ /*num_levels=*/num_bins + 1,
52
+ /*lower_level=*/0,
53
+ /*upper_level=*/num_bins,
54
+ /*num_samples=*/int(x.size(1)),
55
+ c10::cuda::getCurrentCUDAStream()));
56
+ }
57
+ return out;
58
+ }
59
+
60
+ torch::Tensor histogram(torch::Tensor x, int num_bins) {
61
+ TORCH_CHECK(x.is_cuda());
62
+ TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
63
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
64
+ x.scalar_type() == torch::kInt32 ||
65
+ x.scalar_type() == torch::kInt64);
66
+ bool no_batch = x.ndimension() == 1;
67
+ if (no_batch) x = x.view({1, x.numel()});
68
+
69
+ if (x.scalar_type() == torch::kInt16) {
70
+ auto out = cub_histogram<short>(x, num_bins);
71
+ return no_batch ? out.flatten() : out;
72
+ } else if (x.scalar_type() == torch::kInt32) {
73
+ auto out = cub_histogram<int>(x, num_bins);
74
+ return no_batch ? out.flatten() : out;
75
+ } else {
76
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
77
+ auto out = cub_histogram<long>(x, num_bins);
78
+ return no_batch ? out.flatten() : out;
79
+ }
80
+ }
81
+
82
+ } // namespace megablocks
83
+
84
+ #undef CUDA_CALL
85
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_histogram.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Public interface function for computing histograms
8
+ torch::Tensor histogram(torch::Tensor x, int num_bins);
9
+
10
+ } // namespace megablocks
csrc/new_indices.cu ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "new_indices.h"
2
+ #include <cstdint>
3
+ #include <c10/util/Half.h>
4
+ #include <c10/cuda/CUDAStream.h>
5
+
6
+ #define CUDA_CALL(code) \
7
+ do { \
8
+ cudaError_t status = code; \
9
+ std::string err = cudaGetErrorString(status); \
10
+ TORCH_CHECK(status == cudaSuccess, err); \
11
+ } while (0)
12
+
13
+ namespace megablocks {
14
+ namespace construct_indices {
15
+
16
+ // We expect the number of outputs per block to be small. For
17
+ // example, with ffn_hidden_size=4096, we only need to write
18
+ // 32 elements per block per iteration.
19
+ const int kThreadsPerBlock = 32;
20
+
21
+ __global__ void __launch_bounds__(kThreadsPerBlock)
22
+ ConstructIndicesKernel(short * __restrict__ indices,
23
+ int num_columns,
24
+ int block_size,
25
+ const int * __restrict__ padded_bins) {
26
+ // Load the offset for this bins indices.
27
+ int start = 0;
28
+ if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1);
29
+ int end = __ldg(padded_bins + blockIdx.x);
30
+
31
+ // Divide the start and end into blocks.
32
+ start /= block_size;
33
+ end /= block_size;
34
+
35
+ // Offset the output buffer to the start of the bin.
36
+ indices += (start + blockIdx.y) * num_columns + threadIdx.x;
37
+
38
+ // Write the indices to the output.
39
+ int bin_offset = blockIdx.y;
40
+ int num_rows = end - start;
41
+ for (; bin_offset < num_rows; num_rows -= gridDim.y) {
42
+ short *out = indices;
43
+ for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) {
44
+ *out = bid + (blockIdx.x * num_columns);
45
+ out += kThreadsPerBlock;
46
+ }
47
+ indices += gridDim.y * num_columns;
48
+ }
49
+ }
50
+
51
+ cudaError_t ConstructIndices(short * __restrict__ indices,
52
+ int output_block_rows,
53
+ int output_block_columns,
54
+ int block_size,
55
+ const int * __restrict__ padded_bins,
56
+ int num_bins,
57
+ cudaStream_t stream) {
58
+ dim3 block_dim(kThreadsPerBlock);
59
+ dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins));
60
+ ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices,
61
+ output_block_columns,
62
+ block_size,
63
+ padded_bins);
64
+ return cudaGetLastError();
65
+ }
66
+
67
+ } // namespace construct_indices
68
+
69
+ void indices(torch::Tensor padded_bins,
70
+ int block_size,
71
+ int output_block_rows,
72
+ int output_block_columns,
73
+ torch::Tensor out) {
74
+ TORCH_CHECK(padded_bins.is_cuda());
75
+ TORCH_CHECK(padded_bins.ndimension() == 1);
76
+ TORCH_CHECK(padded_bins.scalar_type() == torch::kInt);
77
+
78
+ TORCH_CHECK(out.is_cuda());
79
+ TORCH_CHECK(out.ndimension() == 1);
80
+ TORCH_CHECK(out.scalar_type() == torch::kInt16);
81
+ TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns));
82
+
83
+ // Exit early if there is no work to do.
84
+ if (out.numel() == 0) return;
85
+
86
+ CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(),
87
+ output_block_rows,
88
+ output_block_columns,
89
+ block_size,
90
+ padded_bins.data_ptr<int>(),
91
+ padded_bins.numel(),
92
+ c10::cuda::getCurrentCUDAStream()));
93
+ }
94
+
95
+ } // namespace megablocks
96
+
97
+ #undef CUDA_CALL
csrc/new_indices.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Public interface function for constructing indices from padded bins
8
+ void indices(torch::Tensor padded_bins,
9
+ int block_size,
10
+ int output_block_rows,
11
+ int output_block_columns,
12
+ torch::Tensor out);
13
+
14
+ } // namespace megablocks
csrc/new_replicate.cu ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Modifications: Copyright Advanced Micro Devices, Inc. SPDX License: MIT.
2
+
3
+ #undef CUB_WRAPPED_NAMESPACE
4
+ #define CUB_WRAPPED_NAMESPACE megablocks
5
+
6
+ #include "new_replicate.h"
7
+
8
+ #include <cstdint>
9
+
10
+ #include <cub/cub.cuh>
11
+ #include <c10/util/Half.h>
12
+ #include <c10/cuda/CUDAStream.h>
13
+
14
+ #ifndef USE_ROCM
15
+ #define _LDG(arg) __ldg(arg)
16
+ #else
17
+ #define _LDG(arg) *(arg)
18
+ #endif
19
+
20
+ #define CUDA_CALL(code) \
21
+ do { \
22
+ cudaError_t status = code; \
23
+ std::string err = cudaGetErrorString(status); \
24
+ TORCH_CHECK(status == cudaSuccess, err); \
25
+ } while (0)
26
+
27
+ namespace megablocks {
28
+ namespace replicate {
29
+
30
+ template <typename T, int kThreadsPerBlock>
31
+ __global__ void __launch_bounds__(kThreadsPerBlock)
32
+ ReplicateForwardKernel(T * __restrict__ x,
33
+ int * __restrict__ bins,
34
+ T * __restrict__ out,
35
+ int columns) {
36
+ // Offset to this threadblocks batch.
37
+ //
38
+ // x is [batch_size, num_bins]
39
+ // out is [batch_size, columns]
40
+ // bins is [num_bins]
41
+ int batch_idx = blockIdx.y;
42
+ int num_bins = gridDim.x;
43
+ x += batch_idx * num_bins;
44
+ out += batch_idx * columns;
45
+
46
+ // Load the start/end for this bin.
47
+ int bin_idx = blockIdx.x;
48
+ int start = 0;
49
+ if (bin_idx > 0) start = _LDG(bins + bin_idx - 1);
50
+ int end = _LDG(bins + bin_idx);
51
+
52
+ // Load the value to replicate.
53
+ T value = _LDG((T*)x + bin_idx);
54
+
55
+ // Offset to this threadblocks bin and this threads
56
+ // offset within the bin.
57
+ int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
58
+ out += start + bin_offset;
59
+
60
+ // Replicate the value to the output.
61
+ //
62
+ // TODO(tgale): Vectorize these stores.
63
+ int num_elements = end - start;
64
+ const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
65
+ T *out_ptr = (T*)out;
66
+ for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
67
+ *out_ptr = value;
68
+ out_ptr += kElementsPerLoop;
69
+ }
70
+ }
71
+
72
+ template <typename T>
73
+ cudaError_t ReplicateForward(T *x,
74
+ int batch_size,
75
+ int num_bins,
76
+ int *bins,
77
+ T *out,
78
+ int columns,
79
+ cudaStream_t stream) {
80
+ const int kThreadsPerBlock = 64;
81
+ dim3 block_dim(kThreadsPerBlock, 1, 1);
82
+ int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
83
+ dim3 grid_dim(num_bins, batch_size, group_size);
84
+ ReplicateForwardKernel<T, kThreadsPerBlock><<<
85
+ grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
86
+ return cudaGetLastError();
87
+ }
88
+
89
+ void cub_segmented_reduce(torch::Tensor grad,
90
+ torch::Tensor bins,
91
+ torch::Tensor out,
92
+ cudaStream_t stream) {
93
+ // Append a zero to the bin boundaries for CUB.
94
+ torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
95
+ CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
96
+ 0,
97
+ offsets.numel() * sizeof(int),
98
+ stream));
99
+ CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
100
+ bins.data_ptr<int>(),
101
+ bins.numel() * sizeof(int),
102
+ cudaMemcpyDeviceToDevice,
103
+ stream));
104
+
105
+ // Get temporary buffer size.
106
+ size_t scratchpad_bytes = 0;
107
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
108
+ scratchpad_bytes,
109
+ grad.data_ptr<c10::Half>(),
110
+ out.data_ptr<c10::Half>(),
111
+ bins.numel(),
112
+ offsets.data_ptr<int>(),
113
+ offsets.data_ptr<int>() + 1,
114
+ stream));
115
+
116
+ // Allocate scratchpad.
117
+ auto options = torch::TensorOptions()
118
+ .dtype(torch::kInt8)
119
+ .device(grad.device());
120
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
121
+
122
+ // Run the kernel for each batch item.
123
+ for (int i = 0; i < grad.size(0); ++i) {
124
+ int num_bins = out.size(1);
125
+ int num_values = grad.size(1);
126
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
127
+ scratchpad_bytes,
128
+ grad.data_ptr<c10::Half>() + i * num_values,
129
+ out.data_ptr<c10::Half>() + i * num_bins,
130
+ bins.numel(),
131
+ offsets.data_ptr<int>(),
132
+ offsets.data_ptr<int>() + 1,
133
+ stream));
134
+ }
135
+ }
136
+
137
+ } // namespace replicate
138
+
139
+ void replicate_forward(torch::Tensor x,
140
+ torch::Tensor bins,
141
+ torch::Tensor out) {
142
+ // Validate the inputs.
143
+ TORCH_CHECK(x.is_cuda());
144
+ TORCH_CHECK(x.ndimension() == 2);
145
+ TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
146
+ x.scalar_type() == torch::kInt16 ||
147
+ x.scalar_type() == torch::kInt32);
148
+ TORCH_CHECK(bins.is_cuda());
149
+ TORCH_CHECK(bins.ndimension() == 1);
150
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
151
+ TORCH_CHECK(out.is_cuda());
152
+ TORCH_CHECK(out.ndimension() == 2);
153
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
154
+
155
+ // Batch dimensions should match for input/output.
156
+ TORCH_CHECK(x.size(0) == out.size(0));
157
+
158
+ // One input for each bin (in each batch).
159
+ TORCH_CHECK(x.size(1) == bins.size(0));
160
+
161
+ // Exit early if there is no work to do.
162
+ if (out.numel() == 0) return;
163
+
164
+ switch (x.scalar_type()) {
165
+ case torch::kFloat16:
166
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
167
+ x.size(0),
168
+ x.size(1),
169
+ bins.data_ptr<int>(),
170
+ out.data_ptr<c10::Half>(),
171
+ out.size(1),
172
+ c10::cuda::getCurrentCUDAStream()));
173
+ return;
174
+ case torch::kInt32:
175
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
176
+ x.size(0),
177
+ x.size(1),
178
+ bins.data_ptr<int>(),
179
+ out.data_ptr<int>(),
180
+ out.size(1),
181
+ c10::cuda::getCurrentCUDAStream()));
182
+ return;
183
+ }
184
+ TORCH_CHECK(x.scalar_type() == torch::kInt16);
185
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
186
+ x.size(0),
187
+ x.size(1),
188
+ bins.data_ptr<int>(),
189
+ out.data_ptr<short>(),
190
+ out.size(1),
191
+ c10::cuda::getCurrentCUDAStream()));
192
+ }
193
+
194
+ void replicate_backward(torch::Tensor grad,
195
+ torch::Tensor bins,
196
+ torch::Tensor out) {
197
+ // Validate the inputs.
198
+ TORCH_CHECK(grad.is_cuda());
199
+ TORCH_CHECK(grad.ndimension() == 2);
200
+ TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
201
+ TORCH_CHECK(bins.is_cuda());
202
+ TORCH_CHECK(bins.ndimension() == 1);
203
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
204
+ TORCH_CHECK(out.is_cuda());
205
+ TORCH_CHECK(out.ndimension() == 2);
206
+ TORCH_CHECK(out.scalar_type() == torch::kFloat16);
207
+
208
+ // Batch dimensions should match for input/output.
209
+ TORCH_CHECK(grad.size(0) == out.size(0));
210
+
211
+ // One output for each bin (in each batch).
212
+ TORCH_CHECK(out.size(1) == bins.size(0));
213
+
214
+ replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
215
+ }
216
+
217
+ } // namespace megablocks
218
+
219
+ #undef CUDA_CALL
220
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_replicate.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Forward pass: replicate values from x according to bin sizes
8
+ void replicate_forward(torch::Tensor x,
9
+ torch::Tensor bins,
10
+ torch::Tensor out);
11
+
12
+ // Backward pass: reduce gradients back to bins using segmented reduction
13
+ void replicate_backward(torch::Tensor grad,
14
+ torch::Tensor bins,
15
+ torch::Tensor out);
16
+
17
+ } // namespace megablocks
csrc/new_sort.cu ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include "new_sort.h"
5
+ #include <cstdint>
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+
9
+ #define CUDA_CALL(code) \
10
+ do { \
11
+ cudaError_t status = code; \
12
+ std::string err = cudaGetErrorString(status); \
13
+ TORCH_CHECK(status == cudaSuccess, err); \
14
+ } while (0)
15
+
16
+ namespace megablocks {
17
+
18
+ template <typename T>
19
+ void cub_radix_sort(torch::Tensor x,
20
+ int end_bit,
21
+ torch::Tensor x_out,
22
+ torch::Tensor iota_out) {
23
+ // Get iota for values in sort.
24
+ torch::Tensor iota = torch::arange(0, x.numel(), x.options());
25
+
26
+ // Get temporary buffer size.
27
+ size_t scratchpad_bytes = 0;
28
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
29
+ scratchpad_bytes,
30
+ x.data_ptr<T>(),
31
+ x_out.data_ptr<T>(),
32
+ iota.data_ptr<T>(),
33
+ iota_out.data_ptr<T>(),
34
+ x.numel(),
35
+ /*begin_bit*/0,
36
+ /*end_bit=*/end_bit,
37
+ c10::cuda::getCurrentCUDAStream()));
38
+
39
+ // Allocate scratchpad.
40
+ auto options = torch::TensorOptions()
41
+ .dtype(torch::kInt8)
42
+ .device(x.device());
43
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
44
+
45
+ // Run the kernel.
46
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
47
+ scratchpad_bytes,
48
+ x.data_ptr<T>(),
49
+ x_out.data_ptr<T>(),
50
+ iota.data_ptr<T>(),
51
+ iota_out.data_ptr<T>(),
52
+ x.numel(),
53
+ /*begin_bit=*/0,
54
+ /*end_bit=*/end_bit,
55
+ c10::cuda::getCurrentCUDAStream()));
56
+ }
57
+
58
+ void sort(torch::Tensor x,
59
+ int end_bit,
60
+ torch::Tensor x_out,
61
+ torch::Tensor iota_out) {
62
+ TORCH_CHECK(x.is_cuda());
63
+ TORCH_CHECK(x.ndimension() == 1);
64
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
65
+ x.scalar_type() == torch::kInt32 ||
66
+ x.scalar_type() == torch::kInt64);
67
+ TORCH_CHECK(x_out.is_cuda());
68
+ TORCH_CHECK(x_out.ndimension() == 1);
69
+ TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
70
+ TORCH_CHECK(iota_out.is_cuda());
71
+ TORCH_CHECK(iota_out.ndimension() == 1);
72
+ TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
73
+
74
+ // Exit early if there is not work to do.
75
+ if (x_out.numel() == 0) return;
76
+
77
+ switch (x.scalar_type()) {
78
+ case torch::kInt16:
79
+ return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
80
+ case torch::kInt32:
81
+ return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
82
+ }
83
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
84
+ return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
85
+ }
86
+
87
+ } // namespace megablocks
88
+
89
+ #undef CUDA_CALL
90
+ #undef CUB_WRAPPED_NAMESPACE
csrc/new_sort.h ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/all.h>
4
+
5
+ namespace megablocks {
6
+
7
+ // Public interface function for radix sorting with indices
8
+ void sort(torch::Tensor x,
9
+ int end_bit,
10
+ torch::Tensor x_out,
11
+ torch::Tensor iota_out);
12
+
13
+ } // namespace megablocks
csrc/replicate.h ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cub/cub.cuh>
7
+ #include <c10/util/Half.h>
8
+ #include <c10/cuda/CUDAStream.h>
9
+ // #include <torch/extension.h>
10
+
11
+ #define CUDA_CALL(code) \
12
+ do { \
13
+ cudaError_t status = code; \
14
+ std::string err = cudaGetErrorString(status); \
15
+ TORCH_CHECK(status == cudaSuccess, err); \
16
+ } while (0)
17
+
18
+ namespace megablocks {
19
+ namespace replicate {
20
+
21
+ template <typename T, int kThreadsPerBlock>
22
+ __global__ void __launch_bounds__(kThreadsPerBlock)
23
+ ReplicateForwardKernel(T * __restrict__ x,
24
+ int * __restrict__ bins,
25
+ T * __restrict__ out,
26
+ int columns) {
27
+ // Offset to this threadblocks batch.
28
+ //
29
+ // x is [batch_size, num_bins]
30
+ // out is [batch_size, columns]
31
+ // bins is [num_bins]
32
+ int batch_idx = blockIdx.y;
33
+ int num_bins = gridDim.x;
34
+ x += batch_idx * num_bins;
35
+ out += batch_idx * columns;
36
+
37
+ // Load the start/end for this bin.
38
+ int bin_idx = blockIdx.x;
39
+ int start = 0;
40
+ if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
41
+ int end = __ldg(bins + bin_idx);
42
+
43
+ // Load the value to replicate.
44
+ T value = __ldg((T*)x + bin_idx);
45
+
46
+ // Offset to this threadblocks bin and this threads
47
+ // offset within the bin.
48
+ int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
49
+ out += start + bin_offset;
50
+
51
+ // Replicate the value to the output.
52
+ //
53
+ // TODO(tgale): Vectorize these stores.
54
+ int num_elements = end - start;
55
+ const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
56
+ T *out_ptr = (T*)out;
57
+ for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
58
+ *out_ptr = value;
59
+ out_ptr += kElementsPerLoop;
60
+ }
61
+ }
62
+
63
+ template <typename T>
64
+ cudaError_t ReplicateForward(T *x,
65
+ int batch_size,
66
+ int num_bins,
67
+ int *bins,
68
+ T *out,
69
+ int columns,
70
+ cudaStream_t stream) {
71
+ const int kThreadsPerBlock = 64;
72
+ dim3 block_dim(kThreadsPerBlock, 1, 1);
73
+ int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
74
+ dim3 grid_dim(num_bins, batch_size, group_size);
75
+ ReplicateForwardKernel<T, kThreadsPerBlock><<<
76
+ grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
77
+ return cudaGetLastError();
78
+ }
79
+
80
+ void cub_segmented_reduce(torch::Tensor grad,
81
+ torch::Tensor bins,
82
+ torch::Tensor out,
83
+ cudaStream_t stream) {
84
+ // Append a zero to the bin boundaries for CUB.
85
+ torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
86
+ CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
87
+ 0,
88
+ offsets.numel() * sizeof(int),
89
+ stream));
90
+ CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
91
+ bins.data_ptr<int>(),
92
+ bins.numel() * sizeof(int),
93
+ cudaMemcpyDeviceToDevice,
94
+ stream));
95
+
96
+ // Get temporary buffer size.
97
+ size_t scratchpad_bytes = 0;
98
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
99
+ scratchpad_bytes,
100
+ grad.data_ptr<c10::Half>(),
101
+ out.data_ptr<c10::Half>(),
102
+ bins.numel(),
103
+ offsets.data_ptr<int>(),
104
+ offsets.data_ptr<int>() + 1,
105
+ stream));
106
+
107
+ // Allocate scratchpad.
108
+ auto options = torch::TensorOptions()
109
+ .dtype(torch::kInt8)
110
+ .device(grad.device());
111
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
112
+
113
+ // Run the kernel for each batch item.
114
+ for (int i = 0; i < grad.size(0); ++i) {
115
+ int num_bins = out.size(1);
116
+ int num_values = grad.size(1);
117
+ CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
118
+ scratchpad_bytes,
119
+ grad.data_ptr<c10::Half>() + i * num_values,
120
+ out.data_ptr<c10::Half>() + i * num_bins,
121
+ bins.numel(),
122
+ offsets.data_ptr<int>(),
123
+ offsets.data_ptr<int>() + 1,
124
+ stream));
125
+ }
126
+ }
127
+
128
+ } // namespace replicate
129
+
130
+ void replicate_forward(torch::Tensor x,
131
+ torch::Tensor bins,
132
+ torch::Tensor out) {
133
+ // Validate the inputs.
134
+ TORCH_CHECK(x.is_cuda());
135
+ TORCH_CHECK(x.ndimension() == 2);
136
+ TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
137
+ x.scalar_type() == torch::kInt16 ||
138
+ x.scalar_type() == torch::kInt32);
139
+ TORCH_CHECK(bins.is_cuda());
140
+ TORCH_CHECK(bins.ndimension() == 1);
141
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
142
+ TORCH_CHECK(out.is_cuda());
143
+ TORCH_CHECK(out.ndimension() == 2);
144
+ TORCH_CHECK(out.scalar_type() == x.scalar_type());
145
+
146
+ // Batch dimensions should match for input/output.
147
+ TORCH_CHECK(x.size(0) == out.size(0));
148
+
149
+ // One input for each bin (in each batch).
150
+ TORCH_CHECK(x.size(1) == bins.size(0));
151
+
152
+ // Exit early if there is no work to do.
153
+ if (out.numel() == 0) return;
154
+
155
+ switch (x.scalar_type()) {
156
+ case torch::kFloat16:
157
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
158
+ x.size(0),
159
+ x.size(1),
160
+ bins.data_ptr<int>(),
161
+ out.data_ptr<c10::Half>(),
162
+ out.size(1),
163
+ c10::cuda::getCurrentCUDAStream()));
164
+ return;
165
+ case torch::kInt32:
166
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
167
+ x.size(0),
168
+ x.size(1),
169
+ bins.data_ptr<int>(),
170
+ out.data_ptr<int>(),
171
+ out.size(1),
172
+ c10::cuda::getCurrentCUDAStream()));
173
+ return;
174
+ }
175
+ TORCH_CHECK(x.scalar_type() == torch::kInt16);
176
+ CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
177
+ x.size(0),
178
+ x.size(1),
179
+ bins.data_ptr<int>(),
180
+ out.data_ptr<short>(),
181
+ out.size(1),
182
+ c10::cuda::getCurrentCUDAStream()));
183
+ }
184
+
185
+ void replicate_backward(torch::Tensor grad,
186
+ torch::Tensor bins,
187
+ torch::Tensor out) {
188
+ // Validate the inputs.
189
+ TORCH_CHECK(grad.is_cuda());
190
+ TORCH_CHECK(grad.ndimension() == 2);
191
+ TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
192
+ TORCH_CHECK(bins.is_cuda());
193
+ TORCH_CHECK(bins.ndimension() == 1);
194
+ TORCH_CHECK(bins.scalar_type() == torch::kInt);
195
+ TORCH_CHECK(out.is_cuda());
196
+ TORCH_CHECK(out.ndimension() == 2);
197
+ TORCH_CHECK(out.scalar_type() == torch::kFloat16);
198
+
199
+ // Batch dimensions should match for input/output.
200
+ TORCH_CHECK(grad.size(0) == out.size(0));
201
+
202
+ // One output for each bin (in each batch).
203
+ TORCH_CHECK(out.size(1) == bins.size(0));
204
+
205
+ replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
206
+ }
207
+
208
+ } // namespace megablocks
209
+
210
+ #undef CUDA_CALL
211
+ #undef CUB_WRAPPED_NAMESPACE
csrc/sort.h ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #undef CUB_WRAPPED_NAMESPACE
2
+ #define CUB_WRAPPED_NAMESPACE megablocks
3
+
4
+ #include <cstdint>
5
+
6
+ #include <cub/cub.cuh>
7
+ #include <c10/cuda/CUDAStream.h>
8
+ // #include <torch/extension.h>
9
+
10
+ #define CUDA_CALL(code) \
11
+ do { \
12
+ cudaError_t status = code; \
13
+ std::string err = cudaGetErrorString(status); \
14
+ TORCH_CHECK(status == cudaSuccess, err); \
15
+ } while (0)
16
+
17
+ namespace megablocks {
18
+
19
+ template <typename T>
20
+ void cub_radix_sort(torch::Tensor x,
21
+ int end_bit,
22
+ torch::Tensor x_out,
23
+ torch::Tensor iota_out) {
24
+ // Get iota for values in sort.
25
+ torch::Tensor iota = torch::arange(0, x.numel(), x.options());
26
+
27
+ // Get temporary buffer size.
28
+ size_t scratchpad_bytes = 0;
29
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr,
30
+ scratchpad_bytes,
31
+ x.data_ptr<T>(),
32
+ x_out.data_ptr<T>(),
33
+ iota.data_ptr<T>(),
34
+ iota_out.data_ptr<T>(),
35
+ x.numel(),
36
+ /*begin_bit*/0,
37
+ /*end_bit=*/end_bit,
38
+ c10::cuda::getCurrentCUDAStream()));
39
+
40
+ // Allocate scratchpad.
41
+ auto options = torch::TensorOptions()
42
+ .dtype(torch::kInt8)
43
+ .device(x.device());
44
+ torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);
45
+
46
+ // Run the kernel.
47
+ CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(),
48
+ scratchpad_bytes,
49
+ x.data_ptr<T>(),
50
+ x_out.data_ptr<T>(),
51
+ iota.data_ptr<T>(),
52
+ iota_out.data_ptr<T>(),
53
+ x.numel(),
54
+ /*begin_bit=*/0,
55
+ /*end_bit=*/end_bit,
56
+ c10::cuda::getCurrentCUDAStream()));
57
+ }
58
+
59
+ void sort(torch::Tensor x,
60
+ int end_bit,
61
+ torch::Tensor x_out,
62
+ torch::Tensor iota_out) {
63
+ TORCH_CHECK(x.is_cuda());
64
+ TORCH_CHECK(x.ndimension() == 1);
65
+ TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
66
+ x.scalar_type() == torch::kInt32 ||
67
+ x.scalar_type() == torch::kInt64);
68
+ TORCH_CHECK(x_out.is_cuda());
69
+ TORCH_CHECK(x_out.ndimension() == 1);
70
+ TORCH_CHECK(x_out.scalar_type() == x.scalar_type());
71
+ TORCH_CHECK(iota_out.is_cuda());
72
+ TORCH_CHECK(iota_out.ndimension() == 1);
73
+ TORCH_CHECK(iota_out.scalar_type() == x.scalar_type());
74
+
75
+ // Exit early if there is not work to do.
76
+ if (x_out.numel() == 0) return;
77
+
78
+ switch (x.scalar_type()) {
79
+ case torch::kInt16:
80
+ return cub_radix_sort<short>(x, end_bit, x_out, iota_out);
81
+ case torch::kInt32:
82
+ return cub_radix_sort<int>(x, end_bit, x_out, iota_out);
83
+ }
84
+ TORCH_CHECK(x.scalar_type() == torch::kInt64);
85
+ return cub_radix_sort<long>(x, end_bit, x_out, iota_out);
86
+ }
87
+
88
+ } // namespace megablocks
89
+
90
+ #undef CUDA_CALL
91
+ #undef CUB_WRAPPED_NAMESPACE
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1751968576,
77
+ "narHash": "sha256-cmKrlWpNTG/hq1bCaHXfbdm9T+Y6V+5//EHAVc1TLBE=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "3fcd1e1b46da91b6691261640ffd6b7123d0cb9e",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1753256281,
102
+ "narHash": "sha256-CfL3Fyf2ih7OtyL7ScZUCwOeCj+gjlRyPykhR6Zbt3I=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "dcbbdf2d3c8e78b27321b205b2c9d67ffce6a706",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for megablocks_moe kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+
17
+ pythonCheckInputs = pkgs: with pkgs; [
18
+ tqdm
19
+ py-cpuinfo
20
+ importlib-metadata
21
+ torchmetrics
22
+ ];
23
+ };
24
+ }
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ from typing import List, Optional
6
+
7
+ import pytest
8
+ # from composer.utils import reproducibility
9
+
10
+ # Allowed options for pytest.mark.world_size()
11
+ WORLD_SIZE_OPTIONS = (1, 2)
12
+
13
+ # Enforce deterministic mode before any tests start.
14
+ # reproducibility.configure_deterministic_mode()
15
+
16
+ # TODO: allow plugind when deps resolved
17
+
18
+ # Add the path of any pytest fixture files you want to make global
19
+ pytest_plugins = [
20
+ # 'tests.fixtures.autouse',
21
+ 'tests.fixtures.fixtures',
22
+ ]
23
+
24
+
25
+ def _get_world_size(item: pytest.Item):
26
+ """Returns the world_size of a test, defaults to 1."""
27
+ _default = pytest.mark.world_size(1).mark
28
+ return item.get_closest_marker('world_size', default=_default).args[0]
29
+
30
+
31
+ def _get_option(
32
+ config: pytest.Config,
33
+ name: str,
34
+ default: Optional[str] = None,
35
+ ) -> str: # type: ignore
36
+ val = config.getoption(name)
37
+ if val is not None:
38
+ assert isinstance(val, str)
39
+ return val
40
+ val = config.getini(name)
41
+ if val == []:
42
+ val = None
43
+ if val is None:
44
+ if default is None:
45
+ pytest.fail(f'Config option {name} is not specified but is required',)
46
+ val = default
47
+ assert isinstance(val, str)
48
+ return val
49
+
50
+
51
+ def _add_option(
52
+ parser: pytest.Parser,
53
+ name: str,
54
+ help: str,
55
+ choices: Optional[list[str]] = None,
56
+ ):
57
+ parser.addoption(
58
+ f'--{name}',
59
+ default=None,
60
+ type=str,
61
+ choices=choices,
62
+ help=help,
63
+ )
64
+ parser.addini(
65
+ name=name,
66
+ help=help,
67
+ type='string',
68
+ default=None,
69
+ )
70
+
71
+
72
+ def pytest_collection_modifyitems(
73
+ config: pytest.Config,
74
+ items: List[pytest.Item],
75
+ ) -> None:
76
+ """Filter tests by world_size (for multi-GPU tests)"""
77
+ world_size = int(os.environ.get('WORLD_SIZE', '1'))
78
+ print(f'world_size={world_size}')
79
+
80
+ conditions = [
81
+ lambda item: _get_world_size(item) == world_size,
82
+ ]
83
+
84
+ # keep items that satisfy all conditions
85
+ remaining = []
86
+ deselected = []
87
+ for item in items:
88
+ if all(condition(item) for condition in conditions):
89
+ remaining.append(item)
90
+ else:
91
+ deselected.append(item)
92
+
93
+ if deselected:
94
+ config.hook.pytest_deselected(items=deselected)
95
+ items[:] = remaining
96
+
97
+
98
+ def pytest_addoption(parser: pytest.Parser) -> None:
99
+ _add_option(
100
+ parser,
101
+ 'seed',
102
+ help="""\
103
+ Rank zero seed to use. `reproducibility.seed_all(seed + dist.get_global_rank())` will be invoked
104
+ before each test.""",
105
+ )
106
+
107
+
108
+ def pytest_sessionfinish(session: pytest.Session, exitstatus: int):
109
+ if exitstatus == 5:
110
+ session.exitstatus = 0 # Ignore no-test-ran errors
tests/fixtures/autouse.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import gc
5
+ import logging
6
+ import os
7
+
8
+ import composer
9
+ import pytest
10
+ import torch
11
+ from composer.devices import DeviceCPU, DeviceGPU
12
+ from composer.utils import dist, reproducibility
13
+
14
+
15
+ @pytest.fixture(autouse=True)
16
+ def clear_cuda_cache(request: pytest.FixtureRequest):
17
+ """Clear memory between GPU tests."""
18
+ marker = request.node.get_closest_marker('gpu')
19
+ if marker is not None and torch.cuda.is_available():
20
+ torch.cuda.empty_cache()
21
+ gc.collect() # Only gc on GPU tests as it 2x slows down CPU tests
22
+
23
+
24
+ @pytest.fixture(autouse=True)
25
+ def reset_mlflow_tracking_dir():
26
+ """Reset MLFlow tracking dir so it doesn't persist across tests."""
27
+ try:
28
+ import mlflow
29
+ mlflow.set_tracking_uri(None) # type: ignore
30
+ except ModuleNotFoundError:
31
+ # MLFlow not installed
32
+ pass
33
+
34
+
35
+ @pytest.fixture(scope='session')
36
+ def cleanup_dist():
37
+ """Ensure all dist tests clean up resources properly."""
38
+ yield
39
+ # Avoid race condition where a test is still writing to a file on one rank
40
+ # while the file system is being torn down on another rank.
41
+ dist.barrier()
42
+
43
+
44
+ @pytest.fixture(autouse=True, scope='session')
45
+ def configure_dist(request: pytest.FixtureRequest):
46
+ # Configure dist globally when the world size is greater than 1,
47
+ # so individual tests that do not use the trainer
48
+ # do not need to worry about manually configuring dist.
49
+
50
+ if dist.get_world_size() == 1:
51
+ return
52
+
53
+ device = None
54
+
55
+ for item in request.session.items:
56
+ device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU()
57
+ break
58
+
59
+ assert device is not None
60
+
61
+ if not dist.is_initialized():
62
+ dist.initialize_dist(device, timeout=300.0)
63
+ # Hold PyTest until all ranks have reached this barrier. Ensure that no rank starts
64
+ # any test before other ranks are ready to start it, which could be a cause of random timeouts
65
+ # (e.g. rank 1 starts the next test while rank 0 is finishing up the previous test).
66
+ dist.barrier()
67
+
68
+
69
+ @pytest.fixture(autouse=True)
70
+ def set_log_levels():
71
+ """Ensures all log levels are set to DEBUG."""
72
+ logging.basicConfig()
73
+ logging.getLogger(composer.__name__).setLevel(logging.DEBUG)
74
+
75
+
76
+ @pytest.fixture(autouse=True)
77
+ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
78
+ """Monkeypatch reproducibility.
79
+
80
+ Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local
81
+ seed.
82
+ """
83
+ monkeypatch.setattr(
84
+ reproducibility,
85
+ 'get_random_seed',
86
+ lambda: rank_zero_seed,
87
+ )
88
+ reproducibility.seed_all(rank_zero_seed + dist.get_global_rank())
89
+
90
+
91
+ @pytest.fixture(autouse=True)
92
+ def remove_run_name_env_var():
93
+ # Remove environment variables for run names in unit tests
94
+ composer_run_name = os.environ.get('COMPOSER_RUN_NAME')
95
+ run_name = os.environ.get('RUN_NAME')
96
+
97
+ if 'COMPOSER_RUN_NAME' in os.environ:
98
+ del os.environ['COMPOSER_RUN_NAME']
99
+ if 'RUN_NAME' in os.environ:
100
+ del os.environ['RUN_NAME']
101
+
102
+ yield
103
+
104
+ if composer_run_name is not None:
105
+ os.environ['COMPOSER_RUN_NAME'] = composer_run_name
106
+ if run_name is not None:
107
+ os.environ['RUN_NAME'] = run_name
tests/fixtures/fixtures.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+
6
+ from tests.conftest import _get_option
7
+
8
+
9
+ @pytest.fixture
10
+ def rank_zero_seed(pytestconfig: pytest.Config) -> int:
11
+ """Read the rank_zero_seed from the CLI option."""
12
+ seed = _get_option(pytestconfig, 'seed', default='0')
13
+ return int(seed)
tests/layer_test.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from collections import namedtuple
4
+
5
+
6
+ def test_megablocks_moe_mlp_import():
7
+ """Test if MegaBlocksMoeMLP can be imported."""
8
+ from megablocks.layers import MegaBlocksMoeMLP
9
+
10
+ assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."
11
+
12
+
13
+ def test_megablocks_moe_mlp_functionality():
14
+ """Test the functionality of MegaBlocksMoeMLP."""
15
+ from megablocks.layers import MegaBlocksMoeMLP
16
+
17
+ # Create a simple instance of MegaBlocksMoeMLP
18
+ model = MegaBlocksMoeMLP()
19
+
20
+ # add experts attribute to the model
21
+ model.experts = namedtuple(
22
+ "Experts",
23
+ [
24
+ "gate_up_proj",
25
+ "gate_down_proj",
26
+ "down_proj",
27
+ "hidden_size",
28
+ ],
29
+ )
30
+
31
+ num_experts = 128
32
+ hidden_size = 1152
33
+ intermediate_size = 3072
34
+
35
+ # Shorter names for reading convenience
36
+ ne, hs, isz = num_experts, hidden_size, intermediate_size
37
+
38
+ model.router = torch.nn.Linear(hs, ne).cuda()
39
+ model.router.weight.data.fill_(1)
40
+
41
+ e = model.experts
42
+ e.gate_up_proj = torch.nn.Parameter(torch.ones(ne, hs, isz, device="cuda"))
43
+ e.gate_up_proj_bias = torch.nn.Parameter(torch.zeros(ne, isz, device="cuda"))
44
+ e.down_proj = torch.nn.Parameter(torch.ones(ne, 1536, hs, device="cuda"))
45
+ e.down_proj_bias = torch.nn.Parameter(torch.zeros(ne, hs, device="cuda"))
46
+ e.hidden_size = hs
47
+
48
+ # Create dummy input data
49
+ x = torch.randn(1, 1, 1152).to(torch.device("cuda"))
50
+ output, expert_weights_out = model(x)
51
+
52
+ # print("Output shape:", output.shape)
53
+ assert output.shape == (1, 1, 1152), "Output shape mismatch."
tests/layers/architectures.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from megablocks._layers.arguments import Arguments
8
+
9
+
10
+ class FFN(torch.nn.Module):
11
+
12
+ def __init__(self, args: Arguments):
13
+ super().__init__()
14
+ self.w1 = torch.nn.Parameter(
15
+ torch.empty(
16
+ args.hidden_size,
17
+ args.ffn_hidden_size,
18
+ device=args.device,
19
+ dtype=torch.float16 if args.fp16 else torch.float32,
20
+ ),
21
+ )
22
+ self.w2 = torch.nn.Parameter(
23
+ torch.empty(
24
+ args.ffn_hidden_size,
25
+ args.hidden_size,
26
+ device=args.device,
27
+ dtype=torch.float16 if args.fp16 else torch.float32,
28
+ ),
29
+ )
30
+
31
+ def forward(self, x):
32
+ return torch.matmul(
33
+ F.gelu(torch.matmul(x, self.w1), approximate='tanh'),
34
+ self.w2,
35
+ )
36
+
37
+
38
+ class GLU(FFN):
39
+
40
+ def __init__(self, args: Arguments):
41
+ super().__init__(args)
42
+ self.v1 = torch.nn.Parameter(
43
+ torch.empty(
44
+ args.hidden_size,
45
+ args.ffn_hidden_size,
46
+ device=args.device,
47
+ dtype=torch.float16 if args.fp16 else torch.float32,
48
+ ),
49
+ )
50
+
51
+ def forward(self, x):
52
+ x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1)
53
+ return torch.matmul(x1, self.w2)
tests/layers/moe_test.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from functools import partial
5
+
6
+ import pytest
7
+ import torch
8
+
9
+ from megablocks._layers.arguments import Arguments
10
+ from megablocks._layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss
11
+ from megablocks._layers.router import batched_router_zloss, clear_router_zloss
12
+ from tests.layers.architectures import FFN
13
+
14
+ _FORWARD_TESTS = (
15
+ (16, 1024, 512, 1, 1),
16
+ (16, 1024, 512, 2, 1),
17
+ (16, 1024, 512, 4, 1),
18
+ (16, 1024, 512, 8, 1),
19
+ (8, 2048, 512, 1, 1),
20
+ (8, 2048, 512, 2, 1),
21
+ (8, 2048, 512, 4, 1),
22
+ (16, 1024, 512, 2, 2),
23
+ (16, 1024, 512, 4, 2),
24
+ (16, 1024, 512, 4, 4),
25
+ (16, 1024, 512, 8, 2),
26
+ (16, 1024, 512, 8, 4),
27
+ (16, 1024, 512, 8, 8),
28
+ )
29
+
30
+ _DENSE_TESTS = (
31
+ (16, 1024, 512),
32
+ (8, 2048, 512),
33
+ )
34
+
35
+
36
+ def construct_moe(
37
+ hidden_size: int,
38
+ ffn_hidden_size: int,
39
+ moe_num_experts: int = 1,
40
+ moe_capacity_factor: int = 1,
41
+ moe_top_k: int = 1,
42
+ moe_zloss_weight: float = 0,
43
+ ):
44
+ # All tests are skipped if triton >=3.2.0 is installed since sparse is not supported
45
+ # TODO: Remove this once sparse is supported with triton >=3.2.0
46
+ try:
47
+ import triton
48
+ if triton.__version__ >= '3.2.0':
49
+ pytest.skip('Sparse MLP is not supported with triton >=3.2.0')
50
+ except ImportError:
51
+ pass
52
+
53
+ init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1)
54
+ args = Arguments(
55
+ hidden_size=hidden_size,
56
+ ffn_hidden_size=ffn_hidden_size,
57
+ moe_num_experts=moe_num_experts,
58
+ moe_capacity_factor=moe_capacity_factor,
59
+ moe_top_k=moe_top_k,
60
+ init_method=init_method,
61
+ moe_zloss_weight=moe_zloss_weight,
62
+ )
63
+
64
+ mlp = FFN(args)
65
+ moe_mlp = MoE(args)
66
+
67
+ mlp.cuda(torch.cuda.current_device()).half()
68
+ moe_mlp.cuda(torch.cuda.current_device()).half()
69
+
70
+ # Set the baseline parameters to match exactly.
71
+ if moe_num_experts == 1:
72
+ with torch.no_grad():
73
+ mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
74
+ mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
75
+ return args, mlp, moe_mlp
76
+
77
+
78
+ @pytest.mark.gpu
79
+ @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
80
+ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int):
81
+ x = torch.randn(sl, bs, hs).half().cuda()
82
+
83
+ _, _, layer = construct_moe(
84
+ hidden_size=hs,
85
+ ffn_hidden_size=hs * 2,
86
+ moe_num_experts=num_experts,
87
+ moe_top_k=top_k,
88
+ )
89
+
90
+ out, _ = layer(x)
91
+ assert out.shape == x.shape
92
+ clear_load_balancing_loss()
93
+
94
+
95
+ @pytest.mark.gpu
96
+ @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
97
+ def test_moe_forward_backward(
98
+ bs: int,
99
+ sl: int,
100
+ hs: int,
101
+ num_experts: int,
102
+ top_k: int,
103
+ ):
104
+ x = torch.randn(sl, bs, hs).half().cuda()
105
+ x.requires_grad_(True)
106
+
107
+ args, _, layer = construct_moe(
108
+ hidden_size=hs,
109
+ ffn_hidden_size=hs * 2,
110
+ moe_num_experts=num_experts,
111
+ moe_top_k=top_k,
112
+ )
113
+
114
+ out, _ = layer(x)
115
+ assert out.shape == x.shape
116
+
117
+ loss = out.sum() + batched_load_balancing_loss(args)
118
+ loss.backward()
119
+ layer.zero_grad(set_to_none=True)
120
+ x.grad = None
121
+ clear_load_balancing_loss()
122
+
123
+
124
+ @pytest.mark.gpu
125
+ @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS)
126
+ def test_moe_forward_backward_with_zloss(
127
+ bs: int,
128
+ sl: int,
129
+ hs: int,
130
+ num_experts: int,
131
+ top_k: int,
132
+ ):
133
+ x = torch.randn(sl, bs, hs).half().cuda()
134
+ x.requires_grad_(True)
135
+
136
+ args, _, layer = construct_moe(
137
+ hidden_size=hs,
138
+ ffn_hidden_size=hs * 2,
139
+ moe_num_experts=num_experts,
140
+ moe_top_k=top_k,
141
+ moe_zloss_weight=1e-3,
142
+ )
143
+
144
+ out, _ = layer(x)
145
+ assert out.shape == x.shape
146
+
147
+ loss = out.sum() + batched_load_balancing_loss(args)
148
+ loss.backward()
149
+ layer.zero_grad(set_to_none=True)
150
+ x.grad = None
151
+ clear_load_balancing_loss()
152
+ clear_router_zloss()
153
+
154
+
155
+ @pytest.mark.gpu
156
+ @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
157
+ def test_moe_forward_vs_dense(bs: int, sl: int, hs: int):
158
+ x = torch.randn(sl, bs, hs).half().cuda()
159
+
160
+ _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
161
+
162
+ expected_out = mlp(x)
163
+ out, _ = moe_mlp(x)
164
+ assert out.shape == x.shape == expected_out.shape
165
+ assert torch.allclose(out, expected_out)
166
+ clear_load_balancing_loss()
167
+
168
+
169
+ @pytest.mark.gpu
170
+ @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS)
171
+ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int):
172
+ x = torch.randn(sl, bs, hs).half().cuda()
173
+ x.requires_grad_(True)
174
+
175
+ _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2)
176
+
177
+ out, _ = moe_mlp(x)
178
+ loss = out.sum()
179
+ loss.backward()
180
+ w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze()
181
+ w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
182
+ moe_mlp.zero_grad(set_to_none=True)
183
+ x.grad = None
184
+ clear_load_balancing_loss()
185
+
186
+ expected_out = mlp(x)
187
+ expected_loss = expected_out.sum()
188
+ expected_loss.backward()
189
+ expected_w1_grad = mlp.w1.grad.detach()
190
+ expected_w2_grad = mlp.w2.grad.detach()
191
+ mlp.zero_grad(set_to_none=True)
192
+ x.grad = None
193
+
194
+ # Verify the gradients match.
195
+ assert w1_grad.shape == expected_w1_grad.shape
196
+ assert w2_grad.shape == expected_w2_grad.shape
197
+ assert torch.allclose(w1_grad, expected_w1_grad)
198
+ assert torch.allclose(w2_grad, expected_w2_grad)
199
+ clear_load_balancing_loss()
tests/ops/binned_gather_test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ BINNED_GATHER_TESTS = (
11
+ (4, 2, 2, 1),
12
+ (4, 2, 2, 2),
13
+ (4, 2, 2, 4),
14
+ (1024, 1536, 4, 1),
15
+ (1024, 1536, 4, 2),
16
+ (1024, 1536, 4, 4),
17
+ (1024, 1536, 64, 1),
18
+ (1024, 1536, 64, 2),
19
+ (1024, 1536, 64, 4),
20
+ (1024, 1536, 128, 1),
21
+ (1024, 1536, 128, 2),
22
+ (1024, 1536, 128, 4),
23
+ (16384, 768, 4, 1),
24
+ (16384, 768, 4, 2),
25
+ (16384, 768, 4, 4),
26
+ (16384, 768, 64, 1),
27
+ (16384, 768, 64, 2),
28
+ (16384, 768, 64, 4),
29
+ (16384, 768, 128, 1),
30
+ (16384, 768, 128, 2),
31
+ (16384, 768, 128, 4),
32
+ )
33
+
34
+
35
+ @pytest.mark.gpu
36
+ @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), BINNED_GATHER_TESTS)
37
+ def test_binned_gather(sl: int, hs: int, ne: int, top_k: int):
38
+ # NOTE: Capacity factor == 1.
39
+ ec = (sl * top_k) // ne
40
+
41
+ # Create the data and indices.
42
+ x = torch.randn((sl, hs)).cuda().half()
43
+
44
+ # Randomly assign tokens to experts.
45
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
46
+ _, indices = ops.sort(top_expert)
47
+ bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
48
+
49
+ def binned_gather(
50
+ x: torch.Tensor,
51
+ indices: torch.Tensor,
52
+ bins: torch.Tensor,
53
+ ec: int,
54
+ top_k: int,
55
+ ):
56
+ x = x.cpu().numpy()
57
+ indices = indices.cpu().numpy()
58
+ bins = bins.cpu().numpy()
59
+ start = 0
60
+ out = np.zeros((ne, ec, hs))
61
+ for i in range(ne):
62
+ end = bins[i]
63
+ for j in range(min(ec, end - start)):
64
+ index = indices[start + j] // top_k
65
+ out[i, j, :] = x[index, :]
66
+ start = end
67
+ return torch.from_numpy(out).cuda().half()
68
+
69
+ out = ops.binned_gather(x, indices, bins, ec, top_k)
70
+ expected_out = binned_gather(x, indices, bins, ec, top_k)
71
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/binned_scatter_test.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ _BINNED_SCATTER_TESTS = (
11
+ (4, 2, 2, 1),
12
+ (4, 2, 2, 2),
13
+ (4, 2, 2, 4),
14
+ (1024, 1536, 4, 1),
15
+ (1024, 1536, 4, 2),
16
+ (1024, 1536, 4, 4),
17
+ (1024, 1536, 64, 1),
18
+ (1024, 1536, 64, 2),
19
+ (1024, 1536, 64, 4),
20
+ (1024, 1536, 128, 1),
21
+ (1024, 1536, 128, 2),
22
+ (1024, 1536, 128, 4),
23
+ (16384, 768, 4, 1),
24
+ (16384, 768, 4, 2),
25
+ (16384, 768, 4, 4),
26
+ (16384, 768, 64, 1),
27
+ (16384, 768, 64, 2),
28
+ (16384, 768, 64, 4),
29
+ (16384, 768, 128, 1),
30
+ (16384, 768, 128, 2),
31
+ (16384, 768, 128, 4),
32
+ )
33
+
34
+
35
+ @pytest.mark.gpu
36
+ @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS)
37
+ def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int):
38
+ # NOTE: Capacity factor == 1.
39
+ ec = (sl * top_k) // ne
40
+
41
+ # Create the data and indices.
42
+ x = torch.randn((sl, hs)).cuda().half()
43
+
44
+ # Randomly assign tokens to experts.
45
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
46
+ _, indices = ops.sort(top_expert)
47
+ bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)
48
+
49
+ # Sample weights for the scatter reduce.
50
+ weights = torch.rand((sl * top_k,)).cuda().half()
51
+
52
+ x = ops.binned_gather(x, indices, bins, ec, top_k)
53
+
54
+ def binned_scatter(
55
+ x: torch.Tensor,
56
+ indices: torch.Tensor,
57
+ weights: torch.Tensor,
58
+ bins: torch.Tensor,
59
+ top_k: int,
60
+ ):
61
+ x = x.cpu().numpy()
62
+ indices = indices.cpu().numpy()
63
+ weights = weights.cpu().numpy()
64
+ bins = bins.cpu().numpy()
65
+ start = 0
66
+ out = np.zeros((sl, hs))
67
+ for i in range(ne):
68
+ end = bins[i]
69
+ for j in range(min(ec, end - start)):
70
+ index = indices[start + j]
71
+ scale = weights[index]
72
+ index //= top_k
73
+
74
+ out[index, :] += scale * x[i, j, :]
75
+ start = end
76
+ return torch.from_numpy(out).cuda().half()
77
+
78
+ out = ops.binned_scatter(x, indices, weights, bins, top_k)
79
+ expected_out = binned_scatter(x, indices, weights, bins, top_k)
80
+
81
+ # NOTE: We need to check approximate equality because the
82
+ # scatter reduce uses atomics.
83
+ assert np.testing.assert_allclose(
84
+ out.cpu(),
85
+ expected_out.cpu(),
86
+ rtol=5e-3,
87
+ ) is None
tests/ops/cumsum_test.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+ import torch
6
+
7
+ from megablocks import ops
8
+
9
+ CUMSUM_TESTS = (
10
+ (1, 32),
11
+ (2, 32),
12
+ (2, 1024),
13
+ (4, 1024),
14
+ (8, 1024),
15
+ (16, 1024),
16
+ (32, 1024),
17
+ (64, 1024),
18
+ (128, 1024),
19
+ (2, 16384),
20
+ (4, 16384),
21
+ (8, 16384),
22
+ (16, 16384),
23
+ (32, 16384),
24
+ (64, 16384),
25
+ (128, 16384),
26
+ )
27
+
28
+
29
+ @pytest.mark.gpu
30
+ @pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
31
+ def test_exclusive_cumsum(n: int, m: int):
32
+ x = torch.randint(0, 2, (n, m)).long().cuda()
33
+ out = ops.exclusive_cumsum(x, 1) * x
34
+ expected_out = (torch.cumsum(x, dim=1) - 1) * x
35
+ assert torch.all(torch.eq(out, expected_out))
36
+
37
+
38
+ @pytest.mark.gpu
39
+ @pytest.mark.parametrize(('n', 'm'), CUMSUM_TESTS)
40
+ def test_inclusive_cumsum(n: int, m: int):
41
+ x = torch.randint(0, 2, (n, m)).long().cuda()
42
+ out = ops.inclusive_cumsum(x, 1)
43
+ expected_out = torch.cumsum(x, dim=1)
44
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/histogram_test.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import pytest
5
+ import torch
6
+
7
+ from megablocks import ops
8
+
9
+ _HISTOGRAM_TESTS = (
10
+ (1, 32, torch.int16, 128),
11
+ (1, 1024, torch.int16, 128),
12
+ (1, 16384, torch.int16, 128),
13
+ (1, 32, torch.int32, 128),
14
+ (1, 1024, torch.int32, 128),
15
+ (1, 16384, torch.int32, 128),
16
+ (1, 32, torch.int64, 128),
17
+ (1, 1024, torch.int64, 128),
18
+ (1, 16384, torch.int64, 128),
19
+ (1, 32, torch.int16, 1024),
20
+ (1, 1024, torch.int16, 1024),
21
+ (1, 16384, torch.int16, 1024),
22
+ (1, 32, torch.int32, 1024),
23
+ (1, 1024, torch.int32, 1024),
24
+ (1, 16384, torch.int32, 1024),
25
+ (1, 32, torch.int64, 1024),
26
+ (1, 1024, torch.int64, 1024),
27
+ (1, 16384, torch.int64, 1024),
28
+ (2, 32, torch.int16, 128),
29
+ (2, 1024, torch.int16, 128),
30
+ (2, 16384, torch.int16, 128),
31
+ (2, 32, torch.int32, 128),
32
+ (2, 1024, torch.int32, 128),
33
+ (2, 16384, torch.int32, 128),
34
+ (2, 32, torch.int64, 128),
35
+ (2, 1024, torch.int64, 128),
36
+ (2, 16384, torch.int64, 128),
37
+ (2, 32, torch.int16, 1024),
38
+ (2, 1024, torch.int16, 1024),
39
+ (2, 16384, torch.int16, 1024),
40
+ (2, 32, torch.int32, 1024),
41
+ (2, 1024, torch.int32, 1024),
42
+ (2, 16384, torch.int32, 1024),
43
+ (2, 32, torch.int64, 1024),
44
+ (2, 1024, torch.int64, 1024),
45
+ (2, 16384, torch.int64, 1024),
46
+ (8, 32, torch.int16, 128),
47
+ (8, 1024, torch.int16, 128),
48
+ (8, 16384, torch.int16, 128),
49
+ (8, 32, torch.int32, 128),
50
+ (8, 1024, torch.int32, 128),
51
+ (8, 16384, torch.int32, 128),
52
+ (8, 32, torch.int64, 128),
53
+ (8, 1024, torch.int64, 128),
54
+ (8, 16384, torch.int64, 128),
55
+ (8, 32, torch.int16, 1024),
56
+ (8, 1024, torch.int16, 1024),
57
+ (8, 16384, torch.int16, 1024),
58
+ (8, 32, torch.int32, 1024),
59
+ (8, 1024, torch.int32, 1024),
60
+ (8, 16384, torch.int32, 1024),
61
+ (8, 32, torch.int64, 1024),
62
+ (8, 1024, torch.int64, 1024),
63
+ (8, 16384, torch.int64, 1024),
64
+ )
65
+
66
+
67
+ # Override the seed_all fixture in autouse.py because
68
+ # _histc_cuda does not have a deterministic implementation
69
+ @pytest.fixture()
70
+ def seed_all():
71
+ torch.use_deterministic_algorithms(False)
72
+ return
73
+
74
+
75
+ @pytest.mark.gpu
76
+ @pytest.mark.parametrize(('m', 'n', 'dtype', 'max_val'), _HISTOGRAM_TESTS)
77
+ def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int):
78
+ x = torch.randint(0, max_val, (m, n)).cuda().to(dtype)
79
+
80
+ out = ops.histogram(x, max_val)
81
+ expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)])
82
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/padded_gather_test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ PADDED_GATHER_TESTS = (
11
+ (4, 2, 2, 1),
12
+ (4, 2, 2, 2),
13
+ (1024, 1, 4, 1),
14
+ (1024, 1, 4, 2),
15
+ (1024, 1, 4, 4),
16
+ (1024, 1, 64, 1),
17
+ (1024, 1, 64, 2),
18
+ (1024, 1, 64, 4),
19
+ (1024, 1, 128, 1),
20
+ (1024, 1, 128, 2),
21
+ (1024, 1, 128, 4),
22
+ (1024, 1536, 4, 1),
23
+ (1024, 1536, 4, 2),
24
+ (1024, 1536, 4, 4),
25
+ (1024, 1536, 64, 1),
26
+ (1024, 1536, 64, 2),
27
+ (1024, 1536, 64, 4),
28
+ (1024, 1536, 128, 1),
29
+ (1024, 1536, 128, 2),
30
+ (1024, 1536, 128, 4),
31
+ (16384, 768, 4, 1),
32
+ (16384, 768, 4, 2),
33
+ (16384, 768, 4, 4),
34
+ (16384, 768, 64, 1),
35
+ (16384, 768, 64, 2),
36
+ (16384, 768, 64, 4),
37
+ (16384, 768, 128, 1),
38
+ (16384, 768, 128, 2),
39
+ (16384, 768, 128, 4),
40
+ (16384, 1, 4, 1),
41
+ (16384, 1, 4, 2),
42
+ (16384, 1, 4, 4),
43
+ (16384, 1, 64, 1),
44
+ (16384, 1, 64, 2),
45
+ (16384, 1, 64, 4),
46
+ (16384, 1, 128, 1),
47
+ (16384, 1, 128, 2),
48
+ (16384, 1, 128, 4),
49
+ )
50
+
51
+
52
+ @pytest.mark.gpu
53
+ @pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), PADDED_GATHER_TESTS)
54
+ def testPaddedGather(sl: int, hs: int, ne: int, top_k: int):
55
+ # Create the data and indices.
56
+ x = torch.randn((sl, hs)).cuda().half()
57
+
58
+ # Randomly assign tokens to experts.
59
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
60
+ bin_ids, indices = ops.sort(top_expert)
61
+ tokens_per_expert = ops.histogram(top_expert, ne)
62
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
63
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
64
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
65
+
66
+ def padded_gather(
67
+ x: torch.Tensor,
68
+ indices: torch.Tensor,
69
+ bin_ids: torch.Tensor,
70
+ bins: torch.Tensor,
71
+ padded_bins: torch.Tensor,
72
+ top_k: int,
73
+ ):
74
+ x = x.cpu().numpy()
75
+ indices = indices.cpu().numpy()
76
+ bin_ids = bin_ids.cpu().numpy()
77
+ bins = bins.cpu().numpy()
78
+ padded_bins = padded_bins.cpu().numpy()
79
+
80
+ out = np.zeros((padded_bins[-1], hs))
81
+ in_idx = 0
82
+ for i, end in enumerate(bins):
83
+ out_idx = 0 if i == 0 else padded_bins[i - 1]
84
+ end = bins[i]
85
+ while in_idx < end:
86
+ load_idx = indices[in_idx] // top_k
87
+ out[out_idx, :] = x[load_idx, :]
88
+ in_idx += 1
89
+ out_idx += 1
90
+ return torch.from_numpy(out).cuda().half()
91
+
92
+ out = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
93
+ expected_out = padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
94
+ assert torch.all(torch.eq(out, expected_out))
tests/ops/padded_scatter_test.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ PADDED_SCATTER_TESTS = [
11
+ (4, 2, 2, 2),
12
+ (4, 2, 2, 1),
13
+ (4, 2, 2, 1),
14
+ (4, 2, 2, 1),
15
+ (4, 2, 2, 2),
16
+ (4, 2, 2, 2),
17
+ (1024, 1, 4, 1),
18
+ (1024, 1, 4, 2),
19
+ (1024, 1, 4, 4),
20
+ (1024, 1, 4, 1),
21
+ (1024, 1, 4, 2),
22
+ (1024, 1, 4, 4),
23
+ (1024, 1, 4, 1),
24
+ (1024, 1, 4, 2),
25
+ (1024, 1, 4, 4),
26
+ (1024, 1, 64, 1),
27
+ (1024, 1, 64, 2),
28
+ (1024, 1, 64, 4),
29
+ (1024, 1, 128, 1),
30
+ (1024, 1, 128, 2),
31
+ (1024, 1, 128, 4),
32
+ (1024, 1536, 4, 1),
33
+ (1024, 1536, 4, 2),
34
+ (1024, 1536, 4, 4),
35
+ (1024, 1536, 4, 4),
36
+ (1024, 1536, 4, 4),
37
+ (1024, 1536, 64, 1),
38
+ (1024, 1536, 64, 2),
39
+ (1024, 1536, 64, 4),
40
+ (1024, 1536, 128, 1),
41
+ (1024, 1536, 128, 2),
42
+ (1024, 1536, 128, 4),
43
+ (1024, 1536, 128, 1),
44
+ (1024, 1536, 128, 1),
45
+ (16384, 768, 4, 1),
46
+ (16384, 768, 4, 2),
47
+ (16384, 768, 4, 4),
48
+ (16384, 768, 64, 1),
49
+ (16384, 768, 64, 2),
50
+ (16384, 768, 64, 4),
51
+ (16384, 768, 128, 1),
52
+ (16384, 768, 128, 2),
53
+ (16384, 768, 128, 4),
54
+ (16384, 1, 4, 1),
55
+ (16384, 1, 4, 2),
56
+ (16384, 1, 4, 4),
57
+ (16384, 1, 64, 1),
58
+ (16384, 1, 64, 2),
59
+ (16384, 1, 64, 4),
60
+ (16384, 1, 128, 1),
61
+ (16384, 1, 128, 2),
62
+ (16384, 1, 128, 4),
63
+ (16384, 1, 128, 2),
64
+ (16384, 1, 128, 2),
65
+ ]
66
+
67
+
68
+ def _to_numpy(x: torch.Tensor) -> np.ndarray:
69
+ return x.detach().cpu().numpy()
70
+
71
+
72
+ @pytest.mark.gpu
73
+ @pytest.mark.parametrize((
74
+ 'sl',
75
+ 'hs',
76
+ 'ne',
77
+ 'top_k',
78
+ ), PADDED_SCATTER_TESTS)
79
+ def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
80
+ # Create the data and indices.
81
+ x = torch.randn((sl, hs), requires_grad=True).cuda().half()
82
+
83
+ # Randomly assign tokens to experts.
84
+ top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
85
+ bin_ids, indices = ops.sort(top_expert)
86
+ tokens_per_expert = ops.histogram(top_expert, ne)
87
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
88
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
89
+ bins = ops.inclusive_cumsum(tokens_per_expert, 0)
90
+
91
+ # Sample weights for the scatter reduce.
92
+ weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
93
+
94
+ # Gather the data to prepare for backwards.
95
+ x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
96
+
97
+ def padded_scatter(
98
+ x: torch.Tensor,
99
+ indices: torch.Tensor,
100
+ bin_ids: torch.Tensor,
101
+ weights: torch.Tensor,
102
+ bins: torch.Tensor,
103
+ padded_bins: torch.Tensor,
104
+ top_k: int,
105
+ ):
106
+ x = x.detach().cpu().numpy()
107
+ indices: np.ndarray = _to_numpy(indices)
108
+ bin_ids: np.ndarray = _to_numpy(bin_ids)
109
+ weights: np.ndarray = _to_numpy(weights)
110
+ bins: np.ndarray = _to_numpy(bins)
111
+ padded_bins: np.ndarray = _to_numpy(padded_bins)
112
+
113
+ out = np.zeros((indices.shape[0] // top_k, hs))
114
+ out_idx = 0
115
+ for i in range(len(bins)):
116
+ in_idx = 0 if i == 0 else padded_bins[i - 1]
117
+ end = bins[i]
118
+ while out_idx < end:
119
+ store_idx = indices[out_idx]
120
+ scale = weights[store_idx]
121
+ store_idx //= top_k
122
+
123
+ out[store_idx, :] += scale * x[in_idx, :]
124
+ out_idx += 1
125
+ in_idx += 1
126
+ return torch.from_numpy(out).cuda().half()
127
+
128
+ out = ops.padded_scatter(
129
+ x,
130
+ indices,
131
+ bin_ids,
132
+ weights,
133
+ bins,
134
+ padded_bins,
135
+ top_k,
136
+ )
137
+ expected_out = padded_scatter(
138
+ x,
139
+ indices,
140
+ bin_ids,
141
+ weights,
142
+ bins,
143
+ padded_bins,
144
+ top_k,
145
+ )
146
+
147
+ out.backward(torch.randn_like(out)) # sanity check backward pass
148
+
149
+ # NOTE: We need to check approximate equality because the scatter reduce uses atomics.
150
+ # np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
151
+ assert np.testing.assert_allclose(
152
+ _to_numpy(out),
153
+ _to_numpy(expected_out),
154
+ rtol=5e-3,
155
+ ) is None
tests/ops/replicate_test.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ try:
9
+ from megablocks._ops import ops as backend # type: ignore
10
+ except ModuleNotFoundError as e:
11
+ raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
12
+
13
+ from megablocks import ops
14
+
15
+
16
+ def promote_scalar(x: torch.Tensor) -> torch.Tensor:
17
+ return x.view(1) if not len(x.size()) else x
18
+
19
+
20
+ REPLICATE_TESTS = [
21
+ (8, 1, 1),
22
+ (8, 2, 1),
23
+ (8, 4, 1),
24
+ (8, 8, 1),
25
+ (8, 2, 2),
26
+ (8, 4, 2),
27
+ (8, 8, 2),
28
+ (8, 2, 4),
29
+ (8, 4, 4),
30
+ (8, 8, 4),
31
+ (8, 2, 8),
32
+ (8, 4, 8),
33
+ (8, 8, 8),
34
+ (16384, 2, 1),
35
+ (16384, 4, 1),
36
+ (16384, 8, 1),
37
+ (16384, 16, 1),
38
+ (16384, 32, 1),
39
+ (16384, 64, 1),
40
+ (16384, 128, 1),
41
+ (16384, 2, 2),
42
+ (16384, 4, 2),
43
+ (16384, 8, 2),
44
+ (16384, 16, 2),
45
+ (16384, 32, 2),
46
+ (16384, 64, 2),
47
+ (16384, 128, 2),
48
+ (16384, 2, 4),
49
+ (16384, 4, 4),
50
+ (16384, 8, 4),
51
+ (16384, 16, 4),
52
+ (16384, 32, 4),
53
+ (16384, 64, 4),
54
+ (16384, 128, 4),
55
+ (16384, 2, 8),
56
+ (16384, 4, 8),
57
+ (16384, 8, 8),
58
+ (16384, 16, 8),
59
+ (16384, 32, 8),
60
+ (16384, 64, 8),
61
+ (16384, 128, 8),
62
+ ]
63
+
64
+
65
+ @pytest.mark.gpu
66
+ @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
67
+ def test_replicate(tokens: int, num_centers: int, top_k: int):
68
+ tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
69
+ tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
70
+ bins = ops.inclusive_cumsum(tokens_per_center, 0)
71
+ bins = promote_scalar(bins)
72
+ center_weights = torch.randn(top_k, num_centers).cuda().half()
73
+
74
+ def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
75
+ x = x.cpu().numpy()
76
+ bins = bins.cpu().numpy()
77
+ out = np.zeros((x.shape[0], num_outputs))
78
+ for batch_idx in range(x.shape[0]):
79
+ start = 0
80
+ for i, end in enumerate(bins):
81
+ value = x[batch_idx, i]
82
+ while start < end:
83
+ out[batch_idx, start] = value
84
+ start += 1
85
+ return torch.from_numpy(out).cuda().half()
86
+
87
+ out = ops.replicate(center_weights, bins, tokens)
88
+ expected_out = replicate(center_weights, bins, tokens)
89
+ assert torch.all(torch.eq(out, expected_out))
90
+
91
+
92
+ @pytest.mark.gpu
93
+ @pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
94
+ def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
95
+ tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
96
+ tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
97
+ bins = ops.inclusive_cumsum(tokens_per_center, 0)
98
+ bins = promote_scalar(bins)
99
+ center_weights = torch.randn(top_k, num_centers).cuda().half()
100
+
101
+ grad = ops.replicate(center_weights, bins, tokens)
102
+
103
+ out = torch.empty_like(center_weights)
104
+ backend.replicate_backward(grad, bins, out)
105
+ expected_out = center_weights * tokens_per_center.view([1, num_centers])
106
+
107
+ # NOTE: This floating-point reduction could be a problem for training stability and accuracy.
108
+ assert torch.allclose(out, expected_out, rtol=1e-2)
tests/ops/sort_test.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Dict, Optional, Union
5
+
6
+ import numpy as np
7
+ import pytest
8
+ import torch
9
+
10
+ from megablocks import ops
11
+
12
+ SORT_TESTS = [
13
+ (32, torch.int16, None),
14
+ (1024, torch.int16, None),
15
+ (16384, torch.int16, None),
16
+ (32, torch.int32, None),
17
+ (1024, torch.int32, None),
18
+ (16384, torch.int32, None),
19
+ (32, torch.int64, None),
20
+ (1024, torch.int64, None),
21
+ (16384, torch.int64, None),
22
+ (32, torch.int16, 128),
23
+ (1024, torch.int16, 128),
24
+ (16384, torch.int16, 128),
25
+ (32, torch.int32, 128),
26
+ (1024, torch.int32, 128),
27
+ (16384, torch.int32, 128),
28
+ (32, torch.int64, 128),
29
+ (1024, torch.int64, 128),
30
+ (16384, torch.int64, 128),
31
+ ]
32
+
33
+
34
+ def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]:
35
+ types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = {
36
+ torch.int16: np.int16,
37
+ torch.int32: np.int32,
38
+ torch.int64: np.int64,
39
+ }
40
+ return types[dtype]
41
+
42
+
43
+ @pytest.mark.gpu
44
+ @pytest.mark.parametrize(
45
+ ('n', 'dtype', 'max_val'),
46
+ SORT_TESTS,
47
+ )
48
+ def test_sort(n: int, dtype: torch.dtype, max_val: Optional[int]):
49
+ if max_val is None:
50
+ max_val = np.iinfo(torch_to_numpy_dtype(dtype)).max
51
+ end_bit = int(np.ceil(np.log2(max_val)))
52
+ x = torch.randint(0, max_val, (n,)).cuda().to(dtype)
53
+
54
+ out, indices = ops.sort(x, end_bit)
55
+ expected_out, expected_indices = torch.sort(x)
56
+ assert torch.all(torch.eq(out, expected_out))
57
+
58
+ # NOTE: The indices can be in different order depending
59
+ # on sort stability if multiple values in the array are
60
+ # equal.
61
+ data = torch.empty_like(x)
62
+ data.scatter_(0, indices.long(), out)
63
+ expected_data = torch.empty_like(x)
64
+ expected_data.scatter_(0, expected_indices, expected_out)
65
+ assert torch.all(torch.eq(data, expected_data))
tests/ops/topology_test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import numpy as np
5
+ import pytest
6
+ import torch
7
+
8
+ from megablocks import ops
9
+
10
+ TOPOLOGY_TESTS = (
11
+ (1024, 1536, 2),
12
+ (1024, 1536, 4),
13
+ (1024, 1536, 8),
14
+ (1024, 1536, 16),
15
+ (1024, 1536, 32),
16
+ (1024, 1536, 64),
17
+ (1024, 1536, 128),
18
+ (1024, 1536, 256),
19
+ (1024, 1536, 512),
20
+ (16384, 768, 2),
21
+ (16384, 768, 4),
22
+ (16384, 768, 8),
23
+ (16384, 768, 16),
24
+ (16384, 768, 32),
25
+ (16384, 768, 64),
26
+ (16384, 768, 128),
27
+ (16384, 768, 256),
28
+ (16384, 768, 512),
29
+ (16384, 768, 1024),
30
+ (8, 14336, 8),
31
+ )
32
+
33
+
34
+ @pytest.mark.gpu
35
+ @pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS)
36
+ def test_topology(sl: int, hs: int, ne: int):
37
+ # Create the data and indices.
38
+ blocking = 128
39
+ assert hs % blocking == 0
40
+
41
+ # Randomly assign tokens to experts.
42
+ top_expert = torch.randint(0, ne, (sl,)).cuda().int()
43
+ tokens_per_expert = ops.histogram(top_expert, ne)
44
+ padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking)
45
+ padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
46
+
47
+ # Dimensions for the output indices.
48
+ output_block_rows = int(padded_bins[-1]) // blocking
49
+ output_block_columns = hs // blocking
50
+
51
+ def topology(
52
+ padded_bins: torch.Tensor,
53
+ blocking: torch.Tensor,
54
+ rows: int,
55
+ columns: int,
56
+ ):
57
+ padded_bins = padded_bins.cpu().numpy()
58
+
59
+ out = np.zeros([rows * columns])
60
+ start = 0
61
+ for i in range(padded_bins.shape[0]):
62
+ end = padded_bins[i] // blocking
63
+ while start < end:
64
+ for j in range(columns):
65
+ out[start * columns + j] = j + i * columns
66
+ start += 1
67
+ return torch.from_numpy(out).cuda().short()
68
+
69
+ out = ops.topology(
70
+ padded_bins,
71
+ blocking,
72
+ output_block_rows,
73
+ output_block_columns,
74
+ )
75
+ expected_out = topology(
76
+ padded_bins,
77
+ blocking,
78
+ output_block_rows,
79
+ output_block_columns,
80
+ )
81
+ assert torch.all(torch.eq(out, expected_out))
tests/ops_test.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import megablocks
3
+
4
+ import unittest
5
+ from absl.testing import parameterized
6
+
7
+ # import itertools
8
+ # import numpy as np
9
+
10
+
11
+ def allclose(x, y, pct=2.0):
12
+ mask = torch.isclose(x, y, rtol=1e-5)
13
+ pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
14
+ if pct_diff > pct:
15
+ print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
16
+ print("{:.2f}% of values not close.".format(pct_diff))
17
+ return False
18
+ return True
19
+
20
+
21
+ def add_flags(x):
22
+ out = []
23
+ for y in x:
24
+ for trans_b in (False, True):
25
+ out.append(y + (trans_b, False))
26
+
27
+ # TODO: Revisit enabling batch_sizes_on_device
28
+ # for batch_sizes_on_device in (False, True):
29
+ # out.append(y + (trans_b, batch_sizes_on_device))
30
+ return out
31
+
32
+
33
+ _TEST_PROBLEMS = add_flags((
34
+ (1, 128, 128, 128),
35
+ (8, 128, 128, 128),
36
+ (16, 128, 128, 128),
37
+ (1, 128, 256, 512),
38
+ (8, 128, 256, 512),
39
+ (16, 128, 256, 512),
40
+ ))
41
+
42
+
43
+ def randn(bs, x, y):
44
+ out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
45
+ return out.cuda().to(torch.bfloat16)
46
+
47
+
48
+ def gmm(a, b, batch_sizes, trans_b=False):
49
+ batch_sizes = batch_sizes.cpu().numpy()
50
+
51
+ out = []
52
+ start = 0
53
+ for i, size in enumerate(batch_sizes):
54
+ rhs = b[i, :, :].t() if trans_b else b[i, :, :]
55
+ out.append(a[start:start + size, :] @ rhs)
56
+ start += size
57
+ return torch.cat(out)
58
+
59
+
60
+ @parameterized.parameters(*_TEST_PROBLEMS)
61
+ class OpsTest(parameterized.TestCase):
62
+
63
+ def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
64
+ torch.manual_seed(0)
65
+ a = randn(z, m, k).view(-1, k)
66
+ b = randn(z, n, k) if trans_b else randn(z, k, n)
67
+ batch_sizes = torch.tensor([m] * z)
68
+ if batch_sizes_on_device:
69
+ batch_sizes = batch_sizes.cuda()
70
+
71
+ a.requires_grad_(True)
72
+ b.requires_grad_(True)
73
+ a_ref = a.detach().clone().requires_grad_(True)
74
+ b_ref = b.detach().clone().requires_grad_(True)
75
+
76
+ # out = ops.gmm(a, b, batch_sizes, trans_b)
77
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
78
+ # print("out", out)
79
+ expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
80
+ self.assertTrue(allclose(out, expected_out))
81
+
82
+ # Check gradients.
83
+ out.sum().backward()
84
+ expected_out.sum().backward()
85
+ self.assertTrue(allclose(a.grad, a_ref.grad))
86
+ self.assertTrue(allclose(b.grad, b_ref.grad))
87
+
88
+ def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
89
+ torch.manual_seed(0)
90
+ a = randn(z, m, k).view(-1, k)
91
+ b = randn(z, n, k) if trans_b else randn(z, k, n)
92
+
93
+ dist = torch.rand(z, )
94
+ dist /= dist.sum()
95
+ batch_sizes = (dist * m).to(torch.long)
96
+ error = m * z - batch_sizes.sum()
97
+ batch_sizes[-1] += error
98
+ assert batch_sizes.sum() == (m * z)
99
+ if batch_sizes_on_device:
100
+ batch_sizes = batch_sizes.cuda()
101
+
102
+ a.requires_grad_(True)
103
+ b.requires_grad_(True)
104
+ a_ref = a.detach().clone().requires_grad_(True)
105
+ b_ref = b.detach().clone().requires_grad_(True)
106
+
107
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
108
+ expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
109
+ self.assertTrue(allclose(out, expected_out))
110
+
111
+ # Check gradients.
112
+ out.sum().backward()
113
+ expected_out.sum().backward()
114
+ self.assertTrue(allclose(a.grad, a_ref.grad))
115
+
116
+ # TODO: Review to ensure that the gradients are correct.
117
+ # self.assertTrue(allclose(b.grad, b_ref.grad))
118
+
119
+
120
+ # @parameterized.parameters(False, True)
121
+ @parameterized.parameters(False, False)
122
+ class EdgeCasesTest(unittest.TestCase):
123
+
124
+ def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
125
+ torch.manual_seed(0)
126
+ m = 16384
127
+ k = 4096
128
+ n = 14336
129
+ num_experts = 8
130
+
131
+ a = randn(num_experts, m // num_experts, k).view(-1, k)
132
+ b = randn(num_experts, k, n)
133
+ batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
134
+ if batch_sizes_on_device:
135
+ batch_sizes = batch_sizes.cuda()
136
+
137
+ a.requires_grad_(True)
138
+ b.requires_grad_(True)
139
+ a_ref = a.detach().clone().requires_grad_(True)
140
+ b_ref = b.detach().clone().requires_grad_(True)
141
+
142
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes)
143
+ expected_out = gmm(a_ref, b_ref, batch_sizes)
144
+ self.assertTrue(allclose(out, expected_out))
145
+
146
+ # Check gradients.
147
+ out.sum().backward()
148
+ expected_out.sum().backward()
149
+ self.assertTrue(allclose(a.grad, a_ref.grad))
150
+ self.assertTrue(allclose(b.grad, b_ref.grad))
151
+
152
+ def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
153
+ sz = 128
154
+ total_tokens = 192
155
+
156
+ a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
157
+ b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
158
+ c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
159
+ batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
160
+ if batch_sizes_on_device:
161
+ batch_sizes = batch_sizes.cuda()
162
+
163
+ megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
164
+ self.assertTrue((c[0] == 0).all())
165
+ self.assertTrue((c[1] == 128).all())
166
+ self.assertTrue((c[2] == 0).all())
167
+ self.assertTrue((c[3] == 64).all())
168
+
169
+
170
+ if __name__ == '__main__':
171
+ unittest.main()
tests/parallel_layer_test.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.multiprocessing as mp
4
+ import os
5
+
6
+
7
+ def test_megablocks_moe_mlp_import():
8
+ from megablocks.layers import MegaBlocksMoeMLP
9
+
10
+ assert MegaBlocksMoeMLP is not None, "MegaBlocksMoeMLP import failed."
11
+
12
+
13
+ def run_distributed_test(rank, world_size):
14
+ from megablocks.layers import MegaBlocksMoeMLP
15
+
16
+ os.environ["MASTER_ADDR"] = "localhost"
17
+ os.environ["MASTER_PORT"] = "12355"
18
+ os.environ["RANK"] = str(rank)
19
+ os.environ["WORLD_SIZE"] = str(world_size)
20
+
21
+ dist.init_process_group(
22
+ backend="gloo",
23
+ rank=rank,
24
+ world_size=world_size,
25
+ )
26
+
27
+ expert_parallel_group = torch.distributed.new_group(
28
+ range(torch.distributed.get_world_size())
29
+ )
30
+
31
+ model = MegaBlocksMoeMLP()
32
+ model.expert_parallel_group = expert_parallel_group
33
+
34
+ class Experts:
35
+ def __init__(self):
36
+ self.gate_up_proj = None
37
+ self.gate_up_proj_bias = None
38
+ self.down_proj = None
39
+ self.down_proj_bias = None
40
+ self.hidden_size = None
41
+
42
+ model.experts = Experts()
43
+
44
+ num_experts = 128
45
+ hidden_size = 1152
46
+ intermediate_size = 3072
47
+
48
+ ne, hs, isz = num_experts, hidden_size, intermediate_size
49
+
50
+ experts_per_rank = ne // world_size
51
+
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+
54
+ model.router = torch.nn.Linear(hs, ne).to(device)
55
+ model.router.weight.data.fill_(1)
56
+
57
+ e = model.experts
58
+ e.gate_up_proj = torch.nn.Parameter(
59
+ torch.ones(experts_per_rank, hs, isz, device=device)
60
+ )
61
+ e.gate_up_proj_bias = torch.nn.Parameter(
62
+ torch.zeros(experts_per_rank, isz, device=device)
63
+ )
64
+ e.down_proj = torch.nn.Parameter(
65
+ torch.ones(experts_per_rank, 1536, hs, device=device)
66
+ )
67
+ e.down_proj_bias = torch.nn.Parameter(
68
+ torch.zeros(experts_per_rank, hs, device=device)
69
+ )
70
+ e.hidden_size = hs
71
+
72
+ x = torch.randn(1, 1, 1152).to(device)
73
+ output, expert_weights_out = model(x)
74
+
75
+ assert output.shape == (1, 1, 1152), f"Output shape mismatch on rank {rank}."
76
+
77
+ print(f"Rank {rank}: Test passed! Output shape: {output.shape}")
78
+
79
+ dist.destroy_process_group()
80
+
81
+
82
+ def test_megablocks_moe_mlp_functionality():
83
+ world_size = 2
84
+
85
+ mp.spawn(run_distributed_test, args=(world_size,), nprocs=world_size, join=True)
86
+
87
+ print("Multi-process test completed successfully!")
88
+
89
+
90
+ if __name__ == "__main__":
91
+ test_megablocks_moe_mlp_import()
92
+ print("Import test passed!")
93
+
94
+ test_megablocks_moe_mlp_functionality()
tests/test_gg.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import megablocks
3
+
4
+
5
+ def randn(bs, x, y):
6
+ out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
7
+ return out.cuda().to(torch.bfloat16)
8
+
9
+
10
+ def gmm(a, b, batch_sizes, trans_b=False):
11
+ batch_sizes = batch_sizes.cpu().numpy()
12
+
13
+ out = []
14
+ start = 0
15
+ for i, size in enumerate(batch_sizes):
16
+ rhs = b[i, :, :].t() if trans_b else b[i, :, :]
17
+ out.append(a[start : start + size, :] @ rhs)
18
+ start += size
19
+ return torch.cat(out)
20
+
21
+
22
+ def test_gmm():
23
+ z = 1
24
+ m = 128
25
+ n = 128
26
+ k = 128
27
+ trans_b = False
28
+ batch_sizes_on_device = False
29
+ # TODO: fix to enable batch_sizes_on_device
30
+ # batch_sizes_on_device = True
31
+
32
+ torch.manual_seed(0)
33
+ a = randn(z, m, k).view(-1, k)
34
+ b = randn(z, n, k) if trans_b else randn(z, k, n)
35
+ batch_sizes = torch.tensor([m] * z)
36
+ if batch_sizes_on_device:
37
+ batch_sizes = batch_sizes.cuda()
38
+
39
+ a.requires_grad_(True)
40
+ b.requires_grad_(True)
41
+ a_ref = a.detach().clone().requires_grad_(True)
42
+ b_ref = b.detach().clone().requires_grad_(True)
43
+
44
+ # out = ops.gmm(a, b, batch_sizes, trans_b)
45
+ out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
46
+ print("out", out)
47
+
48
+ expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
49
+
50
+ assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
51
+
52
+ out.sum().backward()
53
+
54
+ expected_out.sum().backward()
55
+ assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
56
+ assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
57
+ print("Test passed successfully!")
tests/test_mb_moe.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import megablocks
3
+
4
+ def test_import():
5
+ """Simple test to check if the module can be imported."""
6
+ print("megablocks_moe module imported successfully.")
7
+ print("Available functions:", dir(megablocks))
8
+
9
+ expected_functions = [
10
+ "Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
11
+ "SparseGLU", "SparseMLP", "argsort",
12
+ "backend", "cumsum", "dMoE", "exclusive_cumsum",
13
+ "get_load_balancing_loss", "grouped_gemm_util", "histogram",
14
+ "inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
15
+ "replicate_forward", "sort", "torch"
16
+ ]
17
+
18
+ # Check if all expected functions are available
19
+ for func in expected_functions:
20
+ assert func in dir(megablocks), f"Missing function: {func}"
21
+
22
+ # exclusive_cumsum
23
+ def test_exclusive_cumsum():
24
+ """Test exclusive cumulative sum."""
25
+ x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
26
+ out = torch.empty_like(x)
27
+ megablocks.exclusive_cumsum(x, 0, out)
28
+ expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
29
+ assert torch.equal(out, expected), f"Expected {expected}, got {out}"
30
+ print("cumsum output:", out)
31
+
32
+ # inclusive_cumsum
33
+ def test_inclusive_cumsum():
34
+ """Test inclusive cumulative sum."""
35
+ x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
36
+ out = torch.empty_like(x)
37
+ megablocks.inclusive_cumsum(x, dim=0, out=out)
38
+ expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
39
+ assert torch.equal(out, expected), f"Expected {expected}, got {out}"
40
+
41
+ # histogram
42
+ def test_histogram():
43
+ """Test histogram operation."""
44
+ x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
45
+ num_bins = 3
46
+ hist = megablocks.histogram(x, num_bins)
47
+ expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
48
+ assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"
tests/test_mb_moe_shared_expert.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import megablocks
3
+ from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
4
+
5
+
6
+ def test_megablocks_moe_mlp_with_shared_expert_import():
7
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
8
+ assert hasattr(mlp, 'shared_up_proj_weight')
9
+ assert hasattr(mlp, 'shared_down_proj_weight')
10
+ assert hasattr(mlp, 'set_shared_expert_weights')
11
+
12
+
13
+ def test_set_shared_expert_weights():
14
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
15
+
16
+ hidden_size = 128
17
+ shared_expert_hidden_size = 256
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ dtype = torch.float32
20
+
21
+ up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device, dtype=dtype)
22
+ down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device, dtype=dtype)
23
+ up_proj_bias = torch.randn(shared_expert_hidden_size, device=device, dtype=dtype)
24
+ down_proj_bias = torch.randn(hidden_size, device=device, dtype=dtype)
25
+
26
+ mlp.set_shared_expert_weights(
27
+ up_proj_weight=up_proj_weight,
28
+ down_proj_weight=down_proj_weight,
29
+ up_proj_bias=up_proj_bias,
30
+ down_proj_bias=down_proj_bias,
31
+ weighted_sum=True,
32
+ activation_fn=torch.nn.functional.gelu
33
+ )
34
+
35
+ assert torch.equal(mlp.shared_up_proj_weight, up_proj_weight)
36
+ assert torch.equal(mlp.shared_down_proj_weight, down_proj_weight)
37
+ assert torch.equal(mlp.shared_up_proj_bias, up_proj_bias)
38
+ assert torch.equal(mlp.shared_down_proj_bias, down_proj_bias)
39
+ assert mlp.shared_expert_weighted_sum == True
40
+ assert mlp.shared_activation_fn == torch.nn.functional.gelu
41
+
42
+
43
+ def test_create_shared_expert_weights():
44
+ hidden_size = 128
45
+ shared_expert_hidden_size = 256
46
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+ dtype = torch.float32
48
+
49
+ def init_method(tensor):
50
+ torch.nn.init.xavier_uniform_(tensor)
51
+
52
+ up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
53
+ hidden_size=hidden_size,
54
+ shared_expert_hidden_size=shared_expert_hidden_size,
55
+ device=device,
56
+ dtype=dtype,
57
+ init_method=init_method
58
+ )
59
+
60
+ assert up_proj_weight.shape == (shared_expert_hidden_size, hidden_size)
61
+ assert down_proj_weight.shape == (hidden_size, shared_expert_hidden_size)
62
+ assert up_proj_weight.device.type == device.type
63
+ assert down_proj_weight.device.type == device.type
64
+ assert up_proj_weight.dtype == dtype
65
+ assert down_proj_weight.dtype == dtype
66
+ assert up_proj_bias is None
67
+ assert down_proj_bias is None
68
+
69
+
70
+ def test_shared_expert_weights_none_by_default():
71
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
72
+
73
+ assert mlp.shared_up_proj_weight is None
74
+ assert mlp.shared_down_proj_weight is None
75
+ assert mlp.shared_up_proj_bias is None
76
+ assert mlp.shared_down_proj_bias is None
77
+ assert mlp.shared_expert_weighted_sum == False
78
+ assert mlp.shared_activation_fn is None
79
+
80
+
81
+ def test_inheritance_from_megablocks_moe_mlp():
82
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
83
+
84
+ from megablocks.layers import MegaBlocksMoeMLP
85
+ assert isinstance(mlp, MegaBlocksMoeMLP)
86
+ assert hasattr(mlp, 'forward')
87
+
88
+
89
+ def test_shared_expert_weights_custom_init():
90
+ hidden_size = 64
91
+ shared_expert_hidden_size = 128
92
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
93
+ dtype = torch.float16
94
+
95
+ def custom_init(tensor):
96
+ torch.nn.init.constant_(tensor, 0.5)
97
+
98
+ def custom_output_init(tensor):
99
+ torch.nn.init.constant_(tensor, 0.1)
100
+
101
+ up_proj_weight, down_proj_weight, up_proj_bias, down_proj_bias = create_shared_expert_weights(
102
+ hidden_size=hidden_size,
103
+ shared_expert_hidden_size=shared_expert_hidden_size,
104
+ device=device,
105
+ dtype=dtype,
106
+ init_method=custom_init,
107
+ output_layer_init_method=custom_output_init
108
+ )
109
+
110
+ assert torch.all(up_proj_weight == 0.5)
111
+ assert torch.all(down_proj_weight == 0.1)
112
+ assert up_proj_weight.dtype == dtype
113
+ assert down_proj_weight.dtype == dtype
114
+
115
+
116
+ def test_shared_expert_weights_dimensions():
117
+ mlp = MegaBlocksMoeMLPWithSharedExpert()
118
+
119
+ batch_size = 4
120
+ seq_len = 16
121
+ hidden_size = 128
122
+ shared_expert_hidden_size = 256
123
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
124
+
125
+ up_proj_weight = torch.randn(shared_expert_hidden_size, hidden_size, device=device)
126
+ down_proj_weight = torch.randn(hidden_size, shared_expert_hidden_size, device=device)
127
+
128
+ mlp.set_shared_expert_weights(
129
+ up_proj_weight=up_proj_weight,
130
+ down_proj_weight=down_proj_weight
131
+ )
132
+
133
+ x = torch.randn(seq_len, batch_size, hidden_size, device=device)
134
+
135
+ expected_up_output_shape = (seq_len, batch_size, shared_expert_hidden_size)
136
+ expected_down_output_shape = (seq_len, batch_size, hidden_size)
137
+
138
+ assert up_proj_weight.shape[1] == x.shape[-1]
139
+ assert down_proj_weight.shape[0] == x.shape[-1]
tests/test_mb_moe_shared_expert_multi.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed as dist
3
+ import torch.multiprocessing as mp
4
+ import os
5
+ import pytest
6
+ from megablocks.layers import MegaBlocksMoeMLPWithSharedExpert, create_shared_expert_weights
7
+
8
+
9
+ def run_distributed_shared_expert_test(rank, world_size):
10
+ os.environ["MASTER_ADDR"] = "localhost"
11
+ os.environ["MASTER_PORT"] = "12356"
12
+ os.environ["RANK"] = str(rank)
13
+ os.environ["WORLD_SIZE"] = str(world_size)
14
+
15
+ dist.init_process_group(
16
+ backend="gloo",
17
+ rank=rank,
18
+ world_size=world_size,
19
+ )
20
+
21
+ model = MegaBlocksMoeMLPWithSharedExpert()
22
+
23
+ hidden_size = 128
24
+ shared_expert_hidden_size = 192
25
+ device = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ def simple_init(tensor):
28
+ torch.nn.init.xavier_uniform_(tensor)
29
+
30
+ shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
31
+ hidden_size=hidden_size,
32
+ shared_expert_hidden_size=shared_expert_hidden_size,
33
+ device=torch.device(device),
34
+ dtype=torch.float32,
35
+ init_method=simple_init
36
+ )
37
+
38
+ model.set_shared_expert_weights(
39
+ up_proj_weight=shared_up_proj_weight,
40
+ down_proj_weight=shared_down_proj_weight,
41
+ up_proj_bias=shared_up_proj_bias,
42
+ down_proj_bias=shared_down_proj_bias,
43
+ weighted_sum=True,
44
+ activation_fn=torch.nn.functional.gelu
45
+ )
46
+
47
+ assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
48
+ assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
49
+ assert model.shared_expert_weighted_sum == True, f"Weighted sum not set correctly on rank {rank}"
50
+
51
+ print(f"Rank {rank}: Shared expert setup test passed!")
52
+
53
+ dist.destroy_process_group()
54
+
55
+
56
+ def run_distributed_shared_expert_weighted_sum_test(rank, world_size):
57
+ os.environ["MASTER_ADDR"] = "localhost"
58
+ os.environ["MASTER_PORT"] = "12357"
59
+ os.environ["RANK"] = str(rank)
60
+ os.environ["WORLD_SIZE"] = str(world_size)
61
+
62
+ dist.init_process_group(
63
+ backend="gloo",
64
+ rank=rank,
65
+ world_size=world_size,
66
+ )
67
+
68
+ model = MegaBlocksMoeMLPWithSharedExpert()
69
+
70
+ hidden_size = 64
71
+ device = "cuda" if torch.cuda.is_available() else "cpu"
72
+
73
+ def simple_init(tensor):
74
+ torch.nn.init.xavier_uniform_(tensor)
75
+
76
+ shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
77
+ hidden_size=hidden_size,
78
+ shared_expert_hidden_size=96,
79
+ device=torch.device(device),
80
+ dtype=torch.float32,
81
+ init_method=simple_init
82
+ )
83
+
84
+ model.set_shared_expert_weights(
85
+ up_proj_weight=shared_up_proj_weight,
86
+ down_proj_weight=shared_down_proj_weight,
87
+ weighted_sum=False,
88
+ activation_fn=torch.nn.functional.relu
89
+ )
90
+
91
+ assert model.shared_up_proj_weight is not None, f"Shared up proj weight not set on rank {rank}"
92
+ assert model.shared_down_proj_weight is not None, f"Shared down proj weight not set on rank {rank}"
93
+ assert model.shared_expert_weighted_sum == False, f"Weighted sum not set correctly on rank {rank}"
94
+ assert model.shared_activation_fn == torch.nn.functional.relu, f"Activation function not set correctly on rank {rank}"
95
+
96
+ print(f"Rank {rank}: Weighted sum setup test passed!")
97
+
98
+ dist.destroy_process_group()
99
+
100
+
101
+ @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
102
+ def test_shared_expert_distributed_functionality(world_size):
103
+ if world_size == 1:
104
+ # Single process test
105
+ model = MegaBlocksMoeMLPWithSharedExpert()
106
+
107
+ hidden_size = 128
108
+ shared_expert_hidden_size = 192
109
+ device = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
+ def simple_init(tensor):
112
+ torch.nn.init.xavier_uniform_(tensor)
113
+
114
+ shared_up_proj_weight, shared_down_proj_weight, shared_up_proj_bias, shared_down_proj_bias = create_shared_expert_weights(
115
+ hidden_size=hidden_size,
116
+ shared_expert_hidden_size=shared_expert_hidden_size,
117
+ device=torch.device(device),
118
+ dtype=torch.float32,
119
+ init_method=simple_init
120
+ )
121
+
122
+ model.set_shared_expert_weights(
123
+ up_proj_weight=shared_up_proj_weight,
124
+ down_proj_weight=shared_down_proj_weight,
125
+ up_proj_bias=shared_up_proj_bias,
126
+ down_proj_bias=shared_down_proj_bias,
127
+ weighted_sum=True,
128
+ activation_fn=torch.nn.functional.gelu
129
+ )
130
+
131
+ assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
132
+ assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
133
+ assert model.shared_expert_weighted_sum == True, "Weighted sum not set correctly"
134
+
135
+ print("Single process shared expert setup test passed!")
136
+ else:
137
+ # Multi-process test
138
+ mp.spawn(run_distributed_shared_expert_test, args=(world_size,), nprocs=world_size, join=True)
139
+ print("Multi-process shared expert test completed successfully!")
140
+
141
+
142
+ @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
143
+ def test_shared_expert_distributed_weighted_sum(world_size):
144
+ if world_size == 1:
145
+ # Single process test
146
+ model = MegaBlocksMoeMLPWithSharedExpert()
147
+
148
+ hidden_size = 64
149
+ device = "cuda" if torch.cuda.is_available() else "cpu"
150
+
151
+ def simple_init(tensor):
152
+ torch.nn.init.xavier_uniform_(tensor)
153
+
154
+ shared_up_proj_weight, shared_down_proj_weight, _, _ = create_shared_expert_weights(
155
+ hidden_size=hidden_size,
156
+ shared_expert_hidden_size=96,
157
+ device=torch.device(device),
158
+ dtype=torch.float32,
159
+ init_method=simple_init
160
+ )
161
+
162
+ model.set_shared_expert_weights(
163
+ up_proj_weight=shared_up_proj_weight,
164
+ down_proj_weight=shared_down_proj_weight,
165
+ weighted_sum=False,
166
+ activation_fn=torch.nn.functional.relu
167
+ )
168
+
169
+ assert model.shared_up_proj_weight is not None, "Shared up proj weight not set"
170
+ assert model.shared_down_proj_weight is not None, "Shared down proj weight not set"
171
+ assert model.shared_expert_weighted_sum == False, "Weighted sum not set correctly"
172
+ assert model.shared_activation_fn == torch.nn.functional.relu, "Activation function not set correctly"
173
+
174
+ print("Single process weighted sum setup test passed!")
175
+ else:
176
+ # Multi-process test
177
+ mp.spawn(run_distributed_shared_expert_weighted_sum_test, args=(world_size,), nprocs=world_size, join=True)
178
+ print("Multi-process shared expert weighted sum test completed successfully!")
179
+
180
+
181
+ def test_shared_expert_single_process():
182
+ model = MegaBlocksMoeMLPWithSharedExpert()
183
+
184
+ assert model.shared_up_proj_weight is None
185
+ assert model.shared_down_proj_weight is None
186
+ assert hasattr(model, 'set_shared_expert_weights')
187
+
188
+ print("Single process shared expert basic test passed!")
189
+
190
+
191
+ if __name__ == "__main__":
192
+ test_shared_expert_single_process()
193
+ print("Single process test passed!")
194
+
195
+ os.environ['WORLD_SIZE'] = '2'
196
+ test_shared_expert_distributed_functionality()
197
+ print("Distributed functionality test passed!")
198
+
199
+ test_shared_expert_distributed_weighted_sum()
200
+ print("Distributed weighted sum test passed!")
torch-ext/megablocks/__init__.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+
6
+ from ._ops import ops
7
+
8
+ #from .grouped_gemm import backend as gg_backend
9
+ #from .grouped_gemm import ops as gg_ops
10
+
11
+
12
+ from ._layers.arguments import Arguments
13
+ from ._layers.dmoe import ParallelDroplessMLP, dMoE
14
+ from ._layers.glu import SparseGLU
15
+ from ._layers.mlp import MLP, SparseMLP
16
+ from ._layers.moe import MoE, ParallelMLP, get_load_balancing_loss
17
+
18
+ from . import layers
19
+
20
+ # This section contains the direct kernel exports (not inlcuded in the original code)
21
+ def exclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
22
+ """
23
+ Compute exclusive cumulative sum along the specified dimension.
24
+
25
+ Args:
26
+ x: Input tensor
27
+ dim: Dimension along which to compute cumsum
28
+ out: Output tensor (modified in-place)
29
+
30
+ Returns:
31
+ The output tensor
32
+ """
33
+ result = ops.exclusive_cumsum(x, dim)
34
+ out.copy_(result)
35
+ return out
36
+
37
+
38
+ def inclusive_cumsum(x: torch.Tensor, dim: int, out: torch.Tensor) -> torch.Tensor:
39
+ """
40
+ Compute inclusive cumulative sum along the specified dimension.
41
+
42
+ Args:
43
+ x: Input tensor
44
+ dim: Dimension along which to compute cumsum
45
+ out: Output tensor (modified in-place)
46
+
47
+ Returns:
48
+ The output tensor
49
+ """
50
+ result = ops.inclusive_cumsum(x, dim)
51
+ out.copy_(result)
52
+ return out
53
+
54
+
55
+ def histogram(x: torch.Tensor, num_bins: int) -> torch.Tensor:
56
+ """
57
+ Compute histogram of input tensor values.
58
+
59
+ Args:
60
+ x: Input tensor
61
+ num_bins: Number of histogram bins
62
+
63
+ Returns:
64
+ Histogram tensor with counts for each bin
65
+ """
66
+ return ops.histogram(x, num_bins)
67
+
68
+
69
+ def indices(
70
+ padded_bins: torch.Tensor,
71
+ block_size: int,
72
+ output_block_rows: int,
73
+ output_block_columns: int,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Construct indices from padded bins for sparse operations.
77
+
78
+ Args:
79
+ padded_bins: Tensor containing bin boundaries
80
+ block_size: Size of each block
81
+ output_block_rows: Number of rows in output blocks
82
+ output_block_columns: Number of columns in output blocks
83
+
84
+ Returns:
85
+ Tensor containing constructed indices
86
+ """
87
+ return ops.indices(padded_bins, block_size, output_block_rows, output_block_columns)
88
+
89
+
90
+ def replicate_forward(
91
+ x: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ """
94
+ Forward pass of replicate operation - replicate values according to bin sizes.
95
+
96
+ Args:
97
+ x: Input tensor with values to replicate
98
+ bins: Tensor containing bin sizes
99
+ out: Output tensor (modified in-place)
100
+
101
+ Returns:
102
+ The output tensor
103
+ """
104
+ return ops.replicate_forward(x, bins, out)
105
+
106
+
107
+ def replicate_backward(
108
+ grad: torch.Tensor, bins: torch.Tensor, out: torch.Tensor
109
+ ) -> torch.Tensor:
110
+ """
111
+ Backward pass of replicate operation - reduce gradients back to bins.
112
+
113
+ Args:
114
+ grad: Gradient tensor to reduce
115
+ bins: Tensor containing bin sizes
116
+ out: Output tensor (modified in-place)
117
+
118
+ Returns:
119
+ The output tensor
120
+ """
121
+ return ops.replicate_backward(grad, bins, out)
122
+
123
+
124
+ def sort(
125
+ x: torch.Tensor, end_bit: int, x_out: torch.Tensor, iota_out: torch.Tensor
126
+ ) -> torch.Tensor:
127
+ """
128
+ Radix sort with index tracking.
129
+
130
+ Args:
131
+ x: Input tensor to sort
132
+ end_bit: Number of bits to consider in sorting
133
+ x_out: Output tensor for sorted values
134
+ iota_out: Output tensor for sorted indices
135
+
136
+ Returns:
137
+ The sorted values tensor
138
+ """
139
+ return ops.sort(x, end_bit, x_out, iota_out)
140
+
141
+
142
+ # Convenience functions for common use cases
143
+ def cumsum(x: torch.Tensor, dim: int = -1, exclusive: bool = False) -> torch.Tensor:
144
+ """
145
+ Compute cumulative sum with automatic output allocation.
146
+
147
+ Args:
148
+ x: Input tensor
149
+ dim: Dimension along which to compute cumsum (default: last dimension)
150
+ exclusive: Whether to compute exclusive (True) or inclusive (False) cumsum
151
+
152
+ Returns:
153
+ New tensor containing the cumulative sum
154
+ """
155
+ out = torch.empty_like(x)
156
+ if exclusive:
157
+ return exclusive_cumsum(x, dim, out)
158
+ else:
159
+ return inclusive_cumsum(x, dim, out)
160
+
161
+
162
+ def argsort(x: torch.Tensor, end_bit: int = 32) -> tuple[torch.Tensor, torch.Tensor]:
163
+ """
164
+ Sort tensor and return both sorted values and indices.
165
+
166
+ Args:
167
+ x: Input tensor to sort
168
+ end_bit: Number of bits to consider in sorting
169
+
170
+ Returns:
171
+ Tuple of (sorted_values, sorted_indices)
172
+ """
173
+ x_out = torch.empty_like(x)
174
+ iota_out = torch.empty_like(x)
175
+ sort(x, end_bit, x_out, iota_out)
176
+ return x_out, iota_out
177
+
178
+
179
+ # Export public API
180
+ __all__ = [
181
+ "MyReplacementLayer",
182
+ # Direct kernel exports
183
+ "exclusive_cumsum",
184
+ "inclusive_cumsum",
185
+ "histogram",
186
+ "indices",
187
+ "replicate_forward",
188
+ "replicate_backward",
189
+ "sort",
190
+ "cumsum",
191
+ "argsort",
192
+ # Original exports
193
+ "Arguments",
194
+ "ParallelDroplessMLP",
195
+ "dMoE",
196
+ "SparseGLU",
197
+ "MLP",
198
+ "SparseMLP",
199
+ "MoE",
200
+ "ParallelMLP",
201
+ "get_load_balancing_loss",
202
+ ]
torch-ext/megablocks/_layers/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ # from megablocks.layers.dmoe import dMoE
5
+ from .moe import MoE
6
+
7
+ __all__ = [
8
+ 'MoE',
9
+ # 'dMoE',
10
+ ]
torch-ext/megablocks/_layers/activation_fn.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Any, Callable, Union
5
+
6
+ import torch
7
+ from ..stk import Matrix
8
+
9
+
10
+ def act_fn(
11
+ x: Matrix,
12
+ function: Callable,
13
+ return_grad_fn: bool = False,
14
+ **kwargs,
15
+ ) -> Union[tuple[Matrix, Any] | Matrix]:
16
+ assert isinstance(x, Matrix)
17
+ with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
18
+ if return_grad_fn:
19
+ x.data.requires_grad = True
20
+ out = function(x.data, **kwargs)
21
+ y = Matrix(
22
+ x.size(),
23
+ out,
24
+ x.row_indices,
25
+ x.column_indices,
26
+ x.offsets,
27
+ x.column_indices_t,
28
+ x.offsets_t,
29
+ x.block_offsets_t,
30
+ )
31
+ if return_grad_fn:
32
+ return y, out.backward
33
+ return y
torch-ext/megablocks/_layers/all_to_all.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Databricks
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class AllToAllOp(torch.autograd.Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op):
12
+ out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype)
13
+
14
+ ctx.input_shape = x.shape
15
+ ctx.output_split_sizes = output_split_sizes
16
+ ctx.input_split_sizes = input_split_sizes
17
+ ctx.group = group
18
+ handle = dist.all_to_all_single(
19
+ out,
20
+ x,
21
+ output_split_sizes=output_split_sizes,
22
+ input_split_sizes=input_split_sizes,
23
+ group=group,
24
+ async_op=async_op,
25
+ )
26
+ return out, handle
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad, _):
30
+ if ctx.needs_input_grad[0]:
31
+ out = torch.empty(
32
+ ctx.input_shape,
33
+ device=grad.device,
34
+ dtype=grad.dtype,
35
+ )
36
+ dist.all_to_all_single(
37
+ out,
38
+ grad,
39
+ output_split_sizes=ctx.input_split_sizes,
40
+ input_split_sizes=ctx.output_split_sizes,
41
+ group=ctx.group,
42
+ )
43
+ return out, None, None, None, None
44
+ return None, None, None, None, None
45
+
46
+
47
+ def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False):
48
+ return AllToAllOp.apply(
49
+ x,
50
+ output_split_sizes,
51
+ input_split_sizes,
52
+ group,
53
+ async_op,
54
+ )