41 template <
typename ThreadGemmShape_,
typename ThreadsPerWarp_>
67 static_assert(AccumulatorsPerThread::kH % 2 == 0,
"Invalid size");
68 static_assert(AccumulatorsPerThread::kW % 2 == 0,
"Invalid size");
78 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 80 __half2
const* a_half2 =
reinterpret_cast<__half2 const*
>(&a[0]);
81 __half2
const* b_half2 =
reinterpret_cast<__half2 const*
>(&b[0]);
82 __half2
const* c_half2 =
reinterpret_cast<__half2 const*
>(&c[0]);
85 __half2* d_half2 =
reinterpret_cast<__half2*
>(&d[0]);
87 for (
int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
88 for (
int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
90 int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
91 int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
94 d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
96 d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
CUTLASS_DEVICE ThreadMultiplyAdd()
Make sure there's an even number of elements in both dimensions.
Definition: hgemm_multiply_add.h:71
half ScalarC
The type for C and D.
Definition: hgemm_multiply_add.h:62
Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: hgemm_multiply_add.h:60
ThreadGemmShape_ ThreadGemmShape
The number of accumulators per thread.
Definition: hgemm_multiply_add.h:46
half ScalarB
The type for B.
Definition: hgemm_multiply_add.h:58
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
A template defining Fragment Concept.
Definition: fragment.h:99
Template implementing matrix multiply-add operations on fragments.
Shape< 1, 1, 2, 1 > InstructionShape
The shape of the instruction.
Definition: hgemm_multiply_add.h:44
ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: hgemm_multiply_add.h:52
Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: hgemm_multiply_add.h:56
Fragment< half, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > Accumulators
The accumulators.
Definition: hgemm_multiply_add.h:64
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = a*b + c.
Definition: hgemm_multiply_add.h:74
half ScalarA
The type for A.
Definition: hgemm_multiply_add.h:54
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Template performing matrix multiply-add operation within a thread.
Definition: thread_multiply_add.h:44
ThreadGemmShape AccumulatorsPerThread
Aliased for compatibility. Will be removed for CUTLASS v2.0.
Definition: hgemm_multiply_add.h:48
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: hgemm_multiply_add.h:50
Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers...