Add cutlass_w8a8
Browse files- LICENSE +201 -0
- README.md +9 -3
- build.toml +41 -0
- cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp +497 -0
- cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +447 -0
- cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +317 -0
- cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +315 -0
- cutlass_w8a8/Epilogues.md +147 -0
- cutlass_w8a8/common.hpp +27 -0
- cutlass_w8a8/scaled_mm_c2x.cu +199 -0
- cutlass_w8a8/scaled_mm_c2x.cuh +219 -0
- cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh +123 -0
- cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh +139 -0
- cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh +368 -0
- cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh +353 -0
- cutlass_w8a8/scaled_mm_c3x.cu +453 -0
- cutlass_w8a8/scaled_mm_entry.cu +222 -0
- ext-torch/__init__.py +44 -0
- ext-torch/registration.h +27 -0
- ext-torch/torch_binding.cpp +32 -0
- ext-torch/torch_binding.h +18 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
-
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
----
|
2 |
+
-license: apache-2.0
|
3 |
+
----
|
4 |
+
|
5 |
+
## Activation
|
6 |
+
|
7 |
+
Quantization kernels from [vLLM](https://github.com/vllm-project/vllm/blob/main/csrc/quantization).
|
8 |
+
|
9 |
+
Copyright 2023-2024, the vLLM team.
|
build.toml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
version = "0.0.1"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
name = "quantization"
|
6 |
+
src = [
|
7 |
+
"ext-torch/registration.h",
|
8 |
+
"ext-torch/torch_binding.cpp",
|
9 |
+
"ext-torch/torch_binding.h"
|
10 |
+
]
|
11 |
+
pysrc = [
|
12 |
+
"ext-torch/__init__.py"
|
13 |
+
]
|
14 |
+
|
15 |
+
[kernel.cutlass_w8a8]
|
16 |
+
capabilities = [ "7.5", "8.0", "8.6", "8.7", "8.9", "9.0", "9.0a" ]
|
17 |
+
src = [
|
18 |
+
"cutlass_w8a8/common.hpp",
|
19 |
+
"cutlass_w8a8/scaled_mm_c2x.cu",
|
20 |
+
"cutlass_w8a8/scaled_mm_c2x.cuh",
|
21 |
+
"cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh",
|
22 |
+
"cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh",
|
23 |
+
"cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh",
|
24 |
+
"cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh",
|
25 |
+
"cutlass_w8a8/scaled_mm_entry.cu",
|
26 |
+
"cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp",
|
27 |
+
"cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp",
|
28 |
+
]
|
29 |
+
include = [ "." ]
|
30 |
+
depends = [ "cutlass", "torch" ]
|
31 |
+
|
32 |
+
[kernel.cutlass_w8a8_hopper]
|
33 |
+
capabilities = [ "9.0", "9.0a" ]
|
34 |
+
src = [
|
35 |
+
"cutlass_w8a8/common.hpp",
|
36 |
+
"cutlass_w8a8/scaled_mm_c3x.cu",
|
37 |
+
"cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp",
|
38 |
+
"cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp",
|
39 |
+
]
|
40 |
+
include = [ "." ]
|
41 |
+
depends = [ "cutlass", "torch" ]
|
cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
3 |
+
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice,
|
9 |
+
*this list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
22 |
+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
23 |
+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
24 |
+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
25 |
+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
26 |
+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
27 |
+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
28 |
+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
29 |
+
*POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
*
|
31 |
+
**************************************************************************************************/
|
32 |
+
|
33 |
+
//
|
34 |
+
// This file is a modified excerpt of
|
35 |
+
// include/cutlass/epilogue/fusion/visitor_load.hpp from
|
36 |
+
// https://github.com/NVIDIA/cutlass v3.5.0
|
37 |
+
// It has been modified to support either
|
38 |
+
// row/column or scalar broadcasting where the tensor being loaded from is
|
39 |
+
// always passed in via a device pointer. This lets one compiled kernel handle
|
40 |
+
// all cases of per-tensor or per-channel/per-token quantization.
|
41 |
+
//
|
42 |
+
// This interface also allows the scales to be passed in as tensors that
|
43 |
+
// consistently reside on the device, which avoids an issue with a previous
|
44 |
+
// implementation where scalars needed to be on the CPU since they
|
45 |
+
// were passed in via float values. This created a potential performance hazard
|
46 |
+
// if scales were initially on the device, and caused torch.compile graph
|
47 |
+
// breaks when moving scales to the CPU.
|
48 |
+
//
|
49 |
+
#pragma once
|
50 |
+
|
51 |
+
// Turn off clang-format for the entire file to keep it close to upstream
|
52 |
+
// clang-format off
|
53 |
+
|
54 |
+
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
|
55 |
+
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
56 |
+
#include "cute/tensor.hpp"
|
57 |
+
|
58 |
+
namespace cutlass::epilogue::threadblock {
|
59 |
+
|
60 |
+
using namespace cute;
|
61 |
+
using namespace detail;
|
62 |
+
|
63 |
+
template<
|
64 |
+
class ThreadMap,
|
65 |
+
class Element,
|
66 |
+
class StrideMNL
|
67 |
+
>
|
68 |
+
struct VisitorRowOrScalarBroadcast {
|
69 |
+
|
70 |
+
// This struct has been modified to have a bool indicating that ptr_row is a
|
71 |
+
// scalar that must be broadcast.
|
72 |
+
struct Arguments {
|
73 |
+
Element const* ptr_row = nullptr;
|
74 |
+
bool row_broadcast = true;
|
75 |
+
StrideMNL dRow = {};
|
76 |
+
};
|
77 |
+
|
78 |
+
using Params = Arguments;
|
79 |
+
|
80 |
+
template <class ProblemShape>
|
81 |
+
static constexpr Params
|
82 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
83 |
+
return args;
|
84 |
+
}
|
85 |
+
|
86 |
+
template <class ProblemShape>
|
87 |
+
static size_t
|
88 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
89 |
+
return 0;
|
90 |
+
}
|
91 |
+
|
92 |
+
struct SharedStorage {};
|
93 |
+
|
94 |
+
// Global load type
|
95 |
+
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
96 |
+
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
97 |
+
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
98 |
+
|
99 |
+
CUTLASS_HOST_DEVICE
|
100 |
+
VisitorRowOrScalarBroadcast() { }
|
101 |
+
|
102 |
+
CUTLASS_HOST_DEVICE
|
103 |
+
VisitorRowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
104 |
+
: params_ptr(¶ms) { }
|
105 |
+
|
106 |
+
Params const* params_ptr;
|
107 |
+
|
108 |
+
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
109 |
+
struct Callbacks : EmptyCallbacks {
|
110 |
+
CUTLASS_DEVICE
|
111 |
+
Callbacks(
|
112 |
+
GTensor&& tC_gRow,
|
113 |
+
RTensor&& tC_rRow,
|
114 |
+
CTensor&& tC_cRow,
|
115 |
+
ProblemShape problem_shape,
|
116 |
+
Params const* params_ptr
|
117 |
+
):
|
118 |
+
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
119 |
+
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
120 |
+
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
121 |
+
n(get<1>(problem_shape)),
|
122 |
+
params_ptr(params_ptr) { }
|
123 |
+
|
124 |
+
GTensor tC_gRow;
|
125 |
+
RTensor tC_rRow;
|
126 |
+
CTensor tC_cRow;
|
127 |
+
Params const* params_ptr;
|
128 |
+
int n;
|
129 |
+
|
130 |
+
// This function is modified from VisitorRowBroadcast
|
131 |
+
CUTLASS_DEVICE void
|
132 |
+
begin_epilogue() {
|
133 |
+
clear(tC_rRow);
|
134 |
+
auto src_v = filter(tC_gRow);
|
135 |
+
auto coord_v = filter(tC_cRow);
|
136 |
+
auto dst_v = filter(tC_rRow);
|
137 |
+
|
138 |
+
if (params_ptr->row_broadcast) {
|
139 |
+
// In this case we are loading from a row vector and broadcasting
|
140 |
+
CUTLASS_PRAGMA_UNROLL
|
141 |
+
for (int i = 0; i < size(src_v); ++i) {
|
142 |
+
bool guard = get<1>(coord_v(i)) < n;
|
143 |
+
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
144 |
+
dst_v(i), (void const*)&src_v(i), guard);
|
145 |
+
}
|
146 |
+
} else {
|
147 |
+
// In this case we are loading from a scalar and broadcasting
|
148 |
+
VecType filled_vec;
|
149 |
+
CUTLASS_PRAGMA_UNROLL
|
150 |
+
for (int i = 0; i < VecLength; i++) {
|
151 |
+
reinterpret_cast<Element*>(&filled_vec)[i] = *(params_ptr->ptr_row);
|
152 |
+
}
|
153 |
+
|
154 |
+
CUTLASS_PRAGMA_UNROLL
|
155 |
+
for (int i = 0; i < size(src_v); ++i) {
|
156 |
+
if (get<1>(coord_v(i)) < n) {
|
157 |
+
dst_v(i) = filled_vec;
|
158 |
+
}
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
|
163 |
+
template <class ElementAccumulator, int FragmentSize>
|
164 |
+
CUTLASS_DEVICE auto // returns an Array
|
165 |
+
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
166 |
+
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
167 |
+
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
168 |
+
return rRow_frg(column_idx);
|
169 |
+
}
|
170 |
+
};
|
171 |
+
|
172 |
+
template <class ProblemShape>
|
173 |
+
CUTLASS_DEVICE auto
|
174 |
+
get_callbacks(
|
175 |
+
gemm::GemmCoord threadblock_tile_offset,
|
176 |
+
int thread_idx,
|
177 |
+
ProblemShape problem_shape
|
178 |
+
) {
|
179 |
+
Tensor mRow = make_tensor(
|
180 |
+
make_gmem_ptr(params_ptr->ptr_row),
|
181 |
+
problem_shape,
|
182 |
+
params_ptr->dRow);
|
183 |
+
|
184 |
+
// VECTOR, FRAGMENT_COLUMN
|
185 |
+
Tensor tC_gRow = recast<VecType>(
|
186 |
+
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
187 |
+
)(_,_,_0{},_0{},_0{},_0{});
|
188 |
+
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
189 |
+
|
190 |
+
// Generate the pred tensor
|
191 |
+
Tensor cRow = make_identity_tensor(mRow.shape());
|
192 |
+
Tensor tC_cRow = outer_partition(
|
193 |
+
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
194 |
+
Shape<Int<VecLength>>{},
|
195 |
+
(_0{})
|
196 |
+
);
|
197 |
+
|
198 |
+
return Callbacks<
|
199 |
+
decltype(tC_gRow), decltype(tC_rRow),
|
200 |
+
decltype(tC_cRow), ProblemShape>(
|
201 |
+
cute::move(tC_gRow),
|
202 |
+
cute::move(tC_rRow),
|
203 |
+
cute::move(tC_cRow),
|
204 |
+
problem_shape,
|
205 |
+
params_ptr
|
206 |
+
);
|
207 |
+
}
|
208 |
+
|
209 |
+
};
|
210 |
+
|
211 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
212 |
+
|
213 |
+
// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null
|
214 |
+
template<
|
215 |
+
class ThreadMap,
|
216 |
+
class Element,
|
217 |
+
class StrideMNL
|
218 |
+
>
|
219 |
+
struct VisitorRowOrZeroBroadcast {
|
220 |
+
|
221 |
+
// This struct has been modified to remove null_default (because it's always 0)
|
222 |
+
struct Arguments {
|
223 |
+
Element const* ptr_row = nullptr;
|
224 |
+
StrideMNL dRow = {};
|
225 |
+
};
|
226 |
+
|
227 |
+
using Params = Arguments;
|
228 |
+
|
229 |
+
template <class ProblemShape>
|
230 |
+
static constexpr Params
|
231 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
232 |
+
return args;
|
233 |
+
}
|
234 |
+
|
235 |
+
template <class ProblemShape>
|
236 |
+
static size_t
|
237 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
238 |
+
return 0;
|
239 |
+
}
|
240 |
+
|
241 |
+
struct SharedStorage {};
|
242 |
+
|
243 |
+
// Global load type
|
244 |
+
static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits<Element>::value;
|
245 |
+
using VecType = uint_bit_t<cute::min(128, vec_bits)>;
|
246 |
+
static int constexpr VecLength = sizeof(VecType) / sizeof(Element);
|
247 |
+
|
248 |
+
CUTLASS_HOST_DEVICE
|
249 |
+
VisitorRowOrZeroBroadcast() { }
|
250 |
+
|
251 |
+
CUTLASS_HOST_DEVICE
|
252 |
+
VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage)
|
253 |
+
: params_ptr(¶ms) { }
|
254 |
+
|
255 |
+
Params const* params_ptr;
|
256 |
+
|
257 |
+
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
258 |
+
struct Callbacks : EmptyCallbacks {
|
259 |
+
CUTLASS_DEVICE
|
260 |
+
Callbacks(
|
261 |
+
GTensor&& tC_gRow,
|
262 |
+
RTensor&& tC_rRow,
|
263 |
+
CTensor&& tC_cRow,
|
264 |
+
ProblemShape problem_shape,
|
265 |
+
Params const* params_ptr
|
266 |
+
):
|
267 |
+
tC_gRow(cute::forward<GTensor>(tC_gRow)),
|
268 |
+
tC_rRow(cute::forward<RTensor>(tC_rRow)),
|
269 |
+
tC_cRow(cute::forward<CTensor>(tC_cRow)),
|
270 |
+
n(get<1>(problem_shape)),
|
271 |
+
params_ptr(params_ptr) { }
|
272 |
+
|
273 |
+
GTensor tC_gRow;
|
274 |
+
RTensor tC_rRow;
|
275 |
+
CTensor tC_cRow;
|
276 |
+
Params const* params_ptr;
|
277 |
+
int n;
|
278 |
+
|
279 |
+
// This function is modified from VisitorRowBroadcast
|
280 |
+
CUTLASS_DEVICE void
|
281 |
+
begin_epilogue() {
|
282 |
+
clear(tC_rRow);
|
283 |
+
auto src_v = filter(tC_gRow);
|
284 |
+
auto coord_v = filter(tC_cRow);
|
285 |
+
auto dst_v = filter(tC_rRow);
|
286 |
+
|
287 |
+
if (params_ptr->ptr_row != nullptr) {
|
288 |
+
// In this case we are loading from a row vector and broadcasting
|
289 |
+
CUTLASS_PRAGMA_UNROLL
|
290 |
+
for (int i = 0; i < size(src_v); ++i) {
|
291 |
+
bool guard = get<1>(coord_v(i)) < n;
|
292 |
+
cutlass::arch::global_load<VecType, sizeof(VecType)>(
|
293 |
+
dst_v(i), (void const*)&src_v(i), guard);
|
294 |
+
}
|
295 |
+
} else {
|
296 |
+
// In this case we are broadcasting 0
|
297 |
+
VecType filled_vec;
|
298 |
+
CUTLASS_PRAGMA_UNROLL
|
299 |
+
for (int i = 0; i < VecLength; i++) {
|
300 |
+
reinterpret_cast<Element*>(&filled_vec)[i] = Element{0};
|
301 |
+
}
|
302 |
+
|
303 |
+
CUTLASS_PRAGMA_UNROLL
|
304 |
+
for (int i = 0; i < size(src_v); ++i) {
|
305 |
+
if (get<1>(coord_v(i)) < n) {
|
306 |
+
dst_v(i) = filled_vec;
|
307 |
+
}
|
308 |
+
}
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
template <class ElementAccumulator, int FragmentSize>
|
313 |
+
CUTLASS_DEVICE auto // returns an Array
|
314 |
+
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
315 |
+
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
316 |
+
Tensor rRow_frg = recast<Array<Element, FragmentSize>>(coalesce(tC_rRow));
|
317 |
+
return rRow_frg(column_idx);
|
318 |
+
}
|
319 |
+
};
|
320 |
+
|
321 |
+
template <class ProblemShape>
|
322 |
+
CUTLASS_DEVICE auto
|
323 |
+
get_callbacks(
|
324 |
+
gemm::GemmCoord threadblock_tile_offset,
|
325 |
+
int thread_idx,
|
326 |
+
ProblemShape problem_shape
|
327 |
+
) {
|
328 |
+
Tensor mRow = make_tensor(
|
329 |
+
make_gmem_ptr(params_ptr->ptr_row),
|
330 |
+
problem_shape,
|
331 |
+
params_ptr->dRow);
|
332 |
+
|
333 |
+
// VECTOR, FRAGMENT_COLUMN
|
334 |
+
Tensor tC_gRow = recast<VecType>(
|
335 |
+
ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset)
|
336 |
+
)(_,_,_0{},_0{},_0{},_0{});
|
337 |
+
Tensor tC_rRow = make_tensor_like(tC_gRow);
|
338 |
+
|
339 |
+
// Generate the pred tensor
|
340 |
+
Tensor cRow = make_identity_tensor(mRow.shape());
|
341 |
+
Tensor tC_cRow = outer_partition(
|
342 |
+
ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}),
|
343 |
+
Shape<Int<VecLength>>{},
|
344 |
+
(_0{})
|
345 |
+
);
|
346 |
+
|
347 |
+
return Callbacks<
|
348 |
+
decltype(tC_gRow), decltype(tC_rRow),
|
349 |
+
decltype(tC_cRow), ProblemShape>(
|
350 |
+
cute::move(tC_gRow),
|
351 |
+
cute::move(tC_rRow),
|
352 |
+
cute::move(tC_cRow),
|
353 |
+
problem_shape,
|
354 |
+
params_ptr
|
355 |
+
);
|
356 |
+
}
|
357 |
+
|
358 |
+
};
|
359 |
+
|
360 |
+
|
361 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
362 |
+
|
363 |
+
// Column vector broadcast
|
364 |
+
template<
|
365 |
+
class ThreadMap,
|
366 |
+
class Element,
|
367 |
+
class StrideMNL = Stride<_1,_0,_0>
|
368 |
+
>
|
369 |
+
struct VisitorColOrScalarBroadcast {
|
370 |
+
|
371 |
+
// This struct has been modified to have a bool indicating that ptr_col is a
|
372 |
+
// scalar that must be broadcast.
|
373 |
+
struct Arguments {
|
374 |
+
Element const* ptr_col = nullptr;
|
375 |
+
bool col_broadcast = true;
|
376 |
+
StrideMNL dCol = {};
|
377 |
+
};
|
378 |
+
|
379 |
+
using Params = Arguments;
|
380 |
+
|
381 |
+
template <class ProblemShape>
|
382 |
+
static constexpr Params
|
383 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
384 |
+
return args;
|
385 |
+
}
|
386 |
+
|
387 |
+
template <class ProblemShape>
|
388 |
+
static size_t
|
389 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
390 |
+
return 0;
|
391 |
+
}
|
392 |
+
|
393 |
+
struct SharedStorage { };
|
394 |
+
|
395 |
+
CUTLASS_HOST_DEVICE
|
396 |
+
VisitorColOrScalarBroadcast() { }
|
397 |
+
|
398 |
+
CUTLASS_HOST_DEVICE
|
399 |
+
VisitorColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
400 |
+
: params_ptr(¶ms) { }
|
401 |
+
|
402 |
+
Params const* params_ptr;
|
403 |
+
|
404 |
+
template <class GTensor, class RTensor, class CTensor, class ProblemShape>
|
405 |
+
struct Callbacks : EmptyCallbacks {
|
406 |
+
CUTLASS_DEVICE
|
407 |
+
Callbacks(
|
408 |
+
GTensor&& tC_gCol,
|
409 |
+
RTensor&& tC_rCol,
|
410 |
+
CTensor&& tC_cCol,
|
411 |
+
ProblemShape problem_shape,
|
412 |
+
Params const* params_ptr
|
413 |
+
):
|
414 |
+
tC_gCol(cute::forward<GTensor>(tC_gCol)),
|
415 |
+
tC_rCol(cute::forward<RTensor>(tC_rCol)),
|
416 |
+
tC_cCol(cute::forward<CTensor>(tC_cCol)),
|
417 |
+
m(get<0>(problem_shape)),
|
418 |
+
params_ptr(params_ptr) { }
|
419 |
+
|
420 |
+
GTensor tC_gCol;
|
421 |
+
RTensor tC_rCol;
|
422 |
+
CTensor tC_cCol;
|
423 |
+
Params const* params_ptr;
|
424 |
+
int m;
|
425 |
+
|
426 |
+
// This function is modified from VisitorColBroadcast
|
427 |
+
CUTLASS_DEVICE void
|
428 |
+
begin_epilogue() {
|
429 |
+
clear(tC_rCol);
|
430 |
+
|
431 |
+
Tensor pred = make_tensor<bool>(shape(tC_gCol));
|
432 |
+
CUTLASS_PRAGMA_UNROLL
|
433 |
+
for (int i = 0; i < size(pred); ++i) {
|
434 |
+
pred(i) = get<0>(tC_cCol(i)) < m;
|
435 |
+
}
|
436 |
+
|
437 |
+
if (params_ptr->col_broadcast) {
|
438 |
+
// In this case we are loading from a column vector and broadcasting
|
439 |
+
copy_if(pred, tC_gCol, tC_rCol);
|
440 |
+
} else {
|
441 |
+
// In this case we are loading from a scalar and broadcasting
|
442 |
+
auto dst_v = filter(tC_rCol);
|
443 |
+
|
444 |
+
CUTLASS_PRAGMA_UNROLL
|
445 |
+
for (int i = 0; i < size(dst_v); ++i) {
|
446 |
+
if (pred(i)) {
|
447 |
+
dst_v(i) = *(params_ptr->ptr_col);
|
448 |
+
}
|
449 |
+
}
|
450 |
+
}
|
451 |
+
}
|
452 |
+
|
453 |
+
template <class ElementAccumulator, int FragmentSize>
|
454 |
+
CUTLASS_DEVICE auto // returns an Array
|
455 |
+
visit(int iter_idx, int row_idx, int column_idx, int frg_idx,
|
456 |
+
Array<ElementAccumulator, FragmentSize> const& frg_acc) {
|
457 |
+
Array<Element, FragmentSize> frg_col;
|
458 |
+
frg_col.fill(tC_rCol(row_idx,iter_idx));
|
459 |
+
return frg_col;
|
460 |
+
}
|
461 |
+
};
|
462 |
+
|
463 |
+
template <class ProblemShape>
|
464 |
+
CUTLASS_DEVICE auto
|
465 |
+
get_callbacks(
|
466 |
+
gemm::GemmCoord threadblock_tile_offset,
|
467 |
+
int thread_idx,
|
468 |
+
ProblemShape problem_shape
|
469 |
+
) {
|
470 |
+
Tensor mCol = make_tensor(
|
471 |
+
make_gmem_ptr(params_ptr->ptr_col),
|
472 |
+
problem_shape,
|
473 |
+
params_ptr->dCol);
|
474 |
+
|
475 |
+
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
|
476 |
+
Tensor tC_gCol = group_modes<1,4>(
|
477 |
+
ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
478 |
+
Tensor tC_rCol = make_tensor_like(tC_gCol);
|
479 |
+
|
480 |
+
// Generate the pred tensor
|
481 |
+
Tensor cCol = make_identity_tensor(mCol.shape());
|
482 |
+
Tensor tC_cCol = group_modes<1,4>(
|
483 |
+
ThreadMap::partition(cCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_));
|
484 |
+
|
485 |
+
return Callbacks<
|
486 |
+
decltype(tC_gCol), decltype(tC_rCol),
|
487 |
+
decltype(tC_cCol), ProblemShape>(
|
488 |
+
cute::move(tC_gCol),
|
489 |
+
cute::move(tC_rCol),
|
490 |
+
cute::move(tC_cCol),
|
491 |
+
problem_shape,
|
492 |
+
params_ptr
|
493 |
+
);
|
494 |
+
}
|
495 |
+
};
|
496 |
+
|
497 |
+
}
|
cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/***************************************************************************************************
|
2 |
+
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
|
3 |
+
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
4 |
+
*
|
5 |
+
* Redistribution and use in source and binary forms, with or without
|
6 |
+
* modification, are permitted provided that the following conditions are met:
|
7 |
+
*
|
8 |
+
* 1. Redistributions of source code must retain the above copyright notice,
|
9 |
+
*this list of conditions and the following disclaimer.
|
10 |
+
*
|
11 |
+
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
* this list of conditions and the following disclaimer in the documentation
|
13 |
+
* and/or other materials provided with the distribution.
|
14 |
+
*
|
15 |
+
* 3. Neither the name of the copyright holder nor the names of its
|
16 |
+
* contributors may be used to endorse or promote products derived from
|
17 |
+
* this software without specific prior written permission.
|
18 |
+
*
|
19 |
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
22 |
+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
23 |
+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
24 |
+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
25 |
+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
26 |
+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
27 |
+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
28 |
+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
29 |
+
*POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
*
|
31 |
+
**************************************************************************************************/
|
32 |
+
|
33 |
+
//
|
34 |
+
// This file is a modified excerpt of
|
35 |
+
// include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
|
36 |
+
// from https://github.com/NVIDIA/cutlass v3.5.0
|
37 |
+
// It has been modified to support either row/column or scalar broadcasting
|
38 |
+
// where the tensor being loaded from is always passed in via a device pointer.
|
39 |
+
// This lets one compiled kernel handle all cases of per-tensor or
|
40 |
+
// per-channel/per-token quantization.
|
41 |
+
//
|
42 |
+
// This interface also allows the scales to be passed in as tensors that
|
43 |
+
// consistently reside on the device, which avoids an issue with a previous
|
44 |
+
// implementation where scalars needed to be on the CPU since they
|
45 |
+
// were passed in via float values. This created a potential performance hazard
|
46 |
+
// if scales were initially on the device, and caused torch.compile graphs
|
47 |
+
// breaks when moving scales to the CPU.
|
48 |
+
//
|
49 |
+
#pragma once
|
50 |
+
|
51 |
+
// Turn off clang-format for the entire file to keep it close to upstream
|
52 |
+
// clang-format off
|
53 |
+
|
54 |
+
#include "cutlass/cutlass.h"
|
55 |
+
#include "cutlass/arch/barrier.h"
|
56 |
+
|
57 |
+
#include "cute/tensor.hpp"
|
58 |
+
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
|
59 |
+
|
60 |
+
namespace cutlass::epilogue::fusion {
|
61 |
+
|
62 |
+
using namespace cute;
|
63 |
+
using namespace detail;
|
64 |
+
|
65 |
+
// Row vector broadcast
|
66 |
+
template<
|
67 |
+
int Stages,
|
68 |
+
class CtaTileShapeMNK,
|
69 |
+
class Element,
|
70 |
+
class StrideMNL = Stride<_0,_1,_0>,
|
71 |
+
int Alignment = 128 / sizeof_bits_v<Element>
|
72 |
+
>
|
73 |
+
struct Sm90RowOrScalarBroadcast {
|
74 |
+
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
75 |
+
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
76 |
+
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
77 |
+
|
78 |
+
struct SharedStorage {
|
79 |
+
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
80 |
+
};
|
81 |
+
|
82 |
+
// This struct has been modified to have a bool indicating that ptr_row is a
|
83 |
+
// scalar that must be broadcast, instead of containing a scalar that is
|
84 |
+
// valid if ptr_row is null.
|
85 |
+
struct Arguments {
|
86 |
+
Element const* ptr_row = nullptr;
|
87 |
+
bool row_broadcast = true;
|
88 |
+
StrideMNL dRow = {};
|
89 |
+
};
|
90 |
+
|
91 |
+
using Params = Arguments;
|
92 |
+
|
93 |
+
template <class ProblemShape>
|
94 |
+
static constexpr Params
|
95 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
96 |
+
return args;
|
97 |
+
}
|
98 |
+
|
99 |
+
template <class ProblemShape>
|
100 |
+
static bool
|
101 |
+
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
102 |
+
return true;
|
103 |
+
}
|
104 |
+
|
105 |
+
template <class ProblemShape>
|
106 |
+
static size_t
|
107 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
108 |
+
return 0;
|
109 |
+
}
|
110 |
+
|
111 |
+
template <class ProblemShape>
|
112 |
+
static cutlass::Status
|
113 |
+
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
114 |
+
CudaHostAdapter* cuda_adapter = nullptr) {
|
115 |
+
return cutlass::Status::kSuccess;
|
116 |
+
}
|
117 |
+
|
118 |
+
CUTLASS_HOST_DEVICE
|
119 |
+
Sm90RowOrScalarBroadcast() { }
|
120 |
+
|
121 |
+
CUTLASS_HOST_DEVICE
|
122 |
+
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
123 |
+
: params(params)
|
124 |
+
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
125 |
+
|
126 |
+
Params params;
|
127 |
+
Element *smem = nullptr;
|
128 |
+
|
129 |
+
CUTLASS_DEVICE bool
|
130 |
+
is_producer_load_needed() const {
|
131 |
+
return false;
|
132 |
+
}
|
133 |
+
|
134 |
+
CUTLASS_DEVICE bool
|
135 |
+
is_C_load_needed() const {
|
136 |
+
return false;
|
137 |
+
}
|
138 |
+
|
139 |
+
CUTLASS_DEVICE bool
|
140 |
+
is_zero() const {
|
141 |
+
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
142 |
+
}
|
143 |
+
|
144 |
+
template <class... Args>
|
145 |
+
CUTLASS_DEVICE auto
|
146 |
+
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
147 |
+
return EmptyProducerLoadCallbacks{};
|
148 |
+
}
|
149 |
+
|
150 |
+
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
151 |
+
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
152 |
+
CUTLASS_DEVICE
|
153 |
+
ConsumerStoreCallbacks(
|
154 |
+
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
155 |
+
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
156 |
+
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
157 |
+
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
158 |
+
: tGS_gRow(tGS_gRow_)
|
159 |
+
, tGS_sRow(tGS_sRow_)
|
160 |
+
, tGS_cRow(tGS_cRow_)
|
161 |
+
, tiled_G2S(tiled_g2s_)
|
162 |
+
, tSR_sRow(tSR_sRow_)
|
163 |
+
, tSR_rRow(tSR_rRow_)
|
164 |
+
, tCcRow(tCcRow_)
|
165 |
+
, residue_tCcRow(residue_tCcRow_)
|
166 |
+
, params(params_) {}
|
167 |
+
|
168 |
+
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
169 |
+
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
170 |
+
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
171 |
+
Tiled_G2S tiled_G2S;
|
172 |
+
|
173 |
+
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
174 |
+
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
175 |
+
|
176 |
+
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
177 |
+
ThrResidue residue_tCcRow; // (m, n)
|
178 |
+
ThrNum thr_num;
|
179 |
+
Params const& params;
|
180 |
+
|
181 |
+
CUTLASS_DEVICE void
|
182 |
+
begin() {
|
183 |
+
if (!params.row_broadcast) {
|
184 |
+
fill(tSR_rRow, *(params.ptr_row));
|
185 |
+
return;
|
186 |
+
}
|
187 |
+
|
188 |
+
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
189 |
+
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
190 |
+
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
191 |
+
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
192 |
+
|
193 |
+
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
194 |
+
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
195 |
+
continue; // OOB of SMEM,
|
196 |
+
}
|
197 |
+
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
198 |
+
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
199 |
+
}
|
200 |
+
else {
|
201 |
+
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
202 |
+
}
|
203 |
+
}
|
204 |
+
synchronize();
|
205 |
+
}
|
206 |
+
|
207 |
+
CUTLASS_DEVICE void
|
208 |
+
begin_loop(int epi_m, int epi_n) {
|
209 |
+
if (epi_m == 0) { // Assumes M-major subtile loop
|
210 |
+
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
211 |
+
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
212 |
+
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
213 |
+
copy(tSR_sRow_flt, tSR_rRow_flt);
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
template <typename ElementAccumulator, int FragmentSize>
|
218 |
+
CUTLASS_DEVICE Array<Element, FragmentSize>
|
219 |
+
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
220 |
+
Array<Element, FragmentSize> frg_row;
|
221 |
+
|
222 |
+
CUTLASS_PRAGMA_UNROLL
|
223 |
+
for (int i = 0; i < FragmentSize; ++i) {
|
224 |
+
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
225 |
+
}
|
226 |
+
|
227 |
+
return frg_row;
|
228 |
+
}
|
229 |
+
};
|
230 |
+
|
231 |
+
template <
|
232 |
+
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
233 |
+
class... Args
|
234 |
+
>
|
235 |
+
CUTLASS_DEVICE auto
|
236 |
+
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
237 |
+
auto [M, N, K, L] = args.problem_shape_mnkl;
|
238 |
+
auto [m, n, k, l] = args.tile_coord_mnkl;
|
239 |
+
using ThreadCount = decltype(size(args.tiled_copy));
|
240 |
+
|
241 |
+
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
242 |
+
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
243 |
+
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
244 |
+
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
245 |
+
//// G2S: Gmem to Smem
|
246 |
+
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
247 |
+
Layout< Shape<_1, ThreadCount>,
|
248 |
+
Stride<_0, _1>>{},
|
249 |
+
Layout<_1>{});
|
250 |
+
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
251 |
+
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
252 |
+
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
253 |
+
|
254 |
+
//// G2S: Coord
|
255 |
+
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
256 |
+
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
257 |
+
|
258 |
+
//// S2R: Smem to Reg
|
259 |
+
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
260 |
+
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
261 |
+
|
262 |
+
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
263 |
+
tGS_gRow,
|
264 |
+
tGS_sRow,
|
265 |
+
tGS_cRow, tiled_g2s,
|
266 |
+
tSR_sRow,
|
267 |
+
tSR_rRow,
|
268 |
+
args.tCcD,
|
269 |
+
args.residue_cD,
|
270 |
+
ThreadCount{},
|
271 |
+
params);
|
272 |
+
}
|
273 |
+
};
|
274 |
+
|
275 |
+
/////////////////////////////////////////////////////////////////////////////////////////////////
|
276 |
+
|
277 |
+
// Column vector broadcast
|
278 |
+
template<
|
279 |
+
int Stages,
|
280 |
+
class CtaTileShapeMNK,
|
281 |
+
class Element,
|
282 |
+
class StrideMNL = Stride<_1,_0,_0>,
|
283 |
+
int Alignment = 128 / sizeof_bits_v<Element>
|
284 |
+
>
|
285 |
+
struct Sm90ColOrScalarBroadcast {
|
286 |
+
static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet");
|
287 |
+
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
288 |
+
static_assert(
|
289 |
+
(cute::is_same_v<StrideMNL, Stride<_1,_0, _0>>) || // col vector broadcast, e.g. per-row alpha/bias
|
290 |
+
(cute::is_same_v<StrideMNL, Stride<_1,_0,int>>)); // batched col vector broadcast, e.g. batched per-row bias
|
291 |
+
|
292 |
+
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
|
293 |
+
struct SharedStorage { };
|
294 |
+
|
295 |
+
// This struct has been modified to have a bool indicating that ptr_col is a
|
296 |
+
// scalar that must be broadcast, instead of containing a scalar that is
|
297 |
+
// valid if ptr_col is null.
|
298 |
+
struct Arguments {
|
299 |
+
Element const* ptr_col = nullptr;
|
300 |
+
bool col_broadcast = true;
|
301 |
+
StrideMNL dCol = {};
|
302 |
+
};
|
303 |
+
|
304 |
+
using Params = Arguments;
|
305 |
+
|
306 |
+
template <class ProblemShape>
|
307 |
+
static constexpr Params
|
308 |
+
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
|
309 |
+
return args;
|
310 |
+
}
|
311 |
+
|
312 |
+
template <class ProblemShape>
|
313 |
+
static bool
|
314 |
+
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
315 |
+
return true;
|
316 |
+
}
|
317 |
+
|
318 |
+
template <class ProblemShape>
|
319 |
+
static size_t
|
320 |
+
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
321 |
+
return 0;
|
322 |
+
}
|
323 |
+
|
324 |
+
template <class ProblemShape>
|
325 |
+
static cutlass::Status
|
326 |
+
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
327 |
+
CudaHostAdapter* cuda_adapter = nullptr) {
|
328 |
+
return cutlass::Status::kSuccess;
|
329 |
+
}
|
330 |
+
|
331 |
+
CUTLASS_DEVICE bool
|
332 |
+
is_producer_load_needed() const {
|
333 |
+
return false;
|
334 |
+
}
|
335 |
+
|
336 |
+
CUTLASS_DEVICE bool
|
337 |
+
is_C_load_needed() const {
|
338 |
+
return false;
|
339 |
+
}
|
340 |
+
|
341 |
+
CUTLASS_DEVICE bool
|
342 |
+
is_zero() const {
|
343 |
+
return (!params.col_broadcast && *(params.ptr_col) == Element(0));
|
344 |
+
}
|
345 |
+
|
346 |
+
CUTLASS_HOST_DEVICE
|
347 |
+
Sm90ColOrScalarBroadcast() { }
|
348 |
+
|
349 |
+
CUTLASS_HOST_DEVICE
|
350 |
+
Sm90ColOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
351 |
+
: params(params) { }
|
352 |
+
|
353 |
+
Params params;
|
354 |
+
|
355 |
+
template <class... Args>
|
356 |
+
CUTLASS_DEVICE auto
|
357 |
+
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
358 |
+
return EmptyProducerLoadCallbacks{};
|
359 |
+
}
|
360 |
+
|
361 |
+
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
362 |
+
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
363 |
+
CUTLASS_DEVICE
|
364 |
+
ConsumerStoreCallbacks(
|
365 |
+
GTensor&& tCgCol,
|
366 |
+
RTensor&& tCrCol,
|
367 |
+
CTensor&& tCcCol,
|
368 |
+
ProblemShape problem_shape,
|
369 |
+
Params const& params
|
370 |
+
):
|
371 |
+
tCgCol(cute::forward<GTensor>(tCgCol)),
|
372 |
+
tCrCol(cute::forward<RTensor>(tCrCol)),
|
373 |
+
tCcCol(cute::forward<CTensor>(tCcCol)),
|
374 |
+
m(get<0>(problem_shape)),
|
375 |
+
params(params) {}
|
376 |
+
|
377 |
+
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
378 |
+
RTensor tCrCol;
|
379 |
+
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
380 |
+
Params const& params;
|
381 |
+
int m;
|
382 |
+
|
383 |
+
CUTLASS_DEVICE void
|
384 |
+
begin() {
|
385 |
+
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
386 |
+
CUTLASS_PRAGMA_UNROLL
|
387 |
+
for (int i = 0; i < size(pred); ++i) {
|
388 |
+
pred(i) = get<0>(tCcCol(i)) < m;
|
389 |
+
}
|
390 |
+
|
391 |
+
if (!params.col_broadcast) {
|
392 |
+
fill(tCrCol, *(params.ptr_col));
|
393 |
+
return;
|
394 |
+
}
|
395 |
+
|
396 |
+
// Filter so we don't issue redundant copies over stride-0 modes
|
397 |
+
// (only works if 0-strides are in same location, which is by construction)
|
398 |
+
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
399 |
+
}
|
400 |
+
|
401 |
+
template <typename ElementAccumulator, int FragmentSize>
|
402 |
+
CUTLASS_DEVICE Array<Element, FragmentSize>
|
403 |
+
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
|
404 |
+
Array<Element, FragmentSize> frg_col;
|
405 |
+
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
|
406 |
+
|
407 |
+
CUTLASS_PRAGMA_UNROLL
|
408 |
+
for (int i = 0; i < FragmentSize; ++i) {
|
409 |
+
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
|
410 |
+
}
|
411 |
+
|
412 |
+
return frg_col;
|
413 |
+
}
|
414 |
+
|
415 |
+
};
|
416 |
+
|
417 |
+
template <
|
418 |
+
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
|
419 |
+
class... Args
|
420 |
+
>
|
421 |
+
CUTLASS_DEVICE auto
|
422 |
+
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
423 |
+
|
424 |
+
auto [M, N, K, L] = args.problem_shape_mnkl;
|
425 |
+
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
|
426 |
+
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
427 |
+
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
428 |
+
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
429 |
+
|
430 |
+
// Generate an identity tensor matching the shape of the global tensor and
|
431 |
+
// partition the same way, this will be used to generate the predicate
|
432 |
+
// tensor for loading
|
433 |
+
Tensor cCol = make_identity_tensor(mCol.shape());
|
434 |
+
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
435 |
+
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
436 |
+
|
437 |
+
return ConsumerStoreCallbacks(
|
438 |
+
cute::move(tCgCol),
|
439 |
+
cute::move(tCrCol),
|
440 |
+
cute::move(tCcCol),
|
441 |
+
args.problem_shape_mnkl,
|
442 |
+
params
|
443 |
+
);
|
444 |
+
}
|
445 |
+
};
|
446 |
+
|
447 |
+
}
|
cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
|
2 |
+
|
3 |
+
/*
|
4 |
+
This file defines custom epilogues for fusing channel scales, token scales,
|
5 |
+
bias, and activation zero-points onto a GEMM operation using the
|
6 |
+
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
|
7 |
+
|
8 |
+
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
9 |
+
as well as a static prepare_args function that constructs an
|
10 |
+
EVTCompute::Arguments struct.
|
11 |
+
*/
|
12 |
+
|
13 |
+
namespace vllm::c2x {
|
14 |
+
|
15 |
+
using namespace cute;
|
16 |
+
|
17 |
+
/*
|
18 |
+
* This class provides the common load descriptors for the
|
19 |
+
* ScaledEpilogue[...] classes
|
20 |
+
*/
|
21 |
+
template <typename ElementD, typename OutputTileThreadMap>
|
22 |
+
struct ScaledEpilogueBase {
|
23 |
+
protected:
|
24 |
+
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
25 |
+
|
26 |
+
template <typename T>
|
27 |
+
using ColOrScalarLoad =
|
28 |
+
cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
29 |
+
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
30 |
+
|
31 |
+
template <typename T>
|
32 |
+
using RowOrScalarLoad =
|
33 |
+
cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
34 |
+
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
35 |
+
|
36 |
+
template <typename T>
|
37 |
+
using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast<
|
38 |
+
OutputTileThreadMap, T, Stride<Int<1>, Int<0>, Int<0>>>;
|
39 |
+
|
40 |
+
template <typename T>
|
41 |
+
using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
42 |
+
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
43 |
+
|
44 |
+
template <typename T>
|
45 |
+
using RowOrZeroLoad =
|
46 |
+
cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast<
|
47 |
+
OutputTileThreadMap, T, Stride<Int<0>, Int<1>, Int<0>>>;
|
48 |
+
|
49 |
+
// This utility function constructs the arguments for the load descriptors
|
50 |
+
// from a tensor. It can handle both row and column, as well as row/column or
|
51 |
+
// scalar cases.
|
52 |
+
template <typename Descriptor, typename T>
|
53 |
+
static auto args_from_tensor(torch::Tensor const& tensor) {
|
54 |
+
using Arguments = typename Descriptor::Arguments;
|
55 |
+
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
56 |
+
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
57 |
+
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
58 |
+
return Arguments{data_ptr, tensor.numel() != 1};
|
59 |
+
} else {
|
60 |
+
// it would technically work but no use case as data_ptr is never nullptr
|
61 |
+
static_assert(!std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
62 |
+
return Arguments{data_ptr};
|
63 |
+
}
|
64 |
+
}
|
65 |
+
|
66 |
+
// This overload handles the case where there might not be a tensor, in which
|
67 |
+
// case a nullptr is passed and a constant (0) is used.
|
68 |
+
template <typename Descriptor, typename T>
|
69 |
+
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
70 |
+
static_assert(std::is_same_v<Descriptor, RowOrZeroLoad<T>>);
|
71 |
+
using Arguments = typename Descriptor::Arguments;
|
72 |
+
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
73 |
+
return Arguments{data_ptr};
|
74 |
+
}
|
75 |
+
};
|
76 |
+
|
77 |
+
/*
|
78 |
+
This epilogue function defines a quantized GEMM operation similar to
|
79 |
+
torch._scaled_mm.
|
80 |
+
|
81 |
+
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
82 |
+
per-row. B can be quantized per-tensor or per-column.
|
83 |
+
Any combination of per-tensor and per-row or column is supported.
|
84 |
+
A and B must have symmetric quantization (zero point == 0).
|
85 |
+
|
86 |
+
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
87 |
+
scales are applied elementwise with numpy-style broadcasting.
|
88 |
+
|
89 |
+
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
90 |
+
the A and B operands respectively. These scales may be either per-tensor or
|
91 |
+
per row or column.
|
92 |
+
*/
|
93 |
+
template <typename ElementD, typename OutputTileThreadMap>
|
94 |
+
struct ScaledEpilogue
|
95 |
+
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
96 |
+
private:
|
97 |
+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
98 |
+
using Accum = typename SUPER::Accum;
|
99 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
100 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
101 |
+
|
102 |
+
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
103 |
+
cutlass::multiplies, float, float,
|
104 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
105 |
+
|
106 |
+
using EVTCompute0 =
|
107 |
+
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
108 |
+
|
109 |
+
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
110 |
+
cutlass::multiplies, ElementD, float,
|
111 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
112 |
+
|
113 |
+
public:
|
114 |
+
using EVTCompute =
|
115 |
+
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
116 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
117 |
+
|
118 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
119 |
+
torch::Tensor const& b_scales) {
|
120 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
121 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
122 |
+
|
123 |
+
typename EVTCompute0::Arguments evt0_args{b_args};
|
124 |
+
return ArgumentType{a_args, evt0_args};
|
125 |
+
}
|
126 |
+
};
|
127 |
+
|
128 |
+
/*
|
129 |
+
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
130 |
+
* This bias can also be used in the per-tensor azp case, where the activation
|
131 |
+
* zero point (azp) is used to compute an azp correction term,
|
132 |
+
* which is folded into the bias.
|
133 |
+
*
|
134 |
+
* The bias tensor must be per-output channel.
|
135 |
+
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
136 |
+
*/
|
137 |
+
template <typename ElementD, typename OutputTileThreadMap>
|
138 |
+
struct ScaledEpilogueBias
|
139 |
+
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
140 |
+
protected:
|
141 |
+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
142 |
+
using Accum = typename SUPER::Accum;
|
143 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
144 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
145 |
+
using Bias = typename SUPER::template RowLoad<ElementD>;
|
146 |
+
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
147 |
+
cutlass::multiplies, float, float,
|
148 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
149 |
+
|
150 |
+
using EVTCompute0 =
|
151 |
+
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
152 |
+
|
153 |
+
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
154 |
+
cutlass::multiply_add, ElementD, float,
|
155 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
156 |
+
|
157 |
+
public:
|
158 |
+
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
159 |
+
EVTCompute0, Bias>;
|
160 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
161 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
162 |
+
torch::Tensor const& b_scales,
|
163 |
+
torch::Tensor const& bias) {
|
164 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
165 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
166 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
167 |
+
|
168 |
+
typename EVTCompute0::Arguments evt0_args{b_args};
|
169 |
+
return ArgumentType{a_args, evt0_args, bias_args};
|
170 |
+
}
|
171 |
+
};
|
172 |
+
|
173 |
+
/*
|
174 |
+
* This epilogue directly supports per-tensor azp in int32 form.
|
175 |
+
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
176 |
+
* term, which should already be multiplied with the scalar azp.
|
177 |
+
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
178 |
+
*
|
179 |
+
* This epilogue also supports bias, which remains per-channel.
|
180 |
+
*/
|
181 |
+
template <typename ElementD, typename OutputTileThreadMap>
|
182 |
+
struct ScaledEpilogueBiasAzp
|
183 |
+
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
184 |
+
private:
|
185 |
+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
186 |
+
using Accum = typename SUPER::Accum;
|
187 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
188 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
189 |
+
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
190 |
+
|
191 |
+
// This is the full AZP term, azp * J @ B, shape (1,n)
|
192 |
+
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
193 |
+
|
194 |
+
// Compute float(accum - azp_adj), both operands are int32_t
|
195 |
+
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
196 |
+
cutlass::minus, float, int32_t,
|
197 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
198 |
+
|
199 |
+
using EVTComputeAzp =
|
200 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Accum, AzpWithAdj>;
|
201 |
+
|
202 |
+
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
203 |
+
cutlass::multiplies, float, float,
|
204 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
205 |
+
|
206 |
+
using EVTComputeScaleB =
|
207 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
208 |
+
EVTComputeAzp>;
|
209 |
+
|
210 |
+
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
211 |
+
cutlass::multiply_add, ElementD, float,
|
212 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
213 |
+
|
214 |
+
public:
|
215 |
+
using EVTCompute =
|
216 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
217 |
+
EVTComputeScaleB, Bias>;
|
218 |
+
|
219 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
220 |
+
|
221 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
222 |
+
torch::Tensor const& b_scales,
|
223 |
+
torch::Tensor const& azp_adj,
|
224 |
+
c10::optional<torch::Tensor> const& bias) {
|
225 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
226 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
227 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
228 |
+
auto azp_adj_args =
|
229 |
+
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
230 |
+
|
231 |
+
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
232 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
233 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
234 |
+
}
|
235 |
+
};
|
236 |
+
|
237 |
+
/*
|
238 |
+
* This epilogue supports per-token azp by computing and applying
|
239 |
+
* the correction term using a rank-1 update. If the term were materialized,
|
240 |
+
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
241 |
+
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
242 |
+
* point for each row of A.
|
243 |
+
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
244 |
+
*
|
245 |
+
* This epilogue also supports bias, which remains per-channel.
|
246 |
+
*/
|
247 |
+
template <typename ElementD, typename OutputTileThreadMap>
|
248 |
+
struct ScaledEpilogueBiasAzpToken
|
249 |
+
: protected ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
250 |
+
private:
|
251 |
+
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
252 |
+
using Accum = typename SUPER::Accum;
|
253 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
254 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
255 |
+
using Bias = typename SUPER::template RowOrZeroLoad<ElementD>;
|
256 |
+
|
257 |
+
// Per-token azp term, shape (m,1)
|
258 |
+
using Azp = typename SUPER::template ColLoad<int32_t>;
|
259 |
+
|
260 |
+
// This is the AZP adjustment term, J @ B, shape (1,n)
|
261 |
+
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
262 |
+
|
263 |
+
// Compute azp * azp_adj
|
264 |
+
using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute<
|
265 |
+
cutlass::multiplies, int32_t, int32_t,
|
266 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
267 |
+
|
268 |
+
using EVTComputeAzp =
|
269 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeAzp, Azp, AzpAdj>;
|
270 |
+
|
271 |
+
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
272 |
+
using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute<
|
273 |
+
cutlass::minus, float, int32_t,
|
274 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
275 |
+
|
276 |
+
using EVTComputeAcc =
|
277 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
278 |
+
|
279 |
+
using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute<
|
280 |
+
cutlass::multiplies, float, float,
|
281 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
282 |
+
|
283 |
+
using EVTComputeScaleB =
|
284 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleB, ScaleB,
|
285 |
+
EVTComputeAcc>;
|
286 |
+
|
287 |
+
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
288 |
+
cutlass::multiply_add, ElementD, float,
|
289 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
290 |
+
|
291 |
+
public:
|
292 |
+
using EVTCompute =
|
293 |
+
cutlass::epilogue::threadblock::Sm80EVT<ComputeScaleBiasA, ScaleA,
|
294 |
+
EVTComputeScaleB, Bias>;
|
295 |
+
|
296 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
297 |
+
|
298 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
299 |
+
torch::Tensor const& b_scales,
|
300 |
+
torch::Tensor const& azp_adj,
|
301 |
+
torch::Tensor const& azp,
|
302 |
+
c10::optional<torch::Tensor> const& bias) {
|
303 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
304 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
305 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
306 |
+
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
307 |
+
auto azp_adj_args =
|
308 |
+
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
309 |
+
|
310 |
+
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
311 |
+
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
312 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
313 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
314 |
+
}
|
315 |
+
};
|
316 |
+
|
317 |
+
}; // namespace vllm::c2x
|
cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
|
2 |
+
|
3 |
+
/*
|
4 |
+
This file defines custom epilogues for fusing channel scales, token scales,
|
5 |
+
bias, and activation zero-points onto a GEMM operation using the
|
6 |
+
CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later.
|
7 |
+
|
8 |
+
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
9 |
+
as well as a static prepare_args function that constructs an
|
10 |
+
EVTCompute::Arguments struct.
|
11 |
+
*/
|
12 |
+
|
13 |
+
namespace vllm::c3x {
|
14 |
+
|
15 |
+
using namespace cute;
|
16 |
+
|
17 |
+
/*
|
18 |
+
* This class provides the common load descriptors for the
|
19 |
+
* ScaledEpilogue[...] classes
|
20 |
+
*/
|
21 |
+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
22 |
+
struct ScaledEpilogueBase {
|
23 |
+
protected:
|
24 |
+
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
|
25 |
+
|
26 |
+
template <typename T>
|
27 |
+
using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
|
28 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
29 |
+
Stride<Int<1>, Int<0>, Int<0>>>;
|
30 |
+
|
31 |
+
template <typename T>
|
32 |
+
using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
33 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
34 |
+
Stride<Int<0>, Int<1>, Int<0>>>;
|
35 |
+
|
36 |
+
// Don't want to support nullptr by default
|
37 |
+
template <typename T, bool EnableNullPtr = false>
|
38 |
+
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
|
39 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
40 |
+
Stride<Int<1>, Int<0>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
41 |
+
|
42 |
+
// Don't want to support nullptr by default
|
43 |
+
template <typename T, bool EnableNullPtr = false>
|
44 |
+
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
45 |
+
0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
|
46 |
+
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<T>, EnableNullPtr>;
|
47 |
+
|
48 |
+
// This utility function constructs the arguments for the load descriptors
|
49 |
+
// from a tensor. It can handle both row and column, as well as row/column or
|
50 |
+
// scalar cases.
|
51 |
+
template <typename Descriptor, typename T>
|
52 |
+
static auto args_from_tensor(torch::Tensor const& tensor) {
|
53 |
+
using Arguments = typename Descriptor::Arguments;
|
54 |
+
auto* data_ptr = static_cast<T*>(tensor.data_ptr());
|
55 |
+
if constexpr (std::is_same_v<Descriptor, ColOrScalarLoad<T>> ||
|
56 |
+
std::is_same_v<Descriptor, RowOrScalarLoad<T>>) {
|
57 |
+
return Arguments{data_ptr, tensor.numel() != 1};
|
58 |
+
} else {
|
59 |
+
static_assert(!std::is_same_v<Descriptor, ColLoad<T, true>> &&
|
60 |
+
!std::is_same_v<Descriptor, RowLoad<T, true>>);
|
61 |
+
return Arguments{data_ptr};
|
62 |
+
}
|
63 |
+
}
|
64 |
+
|
65 |
+
// This overload handles the case where there might not be a tensor, in which
|
66 |
+
// case a nullptr is passed and a constant (0) is used.
|
67 |
+
template <typename Descriptor, typename T>
|
68 |
+
static auto args_from_tensor(c10::optional<torch::Tensor> const& tensor) {
|
69 |
+
using Arguments = typename Descriptor::Arguments;
|
70 |
+
auto* data_ptr = tensor ? static_cast<T*>(tensor->data_ptr()) : nullptr;
|
71 |
+
static_assert(std::is_same_v<Descriptor, ColLoad<T, true>> ||
|
72 |
+
std::is_same_v<Descriptor, RowLoad<T, true>>);
|
73 |
+
return Arguments{data_ptr};
|
74 |
+
}
|
75 |
+
};
|
76 |
+
|
77 |
+
/*
|
78 |
+
This epilogue function defines a quantized GEMM operation similar to
|
79 |
+
torch.scaled_mm_.
|
80 |
+
|
81 |
+
A and B may be both either int8 or fp8_e4m3. A can be
|
82 |
+
quantized per-tensor or per-row. B can be quantized per-tensor or per-column.
|
83 |
+
Any combination of per-tensor and per-row or column is supported.
|
84 |
+
A and B must have symmetric quantization (zero point == 0).
|
85 |
+
|
86 |
+
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
87 |
+
scales are applied elementwise with numpy-style broadcasting.
|
88 |
+
|
89 |
+
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
90 |
+
the A and B operands respectively. These scales may be either per-tensor or
|
91 |
+
per row or column.
|
92 |
+
*/
|
93 |
+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
94 |
+
struct ScaledEpilogue
|
95 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
96 |
+
private:
|
97 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
98 |
+
using Accum = typename SUPER::Accum;
|
99 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
100 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
101 |
+
|
102 |
+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
103 |
+
cutlass::multiplies, float, float,
|
104 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
105 |
+
|
106 |
+
using EVTCompute0 =
|
107 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
108 |
+
|
109 |
+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
110 |
+
cutlass::multiplies, ElementD, float,
|
111 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
112 |
+
|
113 |
+
public:
|
114 |
+
using EVTCompute =
|
115 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0>;
|
116 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
117 |
+
|
118 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
119 |
+
torch::Tensor const& b_scales) {
|
120 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
121 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
122 |
+
|
123 |
+
typename EVTCompute0::Arguments evt0_args{b_args};
|
124 |
+
return ArgumentType{a_args, evt0_args};
|
125 |
+
}
|
126 |
+
};
|
127 |
+
|
128 |
+
/*
|
129 |
+
* This epilogue performs the same operation as ScaledEpilogue, but adds a bias.
|
130 |
+
* This bias can also be used in the per-tensor azp case, where the activation
|
131 |
+
* zero point (azp) is used to compute an azp correction term,
|
132 |
+
* which is folded into the bias.
|
133 |
+
*
|
134 |
+
* The bias tensor must be per-output channel.
|
135 |
+
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
|
136 |
+
*/
|
137 |
+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
138 |
+
struct ScaledEpilogueBias
|
139 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
140 |
+
private:
|
141 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
142 |
+
using Accum = typename SUPER::Accum;
|
143 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
144 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
145 |
+
using Bias = typename SUPER::template RowLoad<ElementD>;
|
146 |
+
|
147 |
+
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
|
148 |
+
cutlass::multiplies, float, float,
|
149 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
150 |
+
|
151 |
+
using EVTCompute0 =
|
152 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
153 |
+
|
154 |
+
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
155 |
+
cutlass::multiply_add, ElementD, float,
|
156 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
157 |
+
|
158 |
+
public:
|
159 |
+
using EVTCompute =
|
160 |
+
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
|
161 |
+
|
162 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
163 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
164 |
+
torch::Tensor const& b_scales,
|
165 |
+
torch::Tensor const& bias) {
|
166 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
167 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
168 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
169 |
+
|
170 |
+
typename EVTCompute0::Arguments evt0_args{b_args};
|
171 |
+
return ArgumentType{a_args, evt0_args, bias_args};
|
172 |
+
}
|
173 |
+
};
|
174 |
+
|
175 |
+
/*
|
176 |
+
* This epilogue directly supports per-tensor azp in int32 form.
|
177 |
+
* As opposed to the per-token epilogue below, this epilogue only has an azp_adj
|
178 |
+
* term, which should already be multiplied with the scalar azp.
|
179 |
+
* The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B.
|
180 |
+
*
|
181 |
+
* This epilogue also supports bias, which remains per-channel.
|
182 |
+
*/
|
183 |
+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
184 |
+
struct ScaledEpilogueBiasAzp
|
185 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
186 |
+
private:
|
187 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
188 |
+
using Accum = typename SUPER::Accum;
|
189 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
190 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
191 |
+
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
192 |
+
|
193 |
+
// This is the full AZP term, azp * J @ B, shape (1,n)
|
194 |
+
using AzpWithAdj = typename SUPER::template RowLoad<int32_t>;
|
195 |
+
|
196 |
+
// Compute float(accum - azp_adj), both operands are int32_t
|
197 |
+
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
198 |
+
cutlass::minus, float, int32_t,
|
199 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
200 |
+
|
201 |
+
using EVTComputeAzp =
|
202 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Accum, AzpWithAdj>;
|
203 |
+
|
204 |
+
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
205 |
+
cutlass::multiplies, float, float,
|
206 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
207 |
+
|
208 |
+
using EVTComputeScaleB =
|
209 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
210 |
+
|
211 |
+
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
212 |
+
cutlass::multiply_add, ElementD, float,
|
213 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
214 |
+
|
215 |
+
public:
|
216 |
+
using EVTCompute =
|
217 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
218 |
+
EVTComputeScaleB, Bias>;
|
219 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
220 |
+
|
221 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
222 |
+
torch::Tensor const& b_scales,
|
223 |
+
torch::Tensor const& azp_adj,
|
224 |
+
c10::optional<torch::Tensor> const& bias) {
|
225 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
226 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
227 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
228 |
+
auto azp_adj_args =
|
229 |
+
SUPER::template args_from_tensor<AzpWithAdj, int32_t>(azp_adj);
|
230 |
+
|
231 |
+
typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args};
|
232 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args};
|
233 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
234 |
+
}
|
235 |
+
};
|
236 |
+
|
237 |
+
/*
|
238 |
+
* This epilogue supports per-token azp by computing and applying
|
239 |
+
* the correction term using a rank-1 update. If the term were materialized,
|
240 |
+
* it would require O(m*n) space, and this way it only requires O(m+n) space.
|
241 |
+
* The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero
|
242 |
+
* point for each row of A.
|
243 |
+
* The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B.
|
244 |
+
*
|
245 |
+
* This epilogue also supports bias, which remains per-channel.
|
246 |
+
*/
|
247 |
+
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
|
248 |
+
struct ScaledEpilogueBiasAzpToken
|
249 |
+
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
|
250 |
+
private:
|
251 |
+
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
|
252 |
+
using Accum = typename SUPER::Accum;
|
253 |
+
using ScaleA = typename SUPER::template ColOrScalarLoad<float>;
|
254 |
+
using ScaleB = typename SUPER::template RowOrScalarLoad<float>;
|
255 |
+
using Bias = typename SUPER::template RowLoad<ElementD, true>;
|
256 |
+
|
257 |
+
// Per-token azp term, shape (m,1)
|
258 |
+
using Azp = typename SUPER::template ColLoad<int32_t>;
|
259 |
+
|
260 |
+
// This is the AZP adjustment term, J @ B, shape (1,n)
|
261 |
+
using AzpAdj = typename SUPER::template RowLoad<int32_t>;
|
262 |
+
|
263 |
+
// Compute azp * azp_adj
|
264 |
+
using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute<
|
265 |
+
cutlass::multiplies, int32_t, int32_t,
|
266 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
267 |
+
|
268 |
+
using EVTComputeAzp =
|
269 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeAzp, Azp, AzpAdj>;
|
270 |
+
|
271 |
+
// Compute float(accum - azp*azp_adj), all operands are int32_t
|
272 |
+
using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute<
|
273 |
+
cutlass::minus, float, int32_t,
|
274 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
275 |
+
|
276 |
+
using EVTComputeAcc =
|
277 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeAcc, Accum, EVTComputeAzp>;
|
278 |
+
|
279 |
+
using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute<
|
280 |
+
cutlass::multiplies, float, float,
|
281 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
282 |
+
|
283 |
+
using EVTComputeScaleB =
|
284 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
285 |
+
|
286 |
+
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
287 |
+
cutlass::multiply_add, ElementD, float,
|
288 |
+
cutlass::FloatRoundStyle::round_to_nearest>;
|
289 |
+
|
290 |
+
public:
|
291 |
+
using EVTCompute =
|
292 |
+
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleBiasA, ScaleA,
|
293 |
+
EVTComputeScaleB, Bias>;
|
294 |
+
using ArgumentType = typename EVTCompute::Arguments;
|
295 |
+
|
296 |
+
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
297 |
+
torch::Tensor const& b_scales,
|
298 |
+
torch::Tensor const& azp_adj,
|
299 |
+
torch::Tensor const& azp,
|
300 |
+
c10::optional<torch::Tensor> const& bias) {
|
301 |
+
auto a_args = SUPER::template args_from_tensor<ScaleA, float>(a_scales);
|
302 |
+
auto b_args = SUPER::template args_from_tensor<ScaleB, float>(b_scales);
|
303 |
+
auto bias_args = SUPER::template args_from_tensor<Bias, ElementD>(bias);
|
304 |
+
auto azp_args = SUPER::template args_from_tensor<Azp, int32_t>(azp);
|
305 |
+
auto azp_adj_args =
|
306 |
+
SUPER::template args_from_tensor<AzpAdj, int32_t>(azp_adj);
|
307 |
+
|
308 |
+
typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args};
|
309 |
+
typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args};
|
310 |
+
typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args};
|
311 |
+
return ArgumentType{a_args, evt_scale_b_args, bias_args};
|
312 |
+
}
|
313 |
+
};
|
314 |
+
|
315 |
+
}; // namespace vllm::c3x
|
cutlass_w8a8/Epilogues.md
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUTLASS Epilogues
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs.
|
5 |
+
|
6 |
+
Currently, we only support symmetric quantization for weights,
|
7 |
+
and symmetric and asymmetric quantization for activations.
|
8 |
+
Both can be quantized per-tensor or per-channel (weights) / per-token (activations).
|
9 |
+
|
10 |
+
There are 4 epilogues:
|
11 |
+
1. ScaledEpilogue: symmetric quantization for activations, no bias.
|
12 |
+
1. ScaledEpilogueBias: symmetric quantization for activations, supports bias.
|
13 |
+
1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias.
|
14 |
+
1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias.
|
15 |
+
|
16 |
+
We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size.
|
17 |
+
Instead, if no bias is passed, the epilogue will use 0 as the bias.
|
18 |
+
That induces a redundant addition operation (and runtime check), but the performance impact is minor.
|
19 |
+
|
20 |
+
## Underlying Linear Algebra
|
21 |
+
|
22 |
+
More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975).
|
23 |
+
|
24 |
+
If $` \widehat X `$ is the quantized $` X `$, our matrices become the following
|
25 |
+
|
26 |
+
```math
|
27 |
+
A = s_a (\widehat A - J_a z_a)
|
28 |
+
```
|
29 |
+
```math
|
30 |
+
B = s_b \widehat B
|
31 |
+
```
|
32 |
+
```math
|
33 |
+
D = A B + C
|
34 |
+
```
|
35 |
+
```math
|
36 |
+
D = s_a s_b \widehat D + C
|
37 |
+
```
|
38 |
+
|
39 |
+
Here, D is the output of the GEMM, and C is the bias.
|
40 |
+
A is the activations and supports asymmetric quantization,
|
41 |
+
and B is the weights and only supports symmetric quantization.
|
42 |
+
$ s_a $ and $s_b$ are the scales for activations and weights, respectively.
|
43 |
+
$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A.
|
44 |
+
Additional epilogues would be required to support asymmetric quantization for weights.
|
45 |
+
|
46 |
+
Expanding further, we can calculate $` \widehat D `$ as follows:
|
47 |
+
|
48 |
+
```math
|
49 |
+
A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B
|
50 |
+
```
|
51 |
+
```math
|
52 |
+
A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right)
|
53 |
+
```
|
54 |
+
```math
|
55 |
+
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
56 |
+
```
|
57 |
+
|
58 |
+
Note that $` \widehat A \widehat B `$ is the raw output of the GEMM,
|
59 |
+
and $` J_a \widehat B `$ is known ahead of time.
|
60 |
+
Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$.
|
61 |
+
|
62 |
+
## Epilogues
|
63 |
+
|
64 |
+
### ScaledEpilogue
|
65 |
+
This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$.
|
66 |
+
The output of the GEMM is:
|
67 |
+
|
68 |
+
```math
|
69 |
+
\widehat D = \widehat A \widehat B
|
70 |
+
```
|
71 |
+
```math
|
72 |
+
D = s_a s_b \widehat D
|
73 |
+
```
|
74 |
+
```math
|
75 |
+
D = s_a s_b \widehat A \widehat B
|
76 |
+
```
|
77 |
+
|
78 |
+
Epilogue parameters:
|
79 |
+
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
80 |
+
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
81 |
+
|
82 |
+
### ScaledEpilogueBias
|
83 |
+
This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$.
|
84 |
+
The output of the GEMM is:
|
85 |
+
|
86 |
+
```math
|
87 |
+
\widehat D = \widehat A \widehat B
|
88 |
+
```
|
89 |
+
```math
|
90 |
+
D = s_a s_b \widehat D + C
|
91 |
+
```
|
92 |
+
```math
|
93 |
+
D = s_a s_b \widehat A \widehat B + C
|
94 |
+
```
|
95 |
+
|
96 |
+
|
97 |
+
Epilogue parameters:
|
98 |
+
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
99 |
+
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
100 |
+
- `bias` is the bias, is always per-channel (row-vector).
|
101 |
+
|
102 |
+
### ScaledEpilogueAzp
|
103 |
+
This epilogue computes the asymmetric per-tensor quantization for activations with bias.
|
104 |
+
The output of the GEMM is:
|
105 |
+
|
106 |
+
```math
|
107 |
+
\widehat D = \widehat A \widehat B - z_a J_a \widehat B
|
108 |
+
```
|
109 |
+
```math
|
110 |
+
D = s_a s_b \widehat D + C
|
111 |
+
```
|
112 |
+
```math
|
113 |
+
D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C
|
114 |
+
```
|
115 |
+
|
116 |
+
Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$.
|
117 |
+
That is precomputed and stored in `azp_with_adj` as a row-vector.
|
118 |
+
|
119 |
+
Epilogue parameters:
|
120 |
+
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
121 |
+
- Generally this will be per-tensor as the zero-points are per-tensor.
|
122 |
+
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
123 |
+
- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector).
|
124 |
+
- `bias` is the bias, is always per-channel (row-vector).
|
125 |
+
|
126 |
+
To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel.
|
127 |
+
|
128 |
+
### ScaledEpilogueAzpPerToken
|
129 |
+
This epilogue computes the asymmetric per-token quantization for activations with bias.
|
130 |
+
|
131 |
+
The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector.
|
132 |
+
That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$.
|
133 |
+
|
134 |
+
Epilogue parameters:
|
135 |
+
- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector).
|
136 |
+
- Generally this will be per-token as the zero-points are per-token.
|
137 |
+
- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector).
|
138 |
+
- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector).
|
139 |
+
- `azp` is the zero-point (`z_a`), is per-token (column-vector).
|
140 |
+
- `bias` is the bias, is always per-channel (row-vector).
|
141 |
+
|
142 |
+
To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel.
|
143 |
+
|
144 |
+
The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM):
|
145 |
+
```
|
146 |
+
out = scale_a * scale_b * (Dq - azp_adj * azp) + bias
|
147 |
+
```
|
cutlass_w8a8/common.hpp
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "cutlass/cutlass.h"
|
4 |
+
#include <climits>
|
5 |
+
|
6 |
+
/**
|
7 |
+
* Helper function for checking CUTLASS errors
|
8 |
+
*/
|
9 |
+
#define CUTLASS_CHECK(status) \
|
10 |
+
{ \
|
11 |
+
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
12 |
+
cutlassGetStatusString(status)) \
|
13 |
+
}
|
14 |
+
|
15 |
+
inline uint32_t next_pow_2(uint32_t const num) {
|
16 |
+
if (num <= 1) return num;
|
17 |
+
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
18 |
+
}
|
19 |
+
|
20 |
+
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
|
21 |
+
int max_shared_mem_per_block_opt_in = 0;
|
22 |
+
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
|
23 |
+
cudaDevAttrMaxSharedMemoryPerBlockOptin,
|
24 |
+
device);
|
25 |
+
return max_shared_mem_per_block_opt_in;
|
26 |
+
}
|
27 |
+
|
cutlass_w8a8/scaled_mm_c2x.cu
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <stddef.h>
|
2 |
+
#include <torch/all.h>
|
3 |
+
#include "cutlass/cutlass.h"
|
4 |
+
|
5 |
+
#include "scaled_mm_c2x.cuh"
|
6 |
+
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
7 |
+
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
8 |
+
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
9 |
+
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
10 |
+
|
11 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp"
|
12 |
+
|
13 |
+
using namespace vllm;
|
14 |
+
|
15 |
+
/*
|
16 |
+
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
17 |
+
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
18 |
+
*/
|
19 |
+
|
20 |
+
template <template <typename, typename> typename Epilogue,
|
21 |
+
typename... EpilogueArgs>
|
22 |
+
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
23 |
+
torch::Tensor const& b,
|
24 |
+
EpilogueArgs&&... epilogue_args) {
|
25 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
26 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
27 |
+
|
28 |
+
if (out.dtype() == torch::kBFloat16) {
|
29 |
+
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
30 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
31 |
+
} else {
|
32 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
33 |
+
return cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
34 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
38 |
+
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
39 |
+
torch::Tensor const& b,
|
40 |
+
torch::Tensor const& a_scales,
|
41 |
+
torch::Tensor const& b_scales,
|
42 |
+
c10::optional<torch::Tensor> const& bias) {
|
43 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
44 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
45 |
+
if (bias) {
|
46 |
+
TORCH_CHECK(bias->dtype() == out.dtype(),
|
47 |
+
"currently bias dtype must match output dtype ", out.dtype());
|
48 |
+
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBias>(
|
49 |
+
out, a, b, a_scales, b_scales, *bias);
|
50 |
+
} else {
|
51 |
+
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogue>(
|
52 |
+
out, a, b, a_scales, b_scales);
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
+
void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a,
|
57 |
+
torch::Tensor const& b,
|
58 |
+
torch::Tensor const& a_scales,
|
59 |
+
torch::Tensor const& b_scales,
|
60 |
+
torch::Tensor const& azp_adj,
|
61 |
+
c10::optional<torch::Tensor> const& azp,
|
62 |
+
c10::optional<torch::Tensor> const& bias) {
|
63 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
64 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
65 |
+
|
66 |
+
if (azp) {
|
67 |
+
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
68 |
+
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
69 |
+
} else {
|
70 |
+
return cutlass_scaled_mm_sm75_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
71 |
+
out, a, b, a_scales, b_scales, azp_adj, bias);
|
72 |
+
}
|
73 |
+
}
|
74 |
+
|
75 |
+
template <template <typename, typename> typename Epilogue,
|
76 |
+
typename... EpilogueArgs>
|
77 |
+
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
78 |
+
torch::Tensor const& b,
|
79 |
+
EpilogueArgs&&... epilogue_args) {
|
80 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
81 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
82 |
+
|
83 |
+
if (out.dtype() == torch::kBFloat16) {
|
84 |
+
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
85 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
86 |
+
} else {
|
87 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
88 |
+
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
89 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
94 |
+
torch::Tensor const& b,
|
95 |
+
torch::Tensor const& a_scales,
|
96 |
+
torch::Tensor const& b_scales,
|
97 |
+
c10::optional<torch::Tensor> const& bias) {
|
98 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
99 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
100 |
+
if (bias) {
|
101 |
+
TORCH_CHECK(bias->dtype() == out.dtype(),
|
102 |
+
"currently bias dtype must match output dtype ", out.dtype());
|
103 |
+
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBias>(
|
104 |
+
out, a, b, a_scales, b_scales, *bias);
|
105 |
+
} else {
|
106 |
+
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogue>(
|
107 |
+
out, a, b, a_scales, b_scales);
|
108 |
+
}
|
109 |
+
}
|
110 |
+
|
111 |
+
void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a,
|
112 |
+
torch::Tensor const& b,
|
113 |
+
torch::Tensor const& a_scales,
|
114 |
+
torch::Tensor const& b_scales,
|
115 |
+
torch::Tensor const& azp_adj,
|
116 |
+
c10::optional<torch::Tensor> const& azp,
|
117 |
+
c10::optional<torch::Tensor> const& bias) {
|
118 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
119 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
120 |
+
|
121 |
+
if (azp) {
|
122 |
+
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
123 |
+
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
124 |
+
} else {
|
125 |
+
return cutlass_scaled_mm_sm80_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
126 |
+
out, a, b, a_scales, b_scales, azp_adj, bias);
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
template <template <typename, typename> typename Epilogue,
|
131 |
+
typename... EpilogueArgs>
|
132 |
+
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
133 |
+
torch::Tensor const& b,
|
134 |
+
EpilogueArgs&&... epilogue_args) {
|
135 |
+
if (a.dtype() == torch::kInt8) {
|
136 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
137 |
+
|
138 |
+
if (out.dtype() == torch::kBFloat16) {
|
139 |
+
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
140 |
+
Epilogue>(
|
141 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
142 |
+
} else {
|
143 |
+
assert(out.dtype() == torch::kFloat16);
|
144 |
+
return cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
145 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
146 |
+
}
|
147 |
+
} else {
|
148 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
149 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
150 |
+
|
151 |
+
if (out.dtype() == torch::kBFloat16) {
|
152 |
+
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
153 |
+
cutlass::bfloat16_t, Epilogue>(
|
154 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
155 |
+
} else {
|
156 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
157 |
+
return cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
158 |
+
cutlass::half_t, Epilogue>(
|
159 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
160 |
+
}
|
161 |
+
}
|
162 |
+
}
|
163 |
+
|
164 |
+
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
165 |
+
torch::Tensor const& b,
|
166 |
+
torch::Tensor const& a_scales,
|
167 |
+
torch::Tensor const& b_scales,
|
168 |
+
c10::optional<torch::Tensor> const& bias) {
|
169 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
170 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
171 |
+
if (bias) {
|
172 |
+
TORCH_CHECK(bias->dtype() == out.dtype(),
|
173 |
+
"currently bias dtype must match output dtype ", out.dtype());
|
174 |
+
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBias>(
|
175 |
+
out, a, b, a_scales, b_scales, *bias);
|
176 |
+
} else {
|
177 |
+
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogue>(
|
178 |
+
out, a, b, a_scales, b_scales);
|
179 |
+
}
|
180 |
+
}
|
181 |
+
|
182 |
+
void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a,
|
183 |
+
torch::Tensor const& b,
|
184 |
+
torch::Tensor const& a_scales,
|
185 |
+
torch::Tensor const& b_scales,
|
186 |
+
torch::Tensor const& azp_adj,
|
187 |
+
c10::optional<torch::Tensor> const& azp,
|
188 |
+
c10::optional<torch::Tensor> const& bias) {
|
189 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
190 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
191 |
+
|
192 |
+
if (azp) {
|
193 |
+
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzpToken>(
|
194 |
+
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
195 |
+
} else {
|
196 |
+
return cutlass_scaled_mm_sm89_epilogue<c2x::ScaledEpilogueBiasAzp>(
|
197 |
+
out, a, b, a_scales, b_scales, azp_adj, bias);
|
198 |
+
}
|
199 |
+
}
|
cutlass_w8a8/scaled_mm_c2x.cuh
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <stddef.h>
|
3 |
+
#include <torch/all.h>
|
4 |
+
|
5 |
+
#include <ATen/cuda/CUDAContext.h>
|
6 |
+
|
7 |
+
// clang-format will break include orders
|
8 |
+
// clang-format off
|
9 |
+
#include "cute/tensor.hpp"
|
10 |
+
#include "cute/atom/mma_atom.hpp"
|
11 |
+
#include "cutlass/numeric_types.h"
|
12 |
+
|
13 |
+
#include "cutlass/cutlass.h"
|
14 |
+
#include "cutlass/gemm_coord.h"
|
15 |
+
#include "cutlass/arch/mma_sm75.h"
|
16 |
+
#include "cutlass/arch/arch.h"
|
17 |
+
#include "cutlass/arch/mma.h"
|
18 |
+
#include "cutlass/gemm/device/gemm.h"
|
19 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
20 |
+
|
21 |
+
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
22 |
+
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
23 |
+
|
24 |
+
#include "common.hpp"
|
25 |
+
// clang-format on
|
26 |
+
|
27 |
+
using namespace cute;
|
28 |
+
|
29 |
+
/*
|
30 |
+
Epilogue functions can be defined to post-process the output before it is
|
31 |
+
written to GPU memory.
|
32 |
+
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
33 |
+
as well as a static prepare_args function that constructs an
|
34 |
+
EVTCompute::Arguments struct.
|
35 |
+
*/
|
36 |
+
|
37 |
+
namespace vllm {
|
38 |
+
|
39 |
+
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
40 |
+
// architectures that will never use the kernel. The purpose of this is to
|
41 |
+
// reduce the size of the compiled binary.
|
42 |
+
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
43 |
+
// into code that will be executed on the device where it is defined.
|
44 |
+
template <typename Kernel>
|
45 |
+
struct enable_sm75_to_sm80 : Kernel {
|
46 |
+
template <typename... Args>
|
47 |
+
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
48 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
49 |
+
Kernel::invoke(std::forward<Args>(args)...);
|
50 |
+
#endif
|
51 |
+
}
|
52 |
+
};
|
53 |
+
|
54 |
+
template <typename Kernel>
|
55 |
+
struct enable_sm80_to_sm89 : Kernel {
|
56 |
+
template <typename... Args>
|
57 |
+
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
58 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
59 |
+
Kernel::invoke(std::forward<Args>(args)...);
|
60 |
+
#endif
|
61 |
+
}
|
62 |
+
};
|
63 |
+
|
64 |
+
template <typename Kernel>
|
65 |
+
struct enable_sm89_to_sm90 : Kernel {
|
66 |
+
template <typename... Args>
|
67 |
+
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
68 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
69 |
+
Kernel::invoke(std::forward<Args>(args)...);
|
70 |
+
#endif
|
71 |
+
}
|
72 |
+
};
|
73 |
+
template <typename Arch, template <typename> typename ArchGuard,
|
74 |
+
typename ElementAB_, typename ElementD_,
|
75 |
+
template <typename, typename> typename Epilogue_, typename TileShape,
|
76 |
+
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
|
77 |
+
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
|
78 |
+
struct cutlass_2x_gemm {
|
79 |
+
using ElementAB = ElementAB_;
|
80 |
+
using ElementD = ElementD_;
|
81 |
+
|
82 |
+
using ElementAcc =
|
83 |
+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
84 |
+
float>::type;
|
85 |
+
|
86 |
+
using Operator =
|
87 |
+
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
88 |
+
cutlass::arch::OpMultiplyAddSaturate,
|
89 |
+
FP8MathOperator>::type;
|
90 |
+
|
91 |
+
using OutputTileThreadMap =
|
92 |
+
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
93 |
+
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
94 |
+
>;
|
95 |
+
|
96 |
+
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
97 |
+
using EVTCompute = typename Epilogue::EVTCompute;
|
98 |
+
|
99 |
+
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
100 |
+
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
101 |
+
Stride<int64_t, Int<1>, Int<0>>>;
|
102 |
+
|
103 |
+
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
104 |
+
|
105 |
+
// clang-format off
|
106 |
+
using RowMajor = typename cutlass::layout::RowMajor;
|
107 |
+
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
108 |
+
using KernelType =
|
109 |
+
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
110 |
+
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
111 |
+
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
112 |
+
float, cutlass::layout::RowMajor, 4,
|
113 |
+
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
114 |
+
Arch,
|
115 |
+
TileShape, WarpShape, InstructionShape,
|
116 |
+
EVTD,
|
117 |
+
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
118 |
+
MainLoopStages, Operator,
|
119 |
+
1 /* epilogue stages */
|
120 |
+
>::GemmKernel>;
|
121 |
+
// clang-format on
|
122 |
+
|
123 |
+
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
124 |
+
};
|
125 |
+
|
126 |
+
template <typename Gemm, typename... EpilogueArgs>
|
127 |
+
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
128 |
+
torch::Tensor const& b,
|
129 |
+
EpilogueArgs&&... epilogue_params) {
|
130 |
+
using ElementAB = typename Gemm::ElementAB;
|
131 |
+
using ElementD = typename Gemm::ElementD;
|
132 |
+
|
133 |
+
int32_t m = a.size(0);
|
134 |
+
int32_t n = b.size(1);
|
135 |
+
int32_t k = a.size(1);
|
136 |
+
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
137 |
+
|
138 |
+
int64_t lda = a.stride(0);
|
139 |
+
int64_t ldb = b.stride(1);
|
140 |
+
int64_t ldc = out.stride(0);
|
141 |
+
|
142 |
+
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
143 |
+
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
144 |
+
|
145 |
+
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
146 |
+
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
147 |
+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
148 |
+
|
149 |
+
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
150 |
+
|
151 |
+
using Epilogue = typename Gemm::Epilogue;
|
152 |
+
auto evt_args =
|
153 |
+
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
154 |
+
|
155 |
+
typename Gemm::EVTD::Arguments epilogue_args{
|
156 |
+
evt_args,
|
157 |
+
d_args,
|
158 |
+
};
|
159 |
+
|
160 |
+
typename Gemm::Op::Arguments args{
|
161 |
+
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
162 |
+
problem_size, // problem size
|
163 |
+
1, // batch count
|
164 |
+
epilogue_args,
|
165 |
+
a_ptr,
|
166 |
+
b_ptr,
|
167 |
+
nullptr,
|
168 |
+
nullptr,
|
169 |
+
0,
|
170 |
+
0,
|
171 |
+
0,
|
172 |
+
0,
|
173 |
+
lda,
|
174 |
+
ldb,
|
175 |
+
ldc,
|
176 |
+
ldc};
|
177 |
+
|
178 |
+
// Launch the CUTLASS GEMM kernel.
|
179 |
+
typename Gemm::Op gemm_op;
|
180 |
+
size_t workspace_size = gemm_op.get_workspace_size(args);
|
181 |
+
auto const workspace_options =
|
182 |
+
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
183 |
+
auto workspace = torch::empty(workspace_size, workspace_options);
|
184 |
+
|
185 |
+
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
186 |
+
|
187 |
+
CUTLASS_CHECK(gemm_op.can_implement(args));
|
188 |
+
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
189 |
+
CUTLASS_CHECK(status);
|
190 |
+
}
|
191 |
+
|
192 |
+
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
193 |
+
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
194 |
+
torch::Tensor const& a,
|
195 |
+
torch::Tensor const& b,
|
196 |
+
EpilogueArgs&&... args) {
|
197 |
+
// In some cases, the GPU isn't able to accommodate the
|
198 |
+
// shared memory requirements of the Gemm. In such cases, use
|
199 |
+
// the FallbackGemm instead.
|
200 |
+
static const int max_shared_mem_per_block_opt_in =
|
201 |
+
get_cuda_max_shared_memory_per_block_opt_in(0);
|
202 |
+
|
203 |
+
size_t const gemm_shared_mem_size =
|
204 |
+
sizeof(typename Gemm::KernelType::SharedStorage);
|
205 |
+
size_t const fallback_gemm_shared_mem_size =
|
206 |
+
sizeof(typename FallbackGemm::KernelType::SharedStorage);
|
207 |
+
|
208 |
+
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
|
209 |
+
return cutlass_gemm_caller<Gemm>(out, a, b,
|
210 |
+
std::forward<EpilogueArgs>(args)...);
|
211 |
+
} else {
|
212 |
+
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
213 |
+
max_shared_mem_per_block_opt_in);
|
214 |
+
return cutlass_gemm_caller<FallbackGemm>(
|
215 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
219 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm_c2x.cuh"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* This file defines Gemm kernel configurations for SM75 based on the Gemm
|
7 |
+
* shape.
|
8 |
+
*/
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
template <typename InType, typename OutType,
|
13 |
+
template <typename, typename> typename Epilogue>
|
14 |
+
struct sm75_config_default {
|
15 |
+
// This config is used in 2 cases,
|
16 |
+
// - M in (256, inf]
|
17 |
+
// - M in (64, 128]
|
18 |
+
// Shared memory required by this Gemm 32768
|
19 |
+
static_assert(std::is_same<InType, int8_t>());
|
20 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
21 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
22 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
23 |
+
using Cutlass2xGemm =
|
24 |
+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
25 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
26 |
+
};
|
27 |
+
|
28 |
+
template <typename InType, typename OutType,
|
29 |
+
template <typename, typename> typename Epilogue>
|
30 |
+
struct sm75_config_M256 {
|
31 |
+
// M in (128, 256]
|
32 |
+
// Shared memory required by this Gemm 65536
|
33 |
+
static_assert(std::is_same<InType, int8_t>());
|
34 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
|
35 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
36 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
37 |
+
using Cutlass2xGemm =
|
38 |
+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
39 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
40 |
+
};
|
41 |
+
|
42 |
+
template <typename InType, typename OutType,
|
43 |
+
template <typename, typename> typename Epilogue>
|
44 |
+
struct sm75_config_M64 {
|
45 |
+
// M in (32, 64]
|
46 |
+
// Shared memory required by this Gemm 49152
|
47 |
+
static_assert(std::is_same<InType, int8_t>());
|
48 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
49 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
50 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
51 |
+
using Cutlass2xGemm =
|
52 |
+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
53 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
54 |
+
};
|
55 |
+
|
56 |
+
template <typename InType, typename OutType,
|
57 |
+
template <typename, typename> typename Epilogue>
|
58 |
+
struct sm75_config_M32 {
|
59 |
+
// M in [1, 32]
|
60 |
+
// Shared memory required by this Gemm 49152
|
61 |
+
static_assert(std::is_same<InType, int8_t>());
|
62 |
+
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
|
63 |
+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
64 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
65 |
+
using Cutlass2xGemm =
|
66 |
+
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
67 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
68 |
+
};
|
69 |
+
|
70 |
+
template <typename InType, typename OutType,
|
71 |
+
template <typename, typename> typename Epilogue,
|
72 |
+
typename... EpilogueArgs>
|
73 |
+
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
74 |
+
torch::Tensor const& a,
|
75 |
+
torch::Tensor const& b,
|
76 |
+
EpilogueArgs&&... args) {
|
77 |
+
static_assert(std::is_same<InType, int8_t>());
|
78 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
79 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
80 |
+
|
81 |
+
using Cutlass2xGemmDefault =
|
82 |
+
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
83 |
+
using Cutlass2xGemmM256 =
|
84 |
+
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
|
85 |
+
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
|
86 |
+
using Cutlass2xGemmM64 =
|
87 |
+
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
88 |
+
using Cutlass2xGemmM32 =
|
89 |
+
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
90 |
+
|
91 |
+
// Due to shared memory requirements, some Gemms may fail to run on some
|
92 |
+
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
93 |
+
// in such cases.
|
94 |
+
// sm75_config_default has the least shared-memory requirements.
|
95 |
+
using FallbackGemm = Cutlass2xGemmDefault;
|
96 |
+
|
97 |
+
uint32_t const m = a.size(0);
|
98 |
+
uint32_t const mp2 =
|
99 |
+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
100 |
+
if (mp2 <= 32) {
|
101 |
+
// M in [1, 32]
|
102 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
103 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
104 |
+
} else if (mp2 <= 64) {
|
105 |
+
// M in (32, 64]
|
106 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
107 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
108 |
+
} else if (mp2 <= 128) {
|
109 |
+
// M in (64, 128]
|
110 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
|
111 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
112 |
+
} else if (mp2 <= 256) {
|
113 |
+
// M in (128, 256]
|
114 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
|
115 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
116 |
+
} else {
|
117 |
+
// M in (256, inf)
|
118 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
119 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
120 |
+
}
|
121 |
+
}
|
122 |
+
|
123 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm_c2x.cuh"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* This file defines Gemm kernel configurations for SM80 based on the Gemm
|
7 |
+
* shape.
|
8 |
+
*/
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
template <typename InType, typename OutType,
|
13 |
+
template <typename, typename> typename Epilogue>
|
14 |
+
struct sm80_config_default {
|
15 |
+
// This config is used in 2 cases,
|
16 |
+
// - M in (128, inf)
|
17 |
+
// - M in (64, 128] and N >= 8192
|
18 |
+
// Shared Memory required by this Gemm - 81920 bytes
|
19 |
+
static_assert(std::is_same<InType, int8_t>());
|
20 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
21 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
22 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
23 |
+
using Cutlass2xGemm =
|
24 |
+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
25 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
26 |
+
};
|
27 |
+
|
28 |
+
template <typename InType, typename OutType,
|
29 |
+
template <typename, typename> typename Epilogue>
|
30 |
+
struct sm80_config_M64 {
|
31 |
+
// This config is used in 2 cases,
|
32 |
+
// - M in (32, 64]
|
33 |
+
// - M in (64, 128] and N < 8192
|
34 |
+
// Shared Memory required by this Gemm - 122880 bytes
|
35 |
+
static_assert(std::is_same<InType, int8_t>());
|
36 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
37 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
38 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
39 |
+
using Cutlass2xGemm =
|
40 |
+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
41 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
42 |
+
};
|
43 |
+
|
44 |
+
template <typename InType, typename OutType,
|
45 |
+
template <typename, typename> typename Epilogue>
|
46 |
+
struct sm80_config_M32 {
|
47 |
+
// M in (16, 32]
|
48 |
+
// Shared Memory required by this Gemm - 61440 bytes
|
49 |
+
static_assert(std::is_same<InType, int8_t>());
|
50 |
+
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
51 |
+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
52 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
53 |
+
using Cutlass2xGemm =
|
54 |
+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
55 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
56 |
+
};
|
57 |
+
|
58 |
+
template <typename InType, typename OutType,
|
59 |
+
template <typename, typename> typename Epilogue>
|
60 |
+
struct sm80_config_M16 {
|
61 |
+
// M in [1, 16]
|
62 |
+
// Shared Memory required by this Gemm - 51200 bytes
|
63 |
+
static_assert(std::is_same<InType, int8_t>());
|
64 |
+
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
65 |
+
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
66 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
67 |
+
using Cutlass2xGemm =
|
68 |
+
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
69 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
70 |
+
};
|
71 |
+
|
72 |
+
template <typename InType, typename OutType,
|
73 |
+
template <typename, typename> typename Epilogue,
|
74 |
+
typename... EpilogueArgs>
|
75 |
+
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
|
76 |
+
torch::Tensor const& a,
|
77 |
+
torch::Tensor const& b,
|
78 |
+
EpilogueArgs&&... args) {
|
79 |
+
static_assert(std::is_same<InType, int8_t>());
|
80 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
81 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
82 |
+
|
83 |
+
using Cutlass2xGemmDefault =
|
84 |
+
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
85 |
+
using Cutlass2xGemmM128BigN =
|
86 |
+
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
87 |
+
using Cutlass2xGemmM128SmallN =
|
88 |
+
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
89 |
+
using Cutlass2xGemmM64 =
|
90 |
+
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
91 |
+
using Cutlass2xGemmM32 =
|
92 |
+
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
93 |
+
using Cutlass2xGemmM16 =
|
94 |
+
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
|
95 |
+
|
96 |
+
// Due to shared memory requirements, some Gemms may fail to run on some
|
97 |
+
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
98 |
+
// in such cases.
|
99 |
+
// sm80_config_M16 has the least shared-memory requirement. However,
|
100 |
+
// based on some profiling, we select sm80_config_M32 as a better alternative
|
101 |
+
// performance wise.
|
102 |
+
using FallbackGemm =
|
103 |
+
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
104 |
+
|
105 |
+
uint32_t const m = a.size(0);
|
106 |
+
uint32_t const mp2 =
|
107 |
+
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
108 |
+
if (mp2 <= 16) {
|
109 |
+
// M in [1, 16]
|
110 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
|
111 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
112 |
+
} else if (mp2 <= 32) {
|
113 |
+
// M in (16, 32]
|
114 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
115 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
116 |
+
} else if (mp2 <= 64) {
|
117 |
+
// M in (32, 64]
|
118 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
119 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
120 |
+
} else if (mp2 <= 128) {
|
121 |
+
// M in (64, 128]
|
122 |
+
uint32_t const n = out.size(1);
|
123 |
+
bool const small_n = n < 8192;
|
124 |
+
if (small_n) {
|
125 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
|
126 |
+
FallbackGemm>(
|
127 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
128 |
+
} else {
|
129 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
|
130 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
131 |
+
}
|
132 |
+
} else {
|
133 |
+
// M in (128, inf)
|
134 |
+
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
135 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c2x_sm89_fp8_dispatch.cuh
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm_c2x.cuh"
|
4 |
+
#include "cutlass/float8.h"
|
5 |
+
|
6 |
+
/**
|
7 |
+
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
|
8 |
+
* shape.
|
9 |
+
*/
|
10 |
+
|
11 |
+
namespace vllm {
|
12 |
+
|
13 |
+
template <typename InType, typename OutType,
|
14 |
+
template <typename, typename> typename Epilogue>
|
15 |
+
struct sm89_fp8_fallback_gemm {
|
16 |
+
// Shared Memory required by this Gemm - 61440 bytes
|
17 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
18 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
|
19 |
+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
20 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
21 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
22 |
+
using Cutlass2xGemm =
|
23 |
+
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
24 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 5,
|
25 |
+
FP8MathOperator>;
|
26 |
+
};
|
27 |
+
|
28 |
+
struct sm89_fp8_config_default {
|
29 |
+
// M in (256, inf)
|
30 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
31 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
32 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
33 |
+
|
34 |
+
template <typename InType, typename OutType,
|
35 |
+
template <typename, typename> typename Epilogue,
|
36 |
+
typename... EpilogueArgs>
|
37 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
38 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
39 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
40 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
41 |
+
|
42 |
+
using FallbackGemm =
|
43 |
+
typename sm89_fp8_fallback_gemm<InType, OutType,
|
44 |
+
Epilogue>::Cutlass2xGemm;
|
45 |
+
|
46 |
+
uint32_t const n = out.size(1);
|
47 |
+
uint32_t const np2 = next_pow_2(n);
|
48 |
+
|
49 |
+
if (np2 <= 4096) {
|
50 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
51 |
+
|
52 |
+
return vllm::fallback_cutlass_gemm_caller<
|
53 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
54 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
55 |
+
InstructionShape, 5, FP8MathOperator>,
|
56 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
57 |
+
} else if (np2 <= 8192) {
|
58 |
+
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
|
59 |
+
|
60 |
+
return vllm::fallback_cutlass_gemm_caller<
|
61 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
62 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
63 |
+
InstructionShape, 3, FP8MathOperator>,
|
64 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
65 |
+
|
66 |
+
} else {
|
67 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
68 |
+
|
69 |
+
return vllm::fallback_cutlass_gemm_caller<
|
70 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
71 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
72 |
+
InstructionShape, 5, FP8MathOperator>,
|
73 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
74 |
+
}
|
75 |
+
}
|
76 |
+
};
|
77 |
+
|
78 |
+
struct sm89_fp8_config_M256 {
|
79 |
+
// M in (128, 256]
|
80 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
81 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
82 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
83 |
+
|
84 |
+
template <typename InType, typename OutType,
|
85 |
+
template <typename, typename> typename Epilogue,
|
86 |
+
typename... EpilogueArgs>
|
87 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
88 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
89 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
90 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
91 |
+
|
92 |
+
using FallbackGemm =
|
93 |
+
typename sm89_fp8_fallback_gemm<InType, OutType,
|
94 |
+
Epilogue>::Cutlass2xGemm;
|
95 |
+
|
96 |
+
uint32_t const n = out.size(1);
|
97 |
+
uint32_t const np2 = next_pow_2(n);
|
98 |
+
|
99 |
+
if (np2 <= 4096) {
|
100 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
101 |
+
|
102 |
+
return vllm::fallback_cutlass_gemm_caller<
|
103 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
104 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
105 |
+
InstructionShape, 3, FP8MathOperator>,
|
106 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
107 |
+
} else {
|
108 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
109 |
+
|
110 |
+
return vllm::fallback_cutlass_gemm_caller<
|
111 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
112 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
113 |
+
InstructionShape, 5, FP8MathOperator>,
|
114 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
115 |
+
}
|
116 |
+
}
|
117 |
+
};
|
118 |
+
|
119 |
+
struct sm89_fp8_config_M128 {
|
120 |
+
// M in (64, 128]
|
121 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
122 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
123 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
124 |
+
|
125 |
+
template <typename InType, typename OutType,
|
126 |
+
template <typename, typename> typename Epilogue,
|
127 |
+
typename... EpilogueArgs>
|
128 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
129 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
130 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
131 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
132 |
+
|
133 |
+
using FallbackGemm =
|
134 |
+
typename sm89_fp8_fallback_gemm<InType, OutType,
|
135 |
+
Epilogue>::Cutlass2xGemm;
|
136 |
+
|
137 |
+
uint32_t const n = out.size(1);
|
138 |
+
uint32_t const np2 = next_pow_2(n);
|
139 |
+
|
140 |
+
if (np2 <= 8192) {
|
141 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
142 |
+
|
143 |
+
return vllm::fallback_cutlass_gemm_caller<
|
144 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
145 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
146 |
+
InstructionShape, 3, FP8MathOperator>,
|
147 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
148 |
+
|
149 |
+
} else if (np2 <= 16384) {
|
150 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
151 |
+
|
152 |
+
return vllm::fallback_cutlass_gemm_caller<
|
153 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
154 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
155 |
+
InstructionShape, 5, FP8MathOperator>,
|
156 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
157 |
+
} else {
|
158 |
+
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
|
159 |
+
|
160 |
+
return vllm::fallback_cutlass_gemm_caller<
|
161 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
162 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
163 |
+
InstructionShape, 3, FP8MathOperator>,
|
164 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
165 |
+
}
|
166 |
+
}
|
167 |
+
};
|
168 |
+
|
169 |
+
struct sm89_fp8_config_M64 {
|
170 |
+
// M in (32, 64]
|
171 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
172 |
+
|
173 |
+
template <typename InType, typename OutType,
|
174 |
+
template <typename, typename> typename Epilogue,
|
175 |
+
typename... EpilogueArgs>
|
176 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
177 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
178 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
179 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
180 |
+
|
181 |
+
using FallbackGemm =
|
182 |
+
typename sm89_fp8_fallback_gemm<InType, OutType,
|
183 |
+
Epilogue>::Cutlass2xGemm;
|
184 |
+
|
185 |
+
uint32_t const n = out.size(1);
|
186 |
+
uint32_t const np2 = next_pow_2(n);
|
187 |
+
|
188 |
+
if (np2 <= 8196) {
|
189 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
190 |
+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
191 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
192 |
+
|
193 |
+
return vllm::fallback_cutlass_gemm_caller<
|
194 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
195 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
196 |
+
InstructionShape, 5, FP8MathOperator>,
|
197 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
198 |
+
} else if (np2 <= 16384) {
|
199 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
200 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
201 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
202 |
+
|
203 |
+
return vllm::fallback_cutlass_gemm_caller<
|
204 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
205 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
206 |
+
InstructionShape, 3, FP8MathOperator>,
|
207 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
208 |
+
} else {
|
209 |
+
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
210 |
+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
211 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
212 |
+
|
213 |
+
return vllm::fallback_cutlass_gemm_caller<
|
214 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
215 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
216 |
+
InstructionShape, 5, FP8MathOperator>,
|
217 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
218 |
+
}
|
219 |
+
}
|
220 |
+
};
|
221 |
+
|
222 |
+
struct sm89_fp8_config_M32 {
|
223 |
+
// M in (16, 32]
|
224 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
225 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
226 |
+
|
227 |
+
template <typename InType, typename OutType,
|
228 |
+
template <typename, typename> typename Epilogue,
|
229 |
+
typename... EpilogueArgs>
|
230 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
231 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
232 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
233 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
234 |
+
|
235 |
+
using FallbackGemm =
|
236 |
+
typename sm89_fp8_fallback_gemm<InType, OutType,
|
237 |
+
Epilogue>::Cutlass2xGemm;
|
238 |
+
|
239 |
+
uint32_t const n = out.size(1);
|
240 |
+
uint32_t const np2 = next_pow_2(n);
|
241 |
+
|
242 |
+
if (np2 <= 8192) {
|
243 |
+
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
244 |
+
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
245 |
+
|
246 |
+
return vllm::fallback_cutlass_gemm_caller<
|
247 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
248 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
249 |
+
InstructionShape, 5, FP8MathOperator>,
|
250 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
251 |
+
} else if (np2 <= 16384) {
|
252 |
+
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
|
253 |
+
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
254 |
+
|
255 |
+
return vllm::fallback_cutlass_gemm_caller<
|
256 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
257 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
258 |
+
InstructionShape, 4, FP8MathOperator>,
|
259 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
260 |
+
} else {
|
261 |
+
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
262 |
+
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
263 |
+
|
264 |
+
return vllm::fallback_cutlass_gemm_caller<
|
265 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
266 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
267 |
+
InstructionShape, 5, FP8MathOperator>,
|
268 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
269 |
+
}
|
270 |
+
}
|
271 |
+
};
|
272 |
+
|
273 |
+
struct sm89_fp8_config_M16 {
|
274 |
+
// M in [1, 16]
|
275 |
+
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
276 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
277 |
+
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
278 |
+
static const int32_t MainLoopStages = 5;
|
279 |
+
|
280 |
+
template <typename InType, typename OutType,
|
281 |
+
template <typename, typename> typename Epilogue,
|
282 |
+
typename... EpilogueArgs>
|
283 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
284 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
285 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
286 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
287 |
+
|
288 |
+
using FallbackGemm =
|
289 |
+
typename sm89_fp8_fallback_gemm<InType, OutType,
|
290 |
+
Epilogue>::Cutlass2xGemm;
|
291 |
+
|
292 |
+
uint32_t const n = out.size(1);
|
293 |
+
uint32_t const np2 = next_pow_2(n);
|
294 |
+
|
295 |
+
if (np2 <= 8192) {
|
296 |
+
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
297 |
+
|
298 |
+
return vllm::fallback_cutlass_gemm_caller<
|
299 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
300 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
301 |
+
InstructionShape, MainLoopStages,
|
302 |
+
FP8MathOperator>,
|
303 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
304 |
+
} else if (np2 <= 24576) {
|
305 |
+
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
|
306 |
+
|
307 |
+
return vllm::fallback_cutlass_gemm_caller<
|
308 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
309 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
310 |
+
InstructionShape, MainLoopStages,
|
311 |
+
FP8MathOperator>,
|
312 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
313 |
+
} else {
|
314 |
+
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
315 |
+
|
316 |
+
return vllm::fallback_cutlass_gemm_caller<
|
317 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
318 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
319 |
+
InstructionShape, MainLoopStages,
|
320 |
+
FP8MathOperator>,
|
321 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
322 |
+
}
|
323 |
+
}
|
324 |
+
};
|
325 |
+
|
326 |
+
template <typename InType, typename OutType,
|
327 |
+
template <typename, typename> typename Epilogue,
|
328 |
+
typename... EpilogueArgs>
|
329 |
+
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
330 |
+
torch::Tensor const& a,
|
331 |
+
torch::Tensor const& b,
|
332 |
+
EpilogueArgs&&... args) {
|
333 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
334 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
335 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
336 |
+
|
337 |
+
uint32_t const m = a.size(0);
|
338 |
+
uint32_t const mp2 =
|
339 |
+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
340 |
+
|
341 |
+
if (mp2 <= 16) {
|
342 |
+
// M in [1, 16]
|
343 |
+
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
|
344 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
345 |
+
} else if (mp2 <= 32) {
|
346 |
+
// M in (16, 32]
|
347 |
+
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
|
348 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
349 |
+
} else if (mp2 <= 64) {
|
350 |
+
// M in (32, 64]
|
351 |
+
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
|
352 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
353 |
+
} else if (mp2 <= 128) {
|
354 |
+
// M in (64, 128]
|
355 |
+
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
|
356 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
357 |
+
} else if (mp2 <= 256) {
|
358 |
+
// M in (128, 256]
|
359 |
+
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
|
360 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
361 |
+
} else {
|
362 |
+
// M in (256, inf)
|
363 |
+
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
|
364 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
365 |
+
}
|
366 |
+
}
|
367 |
+
|
368 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include "scaled_mm_c2x.cuh"
|
4 |
+
|
5 |
+
/**
|
6 |
+
* This file defines Gemm kernel configurations for SM89 (int8) based on the
|
7 |
+
* Gemm shape.
|
8 |
+
*/
|
9 |
+
|
10 |
+
namespace vllm {
|
11 |
+
|
12 |
+
template <typename InType, typename OutType,
|
13 |
+
template <typename, typename> typename Epilogue>
|
14 |
+
struct sm89_int8_fallback_gemm {
|
15 |
+
// Shared mem requirement : 61440
|
16 |
+
static_assert(std::is_same<InType, int8_t>());
|
17 |
+
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
18 |
+
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
19 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
20 |
+
static int32_t const MainLoopStages = 5;
|
21 |
+
|
22 |
+
using Cutlass2xGemm =
|
23 |
+
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
24 |
+
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
25 |
+
};
|
26 |
+
|
27 |
+
struct sm89_int8_config_default {
|
28 |
+
// M in (256, inf)
|
29 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
30 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
31 |
+
|
32 |
+
template <typename InType, typename OutType,
|
33 |
+
template <typename, typename> typename Epilogue,
|
34 |
+
typename... EpilogueArgs>
|
35 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
36 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
37 |
+
static_assert(std::is_same<InType, int8_t>());
|
38 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
39 |
+
|
40 |
+
using FallbackGemm =
|
41 |
+
typename sm89_int8_fallback_gemm<InType, OutType,
|
42 |
+
Epilogue>::Cutlass2xGemm;
|
43 |
+
|
44 |
+
uint32_t const n = out.size(1);
|
45 |
+
uint32_t const np2 = next_pow_2(n);
|
46 |
+
|
47 |
+
if (np2 <= 4096) {
|
48 |
+
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
49 |
+
|
50 |
+
return vllm::fallback_cutlass_gemm_caller<
|
51 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
52 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
53 |
+
InstructionShape, 5>,
|
54 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
55 |
+
} else if (np2 <= 8192) {
|
56 |
+
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
57 |
+
|
58 |
+
return vllm::fallback_cutlass_gemm_caller<
|
59 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
60 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
61 |
+
InstructionShape, 3>,
|
62 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
63 |
+
} else if (np2 <= 16384) {
|
64 |
+
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
65 |
+
|
66 |
+
return vllm::fallback_cutlass_gemm_caller<
|
67 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
68 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
69 |
+
InstructionShape, 5>,
|
70 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
71 |
+
} else {
|
72 |
+
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
73 |
+
|
74 |
+
return vllm::fallback_cutlass_gemm_caller<
|
75 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
76 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
77 |
+
InstructionShape, 3>,
|
78 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
79 |
+
}
|
80 |
+
}
|
81 |
+
};
|
82 |
+
|
83 |
+
struct sm89_int8_config_M256 {
|
84 |
+
// M in (128, 256]
|
85 |
+
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
86 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
87 |
+
|
88 |
+
template <typename InType, typename OutType,
|
89 |
+
template <typename, typename> typename Epilogue,
|
90 |
+
typename... EpilogueArgs>
|
91 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
92 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
93 |
+
static_assert(std::is_same<InType, int8_t>());
|
94 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
95 |
+
|
96 |
+
using FallbackGemm =
|
97 |
+
typename sm89_int8_fallback_gemm<InType, OutType,
|
98 |
+
Epilogue>::Cutlass2xGemm;
|
99 |
+
|
100 |
+
uint32_t const n = out.size(1);
|
101 |
+
uint32_t const np2 = next_pow_2(n);
|
102 |
+
|
103 |
+
if (np2 <= 4096) {
|
104 |
+
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
105 |
+
|
106 |
+
return vllm::fallback_cutlass_gemm_caller<
|
107 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
108 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
109 |
+
InstructionShape, 3>,
|
110 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
111 |
+
} else if (np2 <= 8192) {
|
112 |
+
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
113 |
+
|
114 |
+
return vllm::fallback_cutlass_gemm_caller<
|
115 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
116 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
117 |
+
InstructionShape, 5>,
|
118 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
119 |
+
} else if (np2 <= 16384) {
|
120 |
+
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
121 |
+
|
122 |
+
return vllm::fallback_cutlass_gemm_caller<
|
123 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
124 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
125 |
+
InstructionShape, 3>,
|
126 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
127 |
+
} else {
|
128 |
+
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
129 |
+
|
130 |
+
return vllm::fallback_cutlass_gemm_caller<
|
131 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
132 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
133 |
+
InstructionShape, 5>,
|
134 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
135 |
+
}
|
136 |
+
}
|
137 |
+
};
|
138 |
+
|
139 |
+
struct sm89_int8_config_M128 {
|
140 |
+
// M in (64, 128]
|
141 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
142 |
+
|
143 |
+
template <typename InType, typename OutType,
|
144 |
+
template <typename, typename> typename Epilogue,
|
145 |
+
typename... EpilogueArgs>
|
146 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
147 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
148 |
+
static_assert(std::is_same<InType, int8_t>());
|
149 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
150 |
+
|
151 |
+
using FallbackGemm =
|
152 |
+
typename sm89_int8_fallback_gemm<InType, OutType,
|
153 |
+
Epilogue>::Cutlass2xGemm;
|
154 |
+
|
155 |
+
uint32_t const n = out.size(1);
|
156 |
+
uint32_t const np2 = next_pow_2(n);
|
157 |
+
|
158 |
+
if (np2 <= 8192) {
|
159 |
+
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
160 |
+
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
161 |
+
|
162 |
+
return vllm::fallback_cutlass_gemm_caller<
|
163 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
164 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
165 |
+
InstructionShape, 3>,
|
166 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
167 |
+
} else if (np2 <= 16384) {
|
168 |
+
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
169 |
+
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
170 |
+
|
171 |
+
return vllm::fallback_cutlass_gemm_caller<
|
172 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
173 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
174 |
+
InstructionShape, 5>,
|
175 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
176 |
+
} else {
|
177 |
+
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
178 |
+
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
179 |
+
|
180 |
+
return vllm::fallback_cutlass_gemm_caller<
|
181 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
182 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
183 |
+
InstructionShape, 5>,
|
184 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
185 |
+
}
|
186 |
+
}
|
187 |
+
};
|
188 |
+
|
189 |
+
struct sm89_int8_config_M64 {
|
190 |
+
// M in (32, 64]
|
191 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
192 |
+
|
193 |
+
template <typename InType, typename OutType,
|
194 |
+
template <typename, typename> typename Epilogue,
|
195 |
+
typename... EpilogueArgs>
|
196 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
197 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
198 |
+
static_assert(std::is_same<InType, int8_t>());
|
199 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
200 |
+
|
201 |
+
using FallbackGemm =
|
202 |
+
typename sm89_int8_fallback_gemm<InType, OutType,
|
203 |
+
Epilogue>::Cutlass2xGemm;
|
204 |
+
|
205 |
+
uint32_t const n = out.size(1);
|
206 |
+
uint32_t const np2 = next_pow_2(n);
|
207 |
+
|
208 |
+
if (np2 <= 8192) {
|
209 |
+
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
210 |
+
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
211 |
+
|
212 |
+
return vllm::fallback_cutlass_gemm_caller<
|
213 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
214 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
215 |
+
InstructionShape, 5>,
|
216 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
217 |
+
} else {
|
218 |
+
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
219 |
+
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
220 |
+
|
221 |
+
return vllm::fallback_cutlass_gemm_caller<
|
222 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
223 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
224 |
+
InstructionShape, 3>,
|
225 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
226 |
+
}
|
227 |
+
}
|
228 |
+
};
|
229 |
+
|
230 |
+
struct sm89_int8_config_M32 {
|
231 |
+
// M in (16, 32]
|
232 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
233 |
+
|
234 |
+
template <typename InType, typename OutType,
|
235 |
+
template <typename, typename> typename Epilogue,
|
236 |
+
typename... EpilogueArgs>
|
237 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
238 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
239 |
+
static_assert(std::is_same<InType, int8_t>());
|
240 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
241 |
+
|
242 |
+
using FallbackGemm =
|
243 |
+
typename sm89_int8_fallback_gemm<InType, OutType,
|
244 |
+
Epilogue>::Cutlass2xGemm;
|
245 |
+
|
246 |
+
uint32_t const n = out.size(1);
|
247 |
+
uint32_t const np2 = next_pow_2(n);
|
248 |
+
|
249 |
+
if (np2 <= 8192) {
|
250 |
+
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
251 |
+
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
252 |
+
|
253 |
+
return vllm::fallback_cutlass_gemm_caller<
|
254 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
255 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
256 |
+
InstructionShape, 5>,
|
257 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
258 |
+
} else {
|
259 |
+
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
260 |
+
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
261 |
+
|
262 |
+
return vllm::fallback_cutlass_gemm_caller<
|
263 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
264 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
265 |
+
InstructionShape, 4>,
|
266 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
267 |
+
}
|
268 |
+
}
|
269 |
+
};
|
270 |
+
|
271 |
+
struct sm89_int8_config_M16 {
|
272 |
+
// M in [1, 16]
|
273 |
+
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
274 |
+
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
275 |
+
|
276 |
+
template <typename InType, typename OutType,
|
277 |
+
template <typename, typename> typename Epilogue,
|
278 |
+
typename... EpilogueArgs>
|
279 |
+
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
280 |
+
torch::Tensor const& b, EpilogueArgs&&... args) {
|
281 |
+
static_assert(std::is_same<InType, int8_t>());
|
282 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
283 |
+
|
284 |
+
using FallbackGemm =
|
285 |
+
typename sm89_int8_fallback_gemm<InType, OutType,
|
286 |
+
Epilogue>::Cutlass2xGemm;
|
287 |
+
|
288 |
+
uint32_t const n = out.size(1);
|
289 |
+
uint32_t const np2 = next_pow_2(n);
|
290 |
+
|
291 |
+
if (np2 <= 8192) {
|
292 |
+
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
|
293 |
+
|
294 |
+
return vllm::fallback_cutlass_gemm_caller<
|
295 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
296 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
297 |
+
InstructionShape, 5>,
|
298 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
299 |
+
} else {
|
300 |
+
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
|
301 |
+
|
302 |
+
return vllm::fallback_cutlass_gemm_caller<
|
303 |
+
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
304 |
+
InType, OutType, Epilogue, TileShape, WarpShape,
|
305 |
+
InstructionShape, 4>,
|
306 |
+
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
307 |
+
}
|
308 |
+
}
|
309 |
+
};
|
310 |
+
|
311 |
+
template <typename InType, typename OutType,
|
312 |
+
template <typename, typename> typename Epilogue,
|
313 |
+
typename... EpilogueArgs>
|
314 |
+
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
315 |
+
torch::Tensor const& a,
|
316 |
+
torch::Tensor const& b,
|
317 |
+
EpilogueArgs&&... args) {
|
318 |
+
static_assert(std::is_same<InType, int8_t>());
|
319 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
320 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
321 |
+
|
322 |
+
uint32_t const m = a.size(0);
|
323 |
+
uint32_t const mp2 =
|
324 |
+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
325 |
+
|
326 |
+
if (mp2 <= 16) {
|
327 |
+
// M in [1, 16]
|
328 |
+
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
|
329 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
330 |
+
} else if (mp2 <= 32) {
|
331 |
+
// M in (16, 32]
|
332 |
+
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
|
333 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
334 |
+
} else if (mp2 <= 64) {
|
335 |
+
// M in (32, 64]
|
336 |
+
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
|
337 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
338 |
+
} else if (mp2 <= 128) {
|
339 |
+
// M in (64, 128]
|
340 |
+
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
|
341 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
342 |
+
} else if (mp2 <= 256) {
|
343 |
+
// M in (128, 256]
|
344 |
+
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
|
345 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
346 |
+
} else {
|
347 |
+
// M in (256, inf)
|
348 |
+
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
|
349 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
350 |
+
}
|
351 |
+
}
|
352 |
+
|
353 |
+
} // namespace vllm
|
cutlass_w8a8/scaled_mm_c3x.cu
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// clang-format will break include orders
|
2 |
+
// clang-format off
|
3 |
+
#include <cudaTypedefs.h>
|
4 |
+
|
5 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
6 |
+
|
7 |
+
#include <torch/all.h>
|
8 |
+
|
9 |
+
#include <ATen/cuda/CUDAContext.h>
|
10 |
+
|
11 |
+
#include <iostream>
|
12 |
+
#include <sstream>
|
13 |
+
#include <vector>
|
14 |
+
|
15 |
+
#include "cutlass/cutlass.h"
|
16 |
+
|
17 |
+
#include "cute/tensor.hpp"
|
18 |
+
#include "cute/atom/mma_atom.hpp"
|
19 |
+
#include "cutlass/numeric_types.h"
|
20 |
+
|
21 |
+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
22 |
+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
23 |
+
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
24 |
+
#include "cutlass/gemm/collective/collective_builder.hpp"
|
25 |
+
|
26 |
+
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
27 |
+
#include "common.hpp"
|
28 |
+
// clang-format on
|
29 |
+
|
30 |
+
using namespace cute;
|
31 |
+
using namespace vllm;
|
32 |
+
|
33 |
+
/*
|
34 |
+
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
35 |
+
NVIDIA GPUs with sm90a (Hopper) or later.
|
36 |
+
|
37 |
+
Epilogue functions can be defined to post-process the output before it is
|
38 |
+
written to GPU memory.
|
39 |
+
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
|
40 |
+
as well as a static prepare_args function that constructs an
|
41 |
+
EVTCompute::Arguments struct.
|
42 |
+
*/
|
43 |
+
|
44 |
+
namespace {
|
45 |
+
|
46 |
+
// A wrapper for the GEMM kernel that is used to guard against compilation on
|
47 |
+
// architectures that will never use the kernel. The purpose of this is to
|
48 |
+
// reduce the size of the compiled binary.
|
49 |
+
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
50 |
+
// into code that will be executed on the device where it is defined.
|
51 |
+
template <typename Kernel>
|
52 |
+
struct enable_sm90_or_later : Kernel {
|
53 |
+
template <typename... Args>
|
54 |
+
CUTLASS_DEVICE void operator()(Args&&... args) {
|
55 |
+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
|
56 |
+
Kernel::operator()(std::forward<Args>(args)...);
|
57 |
+
#endif
|
58 |
+
}
|
59 |
+
};
|
60 |
+
template <typename ElementAB_, typename ElementD_,
|
61 |
+
template <typename, typename, typename> typename Epilogue_,
|
62 |
+
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
63 |
+
typename EpilogueSchedule>
|
64 |
+
struct cutlass_3x_gemm {
|
65 |
+
using ElementAB = ElementAB_;
|
66 |
+
using ElementD = ElementD_;
|
67 |
+
using ElementAcc =
|
68 |
+
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
69 |
+
float>::type;
|
70 |
+
|
71 |
+
using EpilogueDescriptor =
|
72 |
+
cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
73 |
+
TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
|
74 |
+
ElementD, EpilogueSchedule>;
|
75 |
+
|
76 |
+
using Epilogue = Epilogue_<ElementAcc, ElementD, EpilogueDescriptor>;
|
77 |
+
|
78 |
+
using StrideD = Stride<int64_t, Int<1>, Int<0>>;
|
79 |
+
using ElementC = void;
|
80 |
+
using StrideC = StrideD;
|
81 |
+
|
82 |
+
using EVTCompute = typename Epilogue::EVTCompute;
|
83 |
+
|
84 |
+
using CollectiveEpilogue =
|
85 |
+
typename cutlass::epilogue::collective::CollectiveBuilder<
|
86 |
+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
|
87 |
+
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
88 |
+
ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
|
89 |
+
EpilogueSchedule, EVTCompute>::CollectiveOp;
|
90 |
+
|
91 |
+
static constexpr size_t CEStorageSize =
|
92 |
+
sizeof(typename CollectiveEpilogue::SharedStorage);
|
93 |
+
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
94 |
+
static_cast<int>(CEStorageSize)>;
|
95 |
+
|
96 |
+
// clang-format off
|
97 |
+
using CollectiveMainloop =
|
98 |
+
typename cutlass::gemm::collective::CollectiveBuilder<
|
99 |
+
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
100 |
+
ElementAB, cutlass::layout::RowMajor, 16,
|
101 |
+
ElementAB, cutlass::layout::ColumnMajor, 16,
|
102 |
+
ElementAcc, TileShape, ClusterShape,
|
103 |
+
Stages,
|
104 |
+
KernelSchedule>::CollectiveOp;
|
105 |
+
// clang-format on
|
106 |
+
|
107 |
+
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
108 |
+
cute::Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue,
|
109 |
+
cutlass::gemm::PersistentScheduler>>;
|
110 |
+
|
111 |
+
struct GemmKernel : public KernelType {};
|
112 |
+
};
|
113 |
+
|
114 |
+
template <typename Gemm, typename... EpilogueArgs>
|
115 |
+
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
116 |
+
torch::Tensor const& b,
|
117 |
+
EpilogueArgs&&... epilogue_params) {
|
118 |
+
using ElementAB = typename Gemm::ElementAB;
|
119 |
+
using ElementD = typename Gemm::ElementD;
|
120 |
+
|
121 |
+
int32_t m = a.size(0);
|
122 |
+
int32_t n = b.size(1);
|
123 |
+
int32_t k = a.size(1);
|
124 |
+
|
125 |
+
int64_t lda = a.stride(0);
|
126 |
+
int64_t ldb = b.stride(1);
|
127 |
+
int64_t ldc = out.stride(0);
|
128 |
+
|
129 |
+
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
130 |
+
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
131 |
+
using StrideC = typename Gemm::StrideC;
|
132 |
+
|
133 |
+
StrideA a_stride{lda, Int<1>{}, 0};
|
134 |
+
StrideB b_stride{ldb, Int<1>{}, 0};
|
135 |
+
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
136 |
+
|
137 |
+
using GemmKernel = typename Gemm::GemmKernel;
|
138 |
+
typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
|
139 |
+
|
140 |
+
auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
|
141 |
+
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
|
142 |
+
typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
|
143 |
+
b_stride};
|
144 |
+
|
145 |
+
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
146 |
+
typename GemmKernel::EpilogueArguments epilogue_args{
|
147 |
+
Gemm::Epilogue::prepare_args(
|
148 |
+
std::forward<EpilogueArgs>(epilogue_params)...),
|
149 |
+
c_ptr, c_stride, c_ptr, c_stride};
|
150 |
+
|
151 |
+
typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
|
152 |
+
prob_shape, mainloop_args, epilogue_args};
|
153 |
+
|
154 |
+
// Launch the CUTLASS GEMM kernel.
|
155 |
+
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
156 |
+
GemmOp gemm_op;
|
157 |
+
CUTLASS_CHECK(gemm_op.can_implement(args));
|
158 |
+
|
159 |
+
size_t workspace_size = gemm_op.get_workspace_size(args);
|
160 |
+
auto const workspace_options =
|
161 |
+
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
162 |
+
auto workspace = torch::empty(workspace_size, workspace_options);
|
163 |
+
|
164 |
+
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
165 |
+
|
166 |
+
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
167 |
+
CUTLASS_CHECK(status);
|
168 |
+
}
|
169 |
+
|
170 |
+
template <typename InType, typename OutType,
|
171 |
+
template <typename, typename, typename> typename Epilogue>
|
172 |
+
struct sm90_fp8_config_default {
|
173 |
+
// M in (128, inf)
|
174 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
175 |
+
using KernelSchedule =
|
176 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
177 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
178 |
+
using TileShape = Shape<_128, _128, _128>;
|
179 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
180 |
+
using Cutlass3xGemm =
|
181 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
182 |
+
KernelSchedule, EpilogueSchedule>;
|
183 |
+
};
|
184 |
+
|
185 |
+
template <typename InType, typename OutType,
|
186 |
+
template <typename, typename, typename> typename Epilogue>
|
187 |
+
struct sm90_fp8_config_M128 {
|
188 |
+
// M in (64, 128]
|
189 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
190 |
+
using KernelSchedule =
|
191 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
192 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
193 |
+
using TileShape = Shape<_64, _128, _128>;
|
194 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
195 |
+
using Cutlass3xGemm =
|
196 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
197 |
+
KernelSchedule, EpilogueSchedule>;
|
198 |
+
};
|
199 |
+
|
200 |
+
template <typename InType, typename OutType,
|
201 |
+
template <typename, typename, typename> typename Epilogue>
|
202 |
+
struct sm90_fp8_config_M64 {
|
203 |
+
// M in [1, 64]
|
204 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
205 |
+
using KernelSchedule =
|
206 |
+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
|
207 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
208 |
+
using TileShape = Shape<_64, _64, _128>;
|
209 |
+
using ClusterShape = Shape<_1, _8, _1>;
|
210 |
+
|
211 |
+
using Cutlass3xGemm =
|
212 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
213 |
+
KernelSchedule, EpilogueSchedule>;
|
214 |
+
};
|
215 |
+
|
216 |
+
template <typename InType, typename OutType,
|
217 |
+
template <typename, typename, typename> typename Epilogue>
|
218 |
+
struct sm90_int8_config_default {
|
219 |
+
// For M > 128 and any N
|
220 |
+
static_assert(std::is_same<InType, int8_t>());
|
221 |
+
using KernelSchedule =
|
222 |
+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
223 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
224 |
+
using TileShape = Shape<_128, _128, _128>;
|
225 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
226 |
+
using Cutlass3xGemm =
|
227 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
228 |
+
KernelSchedule, EpilogueSchedule>;
|
229 |
+
};
|
230 |
+
|
231 |
+
template <typename InType, typename OutType,
|
232 |
+
template <typename, typename, typename> typename Epilogue>
|
233 |
+
struct sm90_int8_config_M128 {
|
234 |
+
// For M in (64, 128] and any N
|
235 |
+
static_assert(std::is_same<InType, int8_t>());
|
236 |
+
using KernelSchedule =
|
237 |
+
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
|
238 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
239 |
+
using TileShape = Shape<_64, _128, _128>;
|
240 |
+
using ClusterShape = Shape<_2, _1, _1>;
|
241 |
+
using Cutlass3xGemm =
|
242 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
243 |
+
KernelSchedule, EpilogueSchedule>;
|
244 |
+
};
|
245 |
+
|
246 |
+
template <typename InType, typename OutType,
|
247 |
+
template <typename, typename, typename> typename Epilogue>
|
248 |
+
struct sm90_int8_config_M64 {
|
249 |
+
// For M in (32, 64] and any N
|
250 |
+
static_assert(std::is_same<InType, int8_t>());
|
251 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
252 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
253 |
+
using TileShape = Shape<_64, _64, _256>;
|
254 |
+
using ClusterShape = Shape<_1, _1, _1>;
|
255 |
+
using Cutlass3xGemm =
|
256 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
257 |
+
KernelSchedule, EpilogueSchedule>;
|
258 |
+
};
|
259 |
+
|
260 |
+
template <typename InType, typename OutType,
|
261 |
+
template <typename, typename, typename> typename Epilogue>
|
262 |
+
struct sm90_int8_config_M32_NBig {
|
263 |
+
// For M in [1, 32] and N >= 8192
|
264 |
+
static_assert(std::is_same<InType, int8_t>());
|
265 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
266 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
267 |
+
using TileShape = Shape<_64, _128, _256>;
|
268 |
+
using ClusterShape = Shape<_1, _4, _1>;
|
269 |
+
using Cutlass3xGemm =
|
270 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
271 |
+
KernelSchedule, EpilogueSchedule>;
|
272 |
+
};
|
273 |
+
|
274 |
+
template <typename InType, typename OutType,
|
275 |
+
template <typename, typename, typename> typename Epilogue>
|
276 |
+
struct sm90_int8_config_M32_NSmall {
|
277 |
+
// For M in [1, 32] and N < 8192
|
278 |
+
static_assert(std::is_same<InType, int8_t>());
|
279 |
+
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
|
280 |
+
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
|
281 |
+
using TileShape = Shape<_64, _64, _256>;
|
282 |
+
using ClusterShape = Shape<_1, _8, _1>;
|
283 |
+
using Cutlass3xGemm =
|
284 |
+
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
285 |
+
KernelSchedule, EpilogueSchedule>;
|
286 |
+
};
|
287 |
+
|
288 |
+
} // namespace
|
289 |
+
|
290 |
+
template <typename InType, typename OutType,
|
291 |
+
template <typename, typename, typename> typename Epilogue,
|
292 |
+
typename... EpilogueArgs>
|
293 |
+
void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
294 |
+
torch::Tensor const& b,
|
295 |
+
EpilogueArgs&&... args) {
|
296 |
+
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
297 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
298 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
299 |
+
|
300 |
+
using Cutlass3xGemmDefault =
|
301 |
+
typename sm90_fp8_config_default<InType, OutType,
|
302 |
+
Epilogue>::Cutlass3xGemm;
|
303 |
+
using Cutlass3xGemmM64 =
|
304 |
+
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
305 |
+
using Cutlass3xGemmM128 =
|
306 |
+
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
307 |
+
|
308 |
+
uint32_t const m = a.size(0);
|
309 |
+
uint32_t const mp2 =
|
310 |
+
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
311 |
+
|
312 |
+
if (mp2 <= 64) {
|
313 |
+
// m in [1, 64]
|
314 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
315 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
316 |
+
} else if (mp2 <= 128) {
|
317 |
+
// m in (64, 128]
|
318 |
+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
319 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
320 |
+
} else {
|
321 |
+
// m in (128, inf)
|
322 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
323 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
324 |
+
}
|
325 |
+
}
|
326 |
+
|
327 |
+
template <typename InType, typename OutType,
|
328 |
+
template <typename, typename, typename> typename Epilogue,
|
329 |
+
typename... EpilogueArgs>
|
330 |
+
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
331 |
+
torch::Tensor const& b,
|
332 |
+
EpilogueArgs&&... args) {
|
333 |
+
static_assert(std::is_same<InType, int8_t>());
|
334 |
+
TORCH_CHECK(a.dtype() == torch::kInt8);
|
335 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
336 |
+
|
337 |
+
using Cutlass3xGemmDefault =
|
338 |
+
typename sm90_int8_config_default<InType, OutType,
|
339 |
+
Epilogue>::Cutlass3xGemm;
|
340 |
+
using Cutlass3xGemmM128 =
|
341 |
+
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
342 |
+
using Cutlass3xGemmM64 =
|
343 |
+
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
344 |
+
using Cutlass3xGemmM32NBig =
|
345 |
+
typename sm90_int8_config_M32_NBig<InType, OutType,
|
346 |
+
Epilogue>::Cutlass3xGemm;
|
347 |
+
using Cutlass3xGemmM32NSmall =
|
348 |
+
typename sm90_int8_config_M32_NSmall<InType, OutType,
|
349 |
+
Epilogue>::Cutlass3xGemm;
|
350 |
+
|
351 |
+
uint32_t const n = out.size(1);
|
352 |
+
bool const is_small_n = n < 8192;
|
353 |
+
|
354 |
+
uint32_t const m = a.size(0);
|
355 |
+
uint32_t const mp2 =
|
356 |
+
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
357 |
+
|
358 |
+
if (mp2 <= 32) {
|
359 |
+
// m in [1, 32]
|
360 |
+
if (is_small_n) {
|
361 |
+
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
|
362 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
363 |
+
} else {
|
364 |
+
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
|
365 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
366 |
+
}
|
367 |
+
} else if (mp2 <= 64) {
|
368 |
+
// m in (32, 64]
|
369 |
+
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
370 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
371 |
+
} else if (mp2 <= 128) {
|
372 |
+
// m in (64, 128]
|
373 |
+
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
374 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
375 |
+
} else {
|
376 |
+
// m in (128, inf)
|
377 |
+
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
378 |
+
out, a, b, std::forward<EpilogueArgs>(args)...);
|
379 |
+
}
|
380 |
+
}
|
381 |
+
|
382 |
+
template <template <typename, typename, typename> typename Epilogue,
|
383 |
+
typename... EpilogueArgs>
|
384 |
+
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
385 |
+
torch::Tensor const& b,
|
386 |
+
EpilogueArgs&&... epilogue_args) {
|
387 |
+
if (a.dtype() == torch::kInt8) {
|
388 |
+
TORCH_CHECK(b.dtype() == torch::kInt8);
|
389 |
+
|
390 |
+
if (out.dtype() == torch::kBFloat16) {
|
391 |
+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
392 |
+
Epilogue>(
|
393 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
394 |
+
} else {
|
395 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
396 |
+
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
397 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
398 |
+
}
|
399 |
+
} else {
|
400 |
+
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
401 |
+
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
402 |
+
|
403 |
+
if (out.dtype() == torch::kBFloat16) {
|
404 |
+
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
405 |
+
cutlass::bfloat16_t, Epilogue>(
|
406 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
407 |
+
} else {
|
408 |
+
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
409 |
+
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
|
410 |
+
cutlass::half_t, Epilogue>(
|
411 |
+
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
412 |
+
}
|
413 |
+
}
|
414 |
+
}
|
415 |
+
|
416 |
+
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
417 |
+
torch::Tensor const& b,
|
418 |
+
torch::Tensor const& a_scales,
|
419 |
+
torch::Tensor const& b_scales,
|
420 |
+
c10::optional<torch::Tensor> const& bias) {
|
421 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
422 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
423 |
+
if (bias) {
|
424 |
+
TORCH_CHECK(bias->dtype() == c.dtype(),
|
425 |
+
"currently bias dtype must match output dtype ", c.dtype());
|
426 |
+
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBias>(
|
427 |
+
c, a, b, a_scales, b_scales, *bias);
|
428 |
+
} else {
|
429 |
+
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogue>(
|
430 |
+
c, a, b, a_scales, b_scales);
|
431 |
+
}
|
432 |
+
}
|
433 |
+
|
434 |
+
void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
|
435 |
+
torch::Tensor const& b,
|
436 |
+
torch::Tensor const& a_scales,
|
437 |
+
torch::Tensor const& b_scales,
|
438 |
+
torch::Tensor const& azp_adj,
|
439 |
+
c10::optional<torch::Tensor> const& azp,
|
440 |
+
c10::optional<torch::Tensor> const& bias) {
|
441 |
+
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
442 |
+
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
443 |
+
|
444 |
+
if (azp) {
|
445 |
+
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzpToken>(
|
446 |
+
out, a, b, a_scales, b_scales, azp_adj, *azp, bias);
|
447 |
+
} else {
|
448 |
+
return cutlass_scaled_mm_sm90_epilogue<c3x::ScaledEpilogueBiasAzp>(
|
449 |
+
out, a, b, a_scales, b_scales, azp_adj, bias);
|
450 |
+
}
|
451 |
+
}
|
452 |
+
|
453 |
+
#endif
|
cutlass_w8a8/scaled_mm_entry.cu
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cudaTypedefs.h>
|
2 |
+
|
3 |
+
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
#include <torch/all.h>
|
5 |
+
|
6 |
+
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
|
7 |
+
torch::Tensor const& b,
|
8 |
+
torch::Tensor const& a_scales,
|
9 |
+
torch::Tensor const& b_scales,
|
10 |
+
c10::optional<torch::Tensor> const& bias);
|
11 |
+
|
12 |
+
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
|
13 |
+
torch::Tensor const& b,
|
14 |
+
torch::Tensor const& a_scales,
|
15 |
+
torch::Tensor const& b_scales,
|
16 |
+
c10::optional<torch::Tensor> const& bias);
|
17 |
+
|
18 |
+
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
|
19 |
+
torch::Tensor const& b,
|
20 |
+
torch::Tensor const& a_scales,
|
21 |
+
torch::Tensor const& b_scales,
|
22 |
+
c10::optional<torch::Tensor> const& bias);
|
23 |
+
|
24 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
25 |
+
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
|
26 |
+
torch::Tensor const& b,
|
27 |
+
torch::Tensor const& a_scales,
|
28 |
+
torch::Tensor const& b_scales,
|
29 |
+
c10::optional<torch::Tensor> const& bias);
|
30 |
+
#endif
|
31 |
+
|
32 |
+
void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a,
|
33 |
+
torch::Tensor const& b,
|
34 |
+
torch::Tensor const& a_scales,
|
35 |
+
torch::Tensor const& b_scales,
|
36 |
+
torch::Tensor const& azp_adj,
|
37 |
+
c10::optional<torch::Tensor> const& azp,
|
38 |
+
c10::optional<torch::Tensor> const& bias);
|
39 |
+
|
40 |
+
void cutlass_scaled_mm_azp_sm80(torch::Tensor& c, torch::Tensor const& a,
|
41 |
+
torch::Tensor const& b,
|
42 |
+
torch::Tensor const& a_scales,
|
43 |
+
torch::Tensor const& b_scales,
|
44 |
+
torch::Tensor const& azp_adj,
|
45 |
+
c10::optional<torch::Tensor> const& azp,
|
46 |
+
c10::optional<torch::Tensor> const& bias);
|
47 |
+
|
48 |
+
void cutlass_scaled_mm_azp_sm89(torch::Tensor& c, torch::Tensor const& a,
|
49 |
+
torch::Tensor const& b,
|
50 |
+
torch::Tensor const& a_scales,
|
51 |
+
torch::Tensor const& b_scales,
|
52 |
+
torch::Tensor const& azp_adj,
|
53 |
+
c10::optional<torch::Tensor> const& azp,
|
54 |
+
c10::optional<torch::Tensor> const& bias);
|
55 |
+
|
56 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
57 |
+
void cutlass_scaled_mm_azp_sm90(torch::Tensor& c, torch::Tensor const& a,
|
58 |
+
torch::Tensor const& b,
|
59 |
+
torch::Tensor const& a_scales,
|
60 |
+
torch::Tensor const& b_scales,
|
61 |
+
torch::Tensor const& azp_adj,
|
62 |
+
c10::optional<torch::Tensor> const& azp,
|
63 |
+
c10::optional<torch::Tensor> const& bias);
|
64 |
+
#endif
|
65 |
+
|
66 |
+
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
67 |
+
// CUTLASS FP8 kernels need at least
|
68 |
+
// CUDA 12.0 on SM90 systems (Hopper)
|
69 |
+
// CUDA 12.4 on SM89 systems (Lovelace)
|
70 |
+
|
71 |
+
#if defined CUDA_VERSION
|
72 |
+
if (cuda_device_capability >= 90) {
|
73 |
+
return CUDA_VERSION >= 12000;
|
74 |
+
} else if (cuda_device_capability >= 89) {
|
75 |
+
return CUDA_VERSION >= 12040;
|
76 |
+
}
|
77 |
+
#endif
|
78 |
+
|
79 |
+
return false;
|
80 |
+
}
|
81 |
+
|
82 |
+
int32_t get_sm_version_num() {
|
83 |
+
int32_t major_capability, minor_capability;
|
84 |
+
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
|
85 |
+
0);
|
86 |
+
cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
|
87 |
+
0);
|
88 |
+
int32_t version_num = major_capability * 10 + minor_capability;
|
89 |
+
return version_num;
|
90 |
+
}
|
91 |
+
|
92 |
+
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
93 |
+
torch::Tensor const& b, torch::Tensor const& a_scales,
|
94 |
+
torch::Tensor const& b_scales,
|
95 |
+
c10::optional<torch::Tensor> const& bias) {
|
96 |
+
// Checks for conformality
|
97 |
+
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
98 |
+
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
99 |
+
b.size(1) == c.size(1));
|
100 |
+
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
101 |
+
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
102 |
+
|
103 |
+
// Check for strides and alignment
|
104 |
+
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
105 |
+
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
106 |
+
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
107 |
+
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
108 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
109 |
+
|
110 |
+
if (bias) {
|
111 |
+
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
|
112 |
+
bias->dim() == 1);
|
113 |
+
}
|
114 |
+
|
115 |
+
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
116 |
+
int32_t version_num = get_sm_version_num();
|
117 |
+
// Hopper
|
118 |
+
|
119 |
+
// Guard against compilation issues for sm90 kernels
|
120 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
121 |
+
if (version_num >= 90) {
|
122 |
+
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
|
123 |
+
return;
|
124 |
+
}
|
125 |
+
#endif
|
126 |
+
|
127 |
+
if (version_num == 89) {
|
128 |
+
// Ada Lovelace
|
129 |
+
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
|
130 |
+
return;
|
131 |
+
}
|
132 |
+
|
133 |
+
if (version_num >= 80) {
|
134 |
+
// Ampere
|
135 |
+
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
|
136 |
+
return;
|
137 |
+
}
|
138 |
+
|
139 |
+
if (version_num >= 75) {
|
140 |
+
// Turing
|
141 |
+
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
|
142 |
+
return;
|
143 |
+
}
|
144 |
+
|
145 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
146 |
+
false,
|
147 |
+
"No compiled cutlass_scaled_mm for a compute capability less than "
|
148 |
+
"CUDA device capability: ",
|
149 |
+
version_num);
|
150 |
+
}
|
151 |
+
|
152 |
+
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
153 |
+
torch::Tensor const& b,
|
154 |
+
torch::Tensor const& a_scales,
|
155 |
+
torch::Tensor const& b_scales,
|
156 |
+
torch::Tensor const& azp_adj,
|
157 |
+
c10::optional<torch::Tensor> const& azp,
|
158 |
+
c10::optional<torch::Tensor> const& bias) {
|
159 |
+
// Checks for conformality
|
160 |
+
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
|
161 |
+
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
|
162 |
+
b.size(1) == c.size(1));
|
163 |
+
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
|
164 |
+
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
|
165 |
+
|
166 |
+
// Check for strides and alignment
|
167 |
+
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
|
168 |
+
TORCH_CHECK(b.stride(0) == 1); // Column-major
|
169 |
+
TORCH_CHECK(c.stride(0) % 16 == 0 &&
|
170 |
+
b.stride(1) % 16 == 0); // 16 Byte Alignment
|
171 |
+
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
172 |
+
|
173 |
+
// bias, azp, azp_adj are all 1d
|
174 |
+
// bias and azp_adj have n elements, azp has m elements
|
175 |
+
if (bias) {
|
176 |
+
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
|
177 |
+
}
|
178 |
+
if (azp) {
|
179 |
+
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
|
180 |
+
}
|
181 |
+
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
|
182 |
+
|
183 |
+
// azp & bias types
|
184 |
+
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
|
185 |
+
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
|
186 |
+
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
|
187 |
+
"currently bias dtype must match output dtype ", c.dtype());
|
188 |
+
|
189 |
+
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
190 |
+
|
191 |
+
int32_t version_num = get_sm_version_num();
|
192 |
+
|
193 |
+
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
194 |
+
if (version_num >= 90) {
|
195 |
+
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
196 |
+
return;
|
197 |
+
}
|
198 |
+
#endif
|
199 |
+
|
200 |
+
if (version_num == 89) {
|
201 |
+
// Ada Lovelace
|
202 |
+
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
203 |
+
return;
|
204 |
+
}
|
205 |
+
|
206 |
+
if (version_num >= 80) {
|
207 |
+
// Ampere
|
208 |
+
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
209 |
+
return;
|
210 |
+
}
|
211 |
+
|
212 |
+
// Turing
|
213 |
+
TORCH_CHECK(version_num >= 75);
|
214 |
+
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
|
215 |
+
return;
|
216 |
+
|
217 |
+
TORCH_CHECK_NOT_IMPLEMENTED(
|
218 |
+
false,
|
219 |
+
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
|
220 |
+
"CUDA device capability: ",
|
221 |
+
version_num);
|
222 |
+
}
|
ext-torch/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
from ._ops import ops
|
7 |
+
except ImportError as e:
|
8 |
+
# Fallback for local development.
|
9 |
+
try:
|
10 |
+
import _quantization
|
11 |
+
ops = torch.ops._quantization
|
12 |
+
except ImportError:
|
13 |
+
raise e
|
14 |
+
|
15 |
+
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
|
16 |
+
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
|
17 |
+
|
18 |
+
def cutlass_scaled_mm(a: torch.Tensor,
|
19 |
+
b: torch.Tensor,
|
20 |
+
scale_a: torch.Tensor,
|
21 |
+
scale_b: torch.Tensor,
|
22 |
+
out_dtype: torch.dtype,
|
23 |
+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
24 |
+
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
|
25 |
+
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
26 |
+
assert bias is None or bias.shape[0] == b.shape[
|
27 |
+
1] and bias.dtype == out_dtype
|
28 |
+
|
29 |
+
m = a.shape[0]
|
30 |
+
n = b.shape[1]
|
31 |
+
|
32 |
+
#if current_platform.is_rocm():
|
33 |
+
# triton_scaled_mm_module = importlib.import_module(
|
34 |
+
# "vllm.model_executor.layers.quantization.compressed_tensors."
|
35 |
+
# "triton_scaled_mm")
|
36 |
+
# triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
|
37 |
+
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
38 |
+
|
39 |
+
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
40 |
+
|
41 |
+
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
|
42 |
+
|
43 |
+
return out
|
44 |
+
|
ext-torch/registration.h
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <Python.h>
|
4 |
+
|
5 |
+
#define _CONCAT(A, B) A##B
|
6 |
+
#define CONCAT(A, B) _CONCAT(A, B)
|
7 |
+
|
8 |
+
#define _STRINGIFY(A) #A
|
9 |
+
#define STRINGIFY(A) _STRINGIFY(A)
|
10 |
+
|
11 |
+
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
12 |
+
// could be a macro instead of a literal token.
|
13 |
+
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
14 |
+
|
15 |
+
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
16 |
+
// could be a macro instead of a literal token.
|
17 |
+
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
18 |
+
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
19 |
+
|
20 |
+
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
21 |
+
// via python's import statement.
|
22 |
+
#define REGISTER_EXTENSION(NAME) \
|
23 |
+
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
24 |
+
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
25 |
+
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
26 |
+
return PyModule_Create(&module); \
|
27 |
+
}
|
ext-torch/torch_binding.cpp
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
|
8 |
+
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
9 |
+
// quantization, as well as bias
|
10 |
+
ops.def(
|
11 |
+
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
12 |
+
" Tensor b, Tensor a_scales,"
|
13 |
+
" Tensor b_scales, Tensor? bias) -> ()");
|
14 |
+
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);
|
15 |
+
|
16 |
+
// CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
17 |
+
// quantization.
|
18 |
+
ops.def(
|
19 |
+
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
20 |
+
" Tensor b, Tensor a_scales,"
|
21 |
+
" Tensor b_scales, Tensor azp_adj,"
|
22 |
+
" Tensor? azp, Tensor? bias) -> ()");
|
23 |
+
ops.impl("cutlass_scaled_mm_azp", torch::kCUDA, &cutlass_scaled_mm_azp);
|
24 |
+
|
25 |
+
// Check if cutlass scaled_mm is supported for CUDA devices of the given
|
26 |
+
// capability
|
27 |
+
ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
|
28 |
+
ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
|
29 |
+
|
30 |
+
}
|
31 |
+
|
32 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
ext-torch/torch_binding.h
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
6 |
+
|
7 |
+
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
8 |
+
torch::Tensor const& b, torch::Tensor const& a_scales,
|
9 |
+
torch::Tensor const& b_scales,
|
10 |
+
c10::optional<torch::Tensor> const& bias);
|
11 |
+
|
12 |
+
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
|
13 |
+
torch::Tensor const& b,
|
14 |
+
torch::Tensor const& a_scales,
|
15 |
+
torch::Tensor const& b_scales,
|
16 |
+
torch::Tensor const& azp_adj,
|
17 |
+
c10::optional<torch::Tensor> const& azp,
|
18 |
+
c10::optional<torch::Tensor> const& bias);
|