31 #ifdef CUTLASS_USE_WMMA_API 45 typename WarpGemmShape_,
46 typename InstructionShape_>
47 struct WmmaGemmMultiplyAdd {
49 typedef InstructionShape_ InstructionShape;
51 typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp;
53 typedef WarpGemmShape_ WarpGemmShape;
55 typedef WarpGemmShape_ AccumulatorsPerWarp;
57 typedef ScalarA_ ScalarA;
59 typedef ScalarB_ ScalarB;
61 typedef ScalarC_ ScalarC;
66 typedef WmmaMatrix<GemmOperand::kA, kLayoutA_, ScalarA, InstructionShape> ElementA;
68 typedef Fragment<ElementA, Iterations::kW> FragmentA;
71 typedef WmmaMatrix<GemmOperand::kB, kLayoutB_, ScalarB, InstructionShape> ElementB;
73 typedef Fragment<ElementB, Iterations::kH> FragmentB;
76 typedef WmmaMatrix<GemmOperand::kC, kLayoutC_, ScalarC, InstructionShape> ElementC;
78 typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
81 CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
84 CUTLASS_DEVICE
void multiply_add(FragmentA
const& a,
86 Accumulators
const& c,
88 for (
int j = 0; j < Iterations::kH; ++j) {
89 for (
int i = 0; i < Iterations::kW; ++i) {
91 ElementA
const& elt_a = a[i];
92 ElementB
const& elt_b = b[j];
93 ElementC
const& elt_c = c[j * Iterations::kW + i];
96 ElementC& elt_d = d[j * Iterations::kW + i];
99 nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
107 #ifdef CUTLASS_USE_SUBBYTE_WMMA 108 template<
typename WarpGemmShape_>
110 struct WmmaGemmMultiplyAdd <MatrixLayout::
kRowMajor,
119 typedef Shape<128, 8, 8> InstructionShape;
121 typedef Shape<1, 4, 8> ThreadsPerWarp;
123 typedef WarpGemmShape_ WarpGemmShape;
125 typedef WarpGemmShape_ AccumulatorsPerWarp;
127 typedef Vector<bin1_t, 32> ScalarA;
129 typedef Vector<bin1_t, 32> ScalarB;
139 InstructionShape> ElementA;
141 typedef Fragment<ElementA, Iterations::kW> FragmentA;
147 InstructionShape> ElementB;
149 typedef Fragment<ElementB, Iterations::kH> FragmentB;
155 InstructionShape> ElementC;
157 typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
160 CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
163 CUTLASS_DEVICE
void multiply_add(FragmentA
const& a,
165 Accumulators
const& c,
167 for (
int j = 0; j < Iterations::kH; ++j) {
168 for (
int i = 0; i < Iterations::kW; ++i) {
170 ElementA
const& elt_a = a[i];
171 ElementB
const& elt_b = b[j];
172 ElementC
const& elt_c = c[j * Iterations::kW + i];
175 ElementC& elt_d = d[j * Iterations::kW + i];
178 nvcuda::wmma::bmma_sync(elt_d,
182 nvcuda::wmma::experimental::bmmaBitOpXOR,
183 nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
192 #ifdef CUTLASS_USE_SUBBYTE_WMMA 193 template<
typename WarpGemmShape_>
195 struct WmmaGemmMultiplyAdd <MatrixLayout::
kRowMajor,
204 typedef Shape<32, 8, 8> InstructionShape;
206 typedef Shape<1, 4, 8> ThreadsPerWarp;
208 typedef WarpGemmShape_ WarpGemmShape;
210 typedef WarpGemmShape_ AccumulatorsPerWarp;
212 typedef Vector<int4_t, 8> ScalarA;
214 typedef Vector<int4_t, 8> ScalarB;
224 InstructionShape> ElementA;
226 typedef Fragment<ElementA, Iterations::kW> FragmentA;
232 InstructionShape> ElementB;
234 typedef Fragment<ElementB, Iterations::kH> FragmentB;
240 InstructionShape> ElementC;
242 typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
245 CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
248 CUTLASS_DEVICE
void multiply_add(FragmentA
const& a,
250 Accumulators
const& c,
252 for (
int j = 0; j < Iterations::kH; ++j) {
253 for (
int i = 0; i < Iterations::kW; ++i) {
255 ElementA
const& elt_a = a[i];
256 ElementB
const& elt_b = b[j];
257 ElementC
const& elt_c = c[j * Iterations::kW + i];
260 ElementC& elt_d = d[j * Iterations::kW + i];
263 nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
272 #ifdef CUTLASS_USE_SUBBYTE_WMMA 273 template<
typename WarpGemmShape_>
275 struct WmmaGemmMultiplyAdd <MatrixLayout::
kRowMajor,
284 typedef Shape<32, 8, 8> InstructionShape;
286 typedef Shape<1, 4, 8> ThreadsPerWarp;
288 typedef WarpGemmShape_ WarpGemmShape;
290 typedef WarpGemmShape_ AccumulatorsPerWarp;
292 typedef Vector<uint4_t, 8> ScalarA;
294 typedef Vector<uint4_t, 8> ScalarB;
304 InstructionShape> ElementA;
306 typedef Fragment<ElementA, Iterations::kW> FragmentA;
312 InstructionShape> ElementB;
314 typedef Fragment<ElementB, Iterations::kH> FragmentB;
320 InstructionShape> ElementC;
322 typedef Fragment<ElementC, Iterations::kH * Iterations::kW> Accumulators;
325 CUTLASS_DEVICE WmmaGemmMultiplyAdd() {}
328 CUTLASS_DEVICE
void multiply_add(FragmentA
const& a,
330 Accumulators
const& c,
332 for (
int j = 0; j < Iterations::kH; ++j) {
333 for (
int i = 0; i < Iterations::kW; ++i) {
335 ElementA
const& elt_a = a[i];
336 ElementB
const& elt_b = b[j];
337 ElementC
const& elt_c = c[j * Iterations::kW + i];
340 ElementC& elt_d = d[j * Iterations::kW + i];
343 nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c);
355 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
Definition: matrix_traits.h:357
Definition: matrix_traits.h:159
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...