danieldk HF staff commited on
Commit
b4cad21
·
1 Parent(s): e87d8e6

Add cutlass_w8a8

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,9 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
1
+ ----
2
+ -license: apache-2.0
3
+ ----
4
+
5
+ ## Activation
6
+
7
+ Quantization kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/quantization).
8
+
9
+ Copyright 2023-2024, the vLLM team.
build.toml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ version = "0.0.1"
3
+
4
+ [torch]
5
+ name = "quantization"
6
+ src = [
7
+ "ext-torch/registration.h",
8
+ "ext-torch/torch_binding.cpp",
9
+ "ext-torch/torch_binding.h"
10
+ ]
11
+ pysrc = [
12
+ "ext-torch/__init__.py"
13
+ ]
14
+
15
+ [kernel.cutlass_w8a8]
16
+ capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
17
+ src = [
18
+ "cutlass_w8a8/common.hpp",
19
+ "cutlass_w8a8/scaled_mm_c2x.cu",
20
+ "cutlass_w8a8/scaled_mm_c2x.cuh",
21
+ "cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh",
22
+ "cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh",
23
+ "cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh",
24
+ "cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh",
25
+ "cutlass_w8a8/scaled_mm_entry.cu",
26
+ "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp",
27
+ "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
28
+ ]
29
+ include = [ "." ]
30
+ depends = [ "cutlass", "torch" ]
31
+
32
+ [kernel.cutlass_w8a8_hopper]
33
+ capabilities = [ "9.0", "9.0a" ]
34
+ src = [
35
+ "cutlass_w8a8/common.hpp",
36
+ "cutlass_w8a8/scaled_mm_c3x.cu",
37
+ "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
38
+ "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
39
+ ]
40
+ include = [ "." ]
41
+ depends = [ "cutlass", "torch" ]
cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
3
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice,
9
+ *this list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29
+ *POSSIBILITY OF SUCH DAMAGE.
30
+ *
31
+ **************************************************************************************************/
32
+
33
+ //
34
+ // This file is a modified excerpt of
35
+ // include/cutlass/epilogue/fusion/visitor_load.hpp from
36
+ // https://github.com/NVIDIA/cutlass v3.5.0
37
+ // It has been modified to support either
38
+ // row/column or scalar broadcasting where the tensor being loaded from is
39
+ // always passed in via a device pointer. This lets one compiled kernel handle
40
+ // all cases of per-tensor or per-channel/per-token quantization.
41
+ //
42
+ // This interface also allows the scales to be passed in as tensors that
43
+ // consistently reside on the device, which avoids an issue with a previous
44
+ // implementation where scalars needed to be on the CPU since they
45
+ // were passed in via float values. This created a potential performance hazard
46
+ // if scales were initially on the device, and caused torch.compile graph
47
+ // breaks when moving scales to the CPU.
48
+ //
49
+ #pragma once
50
+
51
+ // Turn off clang-format for the entire file to keep it close to upstream
52
+ // clang-format off
53
+
54
+ #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
55
+ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
56
+ #include "cute/tensor.hpp"
57
+
58
+ namespace cutlass::epilogue::threadblock {
59
+
60
+ using namespace cute;
61
+ using namespace detail;
62
+
63
+ template<
64
+ class ThreadMap,
65
+ class Element,
66
+ class StrideMNL
67
+ >
68
+ struct VisitorRowOrScalarBroadcast {
69
+
70
+ // This struct has been modified to have a bool indicating that ptr_row is a
71
+ // scalar that must be broadcast.
72
+ struct Arguments {
73
+ Element const* ptr_row = nullptr;
74
+ bool row_broadcast = true;
75
+ StrideMNL dRow = {};
76
+ };
77
+
78
+ using Params = Arguments;
79
+
80
+ template <class ProblemShape>
81
+ static constexpr Params
82
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
83
+ return args;
84
+ }
85
+
86
+ template <class ProblemShape>
87
+ static size_t
88
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
89
+ return 0;
90
+ }
91
+
92
+ struct SharedStorage {};
93
+
94
+ // Global load type
95
+ static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
96
+ using VecType = uint_bit_t<cute::min(128, vec_bits)>;
97
+ static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
98
+
99
+ CUTLASS_HOST_DEVICE
100
+ VisitorRowOrScalarBroadcast() { }
101
+
102
+ CUTLASS_HOST_DEVICE
103
+ VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
104
+ : params_ptr(&params) { }
105
+
106
+ Params const* params_ptr;
107
+
108
+ template <class GTensor, class RTensor, class CTensor, class ProblemShape>
109
+ struct Callbacks : EmptyCallbacks {
110
+ CUTLASS_DEVICE
111
+ Callbacks(
112
+ GTensor&& tC_gRow,
113
+ RTensor&& tC_rRow,
114
+ CTensor&& tC_cRow,
115
+ ProblemShape problem_shape,
116
+ Params const* params_ptr
117
+ ):
118
+ tC_gRow(cute::forward<GTensor>(tC_gRow)),
119
+ tC_rRow(cute::forward<RTensor>(tC_rRow)),
120
+ tC_cRow(cute::forward<CTensor>(tC_cRow)),
121
+ n(get<1>(problem_shape)),
122
+ params_ptr(params_ptr) { }
123
+
124
+ GTensor tC_gRow;
125
+ RTensor tC_rRow;
126
+ CTensor tC_cRow;
127
+ Params const* params_ptr;
128
+ int n;
129
+
130
+ // This function is modified from VisitorRowBroadcast
131
+ CUTLASS_DEVICE void
132
+ begin_epilogue() {
133
+ clear(tC_rRow);
134
+ auto src_v = filter(tC_gRow);
135
+ auto coord_v = filter(tC_cRow);
136
+ auto dst_v = filter(tC_rRow);
137
+
138
+ if (params_ptr->row_broadcast) {
139
+ // In this case we are loading from a row vector and broadcasting
140
+ CUTLASS_PRAGMA_UNROLL
141
+ for (int i = 0; i < size(src_v); ++i) {
142
+ bool guard = get<1>(coord_v(i)) < n;
143
+ cutlass::arch::global_load<VecType, sizeof(VecType)>(
144
+ dst_v(i), (void const*)&src_v(i), guard);
145
+ }
146
+ } else {
147
+ // In this case we are loading from a scalar and broadcasting
148
+ VecType filled_vec;
149
+ CUTLASS_PRAGMA_UNROLL
150
+ for (int i = 0; i < VecLength; i++) {
151
+ reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
152
+ }
153
+
154
+ CUTLASS_PRAGMA_UNROLL
155
+ for (int i = 0; i < size(src_v); ++i) {
156
+ if (get<1>(coord_v(i)) < n) {
157
+ dst_v(i) = filled_vec;
158
+ }
159
+ }
160
+ }
161
+ }
162
+
163
+ template <class ElementAccumulator, int FragmentSize>
164
+ CUTLASS_DEVICE auto // returns an Array
165
+ visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
166
+ Array<ElementAccumulator, FragmentSize> const& frg_acc) {
167
+ Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
168
+ return rRow_frg(column_idx);
169
+ }
170
+ };
171
+
172
+ template <class ProblemShape>
173
+ CUTLASS_DEVICE auto
174
+ get_callbacks(
175
+ gemm::GemmCoord threadblock_tile_offset,
176
+ int thread_idx,
177
+ ProblemShape problem_shape
178
+ ) {
179
+ Tensor mRow = make_tensor(
180
+ make_gmem_ptr(params_ptr->ptr_row),
181
+ problem_shape,
182
+ params_ptr->dRow);
183
+
184
+ // VECTOR, FRAGMENT_COLUMN
185
+ Tensor tC_gRow = recast<VecType>(
186
+ ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
187
+ )(_,_,_0{},_0{},_0{},_0{});
188
+ Tensor tC_rRow = make_tensor_like(tC_gRow);
189
+
190
+ // Generate the pred tensor
191
+ Tensor cRow = make_identity_tensor(mRow.shape());
192
+ Tensor tC_cRow = outer_partition(
193
+ ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
194
+ Shape<Int<VecLength>>{},
195
+ (_0{})
196
+ );
197
+
198
+ return Callbacks<
199
+ decltype(tC_gRow), decltype(tC_rRow),
200
+ decltype(tC_cRow), ProblemShape>(
201
+ cute::move(tC_gRow),
202
+ cute::move(tC_rRow),
203
+ cute::move(tC_cRow),
204
+ problem_shape,
205
+ params_ptr
206
+ );
207
+ }
208
+
209
+ };
210
+
211
+ /////////////////////////////////////////////////////////////////////////////////////////////////
212
+
213
+ // This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
214
+ template<
215
+ class ThreadMap,
216
+ class Element,
217
+ class StrideMNL
218
+ >
219
+ struct VisitorRowOrZeroBroadcast {
220
+
221
+ // This struct has been modified to remove null_default (because it's always 0)
222
+ struct Arguments {
223
+ Element const* ptr_row = nullptr;
224
+ StrideMNL dRow = {};
225
+ };
226
+
227
+ using Params = Arguments;
228
+
229
+ template <class ProblemShape>
230
+ static constexpr Params
231
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
232
+ return args;
233
+ }
234
+
235
+ template <class ProblemShape>
236
+ static size_t
237
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
238
+ return 0;
239
+ }
240
+
241
+ struct SharedStorage {};
242
+
243
+ // Global load type
244
+ static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
245
+ using VecType = uint_bit_t<cute::min(128, vec_bits)>;
246
+ static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
247
+
248
+ CUTLASS_HOST_DEVICE
249
+ VisitorRowOrZeroBroadcast() { }
250
+
251
+ CUTLASS_HOST_DEVICE
252
+ VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
253
+ : params_ptr(&params) { }
254
+
255
+ Params const* params_ptr;
256
+
257
+ template <class GTensor, class RTensor, class CTensor, class ProblemShape>
258
+ struct Callbacks : EmptyCallbacks {
259
+ CUTLASS_DEVICE
260
+ Callbacks(
261
+ GTensor&& tC_gRow,
262
+ RTensor&& tC_rRow,
263
+ CTensor&& tC_cRow,
264
+ ProblemShape problem_shape,
265
+ Params const* params_ptr
266
+ ):
267
+ tC_gRow(cute::forward<GTensor>(tC_gRow)),
268
+ tC_rRow(cute::forward<RTensor>(tC_rRow)),
269
+ tC_cRow(cute::forward<CTensor>(tC_cRow)),
270
+ n(get<1>(problem_shape)),
271
+ params_ptr(params_ptr) { }
272
+
273
+ GTensor tC_gRow;
274
+ RTensor tC_rRow;
275
+ CTensor tC_cRow;
276
+ Params const* params_ptr;
277
+ int n;
278
+
279
+ // This function is modified from VisitorRowBroadcast
280
+ CUTLASS_DEVICE void
281
+ begin_epilogue() {
282
+ clear(tC_rRow);
283
+ auto src_v = filter(tC_gRow);
284
+ auto coord_v = filter(tC_cRow);
285
+ auto dst_v = filter(tC_rRow);
286
+
287
+ if (params_ptr->ptr_row != nullptr) {
288
+ // In this case we are loading from a row vector and broadcasting
289
+ CUTLASS_PRAGMA_UNROLL
290
+ for (int i = 0; i < size(src_v); ++i) {
291
+ bool guard = get<1>(coord_v(i)) < n;
292
+ cutlass::arch::global_load<VecType, sizeof(VecType)>(
293
+ dst_v(i), (void const*)&src_v(i), guard);
294
+ }
295
+ } else {
296
+ // In this case we are broadcasting 0
297
+ VecType filled_vec;
298
+ CUTLASS_PRAGMA_UNROLL
299
+ for (int i = 0; i < VecLength; i++) {
300
+ reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
301
+ }
302
+
303
+ CUTLASS_PRAGMA_UNROLL
304
+ for (int i = 0; i < size(src_v); ++i) {
305
+ if (get<1>(coord_v(i)) < n) {
306
+ dst_v(i) = filled_vec;
307
+ }
308
+ }
309
+ }
310
+ }
311
+
312
+ template <class ElementAccumulator, int FragmentSize>
313
+ CUTLASS_DEVICE auto // returns an Array
314
+ visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
315
+ Array<ElementAccumulator, FragmentSize> const& frg_acc) {
316
+ Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
317
+ return rRow_frg(column_idx);
318
+ }
319
+ };
320
+
321
+ template <class ProblemShape>
322
+ CUTLASS_DEVICE auto
323
+ get_callbacks(
324
+ gemm::GemmCoord threadblock_tile_offset,
325
+ int thread_idx,
326
+ ProblemShape problem_shape
327
+ ) {
328
+ Tensor mRow = make_tensor(
329
+ make_gmem_ptr(params_ptr->ptr_row),
330
+ problem_shape,
331
+ params_ptr->dRow);
332
+
333
+ // VECTOR, FRAGMENT_COLUMN
334
+ Tensor tC_gRow = recast<VecType>(
335
+ ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
336
+ )(_,_,_0{},_0{},_0{},_0{});
337
+ Tensor tC_rRow = make_tensor_like(tC_gRow);
338
+
339
+ // Generate the pred tensor
340
+ Tensor cRow = make_identity_tensor(mRow.shape());
341
+ Tensor tC_cRow = outer_partition(
342
+ ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
343
+ Shape<Int<VecLength>>{},
344
+ (_0{})
345
+ );
346
+
347
+ return Callbacks<
348
+ decltype(tC_gRow), decltype(tC_rRow),
349
+ decltype(tC_cRow), ProblemShape>(
350
+ cute::move(tC_gRow),
351
+ cute::move(tC_rRow),
352
+ cute::move(tC_cRow),
353
+ problem_shape,
354
+ params_ptr
355
+ );
356
+ }
357
+
358
+ };
359
+
360
+
361
+ /////////////////////////////////////////////////////////////////////////////////////////////////
362
+
363
+ // Column vector broadcast
364
+ template<
365
+ class ThreadMap,
366
+ class Element,
367
+ class StrideMNL = Stride<_1,_0,_0>
368
+ >
369
+ struct VisitorColOrScalarBroadcast {
370
+
371
+ // This struct has been modified to have a bool indicating that ptr_col is a
372
+ // scalar that must be broadcast.
373
+ struct Arguments {
374
+ Element const* ptr_col = nullptr;
375
+ bool col_broadcast = true;
376
+ StrideMNL dCol = {};
377
+ };
378
+
379
+ using Params = Arguments;
380
+
381
+ template <class ProblemShape>
382
+ static constexpr Params
383
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
384
+ return args;
385
+ }
386
+
387
+ template <class ProblemShape>
388
+ static size_t
389
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
390
+ return 0;
391
+ }
392
+
393
+ struct SharedStorage { };
394
+
395
+ CUTLASS_HOST_DEVICE
396
+ VisitorColOrScalarBroadcast() { }
397
+
398
+ CUTLASS_HOST_DEVICE
399
+ VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
400
+ : params_ptr(&params) { }
401
+
402
+ Params const* params_ptr;
403
+
404
+ template <class GTensor, class RTensor, class CTensor, class ProblemShape>
405
+ struct Callbacks : EmptyCallbacks {
406
+ CUTLASS_DEVICE
407
+ Callbacks(
408
+ GTensor&& tC_gCol,
409
+ RTensor&& tC_rCol,
410
+ CTensor&& tC_cCol,
411
+ ProblemShape problem_shape,
412
+ Params const* params_ptr
413
+ ):
414
+ tC_gCol(cute::forward<GTensor>(tC_gCol)),
415
+ tC_rCol(cute::forward<RTensor>(tC_rCol)),
416
+ tC_cCol(cute::forward<CTensor>(tC_cCol)),
417
+ m(get<0>(problem_shape)),
418
+ params_ptr(params_ptr) { }
419
+
420
+ GTensor tC_gCol;
421
+ RTensor tC_rCol;
422
+ CTensor tC_cCol;
423
+ Params const* params_ptr;
424
+ int m;
425
+
426
+ // This function is modified from VisitorColBroadcast
427
+ CUTLASS_DEVICE void
428
+ begin_epilogue() {
429
+ clear(tC_rCol);
430
+
431
+ Tensor pred = make_tensor<bool>(shape(tC_gCol));
432
+ CUTLASS_PRAGMA_UNROLL
433
+ for (int i = 0; i < size(pred); ++i) {
434
+ pred(i) = get<0>(tC_cCol(i)) < m;
435
+ }
436
+
437
+ if (params_ptr->col_broadcast) {
438
+ // In this case we are loading from a column vector and broadcasting
439
+ copy_if(pred, tC_gCol, tC_rCol);
440
+ } else {
441
+ // In this case we are loading from a scalar and broadcasting
442
+ auto dst_v = filter(tC_rCol);
443
+
444
+ CUTLASS_PRAGMA_UNROLL
445
+ for (int i = 0; i < size(dst_v); ++i) {
446
+ if (pred(i)) {
447
+ dst_v(i) = *(params_ptr->ptr_col);
448
+ }
449
+ }
450
+ }
451
+ }
452
+
453
+ template <class ElementAccumulator, int FragmentSize>
454
+ CUTLASS_DEVICE auto // returns an Array
455
+ visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
456
+ Array<ElementAccumulator, FragmentSize> const& frg_acc) {
457
+ Array<Element, FragmentSize> frg_col;
458
+ frg_col.fill(tC_rCol(row_idx,iter_idx));
459
+ return frg_col;
460
+ }
461
+ };
462
+
463
+ template <class ProblemShape>
464
+ CUTLASS_DEVICE auto
465
+ get_callbacks(
466
+ gemm::GemmCoord threadblock_tile_offset,
467
+ int thread_idx,
468
+ ProblemShape problem_shape
469
+ ) {
470
+ Tensor mCol = make_tensor(
471
+ make_gmem_ptr(params_ptr->ptr_col),
472
+ problem_shape,
473
+ params_ptr->dCol);
474
+
475
+ // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
476
+ Tensor tC_gCol = group_modes<1,4>(
477
+ ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
478
+ Tensor tC_rCol = make_tensor_like(tC_gCol);
479
+
480
+ // Generate the pred tensor
481
+ Tensor cCol = make_identity_tensor(mCol.shape());
482
+ Tensor tC_cCol = group_modes<1,4>(
483
+ ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
484
+
485
+ return Callbacks<
486
+ decltype(tC_gCol), decltype(tC_rCol),
487
+ decltype(tC_cCol), ProblemShape>(
488
+ cute::move(tC_gCol),
489
+ cute::move(tC_rCol),
490
+ cute::move(tC_cCol),
491
+ problem_shape,
492
+ params_ptr
493
+ );
494
+ }
495
+ };
496
+
497
+ }
cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /***************************************************************************************************
2
+ * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
3
+ *reserved. SPDX-License-Identifier: BSD-3-Clause
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * 1. Redistributions of source code must retain the above copyright notice,
9
+ *this list of conditions and the following disclaimer.
10
+ *
11
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * 3. Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22
+ *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23
+ *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24
+ *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25
+ *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26
+ *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27
+ *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28
+ *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29
+ *POSSIBILITY OF SUCH DAMAGE.
30
+ *
31
+ **************************************************************************************************/
32
+
33
+ //
34
+ // This file is a modified excerpt of
35
+ // include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
36
+ // from https://github.com/NVIDIA/cutlass v3.5.0
37
+ // It has been modified to support either row/column or scalar broadcasting
38
+ // where the tensor being loaded from is always passed in via a device pointer.
39
+ // This lets one compiled kernel handle all cases of per-tensor or
40
+ // per-channel/per-token quantization.
41
+ //
42
+ // This interface also allows the scales to be passed in as tensors that
43
+ // consistently reside on the device, which avoids an issue with a previous
44
+ // implementation where scalars needed to be on the CPU since they
45
+ // were passed in via float values. This created a potential performance hazard
46
+ // if scales were initially on the device, and caused torch.compile graphs
47
+ // breaks when moving scales to the CPU.
48
+ //
49
+ #pragma once
50
+
51
+ // Turn off clang-format for the entire file to keep it close to upstream
52
+ // clang-format off
53
+
54
+ #include "cutlass/cutlass.h"
55
+ #include "cutlass/arch/barrier.h"
56
+
57
+ #include "cute/tensor.hpp"
58
+ #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
59
+
60
+ namespace cutlass::epilogue::fusion {
61
+
62
+ using namespace cute;
63
+ using namespace detail;
64
+
65
+ // Row vector broadcast
66
+ template<
67
+ int Stages,
68
+ class CtaTileShapeMNK,
69
+ class Element,
70
+ class StrideMNL = Stride<_0,_1,_0>,
71
+ int Alignment = 128 / sizeof_bits_v<Element>
72
+ >
73
+ struct Sm90RowOrScalarBroadcast {
74
+ static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
75
+ static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
76
+ static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
77
+
78
+ struct SharedStorage {
79
+ array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
80
+ };
81
+
82
+ // This struct has been modified to have a bool indicating that ptr_row is a
83
+ // scalar that must be broadcast, instead of containing a scalar that is
84
+ // valid if ptr_row is null.
85
+ struct Arguments {
86
+ Element const* ptr_row = nullptr;
87
+ bool row_broadcast = true;
88
+ StrideMNL dRow = {};
89
+ };
90
+
91
+ using Params = Arguments;
92
+
93
+ template <class ProblemShape>
94
+ static constexpr Params
95
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
96
+ return args;
97
+ }
98
+
99
+ template <class ProblemShape>
100
+ static bool
101
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
102
+ return true;
103
+ }
104
+
105
+ template <class ProblemShape>
106
+ static size_t
107
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
108
+ return 0;
109
+ }
110
+
111
+ template <class ProblemShape>
112
+ static cutlass::Status
113
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
114
+ CudaHostAdapter* cuda_adapter = nullptr) {
115
+ return cutlass::Status::kSuccess;
116
+ }
117
+
118
+ CUTLASS_HOST_DEVICE
119
+ Sm90RowOrScalarBroadcast() { }
120
+
121
+ CUTLASS_HOST_DEVICE
122
+ Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
123
+ : params(params)
124
+ , smem(const_cast<Element*>(shared_storage.smem.data())) { }
125
+
126
+ Params params;
127
+ Element *smem = nullptr;
128
+
129
+ CUTLASS_DEVICE bool
130
+ is_producer_load_needed() const {
131
+ return false;
132
+ }
133
+
134
+ CUTLASS_DEVICE bool
135
+ is_C_load_needed() const {
136
+ return false;
137
+ }
138
+
139
+ CUTLASS_DEVICE bool
140
+ is_zero() const {
141
+ return (!params.row_broadcast && *(params.ptr_row) == Element(0));
142
+ }
143
+
144
+ template <class... Args>
145
+ CUTLASS_DEVICE auto
146
+ get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
147
+ return EmptyProducerLoadCallbacks{};
148
+ }
149
+
150
+ template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
151
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
152
+ CUTLASS_DEVICE
153
+ ConsumerStoreCallbacks(
154
+ GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
155
+ GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
156
+ SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
157
+ CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
158
+ : tGS_gRow(tGS_gRow_)
159
+ , tGS_sRow(tGS_sRow_)
160
+ , tGS_cRow(tGS_cRow_)
161
+ , tiled_G2S(tiled_g2s_)
162
+ , tSR_sRow(tSR_sRow_)
163
+ , tSR_rRow(tSR_rRow_)
164
+ , tCcRow(tCcRow_)
165
+ , residue_tCcRow(residue_tCcRow_)
166
+ , params(params_) {}
167
+
168
+ GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
169
+ GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
170
+ GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
171
+ Tiled_G2S tiled_G2S;
172
+
173
+ SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
174
+ SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
175
+
176
+ CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
177
+ ThrResidue residue_tCcRow; // (m, n)
178
+ ThrNum thr_num;
179
+ Params const& params;
180
+
181
+ CUTLASS_DEVICE void
182
+ begin() {
183
+ if (!params.row_broadcast) {
184
+ fill(tSR_rRow, *(params.ptr_row));
185
+ return;
186
+ }
187
+
188
+ auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
189
+ Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
190
+ Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
191
+ Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
192
+
193
+ for (int i = 0; i < size(tGS_gRow_flt); ++i) {
194
+ if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
195
+ continue; // OOB of SMEM,
196
+ }
197
+ if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
198
+ tGS_sRow_flt(i) = tGS_gRow_flt(i);
199
+ }
200
+ else {
201
+ tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
202
+ }
203
+ }
204
+ synchronize();
205
+ }
206
+
207
+ CUTLASS_DEVICE void
208
+ begin_loop(int epi_m, int epi_n) {
209
+ if (epi_m == 0) { // Assumes M-major subtile loop
210
+ if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
211
+ Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
212
+ Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
213
+ copy(tSR_sRow_flt, tSR_rRow_flt);
214
+ }
215
+ }
216
+
217
+ template <typename ElementAccumulator, int FragmentSize>
218
+ CUTLASS_DEVICE Array<Element, FragmentSize>
219
+ visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
220
+ Array<Element, FragmentSize> frg_row;
221
+
222
+ CUTLASS_PRAGMA_UNROLL
223
+ for (int i = 0; i < FragmentSize; ++i) {
224
+ frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
225
+ }
226
+
227
+ return frg_row;
228
+ }
229
+ };
230
+
231
+ template <
232
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
233
+ class... Args
234
+ >
235
+ CUTLASS_DEVICE auto
236
+ get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
237
+ auto [M, N, K, L] = args.problem_shape_mnkl;
238
+ auto [m, n, k, l] = args.tile_coord_mnkl;
239
+ using ThreadCount = decltype(size(args.tiled_copy));
240
+
241
+ Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
242
+ Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
243
+ Tensor sRow = make_tensor(make_smem_ptr(smem),
244
+ make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
245
+ //// G2S: Gmem to Smem
246
+ auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
247
+ Layout< Shape<_1, ThreadCount>,
248
+ Stride<_0, _1>>{},
249
+ Layout<_1>{});
250
+ auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
251
+ Tensor tGS_gRow = thr_g2s.partition_S(gRow);
252
+ Tensor tGS_sRow = thr_g2s.partition_D(sRow);
253
+
254
+ //// G2S: Coord
255
+ auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
256
+ Tensor tGS_cRow = thr_g2s.partition_S(cRow);
257
+
258
+ //// S2R: Smem to Reg
259
+ Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
260
+ Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
261
+
262
+ return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
263
+ tGS_gRow,
264
+ tGS_sRow,
265
+ tGS_cRow, tiled_g2s,
266
+ tSR_sRow,
267
+ tSR_rRow,
268
+ args.tCcD,
269
+ args.residue_cD,
270
+ ThreadCount{},
271
+ params);
272
+ }
273
+ };
274
+
275
+ /////////////////////////////////////////////////////////////////////////////////////////////////
276
+
277
+ // Column vector broadcast
278
+ template<
279
+ int Stages,
280
+ class CtaTileShapeMNK,
281
+ class Element,
282
+ class StrideMNL = Stride<_1,_0,_0>,
283
+ int Alignment = 128 / sizeof_bits_v<Element>
284
+ >
285
+ struct Sm90ColOrScalarBroadcast {
286
+ static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
287
+ static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
288
+ static_assert(
289
+ (cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
290
+ (cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
291
+
292
+ // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
293
+ struct SharedStorage { };
294
+
295
+ // This struct has been modified to have a bool indicating that ptr_col is a
296
+ // scalar that must be broadcast, instead of containing a scalar that is
297
+ // valid if ptr_col is null.
298
+ struct Arguments {
299
+ Element const* ptr_col = nullptr;
300
+ bool col_broadcast = true;
301
+ StrideMNL dCol = {};
302
+ };
303
+
304
+ using Params = Arguments;
305
+
306
+ template <class ProblemShape>
307
+ static constexpr Params
308
+ to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
309
+ return args;
310
+ }
311
+
312
+ template <class ProblemShape>
313
+ static bool
314
+ can_implement(ProblemShape const& problem_shape, Arguments const& args) {
315
+ return true;
316
+ }
317
+
318
+ template <class ProblemShape>
319
+ static size_t
320
+ get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
321
+ return 0;
322
+ }
323
+
324
+ template <class ProblemShape>
325
+ static cutlass::Status
326
+ initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
327
+ CudaHostAdapter* cuda_adapter = nullptr) {
328
+ return cutlass::Status::kSuccess;
329
+ }
330
+
331
+ CUTLASS_DEVICE bool
332
+ is_producer_load_needed() const {
333
+ return false;
334
+ }
335
+
336
+ CUTLASS_DEVICE bool
337
+ is_C_load_needed() const {
338
+ return false;
339
+ }
340
+
341
+ CUTLASS_DEVICE bool
342
+ is_zero() const {
343
+ return (!params.col_broadcast && *(params.ptr_col) == Element(0));
344
+ }
345
+
346
+ CUTLASS_HOST_DEVICE
347
+ Sm90ColOrScalarBroadcast() { }
348
+
349
+ CUTLASS_HOST_DEVICE
350
+ Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
351
+ : params(params) { }
352
+
353
+ Params params;
354
+
355
+ template <class... Args>
356
+ CUTLASS_DEVICE auto
357
+ get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
358
+ return EmptyProducerLoadCallbacks{};
359
+ }
360
+
361
+ template<class GTensor, class RTensor, class CTensor, class ProblemShape>
362
+ struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
363
+ CUTLASS_DEVICE
364
+ ConsumerStoreCallbacks(
365
+ GTensor&& tCgCol,
366
+ RTensor&& tCrCol,
367
+ CTensor&& tCcCol,
368
+ ProblemShape problem_shape,
369
+ Params const& params
370
+ ):
371
+ tCgCol(cute::forward<GTensor>(tCgCol)),
372
+ tCrCol(cute::forward<RTensor>(tCrCol)),
373
+ tCcCol(cute::forward<CTensor>(tCcCol)),
374
+ m(get<0>(problem_shape)),
375
+ params(params) {}
376
+
377
+ GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
378
+ RTensor tCrCol;
379
+ CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
380
+ Params const& params;
381
+ int m;
382
+
383
+ CUTLASS_DEVICE void
384
+ begin() {
385
+ Tensor pred = make_tensor<bool>(shape(tCgCol));
386
+ CUTLASS_PRAGMA_UNROLL
387
+ for (int i = 0; i < size(pred); ++i) {
388
+ pred(i) = get<0>(tCcCol(i)) < m;
389
+ }
390
+
391
+ if (!params.col_broadcast) {
392
+ fill(tCrCol, *(params.ptr_col));
393
+ return;
394
+ }
395
+
396
+ // Filter so we don't issue redundant copies over stride-0 modes
397
+ // (only works if 0-strides are in same location, which is by construction)
398
+ copy_if(pred, filter(tCgCol), filter(tCrCol));
399
+ }
400
+
401
+ template <typename ElementAccumulator, int FragmentSize>
402
+ CUTLASS_DEVICE Array<Element, FragmentSize>
403
+ visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
404
+ Array<Element, FragmentSize> frg_col;
405
+ Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
406
+
407
+ CUTLASS_PRAGMA_UNROLL
408
+ for (int i = 0; i < FragmentSize; ++i) {
409
+ frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
410
+ }
411
+
412
+ return frg_col;
413
+ }
414
+
415
+ };
416
+
417
+ template <
418
+ bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
419
+ class... Args
420
+ >
421
+ CUTLASS_DEVICE auto
422
+ get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
423
+
424
+ auto [M, N, K, L] = args.problem_shape_mnkl;
425
+ Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
426
+ Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
427
+ mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
428
+ Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
429
+
430
+ // Generate an identity tensor matching the shape of the global tensor and
431
+ // partition the same way, this will be used to generate the predicate
432
+ // tensor for loading
433
+ Tensor cCol = make_identity_tensor(mCol.shape());
434
+ Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
435
+ cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
436
+
437
+ return ConsumerStoreCallbacks(
438
+ cute::move(tCgCol),
439
+ cute::move(tCrCol),
440
+ cute::move(tCcCol),
441
+ args.problem_shape_mnkl,
442
+ params
443
+ );
444
+ }
445
+ };
446
+
447
+ }
cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
2
+
3
+ /*
4
+ This file defines custom epilogues for fusing channel scales, token scales,
5
+ bias, and activation zero-points onto a GEMM operation using the
6
+ CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
7
+
8
+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
9
+ as well as a static prepare_args function that constructs an
10
+ EVTCompute::Arguments struct.
11
+ */
12
+
13
+ namespace vllm::c2x {
14
+
15
+ using namespace cute;
16
+
17
+ /*
18
+ * This class provides the common load descriptors for the
19
+ * ScaledEpilogue[...] classes
20
+ */
21
+ template <typename ElementD, typename OutputTileThreadMap>
22
+ struct ScaledEpilogueBase {
23
+ protected:
24
+ using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
25
+
26
+ template <typename T>
27
+ using ColOrScalarLoad =
28
+ cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
29
+ OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
30
+
31
+ template <typename T>
32
+ using RowOrScalarLoad =
33
+ cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
34
+ OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
35
+
36
+ template <typename T>
37
+ using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
38
+ OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
39
+
40
+ template <typename T>
41
+ using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
42
+ OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
43
+
44
+ template <typename T>
45
+ using RowOrZeroLoad =
46
+ cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
47
+ OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
48
+
49
+ // This utility function constructs the arguments for the load descriptors
50
+ // from a tensor. It can handle both row and column, as well as row/column or
51
+ // scalar cases.
52
+ template <typename Descriptor, typename T>
53
+ static auto args_from_tensor(torch::Tensor const& tensor) {
54
+ using Arguments = typename Descriptor::Arguments;
55
+ auto* data_ptr = static_cast<T*>(tensor.data_ptr());
56
+ if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
57
+ std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
58
+ return Arguments{data_ptr, tensor.numel() != 1};
59
+ } else {
60
+ // it would technically work but no use case as data_ptr is never nullptr
61
+ static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
62
+ return Arguments{data_ptr};
63
+ }
64
+ }
65
+
66
+ // This overload handles the case where there might not be a tensor, in which
67
+ // case a nullptr is passed and a constant (0) is used.
68
+ template <typename Descriptor, typename T>
69
+ static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
70
+ static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
71
+ using Arguments = typename Descriptor::Arguments;
72
+ auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
73
+ return Arguments{data_ptr};
74
+ }
75
+ };
76
+
77
+ /*
78
+ This epilogue function defines a quantized GEMM operation similar to
79
+ torch._scaled_mm.
80
+
81
+ A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
82
+ per-row. B can be quantized per-tensor or per-column.
83
+ Any combination of per-tensor and per-row or column is supported.
84
+ A and B must have symmetric quantization (zero point == 0).
85
+
86
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
87
+ scales are applied elementwise with numpy-style broadcasting.
88
+
89
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
90
+ the A and B operands respectively. These scales may be either per-tensor or
91
+ per row or column.
92
+ */
93
+ template <typename ElementD, typename OutputTileThreadMap>
94
+ struct ScaledEpilogue
95
+ : private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
96
+ private:
97
+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
98
+ using Accum = typename SUPER::Accum;
99
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
100
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
101
+
102
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
103
+ cutlass::multiplies, float, float,
104
+ cutlass::FloatRoundStyle::round_to_nearest>;
105
+
106
+ using EVTCompute0 =
107
+ cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
108
+
109
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
110
+ cutlass::multiplies, ElementD, float,
111
+ cutlass::FloatRoundStyle::round_to_nearest>;
112
+
113
+ public:
114
+ using EVTCompute =
115
+ cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
116
+ using ArgumentType = typename EVTCompute::Arguments;
117
+
118
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
119
+ torch::Tensor const& b_scales) {
120
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
121
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
122
+
123
+ typename EVTCompute0::Arguments evt0_args{b_args};
124
+ return ArgumentType{a_args, evt0_args};
125
+ }
126
+ };
127
+
128
+ /*
129
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
130
+ * This bias can also be used in the per-tensor azp case, where the activation
131
+ * zero point (azp) is used to compute an azp correction term,
132
+ * which is folded into the bias.
133
+ *
134
+ * The bias tensor must be per-output channel.
135
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
136
+ */
137
+ template <typename ElementD, typename OutputTileThreadMap>
138
+ struct ScaledEpilogueBias
139
+ : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
140
+ protected:
141
+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
142
+ using Accum = typename SUPER::Accum;
143
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
144
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
145
+ using Bias = typename SUPER::template RowLoad<ElementD>;
146
+ using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
147
+ cutlass::multiplies, float, float,
148
+ cutlass::FloatRoundStyle::round_to_nearest>;
149
+
150
+ using EVTCompute0 =
151
+ cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
152
+
153
+ using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
154
+ cutlass::multiply_add, ElementD, float,
155
+ cutlass::FloatRoundStyle::round_to_nearest>;
156
+
157
+ public:
158
+ using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
159
+ EVTCompute0, Bias>;
160
+ using ArgumentType = typename EVTCompute::Arguments;
161
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
162
+ torch::Tensor const& b_scales,
163
+ torch::Tensor const& bias) {
164
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
165
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
166
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
167
+
168
+ typename EVTCompute0::Arguments evt0_args{b_args};
169
+ return ArgumentType{a_args, evt0_args, bias_args};
170
+ }
171
+ };
172
+
173
+ /*
174
+ * This epilogue directly supports per-tensor azp in int32 form.
175
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
176
+ * term, which should already be multiplied with the scalar azp.
177
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
178
+ *
179
+ * This epilogue also supports bias, which remains per-channel.
180
+ */
181
+ template <typename ElementD, typename OutputTileThreadMap>
182
+ struct ScaledEpilogueBiasAzp
183
+ : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
184
+ private:
185
+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
186
+ using Accum = typename SUPER::Accum;
187
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
188
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
189
+ using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
190
+
191
+ // This is the full AZP term, azp * J @ B, shape (1,n)
192
+ using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
193
+
194
+ // Compute float(accum - azp_adj), both operands are int32_t
195
+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
196
+ cutlass::minus, float, int32_t,
197
+ cutlass::FloatRoundStyle::round_to_nearest>;
198
+
199
+ using EVTComputeAzp =
200
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
201
+
202
+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
203
+ cutlass::multiplies, float, float,
204
+ cutlass::FloatRoundStyle::round_to_nearest>;
205
+
206
+ using EVTComputeScaleB =
207
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
208
+ EVTComputeAzp>;
209
+
210
+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
211
+ cutlass::multiply_add, ElementD, float,
212
+ cutlass::FloatRoundStyle::round_to_nearest>;
213
+
214
+ public:
215
+ using EVTCompute =
216
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
217
+ EVTComputeScaleB, Bias>;
218
+
219
+ using ArgumentType = typename EVTCompute::Arguments;
220
+
221
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
222
+ torch::Tensor const& b_scales,
223
+ torch::Tensor const& azp_adj,
224
+ c10::optional<torch::Tensor> const& bias) {
225
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
226
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
227
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
228
+ auto azp_adj_args =
229
+ SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
230
+
231
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
232
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
233
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
234
+ }
235
+ };
236
+
237
+ /*
238
+ * This epilogue supports per-token azp by computing and applying
239
+ * the correction term using a rank-1 update. If the term were materialized,
240
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
241
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
242
+ * point for each row of A.
243
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
244
+ *
245
+ * This epilogue also supports bias, which remains per-channel.
246
+ */
247
+ template <typename ElementD, typename OutputTileThreadMap>
248
+ struct ScaledEpilogueBiasAzpToken
249
+ : protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
250
+ private:
251
+ using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
252
+ using Accum = typename SUPER::Accum;
253
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
254
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
255
+ using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
256
+
257
+ // Per-token azp term, shape (m,1)
258
+ using Azp = typename SUPER::template ColLoad<int32_t>;
259
+
260
+ // This is the AZP adjustment term, J @ B, shape (1,n)
261
+ using AzpAdj = typename SUPER::template RowLoad<int32_t>;
262
+
263
+ // Compute azp * azp_adj
264
+ using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
265
+ cutlass::multiplies, int32_t, int32_t,
266
+ cutlass::FloatRoundStyle::round_to_nearest>;
267
+
268
+ using EVTComputeAzp =
269
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
270
+
271
+ // Compute float(accum - azp*azp_adj), all operands are int32_t
272
+ using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
273
+ cutlass::minus, float, int32_t,
274
+ cutlass::FloatRoundStyle::round_to_nearest>;
275
+
276
+ using EVTComputeAcc =
277
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
278
+
279
+ using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
280
+ cutlass::multiplies, float, float,
281
+ cutlass::FloatRoundStyle::round_to_nearest>;
282
+
283
+ using EVTComputeScaleB =
284
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
285
+ EVTComputeAcc>;
286
+
287
+ using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
288
+ cutlass::multiply_add, ElementD, float,
289
+ cutlass::FloatRoundStyle::round_to_nearest>;
290
+
291
+ public:
292
+ using EVTCompute =
293
+ cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
294
+ EVTComputeScaleB, Bias>;
295
+
296
+ using ArgumentType = typename EVTCompute::Arguments;
297
+
298
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
299
+ torch::Tensor const& b_scales,
300
+ torch::Tensor const& azp_adj,
301
+ torch::Tensor const& azp,
302
+ c10::optional<torch::Tensor> const& bias) {
303
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
304
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
305
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
306
+ auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
307
+ auto azp_adj_args =
308
+ SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
309
+
310
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
311
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
312
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
313
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
314
+ }
315
+ };
316
+
317
+ }; // namespace vllm::c2x
cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
2
+
3
+ /*
4
+ This file defines custom epilogues for fusing channel scales, token scales,
5
+ bias, and activation zero-points onto a GEMM operation using the
6
+ CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
7
+
8
+ Epilogues must contain a public type named EVTCompute of type Sm90EVT,
9
+ as well as a static prepare_args function that constructs an
10
+ EVTCompute::Arguments struct.
11
+ */
12
+
13
+ namespace vllm::c3x {
14
+
15
+ using namespace cute;
16
+
17
+ /*
18
+ * This class provides the common load descriptors for the
19
+ * ScaledEpilogue[...] classes
20
+ */
21
+ template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
22
+ struct ScaledEpilogueBase {
23
+ protected:
24
+ using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
25
+
26
+ template <typename T>
27
+ using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
28
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
29
+ Stride<Int<1>, Int<0>, Int<0>>>;
30
+
31
+ template <typename T>
32
+ using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
33
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
34
+ Stride<Int<0>, Int<1>, Int<0>>>;
35
+
36
+ // Don't want to support nullptr by default
37
+ template <typename T, bool EnableNullPtr = false>
38
+ using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
39
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
40
+ Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
41
+
42
+ // Don't want to support nullptr by default
43
+ template <typename T, bool EnableNullPtr = false>
44
+ using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
45
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
46
+ Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
47
+
48
+ // This utility function constructs the arguments for the load descriptors
49
+ // from a tensor. It can handle both row and column, as well as row/column or
50
+ // scalar cases.
51
+ template <typename Descriptor, typename T>
52
+ static auto args_from_tensor(torch::Tensor const& tensor) {
53
+ using Arguments = typename Descriptor::Arguments;
54
+ auto* data_ptr = static_cast<T*>(tensor.data_ptr());
55
+ if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
56
+ std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
57
+ return Arguments{data_ptr, tensor.numel() != 1};
58
+ } else {
59
+ static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
60
+ !std::is_same_v<Descriptor, RowLoad<T, true>>);
61
+ return Arguments{data_ptr};
62
+ }
63
+ }
64
+
65
+ // This overload handles the case where there might not be a tensor, in which
66
+ // case a nullptr is passed and a constant (0) is used.
67
+ template <typename Descriptor, typename T>
68
+ static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
69
+ using Arguments = typename Descriptor::Arguments;
70
+ auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
71
+ static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
72
+ std::is_same_v<Descriptor, RowLoad<T, true>>);
73
+ return Arguments{data_ptr};
74
+ }
75
+ };
76
+
77
+ /*
78
+ This epilogue function defines a quantized GEMM operation similar to
79
+ torch.scaled_mm_.
80
+
81
+ A and B may be both either int8 or fp8_e4m3. A can be
82
+ quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
83
+ Any combination of per-tensor and per-row or column is supported.
84
+ A and B must have symmetric quantization (zero point == 0).
85
+
86
+ So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
87
+ scales are applied elementwise with numpy-style broadcasting.
88
+
89
+ ScaleA and ScaleB define the epilogue functions that apply the scales for
90
+ the A and B operands respectively. These scales may be either per-tensor or
91
+ per row or column.
92
+ */
93
+ template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
94
+ struct ScaledEpilogue
95
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
96
+ private:
97
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
98
+ using Accum = typename SUPER::Accum;
99
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
100
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
101
+
102
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
103
+ cutlass::multiplies, float, float,
104
+ cutlass::FloatRoundStyle::round_to_nearest>;
105
+
106
+ using EVTCompute0 =
107
+ cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
108
+
109
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
110
+ cutlass::multiplies, ElementD, float,
111
+ cutlass::FloatRoundStyle::round_to_nearest>;
112
+
113
+ public:
114
+ using EVTCompute =
115
+ cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
116
+ using ArgumentType = typename EVTCompute::Arguments;
117
+
118
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
119
+ torch::Tensor const& b_scales) {
120
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
121
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
122
+
123
+ typename EVTCompute0::Arguments evt0_args{b_args};
124
+ return ArgumentType{a_args, evt0_args};
125
+ }
126
+ };
127
+
128
+ /*
129
+ * This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
130
+ * This bias can also be used in the per-tensor azp case, where the activation
131
+ * zero point (azp) is used to compute an azp correction term,
132
+ * which is folded into the bias.
133
+ *
134
+ * The bias tensor must be per-output channel.
135
+ * ScaleA and ScaleB can be per-tensor or per-token/per-channel.
136
+ */
137
+ template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
138
+ struct ScaledEpilogueBias
139
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
140
+ private:
141
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
142
+ using Accum = typename SUPER::Accum;
143
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
144
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
145
+ using Bias = typename SUPER::template RowLoad<ElementD>;
146
+
147
+ using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
148
+ cutlass::multiplies, float, float,
149
+ cutlass::FloatRoundStyle::round_to_nearest>;
150
+
151
+ using EVTCompute0 =
152
+ cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
153
+
154
+ using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
155
+ cutlass::multiply_add, ElementD, float,
156
+ cutlass::FloatRoundStyle::round_to_nearest>;
157
+
158
+ public:
159
+ using EVTCompute =
160
+ cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
161
+
162
+ using ArgumentType = typename EVTCompute::Arguments;
163
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
164
+ torch::Tensor const& b_scales,
165
+ torch::Tensor const& bias) {
166
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
167
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
168
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
169
+
170
+ typename EVTCompute0::Arguments evt0_args{b_args};
171
+ return ArgumentType{a_args, evt0_args, bias_args};
172
+ }
173
+ };
174
+
175
+ /*
176
+ * This epilogue directly supports per-tensor azp in int32 form.
177
+ * As opposed to the per-token epilogue below, this epilogue only has an azp_adj
178
+ * term, which should already be multiplied with the scalar azp.
179
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
180
+ *
181
+ * This epilogue also supports bias, which remains per-channel.
182
+ */
183
+ template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
184
+ struct ScaledEpilogueBiasAzp
185
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
186
+ private:
187
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
188
+ using Accum = typename SUPER::Accum;
189
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
190
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
191
+ using Bias = typename SUPER::template RowLoad<ElementD, true>;
192
+
193
+ // This is the full AZP term, azp * J @ B, shape (1,n)
194
+ using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
195
+
196
+ // Compute float(accum - azp_adj), both operands are int32_t
197
+ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
198
+ cutlass::minus, float, int32_t,
199
+ cutlass::FloatRoundStyle::round_to_nearest>;
200
+
201
+ using EVTComputeAzp =
202
+ cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
203
+
204
+ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
205
+ cutlass::multiplies, float, float,
206
+ cutlass::FloatRoundStyle::round_to_nearest>;
207
+
208
+ using EVTComputeScaleB =
209
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
210
+
211
+ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
212
+ cutlass::multiply_add, ElementD, float,
213
+ cutlass::FloatRoundStyle::round_to_nearest>;
214
+
215
+ public:
216
+ using EVTCompute =
217
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
218
+ EVTComputeScaleB, Bias>;
219
+ using ArgumentType = typename EVTCompute::Arguments;
220
+
221
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
222
+ torch::Tensor const& b_scales,
223
+ torch::Tensor const& azp_adj,
224
+ c10::optional<torch::Tensor> const& bias) {
225
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
226
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
227
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
228
+ auto azp_adj_args =
229
+ SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
230
+
231
+ typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
232
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
233
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
234
+ }
235
+ };
236
+
237
+ /*
238
+ * This epilogue supports per-token azp by computing and applying
239
+ * the correction term using a rank-1 update. If the term were materialized,
240
+ * it would require O(m*n) space, and this way it only requires O(m+n) space.
241
+ * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
242
+ * point for each row of A.
243
+ * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
244
+ *
245
+ * This epilogue also supports bias, which remains per-channel.
246
+ */
247
+ template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
248
+ struct ScaledEpilogueBiasAzpToken
249
+ : private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
250
+ private:
251
+ using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
252
+ using Accum = typename SUPER::Accum;
253
+ using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
254
+ using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
255
+ using Bias = typename SUPER::template RowLoad<ElementD, true>;
256
+
257
+ // Per-token azp term, shape (m,1)
258
+ using Azp = typename SUPER::template ColLoad<int32_t>;
259
+
260
+ // This is the AZP adjustment term, J @ B, shape (1,n)
261
+ using AzpAdj = typename SUPER::template RowLoad<int32_t>;
262
+
263
+ // Compute azp * azp_adj
264
+ using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
265
+ cutlass::multiplies, int32_t, int32_t,
266
+ cutlass::FloatRoundStyle::round_to_nearest>;
267
+
268
+ using EVTComputeAzp =
269
+ cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
270
+
271
+ // Compute float(accum - azp*azp_adj), all operands are int32_t
272
+ using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
273
+ cutlass::minus, float, int32_t,
274
+ cutlass::FloatRoundStyle::round_to_nearest>;
275
+
276
+ using EVTComputeAcc =
277
+ cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
278
+
279
+ using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
280
+ cutlass::multiplies, float, float,
281
+ cutlass::FloatRoundStyle::round_to_nearest>;
282
+
283
+ using EVTComputeScaleB =
284
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
285
+
286
+ using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
287
+ cutlass::multiply_add, ElementD, float,
288
+ cutlass::FloatRoundStyle::round_to_nearest>;
289
+
290
+ public:
291
+ using EVTCompute =
292
+ cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
293
+ EVTComputeScaleB, Bias>;
294
+ using ArgumentType = typename EVTCompute::Arguments;
295
+
296
+ static ArgumentType prepare_args(torch::Tensor const& a_scales,
297
+ torch::Tensor const& b_scales,
298
+ torch::Tensor const& azp_adj,
299
+ torch::Tensor const& azp,
300
+ c10::optional<torch::Tensor> const& bias) {
301
+ auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
302
+ auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
303
+ auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
304
+ auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
305
+ auto azp_adj_args =
306
+ SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
307
+
308
+ typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
309
+ typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
310
+ typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
311
+ return ArgumentType{a_args, evt_scale_b_args, bias_args};
312
+ }
313
+ };
314
+
315
+ }; // namespace vllm::c3x
cutlass_w8a8/Epilogues.md ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUTLASS Epilogues
2
+
3
+ ## Introduction
4
+ This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
5
+
6
+ Currently, we only support symmetric quantization for weights,
7
+ and symmetric and asymmetric quantization for activations.
8
+ Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
9
+
10
+ There are 4 epilogues:
11
+ 1. ScaledEpilogue: symmetric quantization for activations, no bias.
12
+ 1. ScaledEpilogueBias: symmetric quantization for activations, supports bias.
13
+ 1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
14
+ 1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
15
+
16
+ We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
17
+ Instead, if no bias is passed, the epilogue will use 0 as the bias.
18
+ That induces a redundant addition operation (and runtime check), but the performance impact is minor.
19
+
20
+ ## Underlying Linear Algebra
21
+
22
+ More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).
23
+
24
+ If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
25
+
26
+ ```math
27
+ A = s_a (\widehat A - J_a z_a)
28
+ ```
29
+ ```math
30
+ B = s_b \widehat B
31
+ ```
32
+ ```math
33
+ D = A B + C
34
+ ```
35
+ ```math
36
+ D = s_a s_b \widehat D + C
37
+ ```
38
+
39
+ Here, D is the output of the GEMM, and C is the bias.
40
+ A is the activations and supports asymmetric quantization,
41
+ and B is the weights and only supports symmetric quantization.
42
+ $ s_a $ and $s_b$ are the scales for activations and weights, respectively.
43
+ $ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
44
+ Additional epilogues would be required to support asymmetric quantization for weights.
45
+
46
+ Expanding further, we can calculate $` \widehat D `$ as follows:
47
+
48
+ ```math
49
+ A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
50
+ ```
51
+ ```math
52
+ A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
53
+ ```
54
+ ```math
55
+ \widehat D = \widehat A \widehat B - z_a J_a \widehat B
56
+ ```
57
+
58
+ Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
59
+ and $` J_a \widehat B `$ is known ahead of time.
60
+ Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.
61
+
62
+ ## Epilogues
63
+
64
+ ### ScaledEpilogue
65
+ This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
66
+ The output of the GEMM is:
67
+
68
+ ```math
69
+ \widehat D = \widehat A \widehat B
70
+ ```
71
+ ```math
72
+ D = s_a s_b \widehat D
73
+ ```
74
+ ```math
75
+ D = s_a s_b \widehat A \widehat B
76
+ ```
77
+
78
+ Epilogue parameters:
79
+ - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
80
+ - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
81
+
82
+ ### ScaledEpilogueBias
83
+ This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
84
+ The output of the GEMM is:
85
+
86
+ ```math
87
+ \widehat D = \widehat A \widehat B
88
+ ```
89
+ ```math
90
+ D = s_a s_b \widehat D + C
91
+ ```
92
+ ```math
93
+ D = s_a s_b \widehat A \widehat B + C
94
+ ```
95
+
96
+
97
+ Epilogue parameters:
98
+ - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
99
+ - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
100
+ - `bias` is the bias, is always per-channel (row-vector).
101
+
102
+ ### ScaledEpilogueAzp
103
+ This epilogue computes the asymmetric per-tensor quantization for activations with bias.
104
+ The output of the GEMM is:
105
+
106
+ ```math
107
+ \widehat D = \widehat A \widehat B - z_a J_a \widehat B
108
+ ```
109
+ ```math
110
+ D = s_a s_b \widehat D + C
111
+ ```
112
+ ```math
113
+ D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
114
+ ```
115
+
116
+ Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
117
+ That is precomputed and stored in `azp_with_adj` as a row-vector.
118
+
119
+ Epilogue parameters:
120
+ - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
121
+ - Generally this will be per-tensor as the zero-points are per-tensor.
122
+ - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
123
+ - `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
124
+ - `bias` is the bias, is always per-channel (row-vector).
125
+
126
+ To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
127
+
128
+ ### ScaledEpilogueAzpPerToken
129
+ This epilogue computes the asymmetric per-token quantization for activations with bias.
130
+
131
+ The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
132
+ That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
133
+
134
+ Epilogue parameters:
135
+ - `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
136
+ - Generally this will be per-token as the zero-points are per-token.
137
+ - `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
138
+ - `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
139
+ - `azp` is the zero-point (`z_a`), is per-token (column-vector).
140
+ - `bias` is the bias, is always per-channel (row-vector).
141
+
142
+ To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
143
+
144
+ The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
145
+ ```
146
+ out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
147
+ ```
cutlass_w8a8/common.hpp ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "cutlass/cutlass.h"
4
+ #include <climits>
5
+
6
+ /**
7
+ * Helper function for checking CUTLASS errors
8
+ */
9
+ #define CUTLASS_CHECK(status) \
10
+ { \
11
+ TORCH_CHECK(status == cutlass::Status::kSuccess, \
12
+ cutlassGetStatusString(status)) \
13
+ }
14
+
15
+ inline uint32_t next_pow_2(uint32_t const num) {
16
+ if (num <= 1) return num;
17
+ return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
18
+ }
19
+
20
+ inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
21
+ int max_shared_mem_per_block_opt_in = 0;
22
+ cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
23
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
24
+ device);
25
+ return max_shared_mem_per_block_opt_in;
26
+ }
27
+
cutlass_w8a8/scaled_mm_c2x.cu ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <stddef.h>
2
+ #include <torch/all.h>
3
+ #include "cutlass/cutlass.h"
4
+
5
+ #include "scaled_mm_c2x.cuh"
6
+ #include "scaled_mm_c2x_sm75_dispatch.cuh"
7
+ #include "scaled_mm_c2x_sm80_dispatch.cuh"
8
+ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
9
+ #include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
10
+
11
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
12
+
13
+ using namespace vllm;
14
+
15
+ /*
16
+ This file defines quantized GEMM operations using the CUTLASS 2.x API, for
17
+ NVIDIA GPUs with SM versions prior to sm90 (Hopper).
18
+ */
19
+
20
+ template <template <typename, typename> typename Epilogue,
21
+ typename... EpilogueArgs>
22
+ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
23
+ torch::Tensor const& b,
24
+ EpilogueArgs&&... epilogue_args) {
25
+ TORCH_CHECK(a.dtype() == torch::kInt8);
26
+ TORCH_CHECK(b.dtype() == torch::kInt8);
27
+
28
+ if (out.dtype() == torch::kBFloat16) {
29
+ return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
30
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
31
+ } else {
32
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
33
+ return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
34
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
35
+ }
36
+ }
37
+
38
+ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
39
+ torch::Tensor const& b,
40
+ torch::Tensor const& a_scales,
41
+ torch::Tensor const& b_scales,
42
+ c10::optional<torch::Tensor> const& bias) {
43
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
44
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
45
+ if (bias) {
46
+ TORCH_CHECK(bias->dtype() == out.dtype(),
47
+ "currently bias dtype must match output dtype ", out.dtype());
48
+ return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
49
+ out, a, b, a_scales, b_scales, *bias);
50
+ } else {
51
+ return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
52
+ out, a, b, a_scales, b_scales);
53
+ }
54
+ }
55
+
56
+ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
57
+ torch::Tensor const& b,
58
+ torch::Tensor const& a_scales,
59
+ torch::Tensor const& b_scales,
60
+ torch::Tensor const& azp_adj,
61
+ c10::optional<torch::Tensor> const& azp,
62
+ c10::optional<torch::Tensor> const& bias) {
63
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
64
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
65
+
66
+ if (azp) {
67
+ return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
68
+ out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
69
+ } else {
70
+ return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
71
+ out, a, b, a_scales, b_scales, azp_adj, bias);
72
+ }
73
+ }
74
+
75
+ template <template <typename, typename> typename Epilogue,
76
+ typename... EpilogueArgs>
77
+ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
78
+ torch::Tensor const& b,
79
+ EpilogueArgs&&... epilogue_args) {
80
+ TORCH_CHECK(a.dtype() == torch::kInt8);
81
+ TORCH_CHECK(b.dtype() == torch::kInt8);
82
+
83
+ if (out.dtype() == torch::kBFloat16) {
84
+ return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
85
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
86
+ } else {
87
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
88
+ return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
89
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
90
+ }
91
+ }
92
+
93
+ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
94
+ torch::Tensor const& b,
95
+ torch::Tensor const& a_scales,
96
+ torch::Tensor const& b_scales,
97
+ c10::optional<torch::Tensor> const& bias) {
98
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
99
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
100
+ if (bias) {
101
+ TORCH_CHECK(bias->dtype() == out.dtype(),
102
+ "currently bias dtype must match output dtype ", out.dtype());
103
+ return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
104
+ out, a, b, a_scales, b_scales, *bias);
105
+ } else {
106
+ return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
107
+ out, a, b, a_scales, b_scales);
108
+ }
109
+ }
110
+
111
+ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
112
+ torch::Tensor const& b,
113
+ torch::Tensor const& a_scales,
114
+ torch::Tensor const& b_scales,
115
+ torch::Tensor const& azp_adj,
116
+ c10::optional<torch::Tensor> const& azp,
117
+ c10::optional<torch::Tensor> const& bias) {
118
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
119
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
120
+
121
+ if (azp) {
122
+ return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
123
+ out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
124
+ } else {
125
+ return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
126
+ out, a, b, a_scales, b_scales, azp_adj, bias);
127
+ }
128
+ }
129
+
130
+ template <template <typename, typename> typename Epilogue,
131
+ typename... EpilogueArgs>
132
+ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
133
+ torch::Tensor const& b,
134
+ EpilogueArgs&&... epilogue_args) {
135
+ if (a.dtype() == torch::kInt8) {
136
+ TORCH_CHECK(b.dtype() == torch::kInt8);
137
+
138
+ if (out.dtype() == torch::kBFloat16) {
139
+ return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
140
+ Epilogue>(
141
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
142
+ } else {
143
+ assert(out.dtype() == torch::kFloat16);
144
+ return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
145
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
146
+ }
147
+ } else {
148
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
149
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
150
+
151
+ if (out.dtype() == torch::kBFloat16) {
152
+ return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
153
+ cutlass::bfloat16_t, Epilogue>(
154
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
155
+ } else {
156
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
157
+ return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
158
+ cutlass::half_t, Epilogue>(
159
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
160
+ }
161
+ }
162
+ }
163
+
164
+ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
165
+ torch::Tensor const& b,
166
+ torch::Tensor const& a_scales,
167
+ torch::Tensor const& b_scales,
168
+ c10::optional<torch::Tensor> const& bias) {
169
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
170
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
171
+ if (bias) {
172
+ TORCH_CHECK(bias->dtype() == out.dtype(),
173
+ "currently bias dtype must match output dtype ", out.dtype());
174
+ return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
175
+ out, a, b, a_scales, b_scales, *bias);
176
+ } else {
177
+ return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
178
+ out, a, b, a_scales, b_scales);
179
+ }
180
+ }
181
+
182
+ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
183
+ torch::Tensor const& b,
184
+ torch::Tensor const& a_scales,
185
+ torch::Tensor const& b_scales,
186
+ torch::Tensor const& azp_adj,
187
+ c10::optional<torch::Tensor> const& azp,
188
+ c10::optional<torch::Tensor> const& bias) {
189
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
190
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
191
+
192
+ if (azp) {
193
+ return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
194
+ out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
195
+ } else {
196
+ return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
197
+ out, a, b, a_scales, b_scales, azp_adj, bias);
198
+ }
199
+ }
cutlass_w8a8/scaled_mm_c2x.cuh ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <stddef.h>
3
+ #include <torch/all.h>
4
+
5
+ #include <ATen/cuda/CUDAContext.h>
6
+
7
+ // clang-format will break include orders
8
+ // clang-format off
9
+ #include "cute/tensor.hpp"
10
+ #include "cute/atom/mma_atom.hpp"
11
+ #include "cutlass/numeric_types.h"
12
+
13
+ #include "cutlass/cutlass.h"
14
+ #include "cutlass/gemm_coord.h"
15
+ #include "cutlass/arch/mma_sm75.h"
16
+ #include "cutlass/arch/arch.h"
17
+ #include "cutlass/arch/mma.h"
18
+ #include "cutlass/gemm/device/gemm.h"
19
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
20
+
21
+ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
22
+ #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
23
+
24
+ #include "common.hpp"
25
+ // clang-format on
26
+
27
+ using namespace cute;
28
+
29
+ /*
30
+ Epilogue functions can be defined to post-process the output before it is
31
+ written to GPU memory.
32
+ Epilogues must contain a public type named EVTCompute of type Sm80EVT,
33
+ as well as a static prepare_args function that constructs an
34
+ EVTCompute::Arguments struct.
35
+ */
36
+
37
+ namespace vllm {
38
+
39
+ // Wrappers for the GEMM kernel that is used to guard against compilation on
40
+ // architectures that will never use the kernel. The purpose of this is to
41
+ // reduce the size of the compiled binary.
42
+ // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
43
+ // into code that will be executed on the device where it is defined.
44
+ template <typename Kernel>
45
+ struct enable_sm75_to_sm80 : Kernel {
46
+ template <typename... Args>
47
+ CUTLASS_DEVICE static void invoke(Args&&... args) {
48
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
49
+ Kernel::invoke(std::forward<Args>(args)...);
50
+ #endif
51
+ }
52
+ };
53
+
54
+ template <typename Kernel>
55
+ struct enable_sm80_to_sm89 : Kernel {
56
+ template <typename... Args>
57
+ CUTLASS_DEVICE static void invoke(Args&&... args) {
58
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
59
+ Kernel::invoke(std::forward<Args>(args)...);
60
+ #endif
61
+ }
62
+ };
63
+
64
+ template <typename Kernel>
65
+ struct enable_sm89_to_sm90 : Kernel {
66
+ template <typename... Args>
67
+ CUTLASS_DEVICE static void invoke(Args&&... args) {
68
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
69
+ Kernel::invoke(std::forward<Args>(args)...);
70
+ #endif
71
+ }
72
+ };
73
+ template <typename Arch, template <typename> typename ArchGuard,
74
+ typename ElementAB_, typename ElementD_,
75
+ template <typename, typename> typename Epilogue_, typename TileShape,
76
+ typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
77
+ typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
78
+ struct cutlass_2x_gemm {
79
+ using ElementAB = ElementAB_;
80
+ using ElementD = ElementD_;
81
+
82
+ using ElementAcc =
83
+ typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
84
+ float>::type;
85
+
86
+ using Operator =
87
+ typename std::conditional<std::is_same_v<ElementAB, int8_t>,
88
+ cutlass::arch::OpMultiplyAddSaturate,
89
+ FP8MathOperator>::type;
90
+
91
+ using OutputTileThreadMap =
92
+ cutlass::epilogue::threadblock::OutputTileThreadLayout<
93
+ TileShape, WarpShape, float, 4, 1 /* epilogue stages */
94
+ >;
95
+
96
+ using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
97
+ using EVTCompute = typename Epilogue::EVTCompute;
98
+
99
+ using D = cutlass::epilogue::threadblock::VisitorAuxStore<
100
+ OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
101
+ Stride<int64_t, Int<1>, Int<0>>>;
102
+
103
+ using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
104
+
105
+ // clang-format off
106
+ using RowMajor = typename cutlass::layout::RowMajor;
107
+ using ColumnMajor = typename cutlass::layout::ColumnMajor;
108
+ using KernelType =
109
+ ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
110
+ ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
111
+ ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
112
+ float, cutlass::layout::RowMajor, 4,
113
+ ElementAcc, float, cutlass::arch::OpClassTensorOp,
114
+ Arch,
115
+ TileShape, WarpShape, InstructionShape,
116
+ EVTD,
117
+ cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
118
+ MainLoopStages, Operator,
119
+ 1 /* epilogue stages */
120
+ >::GemmKernel>;
121
+ // clang-format on
122
+
123
+ using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
124
+ };
125
+
126
+ template <typename Gemm, typename... EpilogueArgs>
127
+ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
128
+ torch::Tensor const& b,
129
+ EpilogueArgs&&... epilogue_params) {
130
+ using ElementAB = typename Gemm::ElementAB;
131
+ using ElementD = typename Gemm::ElementD;
132
+
133
+ int32_t m = a.size(0);
134
+ int32_t n = b.size(1);
135
+ int32_t k = a.size(1);
136
+ cutlass::gemm::GemmCoord problem_size{m, n, k};
137
+
138
+ int64_t lda = a.stride(0);
139
+ int64_t ldb = b.stride(1);
140
+ int64_t ldc = out.stride(0);
141
+
142
+ using StrideC = Stride<int64_t, Int<1>, Int<0>>;
143
+ StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
144
+
145
+ auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
146
+ auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
147
+ auto c_ptr = static_cast<ElementD*>(out.data_ptr());
148
+
149
+ typename Gemm::D::Arguments d_args{c_ptr, c_stride};
150
+
151
+ using Epilogue = typename Gemm::Epilogue;
152
+ auto evt_args =
153
+ Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
154
+
155
+ typename Gemm::EVTD::Arguments epilogue_args{
156
+ evt_args,
157
+ d_args,
158
+ };
159
+
160
+ typename Gemm::Op::Arguments args{
161
+ cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
162
+ problem_size, // problem size
163
+ 1, // batch count
164
+ epilogue_args,
165
+ a_ptr,
166
+ b_ptr,
167
+ nullptr,
168
+ nullptr,
169
+ 0,
170
+ 0,
171
+ 0,
172
+ 0,
173
+ lda,
174
+ ldb,
175
+ ldc,
176
+ ldc};
177
+
178
+ // Launch the CUTLASS GEMM kernel.
179
+ typename Gemm::Op gemm_op;
180
+ size_t workspace_size = gemm_op.get_workspace_size(args);
181
+ auto const workspace_options =
182
+ torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
183
+ auto workspace = torch::empty(workspace_size, workspace_options);
184
+
185
+ auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
186
+
187
+ CUTLASS_CHECK(gemm_op.can_implement(args));
188
+ cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
189
+ CUTLASS_CHECK(status);
190
+ }
191
+
192
+ template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
193
+ inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
194
+ torch::Tensor const& a,
195
+ torch::Tensor const& b,
196
+ EpilogueArgs&&... args) {
197
+ // In some cases, the GPU isn't able to accommodate the
198
+ // shared memory requirements of the Gemm. In such cases, use
199
+ // the FallbackGemm instead.
200
+ static const int max_shared_mem_per_block_opt_in =
201
+ get_cuda_max_shared_memory_per_block_opt_in(0);
202
+
203
+ size_t const gemm_shared_mem_size =
204
+ sizeof(typename Gemm::KernelType::SharedStorage);
205
+ size_t const fallback_gemm_shared_mem_size =
206
+ sizeof(typename FallbackGemm::KernelType::SharedStorage);
207
+
208
+ if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
209
+ return cutlass_gemm_caller<Gemm>(out, a, b,
210
+ std::forward<EpilogueArgs>(args)...);
211
+ } else {
212
+ TORCH_CHECK(fallback_gemm_shared_mem_size <=
213
+ max_shared_mem_per_block_opt_in);
214
+ return cutlass_gemm_caller<FallbackGemm>(
215
+ out, a, b, std::forward<EpilogueArgs>(args)...);
216
+ }
217
+ }
218
+
219
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm_c2x.cuh"
4
+
5
+ /**
6
+ * This file defines Gemm kernel configurations for SM75 based on the Gemm
7
+ * shape.
8
+ */
9
+
10
+ namespace vllm {
11
+
12
+ template <typename InType, typename OutType,
13
+ template <typename, typename> typename Epilogue>
14
+ struct sm75_config_default {
15
+ // This config is used in 2 cases,
16
+ // - M in (256, inf]
17
+ // - M in (64, 128]
18
+ // Shared memory required by this Gemm 32768
19
+ static_assert(std::is_same<InType, int8_t>());
20
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
21
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
22
+ using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
23
+ using Cutlass2xGemm =
24
+ cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
25
+ Epilogue, TileShape, WarpShape, InstructionShape, 2>;
26
+ };
27
+
28
+ template <typename InType, typename OutType,
29
+ template <typename, typename> typename Epilogue>
30
+ struct sm75_config_M256 {
31
+ // M in (128, 256]
32
+ // Shared memory required by this Gemm 65536
33
+ static_assert(std::is_same<InType, int8_t>());
34
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
35
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
36
+ using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
37
+ using Cutlass2xGemm =
38
+ cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
39
+ Epilogue, TileShape, WarpShape, InstructionShape, 2>;
40
+ };
41
+
42
+ template <typename InType, typename OutType,
43
+ template <typename, typename> typename Epilogue>
44
+ struct sm75_config_M64 {
45
+ // M in (32, 64]
46
+ // Shared memory required by this Gemm 49152
47
+ static_assert(std::is_same<InType, int8_t>());
48
+ using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
49
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
50
+ using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
51
+ using Cutlass2xGemm =
52
+ cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
53
+ Epilogue, TileShape, WarpShape, InstructionShape, 2>;
54
+ };
55
+
56
+ template <typename InType, typename OutType,
57
+ template <typename, typename> typename Epilogue>
58
+ struct sm75_config_M32 {
59
+ // M in [1, 32]
60
+ // Shared memory required by this Gemm 49152
61
+ static_assert(std::is_same<InType, int8_t>());
62
+ using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
63
+ using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
64
+ using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
65
+ using Cutlass2xGemm =
66
+ cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
67
+ Epilogue, TileShape, WarpShape, InstructionShape, 2>;
68
+ };
69
+
70
+ template <typename InType, typename OutType,
71
+ template <typename, typename> typename Epilogue,
72
+ typename... EpilogueArgs>
73
+ inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
74
+ torch::Tensor const& a,
75
+ torch::Tensor const& b,
76
+ EpilogueArgs&&... args) {
77
+ static_assert(std::is_same<InType, int8_t>());
78
+ TORCH_CHECK(a.dtype() == torch::kInt8);
79
+ TORCH_CHECK(b.dtype() == torch::kInt8);
80
+
81
+ using Cutlass2xGemmDefault =
82
+ typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
83
+ using Cutlass2xGemmM256 =
84
+ typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
85
+ using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
86
+ using Cutlass2xGemmM64 =
87
+ typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
88
+ using Cutlass2xGemmM32 =
89
+ typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
90
+
91
+ // Due to shared memory requirements, some Gemms may fail to run on some
92
+ // GPUs. As the name indicates, the Fallback Gemm is used as an alternative
93
+ // in such cases.
94
+ // sm75_config_default has the least shared-memory requirements.
95
+ using FallbackGemm = Cutlass2xGemmDefault;
96
+
97
+ uint32_t const m = a.size(0);
98
+ uint32_t const mp2 =
99
+ std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
100
+ if (mp2 <= 32) {
101
+ // M in [1, 32]
102
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
103
+ out, a, b, std::forward<EpilogueArgs>(args)...);
104
+ } else if (mp2 <= 64) {
105
+ // M in (32, 64]
106
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
107
+ out, a, b, std::forward<EpilogueArgs>(args)...);
108
+ } else if (mp2 <= 128) {
109
+ // M in (64, 128]
110
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
111
+ out, a, b, std::forward<EpilogueArgs>(args)...);
112
+ } else if (mp2 <= 256) {
113
+ // M in (128, 256]
114
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
115
+ out, a, b, std::forward<EpilogueArgs>(args)...);
116
+ } else {
117
+ // M in (256, inf)
118
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
119
+ out, a, b, std::forward<EpilogueArgs>(args)...);
120
+ }
121
+ }
122
+
123
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm_c2x.cuh"
4
+
5
+ /**
6
+ * This file defines Gemm kernel configurations for SM80 based on the Gemm
7
+ * shape.
8
+ */
9
+
10
+ namespace vllm {
11
+
12
+ template <typename InType, typename OutType,
13
+ template <typename, typename> typename Epilogue>
14
+ struct sm80_config_default {
15
+ // This config is used in 2 cases,
16
+ // - M in (128, inf)
17
+ // - M in (64, 128] and N >= 8192
18
+ // Shared Memory required by this Gemm - 81920 bytes
19
+ static_assert(std::is_same<InType, int8_t>());
20
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
21
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
22
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
23
+ using Cutlass2xGemm =
24
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
25
+ Epilogue, TileShape, WarpShape, InstructionShape, 5>;
26
+ };
27
+
28
+ template <typename InType, typename OutType,
29
+ template <typename, typename> typename Epilogue>
30
+ struct sm80_config_M64 {
31
+ // This config is used in 2 cases,
32
+ // - M in (32, 64]
33
+ // - M in (64, 128] and N < 8192
34
+ // Shared Memory required by this Gemm - 122880 bytes
35
+ static_assert(std::is_same<InType, int8_t>());
36
+ using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
37
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
38
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
39
+ using Cutlass2xGemm =
40
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
41
+ Epilogue, TileShape, WarpShape, InstructionShape, 5>;
42
+ };
43
+
44
+ template <typename InType, typename OutType,
45
+ template <typename, typename> typename Epilogue>
46
+ struct sm80_config_M32 {
47
+ // M in (16, 32]
48
+ // Shared Memory required by this Gemm - 61440 bytes
49
+ static_assert(std::is_same<InType, int8_t>());
50
+ using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
51
+ using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
52
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
53
+ using Cutlass2xGemm =
54
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
55
+ Epilogue, TileShape, WarpShape, InstructionShape, 5>;
56
+ };
57
+
58
+ template <typename InType, typename OutType,
59
+ template <typename, typename> typename Epilogue>
60
+ struct sm80_config_M16 {
61
+ // M in [1, 16]
62
+ // Shared Memory required by this Gemm - 51200 bytes
63
+ static_assert(std::is_same<InType, int8_t>());
64
+ using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
65
+ using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
66
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
67
+ using Cutlass2xGemm =
68
+ cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
69
+ Epilogue, TileShape, WarpShape, InstructionShape, 5>;
70
+ };
71
+
72
+ template <typename InType, typename OutType,
73
+ template <typename, typename> typename Epilogue,
74
+ typename... EpilogueArgs>
75
+ inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
76
+ torch::Tensor const& a,
77
+ torch::Tensor const& b,
78
+ EpilogueArgs&&... args) {
79
+ static_assert(std::is_same<InType, int8_t>());
80
+ TORCH_CHECK(a.dtype() == torch::kInt8);
81
+ TORCH_CHECK(b.dtype() == torch::kInt8);
82
+
83
+ using Cutlass2xGemmDefault =
84
+ typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
85
+ using Cutlass2xGemmM128BigN =
86
+ typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
87
+ using Cutlass2xGemmM128SmallN =
88
+ typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
89
+ using Cutlass2xGemmM64 =
90
+ typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
91
+ using Cutlass2xGemmM32 =
92
+ typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
93
+ using Cutlass2xGemmM16 =
94
+ typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
95
+
96
+ // Due to shared memory requirements, some Gemms may fail to run on some
97
+ // GPUs. As the name indicates, the Fallback Gemm is used as an alternative
98
+ // in such cases.
99
+ // sm80_config_M16 has the least shared-memory requirement. However,
100
+ // based on some profiling, we select sm80_config_M32 as a better alternative
101
+ // performance wise.
102
+ using FallbackGemm =
103
+ typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
104
+
105
+ uint32_t const m = a.size(0);
106
+ uint32_t const mp2 =
107
+ std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
108
+ if (mp2 <= 16) {
109
+ // M in [1, 16]
110
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
111
+ out, a, b, std::forward<EpilogueArgs>(args)...);
112
+ } else if (mp2 <= 32) {
113
+ // M in (16, 32]
114
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
115
+ out, a, b, std::forward<EpilogueArgs>(args)...);
116
+ } else if (mp2 <= 64) {
117
+ // M in (32, 64]
118
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
119
+ out, a, b, std::forward<EpilogueArgs>(args)...);
120
+ } else if (mp2 <= 128) {
121
+ // M in (64, 128]
122
+ uint32_t const n = out.size(1);
123
+ bool const small_n = n < 8192;
124
+ if (small_n) {
125
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
126
+ FallbackGemm>(
127
+ out, a, b, std::forward<EpilogueArgs>(args)...);
128
+ } else {
129
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
130
+ out, a, b, std::forward<EpilogueArgs>(args)...);
131
+ }
132
+ } else {
133
+ // M in (128, inf)
134
+ return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
135
+ out, a, b, std::forward<EpilogueArgs>(args)...);
136
+ }
137
+ }
138
+
139
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm_c2x.cuh"
4
+ #include "cutlass/float8.h"
5
+
6
+ /**
7
+ * This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
8
+ * shape.
9
+ */
10
+
11
+ namespace vllm {
12
+
13
+ template <typename InType, typename OutType,
14
+ template <typename, typename> typename Epilogue>
15
+ struct sm89_fp8_fallback_gemm {
16
+ // Shared Memory required by this Gemm - 61440 bytes
17
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
18
+ using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
19
+ using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
20
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
21
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
22
+ using Cutlass2xGemm =
23
+ cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
24
+ Epilogue, TileShape, WarpShape, InstructionShape, 5,
25
+ FP8MathOperator>;
26
+ };
27
+
28
+ struct sm89_fp8_config_default {
29
+ // M in (256, inf)
30
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
31
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
32
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
33
+
34
+ template <typename InType, typename OutType,
35
+ template <typename, typename> typename Epilogue,
36
+ typename... EpilogueArgs>
37
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
38
+ torch::Tensor const& b, EpilogueArgs&&... args) {
39
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
40
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
41
+
42
+ using FallbackGemm =
43
+ typename sm89_fp8_fallback_gemm<InType, OutType,
44
+ Epilogue>::Cutlass2xGemm;
45
+
46
+ uint32_t const n = out.size(1);
47
+ uint32_t const np2 = next_pow_2(n);
48
+
49
+ if (np2 <= 4096) {
50
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
51
+
52
+ return vllm::fallback_cutlass_gemm_caller<
53
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
54
+ InType, OutType, Epilogue, TileShape, WarpShape,
55
+ InstructionShape, 5, FP8MathOperator>,
56
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
57
+ } else if (np2 <= 8192) {
58
+ using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
59
+
60
+ return vllm::fallback_cutlass_gemm_caller<
61
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
62
+ InType, OutType, Epilogue, TileShape, WarpShape,
63
+ InstructionShape, 3, FP8MathOperator>,
64
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
65
+
66
+ } else {
67
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
68
+
69
+ return vllm::fallback_cutlass_gemm_caller<
70
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
71
+ InType, OutType, Epilogue, TileShape, WarpShape,
72
+ InstructionShape, 5, FP8MathOperator>,
73
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
74
+ }
75
+ }
76
+ };
77
+
78
+ struct sm89_fp8_config_M256 {
79
+ // M in (128, 256]
80
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
81
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
82
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
83
+
84
+ template <typename InType, typename OutType,
85
+ template <typename, typename> typename Epilogue,
86
+ typename... EpilogueArgs>
87
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
88
+ torch::Tensor const& b, EpilogueArgs&&... args) {
89
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
90
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
91
+
92
+ using FallbackGemm =
93
+ typename sm89_fp8_fallback_gemm<InType, OutType,
94
+ Epilogue>::Cutlass2xGemm;
95
+
96
+ uint32_t const n = out.size(1);
97
+ uint32_t const np2 = next_pow_2(n);
98
+
99
+ if (np2 <= 4096) {
100
+ using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
101
+
102
+ return vllm::fallback_cutlass_gemm_caller<
103
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
104
+ InType, OutType, Epilogue, TileShape, WarpShape,
105
+ InstructionShape, 3, FP8MathOperator>,
106
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
107
+ } else {
108
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
109
+
110
+ return vllm::fallback_cutlass_gemm_caller<
111
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
112
+ InType, OutType, Epilogue, TileShape, WarpShape,
113
+ InstructionShape, 5, FP8MathOperator>,
114
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
115
+ }
116
+ }
117
+ };
118
+
119
+ struct sm89_fp8_config_M128 {
120
+ // M in (64, 128]
121
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
122
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
123
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
124
+
125
+ template <typename InType, typename OutType,
126
+ template <typename, typename> typename Epilogue,
127
+ typename... EpilogueArgs>
128
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
129
+ torch::Tensor const& b, EpilogueArgs&&... args) {
130
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
131
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
132
+
133
+ using FallbackGemm =
134
+ typename sm89_fp8_fallback_gemm<InType, OutType,
135
+ Epilogue>::Cutlass2xGemm;
136
+
137
+ uint32_t const n = out.size(1);
138
+ uint32_t const np2 = next_pow_2(n);
139
+
140
+ if (np2 <= 8192) {
141
+ using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
142
+
143
+ return vllm::fallback_cutlass_gemm_caller<
144
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
145
+ InType, OutType, Epilogue, TileShape, WarpShape,
146
+ InstructionShape, 3, FP8MathOperator>,
147
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
148
+
149
+ } else if (np2 <= 16384) {
150
+ using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
151
+
152
+ return vllm::fallback_cutlass_gemm_caller<
153
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
154
+ InType, OutType, Epilogue, TileShape, WarpShape,
155
+ InstructionShape, 5, FP8MathOperator>,
156
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
157
+ } else {
158
+ using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
159
+
160
+ return vllm::fallback_cutlass_gemm_caller<
161
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
162
+ InType, OutType, Epilogue, TileShape, WarpShape,
163
+ InstructionShape, 3, FP8MathOperator>,
164
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
165
+ }
166
+ }
167
+ };
168
+
169
+ struct sm89_fp8_config_M64 {
170
+ // M in (32, 64]
171
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
172
+
173
+ template <typename InType, typename OutType,
174
+ template <typename, typename> typename Epilogue,
175
+ typename... EpilogueArgs>
176
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
177
+ torch::Tensor const& b, EpilogueArgs&&... args) {
178
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
179
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
180
+
181
+ using FallbackGemm =
182
+ typename sm89_fp8_fallback_gemm<InType, OutType,
183
+ Epilogue>::Cutlass2xGemm;
184
+
185
+ uint32_t const n = out.size(1);
186
+ uint32_t const np2 = next_pow_2(n);
187
+
188
+ if (np2 <= 8196) {
189
+ using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
190
+ using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
191
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
192
+
193
+ return vllm::fallback_cutlass_gemm_caller<
194
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
195
+ InType, OutType, Epilogue, TileShape, WarpShape,
196
+ InstructionShape, 5, FP8MathOperator>,
197
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
198
+ } else if (np2 <= 16384) {
199
+ using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
200
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
201
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
202
+
203
+ return vllm::fallback_cutlass_gemm_caller<
204
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
205
+ InType, OutType, Epilogue, TileShape, WarpShape,
206
+ InstructionShape, 3, FP8MathOperator>,
207
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
208
+ } else {
209
+ using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
210
+ using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
211
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
212
+
213
+ return vllm::fallback_cutlass_gemm_caller<
214
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
215
+ InType, OutType, Epilogue, TileShape, WarpShape,
216
+ InstructionShape, 5, FP8MathOperator>,
217
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
218
+ }
219
+ }
220
+ };
221
+
222
+ struct sm89_fp8_config_M32 {
223
+ // M in (16, 32]
224
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
225
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
226
+
227
+ template <typename InType, typename OutType,
228
+ template <typename, typename> typename Epilogue,
229
+ typename... EpilogueArgs>
230
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
231
+ torch::Tensor const& b, EpilogueArgs&&... args) {
232
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
233
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
234
+
235
+ using FallbackGemm =
236
+ typename sm89_fp8_fallback_gemm<InType, OutType,
237
+ Epilogue>::Cutlass2xGemm;
238
+
239
+ uint32_t const n = out.size(1);
240
+ uint32_t const np2 = next_pow_2(n);
241
+
242
+ if (np2 <= 8192) {
243
+ using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
244
+ using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
245
+
246
+ return vllm::fallback_cutlass_gemm_caller<
247
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
248
+ InType, OutType, Epilogue, TileShape, WarpShape,
249
+ InstructionShape, 5, FP8MathOperator>,
250
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
251
+ } else if (np2 <= 16384) {
252
+ using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
253
+ using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
254
+
255
+ return vllm::fallback_cutlass_gemm_caller<
256
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
257
+ InType, OutType, Epilogue, TileShape, WarpShape,
258
+ InstructionShape, 4, FP8MathOperator>,
259
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
260
+ } else {
261
+ using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
262
+ using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
263
+
264
+ return vllm::fallback_cutlass_gemm_caller<
265
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
266
+ InType, OutType, Epilogue, TileShape, WarpShape,
267
+ InstructionShape, 5, FP8MathOperator>,
268
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
269
+ }
270
+ }
271
+ };
272
+
273
+ struct sm89_fp8_config_M16 {
274
+ // M in [1, 16]
275
+ using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
276
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
277
+ using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
278
+ static const int32_t MainLoopStages = 5;
279
+
280
+ template <typename InType, typename OutType,
281
+ template <typename, typename> typename Epilogue,
282
+ typename... EpilogueArgs>
283
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
284
+ torch::Tensor const& b, EpilogueArgs&&... args) {
285
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
286
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
287
+
288
+ using FallbackGemm =
289
+ typename sm89_fp8_fallback_gemm<InType, OutType,
290
+ Epilogue>::Cutlass2xGemm;
291
+
292
+ uint32_t const n = out.size(1);
293
+ uint32_t const np2 = next_pow_2(n);
294
+
295
+ if (np2 <= 8192) {
296
+ using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
297
+
298
+ return vllm::fallback_cutlass_gemm_caller<
299
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
300
+ InType, OutType, Epilogue, TileShape, WarpShape,
301
+ InstructionShape, MainLoopStages,
302
+ FP8MathOperator>,
303
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
304
+ } else if (np2 <= 24576) {
305
+ using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
306
+
307
+ return vllm::fallback_cutlass_gemm_caller<
308
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
309
+ InType, OutType, Epilogue, TileShape, WarpShape,
310
+ InstructionShape, MainLoopStages,
311
+ FP8MathOperator>,
312
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
313
+ } else {
314
+ using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
315
+
316
+ return vllm::fallback_cutlass_gemm_caller<
317
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
318
+ InType, OutType, Epilogue, TileShape, WarpShape,
319
+ InstructionShape, MainLoopStages,
320
+ FP8MathOperator>,
321
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
322
+ }
323
+ }
324
+ };
325
+
326
+ template <typename InType, typename OutType,
327
+ template <typename, typename> typename Epilogue,
328
+ typename... EpilogueArgs>
329
+ inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
330
+ torch::Tensor const& a,
331
+ torch::Tensor const& b,
332
+ EpilogueArgs&&... args) {
333
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
334
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
335
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
336
+
337
+ uint32_t const m = a.size(0);
338
+ uint32_t const mp2 =
339
+ std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
340
+
341
+ if (mp2 <= 16) {
342
+ // M in [1, 16]
343
+ return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
344
+ out, a, b, std::forward<EpilogueArgs>(args)...);
345
+ } else if (mp2 <= 32) {
346
+ // M in (16, 32]
347
+ return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
348
+ out, a, b, std::forward<EpilogueArgs>(args)...);
349
+ } else if (mp2 <= 64) {
350
+ // M in (32, 64]
351
+ return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
352
+ out, a, b, std::forward<EpilogueArgs>(args)...);
353
+ } else if (mp2 <= 128) {
354
+ // M in (64, 128]
355
+ return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
356
+ out, a, b, std::forward<EpilogueArgs>(args)...);
357
+ } else if (mp2 <= 256) {
358
+ // M in (128, 256]
359
+ return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
360
+ out, a, b, std::forward<EpilogueArgs>(args)...);
361
+ } else {
362
+ // M in (256, inf)
363
+ return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
364
+ out, a, b, std::forward<EpilogueArgs>(args)...);
365
+ }
366
+ }
367
+
368
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "scaled_mm_c2x.cuh"
4
+
5
+ /**
6
+ * This file defines Gemm kernel configurations for SM89 (int8) based on the
7
+ * Gemm shape.
8
+ */
9
+
10
+ namespace vllm {
11
+
12
+ template <typename InType, typename OutType,
13
+ template <typename, typename> typename Epilogue>
14
+ struct sm89_int8_fallback_gemm {
15
+ // Shared mem requirement : 61440
16
+ static_assert(std::is_same<InType, int8_t>());
17
+ using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
18
+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
19
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
20
+ static int32_t const MainLoopStages = 5;
21
+
22
+ using Cutlass2xGemm =
23
+ cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
24
+ Epilogue, TileShape, WarpShape, InstructionShape, 5>;
25
+ };
26
+
27
+ struct sm89_int8_config_default {
28
+ // M in (256, inf)
29
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
30
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
31
+
32
+ template <typename InType, typename OutType,
33
+ template <typename, typename> typename Epilogue,
34
+ typename... EpilogueArgs>
35
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
36
+ torch::Tensor const& b, EpilogueArgs&&... args) {
37
+ static_assert(std::is_same<InType, int8_t>());
38
+ TORCH_CHECK(a.dtype() == torch::kInt8);
39
+
40
+ using FallbackGemm =
41
+ typename sm89_int8_fallback_gemm<InType, OutType,
42
+ Epilogue>::Cutlass2xGemm;
43
+
44
+ uint32_t const n = out.size(1);
45
+ uint32_t const np2 = next_pow_2(n);
46
+
47
+ if (np2 <= 4096) {
48
+ using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
49
+
50
+ return vllm::fallback_cutlass_gemm_caller<
51
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
52
+ InType, OutType, Epilogue, TileShape, WarpShape,
53
+ InstructionShape, 5>,
54
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
55
+ } else if (np2 <= 8192) {
56
+ using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
57
+
58
+ return vllm::fallback_cutlass_gemm_caller<
59
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
60
+ InType, OutType, Epilogue, TileShape, WarpShape,
61
+ InstructionShape, 3>,
62
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
63
+ } else if (np2 <= 16384) {
64
+ using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
65
+
66
+ return vllm::fallback_cutlass_gemm_caller<
67
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
68
+ InType, OutType, Epilogue, TileShape, WarpShape,
69
+ InstructionShape, 5>,
70
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
71
+ } else {
72
+ using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
73
+
74
+ return vllm::fallback_cutlass_gemm_caller<
75
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
76
+ InType, OutType, Epilogue, TileShape, WarpShape,
77
+ InstructionShape, 3>,
78
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
79
+ }
80
+ }
81
+ };
82
+
83
+ struct sm89_int8_config_M256 {
84
+ // M in (128, 256]
85
+ using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
86
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
87
+
88
+ template <typename InType, typename OutType,
89
+ template <typename, typename> typename Epilogue,
90
+ typename... EpilogueArgs>
91
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
92
+ torch::Tensor const& b, EpilogueArgs&&... args) {
93
+ static_assert(std::is_same<InType, int8_t>());
94
+ TORCH_CHECK(a.dtype() == torch::kInt8);
95
+
96
+ using FallbackGemm =
97
+ typename sm89_int8_fallback_gemm<InType, OutType,
98
+ Epilogue>::Cutlass2xGemm;
99
+
100
+ uint32_t const n = out.size(1);
101
+ uint32_t const np2 = next_pow_2(n);
102
+
103
+ if (np2 <= 4096) {
104
+ using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
105
+
106
+ return vllm::fallback_cutlass_gemm_caller<
107
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
108
+ InType, OutType, Epilogue, TileShape, WarpShape,
109
+ InstructionShape, 3>,
110
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
111
+ } else if (np2 <= 8192) {
112
+ using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
113
+
114
+ return vllm::fallback_cutlass_gemm_caller<
115
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
116
+ InType, OutType, Epilogue, TileShape, WarpShape,
117
+ InstructionShape, 5>,
118
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
119
+ } else if (np2 <= 16384) {
120
+ using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
121
+
122
+ return vllm::fallback_cutlass_gemm_caller<
123
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
124
+ InType, OutType, Epilogue, TileShape, WarpShape,
125
+ InstructionShape, 3>,
126
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
127
+ } else {
128
+ using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
129
+
130
+ return vllm::fallback_cutlass_gemm_caller<
131
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
132
+ InType, OutType, Epilogue, TileShape, WarpShape,
133
+ InstructionShape, 5>,
134
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
135
+ }
136
+ }
137
+ };
138
+
139
+ struct sm89_int8_config_M128 {
140
+ // M in (64, 128]
141
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
142
+
143
+ template <typename InType, typename OutType,
144
+ template <typename, typename> typename Epilogue,
145
+ typename... EpilogueArgs>
146
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
147
+ torch::Tensor const& b, EpilogueArgs&&... args) {
148
+ static_assert(std::is_same<InType, int8_t>());
149
+ TORCH_CHECK(a.dtype() == torch::kInt8);
150
+
151
+ using FallbackGemm =
152
+ typename sm89_int8_fallback_gemm<InType, OutType,
153
+ Epilogue>::Cutlass2xGemm;
154
+
155
+ uint32_t const n = out.size(1);
156
+ uint32_t const np2 = next_pow_2(n);
157
+
158
+ if (np2 <= 8192) {
159
+ using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
160
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
161
+
162
+ return vllm::fallback_cutlass_gemm_caller<
163
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
164
+ InType, OutType, Epilogue, TileShape, WarpShape,
165
+ InstructionShape, 3>,
166
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
167
+ } else if (np2 <= 16384) {
168
+ using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
169
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
170
+
171
+ return vllm::fallback_cutlass_gemm_caller<
172
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
173
+ InType, OutType, Epilogue, TileShape, WarpShape,
174
+ InstructionShape, 5>,
175
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
176
+ } else {
177
+ using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
178
+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
179
+
180
+ return vllm::fallback_cutlass_gemm_caller<
181
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
182
+ InType, OutType, Epilogue, TileShape, WarpShape,
183
+ InstructionShape, 5>,
184
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
185
+ }
186
+ }
187
+ };
188
+
189
+ struct sm89_int8_config_M64 {
190
+ // M in (32, 64]
191
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
192
+
193
+ template <typename InType, typename OutType,
194
+ template <typename, typename> typename Epilogue,
195
+ typename... EpilogueArgs>
196
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
197
+ torch::Tensor const& b, EpilogueArgs&&... args) {
198
+ static_assert(std::is_same<InType, int8_t>());
199
+ TORCH_CHECK(a.dtype() == torch::kInt8);
200
+
201
+ using FallbackGemm =
202
+ typename sm89_int8_fallback_gemm<InType, OutType,
203
+ Epilogue>::Cutlass2xGemm;
204
+
205
+ uint32_t const n = out.size(1);
206
+ uint32_t const np2 = next_pow_2(n);
207
+
208
+ if (np2 <= 8192) {
209
+ using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
210
+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
211
+
212
+ return vllm::fallback_cutlass_gemm_caller<
213
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
214
+ InType, OutType, Epilogue, TileShape, WarpShape,
215
+ InstructionShape, 5>,
216
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
217
+ } else {
218
+ using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
219
+ using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
220
+
221
+ return vllm::fallback_cutlass_gemm_caller<
222
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
223
+ InType, OutType, Epilogue, TileShape, WarpShape,
224
+ InstructionShape, 3>,
225
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
226
+ }
227
+ }
228
+ };
229
+
230
+ struct sm89_int8_config_M32 {
231
+ // M in (16, 32]
232
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
233
+
234
+ template <typename InType, typename OutType,
235
+ template <typename, typename> typename Epilogue,
236
+ typename... EpilogueArgs>
237
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
238
+ torch::Tensor const& b, EpilogueArgs&&... args) {
239
+ static_assert(std::is_same<InType, int8_t>());
240
+ TORCH_CHECK(a.dtype() == torch::kInt8);
241
+
242
+ using FallbackGemm =
243
+ typename sm89_int8_fallback_gemm<InType, OutType,
244
+ Epilogue>::Cutlass2xGemm;
245
+
246
+ uint32_t const n = out.size(1);
247
+ uint32_t const np2 = next_pow_2(n);
248
+
249
+ if (np2 <= 8192) {
250
+ using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
251
+ using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
252
+
253
+ return vllm::fallback_cutlass_gemm_caller<
254
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
255
+ InType, OutType, Epilogue, TileShape, WarpShape,
256
+ InstructionShape, 5>,
257
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
258
+ } else {
259
+ using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
260
+ using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
261
+
262
+ return vllm::fallback_cutlass_gemm_caller<
263
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
264
+ InType, OutType, Epilogue, TileShape, WarpShape,
265
+ InstructionShape, 4>,
266
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
267
+ }
268
+ }
269
+ };
270
+
271
+ struct sm89_int8_config_M16 {
272
+ // M in [1, 16]
273
+ using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
274
+ using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
275
+
276
+ template <typename InType, typename OutType,
277
+ template <typename, typename> typename Epilogue,
278
+ typename... EpilogueArgs>
279
+ static void dispatch(torch::Tensor& out, torch::Tensor const& a,
280
+ torch::Tensor const& b, EpilogueArgs&&... args) {
281
+ static_assert(std::is_same<InType, int8_t>());
282
+ TORCH_CHECK(a.dtype() == torch::kInt8);
283
+
284
+ using FallbackGemm =
285
+ typename sm89_int8_fallback_gemm<InType, OutType,
286
+ Epilogue>::Cutlass2xGemm;
287
+
288
+ uint32_t const n = out.size(1);
289
+ uint32_t const np2 = next_pow_2(n);
290
+
291
+ if (np2 <= 8192) {
292
+ using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
293
+
294
+ return vllm::fallback_cutlass_gemm_caller<
295
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
296
+ InType, OutType, Epilogue, TileShape, WarpShape,
297
+ InstructionShape, 5>,
298
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
299
+ } else {
300
+ using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
301
+
302
+ return vllm::fallback_cutlass_gemm_caller<
303
+ vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
304
+ InType, OutType, Epilogue, TileShape, WarpShape,
305
+ InstructionShape, 4>,
306
+ FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
307
+ }
308
+ }
309
+ };
310
+
311
+ template <typename InType, typename OutType,
312
+ template <typename, typename> typename Epilogue,
313
+ typename... EpilogueArgs>
314
+ inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
315
+ torch::Tensor const& a,
316
+ torch::Tensor const& b,
317
+ EpilogueArgs&&... args) {
318
+ static_assert(std::is_same<InType, int8_t>());
319
+ TORCH_CHECK(a.dtype() == torch::kInt8);
320
+ TORCH_CHECK(b.dtype() == torch::kInt8);
321
+
322
+ uint32_t const m = a.size(0);
323
+ uint32_t const mp2 =
324
+ std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
325
+
326
+ if (mp2 <= 16) {
327
+ // M in [1, 16]
328
+ return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
329
+ out, a, b, std::forward<EpilogueArgs>(args)...);
330
+ } else if (mp2 <= 32) {
331
+ // M in (16, 32]
332
+ return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
333
+ out, a, b, std::forward<EpilogueArgs>(args)...);
334
+ } else if (mp2 <= 64) {
335
+ // M in (32, 64]
336
+ return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
337
+ out, a, b, std::forward<EpilogueArgs>(args)...);
338
+ } else if (mp2 <= 128) {
339
+ // M in (64, 128]
340
+ return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
341
+ out, a, b, std::forward<EpilogueArgs>(args)...);
342
+ } else if (mp2 <= 256) {
343
+ // M in (128, 256]
344
+ return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
345
+ out, a, b, std::forward<EpilogueArgs>(args)...);
346
+ } else {
347
+ // M in (256, inf)
348
+ return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
349
+ out, a, b, std::forward<EpilogueArgs>(args)...);
350
+ }
351
+ }
352
+
353
+ } // namespace vllm
cutlass_w8a8/scaled_mm_c3x.cu ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // clang-format will break include orders
2
+ // clang-format off
3
+ #include <cudaTypedefs.h>
4
+
5
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
6
+
7
+ #include <torch/all.h>
8
+
9
+ #include <ATen/cuda/CUDAContext.h>
10
+
11
+ #include <iostream>
12
+ #include <sstream>
13
+ #include <vector>
14
+
15
+ #include "cutlass/cutlass.h"
16
+
17
+ #include "cute/tensor.hpp"
18
+ #include "cute/atom/mma_atom.hpp"
19
+ #include "cutlass/numeric_types.h"
20
+
21
+ #include "cutlass/gemm/device/gemm_universal_adapter.h"
22
+ #include "cutlass/gemm/kernel/gemm_universal.hpp"
23
+ #include "cutlass/epilogue/collective/collective_builder.hpp"
24
+ #include "cutlass/gemm/collective/collective_builder.hpp"
25
+
26
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
27
+ #include "common.hpp"
28
+ // clang-format on
29
+
30
+ using namespace cute;
31
+ using namespace vllm;
32
+
33
+ /*
34
+ This file defines quantized GEMM operations using the CUTLASS 3.x API, for
35
+ NVIDIA GPUs with sm90a (Hopper) or later.
36
+
37
+ Epilogue functions can be defined to post-process the output before it is
38
+ written to GPU memory.
39
+ Epilogues must contain a public type named EVTCompute of type Sm90EVT,
40
+ as well as a static prepare_args function that constructs an
41
+ EVTCompute::Arguments struct.
42
+ */
43
+
44
+ namespace {
45
+
46
+ // A wrapper for the GEMM kernel that is used to guard against compilation on
47
+ // architectures that will never use the kernel. The purpose of this is to
48
+ // reduce the size of the compiled binary.
49
+ // __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
50
+ // into code that will be executed on the device where it is defined.
51
+ template <typename Kernel>
52
+ struct enable_sm90_or_later : Kernel {
53
+ template <typename... Args>
54
+ CUTLASS_DEVICE void operator()(Args&&... args) {
55
+ #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
56
+ Kernel::operator()(std::forward<Args>(args)...);
57
+ #endif
58
+ }
59
+ };
60
+ template <typename ElementAB_, typename ElementD_,
61
+ template <typename, typename, typename> typename Epilogue_,
62
+ typename TileShape, typename ClusterShape, typename KernelSchedule,
63
+ typename EpilogueSchedule>
64
+ struct cutlass_3x_gemm {
65
+ using ElementAB = ElementAB_;
66
+ using ElementD = ElementD_;
67
+ using ElementAcc =
68
+ typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
69
+ float>::type;
70
+
71
+ using EpilogueDescriptor =
72
+ cutlass::epilogue::collective::detail::EpilogueDescriptor<
73
+ TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
74
+ ElementD, EpilogueSchedule>;
75
+
76
+ using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
77
+
78
+ using StrideD = Stride<int64_t, Int<1>, Int<0>>;
79
+ using ElementC = void;
80
+ using StrideC = StrideD;
81
+
82
+ using EVTCompute = typename Epilogue::EVTCompute;
83
+
84
+ using CollectiveEpilogue =
85
+ typename cutlass::epilogue::collective::CollectiveBuilder<
86
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
87
+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
88
+ ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
89
+ EpilogueSchedule, EVTCompute>::CollectiveOp;
90
+
91
+ static constexpr size_t CEStorageSize =
92
+ sizeof(typename CollectiveEpilogue::SharedStorage);
93
+ using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
94
+ static_cast<int>(CEStorageSize)>;
95
+
96
+ // clang-format off
97
+ using CollectiveMainloop =
98
+ typename cutlass::gemm::collective::CollectiveBuilder<
99
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
100
+ ElementAB, cutlass::layout::RowMajor, 16,
101
+ ElementAB, cutlass::layout::ColumnMajor, 16,
102
+ ElementAcc, TileShape, ClusterShape,
103
+ Stages,
104
+ KernelSchedule>::CollectiveOp;
105
+ // clang-format on
106
+
107
+ using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
108
+ cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
109
+ cutlass::gemm::PersistentScheduler>>;
110
+
111
+ struct GemmKernel : public KernelType {};
112
+ };
113
+
114
+ template <typename Gemm, typename... EpilogueArgs>
115
+ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
116
+ torch::Tensor const& b,
117
+ EpilogueArgs&&... epilogue_params) {
118
+ using ElementAB = typename Gemm::ElementAB;
119
+ using ElementD = typename Gemm::ElementD;
120
+
121
+ int32_t m = a.size(0);
122
+ int32_t n = b.size(1);
123
+ int32_t k = a.size(1);
124
+
125
+ int64_t lda = a.stride(0);
126
+ int64_t ldb = b.stride(1);
127
+ int64_t ldc = out.stride(0);
128
+
129
+ using StrideA = Stride<int64_t, Int<1>, int64_t>;
130
+ using StrideB = Stride<int64_t, Int<1>, int64_t>;
131
+ using StrideC = typename Gemm::StrideC;
132
+
133
+ StrideA a_stride{lda, Int<1>{}, 0};
134
+ StrideB b_stride{ldb, Int<1>{}, 0};
135
+ StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
136
+
137
+ using GemmKernel = typename Gemm::GemmKernel;
138
+ typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
139
+
140
+ auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
141
+ auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
142
+ typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
143
+ b_stride};
144
+
145
+ auto c_ptr = static_cast<ElementD*>(out.data_ptr());
146
+ typename GemmKernel::EpilogueArguments epilogue_args{
147
+ Gemm::Epilogue::prepare_args(
148
+ std::forward<EpilogueArgs>(epilogue_params)...),
149
+ c_ptr, c_stride, c_ptr, c_stride};
150
+
151
+ typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
152
+ prob_shape, mainloop_args, epilogue_args};
153
+
154
+ // Launch the CUTLASS GEMM kernel.
155
+ using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
156
+ GemmOp gemm_op;
157
+ CUTLASS_CHECK(gemm_op.can_implement(args));
158
+
159
+ size_t workspace_size = gemm_op.get_workspace_size(args);
160
+ auto const workspace_options =
161
+ torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
162
+ auto workspace = torch::empty(workspace_size, workspace_options);
163
+
164
+ auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
165
+
166
+ cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
167
+ CUTLASS_CHECK(status);
168
+ }
169
+
170
+ template <typename InType, typename OutType,
171
+ template <typename, typename, typename> typename Epilogue>
172
+ struct sm90_fp8_config_default {
173
+ // M in (128, inf)
174
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
175
+ using KernelSchedule =
176
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
177
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
178
+ using TileShape = Shape<_128, _128, _128>;
179
+ using ClusterShape = Shape<_2, _1, _1>;
180
+ using Cutlass3xGemm =
181
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
182
+ KernelSchedule, EpilogueSchedule>;
183
+ };
184
+
185
+ template <typename InType, typename OutType,
186
+ template <typename, typename, typename> typename Epilogue>
187
+ struct sm90_fp8_config_M128 {
188
+ // M in (64, 128]
189
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
190
+ using KernelSchedule =
191
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
192
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
193
+ using TileShape = Shape<_64, _128, _128>;
194
+ using ClusterShape = Shape<_2, _1, _1>;
195
+ using Cutlass3xGemm =
196
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
197
+ KernelSchedule, EpilogueSchedule>;
198
+ };
199
+
200
+ template <typename InType, typename OutType,
201
+ template <typename, typename, typename> typename Epilogue>
202
+ struct sm90_fp8_config_M64 {
203
+ // M in [1, 64]
204
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
205
+ using KernelSchedule =
206
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
207
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
208
+ using TileShape = Shape<_64, _64, _128>;
209
+ using ClusterShape = Shape<_1, _8, _1>;
210
+
211
+ using Cutlass3xGemm =
212
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
213
+ KernelSchedule, EpilogueSchedule>;
214
+ };
215
+
216
+ template <typename InType, typename OutType,
217
+ template <typename, typename, typename> typename Epilogue>
218
+ struct sm90_int8_config_default {
219
+ // For M > 128 and any N
220
+ static_assert(std::is_same<InType, int8_t>());
221
+ using KernelSchedule =
222
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
223
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
224
+ using TileShape = Shape<_128, _128, _128>;
225
+ using ClusterShape = Shape<_2, _1, _1>;
226
+ using Cutlass3xGemm =
227
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
228
+ KernelSchedule, EpilogueSchedule>;
229
+ };
230
+
231
+ template <typename InType, typename OutType,
232
+ template <typename, typename, typename> typename Epilogue>
233
+ struct sm90_int8_config_M128 {
234
+ // For M in (64, 128] and any N
235
+ static_assert(std::is_same<InType, int8_t>());
236
+ using KernelSchedule =
237
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
238
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
239
+ using TileShape = Shape<_64, _128, _128>;
240
+ using ClusterShape = Shape<_2, _1, _1>;
241
+ using Cutlass3xGemm =
242
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
243
+ KernelSchedule, EpilogueSchedule>;
244
+ };
245
+
246
+ template <typename InType, typename OutType,
247
+ template <typename, typename, typename> typename Epilogue>
248
+ struct sm90_int8_config_M64 {
249
+ // For M in (32, 64] and any N
250
+ static_assert(std::is_same<InType, int8_t>());
251
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
252
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
253
+ using TileShape = Shape<_64, _64, _256>;
254
+ using ClusterShape = Shape<_1, _1, _1>;
255
+ using Cutlass3xGemm =
256
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
257
+ KernelSchedule, EpilogueSchedule>;
258
+ };
259
+
260
+ template <typename InType, typename OutType,
261
+ template <typename, typename, typename> typename Epilogue>
262
+ struct sm90_int8_config_M32_NBig {
263
+ // For M in [1, 32] and N >= 8192
264
+ static_assert(std::is_same<InType, int8_t>());
265
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
266
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
267
+ using TileShape = Shape<_64, _128, _256>;
268
+ using ClusterShape = Shape<_1, _4, _1>;
269
+ using Cutlass3xGemm =
270
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
271
+ KernelSchedule, EpilogueSchedule>;
272
+ };
273
+
274
+ template <typename InType, typename OutType,
275
+ template <typename, typename, typename> typename Epilogue>
276
+ struct sm90_int8_config_M32_NSmall {
277
+ // For M in [1, 32] and N < 8192
278
+ static_assert(std::is_same<InType, int8_t>());
279
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
280
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
281
+ using TileShape = Shape<_64, _64, _256>;
282
+ using ClusterShape = Shape<_1, _8, _1>;
283
+ using Cutlass3xGemm =
284
+ cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
285
+ KernelSchedule, EpilogueSchedule>;
286
+ };
287
+
288
+ } // namespace
289
+
290
+ template <typename InType, typename OutType,
291
+ template <typename, typename, typename> typename Epilogue,
292
+ typename... EpilogueArgs>
293
+ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
294
+ torch::Tensor const& b,
295
+ EpilogueArgs&&... args) {
296
+ static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
297
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
298
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
299
+
300
+ using Cutlass3xGemmDefault =
301
+ typename sm90_fp8_config_default<InType, OutType,
302
+ Epilogue>::Cutlass3xGemm;
303
+ using Cutlass3xGemmM64 =
304
+ typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
305
+ using Cutlass3xGemmM128 =
306
+ typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
307
+
308
+ uint32_t const m = a.size(0);
309
+ uint32_t const mp2 =
310
+ std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
311
+
312
+ if (mp2 <= 64) {
313
+ // m in [1, 64]
314
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
315
+ out, a, b, std::forward<EpilogueArgs>(args)...);
316
+ } else if (mp2 <= 128) {
317
+ // m in (64, 128]
318
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
319
+ out, a, b, std::forward<EpilogueArgs>(args)...);
320
+ } else {
321
+ // m in (128, inf)
322
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
323
+ out, a, b, std::forward<EpilogueArgs>(args)...);
324
+ }
325
+ }
326
+
327
+ template <typename InType, typename OutType,
328
+ template <typename, typename, typename> typename Epilogue,
329
+ typename... EpilogueArgs>
330
+ void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
331
+ torch::Tensor const& b,
332
+ EpilogueArgs&&... args) {
333
+ static_assert(std::is_same<InType, int8_t>());
334
+ TORCH_CHECK(a.dtype() == torch::kInt8);
335
+ TORCH_CHECK(b.dtype() == torch::kInt8);
336
+
337
+ using Cutlass3xGemmDefault =
338
+ typename sm90_int8_config_default<InType, OutType,
339
+ Epilogue>::Cutlass3xGemm;
340
+ using Cutlass3xGemmM128 =
341
+ typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
342
+ using Cutlass3xGemmM64 =
343
+ typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
344
+ using Cutlass3xGemmM32NBig =
345
+ typename sm90_int8_config_M32_NBig<InType, OutType,
346
+ Epilogue>::Cutlass3xGemm;
347
+ using Cutlass3xGemmM32NSmall =
348
+ typename sm90_int8_config_M32_NSmall<InType, OutType,
349
+ Epilogue>::Cutlass3xGemm;
350
+
351
+ uint32_t const n = out.size(1);
352
+ bool const is_small_n = n < 8192;
353
+
354
+ uint32_t const m = a.size(0);
355
+ uint32_t const mp2 =
356
+ std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
357
+
358
+ if (mp2 <= 32) {
359
+ // m in [1, 32]
360
+ if (is_small_n) {
361
+ return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
362
+ out, a, b, std::forward<EpilogueArgs>(args)...);
363
+ } else {
364
+ return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
365
+ out, a, b, std::forward<EpilogueArgs>(args)...);
366
+ }
367
+ } else if (mp2 <= 64) {
368
+ // m in (32, 64]
369
+ return cutlass_gemm_caller<Cutlass3xGemmM64>(
370
+ out, a, b, std::forward<EpilogueArgs>(args)...);
371
+ } else if (mp2 <= 128) {
372
+ // m in (64, 128]
373
+ return cutlass_gemm_caller<Cutlass3xGemmM128>(
374
+ out, a, b, std::forward<EpilogueArgs>(args)...);
375
+ } else {
376
+ // m in (128, inf)
377
+ return cutlass_gemm_caller<Cutlass3xGemmDefault>(
378
+ out, a, b, std::forward<EpilogueArgs>(args)...);
379
+ }
380
+ }
381
+
382
+ template <template <typename, typename, typename> typename Epilogue,
383
+ typename... EpilogueArgs>
384
+ void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
385
+ torch::Tensor const& b,
386
+ EpilogueArgs&&... epilogue_args) {
387
+ if (a.dtype() == torch::kInt8) {
388
+ TORCH_CHECK(b.dtype() == torch::kInt8);
389
+
390
+ if (out.dtype() == torch::kBFloat16) {
391
+ return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
392
+ Epilogue>(
393
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
394
+ } else {
395
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
396
+ return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
397
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
398
+ }
399
+ } else {
400
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
401
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
402
+
403
+ if (out.dtype() == torch::kBFloat16) {
404
+ return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
405
+ cutlass::bfloat16_t, Epilogue>(
406
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
407
+ } else {
408
+ TORCH_CHECK(out.dtype() == torch::kFloat16);
409
+ return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
410
+ cutlass::half_t, Epilogue>(
411
+ out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
412
+ }
413
+ }
414
+ }
415
+
416
+ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
417
+ torch::Tensor const& b,
418
+ torch::Tensor const& a_scales,
419
+ torch::Tensor const& b_scales,
420
+ c10::optional<torch::Tensor> const& bias) {
421
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
422
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
423
+ if (bias) {
424
+ TORCH_CHECK(bias->dtype() == c.dtype(),
425
+ "currently bias dtype must match output dtype ", c.dtype());
426
+ return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
427
+ c, a, b, a_scales, b_scales, *bias);
428
+ } else {
429
+ return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
430
+ c, a, b, a_scales, b_scales);
431
+ }
432
+ }
433
+
434
+ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
435
+ torch::Tensor const& b,
436
+ torch::Tensor const& a_scales,
437
+ torch::Tensor const& b_scales,
438
+ torch::Tensor const& azp_adj,
439
+ c10::optional<torch::Tensor> const& azp,
440
+ c10::optional<torch::Tensor> const& bias) {
441
+ TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
442
+ TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
443
+
444
+ if (azp) {
445
+ return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
446
+ out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
447
+ } else {
448
+ return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
449
+ out, a, b, a_scales, b_scales, azp_adj, bias);
450
+ }
451
+ }
452
+
453
+ #endif
cutlass_w8a8/scaled_mm_entry.cu ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cudaTypedefs.h>
2
+
3
+ #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/all.h>
5
+
6
+ void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
7
+ torch::Tensor const& b,
8
+ torch::Tensor const& a_scales,
9
+ torch::Tensor const& b_scales,
10
+ c10::optional<torch::Tensor> const& bias);
11
+
12
+ void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
13
+ torch::Tensor const& b,
14
+ torch::Tensor const& a_scales,
15
+ torch::Tensor const& b_scales,
16
+ c10::optional<torch::Tensor> const& bias);
17
+
18
+ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
19
+ torch::Tensor const& b,
20
+ torch::Tensor const& a_scales,
21
+ torch::Tensor const& b_scales,
22
+ c10::optional<torch::Tensor> const& bias);
23
+
24
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
25
+ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
26
+ torch::Tensor const& b,
27
+ torch::Tensor const& a_scales,
28
+ torch::Tensor const& b_scales,
29
+ c10::optional<torch::Tensor> const& bias);
30
+ #endif
31
+
32
+ void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
33
+ torch::Tensor const& b,
34
+ torch::Tensor const& a_scales,
35
+ torch::Tensor const& b_scales,
36
+ torch::Tensor const& azp_adj,
37
+ c10::optional<torch::Tensor> const& azp,
38
+ c10::optional<torch::Tensor> const& bias);
39
+
40
+ void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
41
+ torch::Tensor const& b,
42
+ torch::Tensor const& a_scales,
43
+ torch::Tensor const& b_scales,
44
+ torch::Tensor const& azp_adj,
45
+ c10::optional<torch::Tensor> const& azp,
46
+ c10::optional<torch::Tensor> const& bias);
47
+
48
+ void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
49
+ torch::Tensor const& b,
50
+ torch::Tensor const& a_scales,
51
+ torch::Tensor const& b_scales,
52
+ torch::Tensor const& azp_adj,
53
+ c10::optional<torch::Tensor> const& azp,
54
+ c10::optional<torch::Tensor> const& bias);
55
+
56
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
57
+ void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
58
+ torch::Tensor const& b,
59
+ torch::Tensor const& a_scales,
60
+ torch::Tensor const& b_scales,
61
+ torch::Tensor const& azp_adj,
62
+ c10::optional<torch::Tensor> const& azp,
63
+ c10::optional<torch::Tensor> const& bias);
64
+ #endif
65
+
66
+ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
67
+ // CUTLASS FP8 kernels need at least
68
+ // CUDA 12.0 on SM90 systems (Hopper)
69
+ // CUDA 12.4 on SM89 systems (Lovelace)
70
+
71
+ #if defined CUDA_VERSION
72
+ if (cuda_device_capability >= 90) {
73
+ return CUDA_VERSION >= 12000;
74
+ } else if (cuda_device_capability >= 89) {
75
+ return CUDA_VERSION >= 12040;
76
+ }
77
+ #endif
78
+
79
+ return false;
80
+ }
81
+
82
+ int32_t get_sm_version_num() {
83
+ int32_t major_capability, minor_capability;
84
+ cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
85
+ 0);
86
+ cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
87
+ 0);
88
+ int32_t version_num = major_capability * 10 + minor_capability;
89
+ return version_num;
90
+ }
91
+
92
+ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
93
+ torch::Tensor const& b, torch::Tensor const& a_scales,
94
+ torch::Tensor const& b_scales,
95
+ c10::optional<torch::Tensor> const& bias) {
96
+ // Checks for conformality
97
+ TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
98
+ TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
99
+ b.size(1) == c.size(1));
100
+ TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
101
+ TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
102
+
103
+ // Check for strides and alignment
104
+ TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
105
+ TORCH_CHECK(b.stride(0) == 1); // Column-major
106
+ TORCH_CHECK(c.stride(0) % 16 == 0 &&
107
+ b.stride(1) % 16 == 0); // 16 Byte Alignment
108
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
109
+
110
+ if (bias) {
111
+ TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
112
+ bias->dim() == 1);
113
+ }
114
+
115
+ at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
116
+ int32_t version_num = get_sm_version_num();
117
+ // Hopper
118
+
119
+ // Guard against compilation issues for sm90 kernels
120
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
121
+ if (version_num >= 90) {
122
+ cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
123
+ return;
124
+ }
125
+ #endif
126
+
127
+ if (version_num == 89) {
128
+ // Ada Lovelace
129
+ cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
130
+ return;
131
+ }
132
+
133
+ if (version_num >= 80) {
134
+ // Ampere
135
+ cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
136
+ return;
137
+ }
138
+
139
+ if (version_num >= 75) {
140
+ // Turing
141
+ cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
142
+ return;
143
+ }
144
+
145
+ TORCH_CHECK_NOT_IMPLEMENTED(
146
+ false,
147
+ "No compiled cutlass_scaled_mm for a compute capability less than "
148
+ "CUDA device capability: ",
149
+ version_num);
150
+ }
151
+
152
+ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
153
+ torch::Tensor const& b,
154
+ torch::Tensor const& a_scales,
155
+ torch::Tensor const& b_scales,
156
+ torch::Tensor const& azp_adj,
157
+ c10::optional<torch::Tensor> const& azp,
158
+ c10::optional<torch::Tensor> const& bias) {
159
+ // Checks for conformality
160
+ TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
161
+ TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
162
+ b.size(1) == c.size(1));
163
+ TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
164
+ TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
165
+
166
+ // Check for strides and alignment
167
+ TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
168
+ TORCH_CHECK(b.stride(0) == 1); // Column-major
169
+ TORCH_CHECK(c.stride(0) % 16 == 0 &&
170
+ b.stride(1) % 16 == 0); // 16 Byte Alignment
171
+ TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
172
+
173
+ // bias, azp, azp_adj are all 1d
174
+ // bias and azp_adj have n elements, azp has m elements
175
+ if (bias) {
176
+ TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
177
+ }
178
+ if (azp) {
179
+ TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
180
+ }
181
+ TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
182
+
183
+ // azp & bias types
184
+ TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
185
+ TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
186
+ TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
187
+ "currently bias dtype must match output dtype ", c.dtype());
188
+
189
+ at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
190
+
191
+ int32_t version_num = get_sm_version_num();
192
+
193
+ #if defined CUDA_VERSION && CUDA_VERSION >= 12000
194
+ if (version_num >= 90) {
195
+ cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
196
+ return;
197
+ }
198
+ #endif
199
+
200
+ if (version_num == 89) {
201
+ // Ada Lovelace
202
+ cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
203
+ return;
204
+ }
205
+
206
+ if (version_num >= 80) {
207
+ // Ampere
208
+ cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
209
+ return;
210
+ }
211
+
212
+ // Turing
213
+ TORCH_CHECK(version_num >= 75);
214
+ cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
215
+ return;
216
+
217
+ TORCH_CHECK_NOT_IMPLEMENTED(
218
+ false,
219
+ "No compiled cutlass_scaled_mm_azp for a compute capability less than "
220
+ "CUDA device capability: ",
221
+ version_num);
222
+ }
ext-torch/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+
5
+ try:
6
+ from ._ops import ops
7
+ except ImportError as e:
8
+ # Fallback for local development.
9
+ try:
10
+ import _quantization
11
+ ops = torch.ops._quantization
12
+ except ImportError:
13
+ raise e
14
+
15
+ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
16
+ return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
17
+
18
+ def cutlass_scaled_mm(a: torch.Tensor,
19
+ b: torch.Tensor,
20
+ scale_a: torch.Tensor,
21
+ scale_b: torch.Tensor,
22
+ out_dtype: torch.dtype,
23
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
24
+ assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
25
+ assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
26
+ assert bias is None or bias.shape[0] == b.shape[
27
+ 1] and bias.dtype == out_dtype
28
+
29
+ m = a.shape[0]
30
+ n = b.shape[1]
31
+
32
+ #if current_platform.is_rocm():
33
+ # triton_scaled_mm_module = importlib.import_module(
34
+ # "vllm.model_executor.layers.quantization.compressed_tensors."
35
+ # "triton_scaled_mm")
36
+ # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
37
+ # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
38
+
39
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
40
+
41
+ ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
42
+
43
+ return out
44
+
ext-torch/registration.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <Python.h>
4
+
5
+ #define _CONCAT(A, B) A##B
6
+ #define CONCAT(A, B) _CONCAT(A, B)
7
+
8
+ #define _STRINGIFY(A) #A
9
+ #define STRINGIFY(A) _STRINGIFY(A)
10
+
11
+ // A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
12
+ // could be a macro instead of a literal token.
13
+ #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
14
+
15
+ // A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
16
+ // could be a macro instead of a literal token.
17
+ #define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
18
+ TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
19
+
20
+ // REGISTER_EXTENSION allows the shared library to be loaded and initialized
21
+ // via python's import statement.
22
+ #define REGISTER_EXTENSION(NAME) \
23
+ PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
24
+ static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
25
+ STRINGIFY(NAME), nullptr, 0, nullptr}; \
26
+ return PyModule_Create(&module); \
27
+ }
ext-torch/torch_binding.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+
8
+ // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
9
+ // quantization, as well as bias
10
+ ops.def(
11
+ "cutlass_scaled_mm(Tensor! out, Tensor a,"
12
+ " Tensor b, Tensor a_scales,"
13
+ " Tensor b_scales, Tensor? bias) -> ()");
14
+ ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
15
+
16
+ // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
17
+ // quantization.
18
+ ops.def(
19
+ "cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
20
+ " Tensor b, Tensor a_scales,"
21
+ " Tensor b_scales, Tensor azp_adj,"
22
+ " Tensor? azp, Tensor? bias) -> ()");
23
+ ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
24
+
25
+ // Check if cutlass scaled_mm is supported for CUDA devices of the given
26
+ // capability
27
+ ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
28
+ ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
29
+
30
+ }
31
+
32
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
ext-torch/torch_binding.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
6
+
7
+ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
8
+ torch::Tensor const& b, torch::Tensor const& a_scales,
9
+ torch::Tensor const& b_scales,
10
+ c10::optional<torch::Tensor> const& bias);
11
+
12
+ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
13
+ torch::Tensor const& b,
14
+ torch::Tensor const& a_scales,
15
+ torch::Tensor const& b_scales,
16
+ torch::Tensor const& azp_adj,
17
+ c10::optional<torch::Tensor> const& azp,
18
+ c10::optional<torch::Tensor> const& bias);