140 problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1),
162 problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1),
174 assert(_problem_size.
batch() == 1);
182 long long _batch_stride_A,
184 long long _batch_stride_B,
187 long long _batch_stride_C,
189 long long _batch_stride_D
GEMM problem description.
Definition: gemm_desc.h:50
TensorRef< CType const, 2 > TensorRefC
Tensor reference to C operand.
Definition: gemm_desc.h:74
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE GemmDesc(GemmCoord _problem_size, SType _alpha, TensorRefA const &_A, long long _batch_stride_A, TensorRefB const &_B, long long _batch_stride_B, SType _beta, TensorRefC const &_C, long long _batch_stride_C, TensorRefD const &_D, long long _batch_stride_D)
Constructor for strided batch GEMM GEMM.
Definition: gemm_desc.h:179
TensorRefC C
The source matrix C.
Definition: gemm_desc.h:111
SType alpha
The alpha scaling values.
Definition: gemm_desc.h:93
TensorRefA A
The source matrix A.
Definition: gemm_desc.h:96
GemmCoord problem_size
The dimensions of the GEMM.
Definition: gemm_desc.h:90
Definition: gemm_coord.h:43
long long batch_stride_D
batch stride for D operand
Definition: gemm_desc.h:120
TensorRefB B
The source matrix B.
Definition: gemm_desc.h:102
TensorRef< AType const, 2 > TensorRefA
Tensor reference to A operand.
Definition: gemm_desc.h:62
CUTLASS_HOST_DEVICE GemmDesc(GemmCoord _problem_size, SType _alpha, TensorRefA const &_A, TensorRefB const &_B, SType _beta, TensorRefC const &_C, TensorRefD const &_D)
Constructor for basic GEMM with batch count = 1.
Definition: gemm_desc.h:154
DType_ DType
Destination accumulator type.
Definition: gemm_desc.h:77
long long batch_stride_A
batch stride for A operand
Definition: gemm_desc.h:99
SType beta
The beta scaling values.
Definition: gemm_desc.h:108
SType_ SType
Scalar type for alpha and beta.
Definition: gemm_desc.h:83
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
AType_ AType
Source accumulator matrix type.
Definition: gemm_desc.h:59
CType_ CType
Source accumulator matrix type.
Definition: gemm_desc.h:71
TensorRef< BType const, 2 > TensorRefB
Tensor reference to B operand.
Definition: gemm_desc.h:68
long long batch_stride_B
batch stride for B operand
Definition: gemm_desc.h:105
TensorRefD D
The destination matrix D.
Definition: gemm_desc.h:117
TensorRef< DType, 2 > TensorRefD
Tensor reference to D operand.
Definition: gemm_desc.h:80
BType_ BType
Destination accumulator type.
Definition: gemm_desc.h:65
CUTLASS_HOST_DEVICE Index const & batch() const
Returns the GEMM batch coordinate.
Definition: gemm_coord.h:113
CUTLASS_HOST_DEVICE GemmDesc()
Default ctor.
Definition: gemm_desc.h:128
CUTLASS_HOST_DEVICE GemmDesc(Coord< 3 > _problem_size, SType _alpha, TensorRefA const &_A, TensorRefB const &_B, SType _beta, TensorRefC const &_C, TensorRefD const &_D)
Constructor for basic GEMM with batch count = 1.
Definition: gemm_desc.h:132
long long batch_stride_C
batch stride for C operand
Definition: gemm_desc.h:114
Index_ Index
Index type for dimensions and strides.
Definition: gemm_desc.h:56
GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system...