52 template <
typename Tile_,
typename Threads_,
bool = (Tile_::kW < Threads_::kW)>
53 struct ReshapeThreads {
54 typedef Threads_ Threads;
57 template <
typename Tile_,
typename Threads_>
59 typedef Shape<Threads_::kD, Threads_::kH * Threads_::kW / Tile_::kW, Tile_::kW, 1>
Threads;
98 VectorizedTile::kH / Threads::kH,
99 VectorizedTile::kW / Threads::kW,
112 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
119 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kStr
ideH_,
int kAccessSize_>
121 MatrixLayout::kColumnMajor,
155 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
162 template <
typename TileTraits_,
typename Index_ =
int>
165 typename TileTraits_::Scalar,
166 TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
167 : IteratorAdvance::kW,
168 MemorySpace::kGlobal,
173 typename TileTraits_::Scalar,
182 typedef typename TileTraits_::Tile
Tile;
186 typedef typename TileTraits_::Scalar
Scalar;
188 typedef typename TileTraits_::Threads
Threads;
224 for (
int d = 0; d < Base::Iterations::kD; ++d) {
225 for (
int h = 0; h < Base::Iterations::kH; ++h) {
226 for (
int w = 0; w < Base::Iterations::kW; ++w) {
227 for (
int c = 0; c < Base::Iterations::kC; ++c) {
228 bool flag = w * Base::Delta::kW +
thread_offset[2] + block_offset[2] < bounds[2];
232 (h * Base::Delta::kH + d * Base::Delta::kD) +
thread_offset[1] + block_offset[1] <
235 flag = flag && (h * Base::Delta::kH) +
thread_offset[1] + block_offset[1] < bounds[1];
283 for (
int d = 0; d < Base::Iterations::kD; ++d) {
284 for (
int h = 0; h < Base::Iterations::kH; ++h) {
285 for (
int w = 0; w < Base::Iterations::kW; ++w) {
286 for (
int c = 0; c < Base::Iterations::kC; ++c) {
289 offset +=
thread_offset[1] + h * Base::Delta::kH + d * Base::Delta::kD;
313 LongIndex _offset = offset.template dot<LongIndex>(
331 template <
typename Fragment>
334 for (
int d = 0; d < Base::Iterations::kD; ++d) {
335 for (
int h = 0; h < Base::Iterations::kH; ++h) {
336 for (
int w = 0; w < Base::Iterations::kW; ++w) {
337 for (
int c = 0; c < Base::Iterations::kC; ++c) {
338 if (
valid(d, h, w, c)) {
340 reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
347 if (w < Base::Iterations::kW - 1) {
351 if (h < Base::Iterations::kH - 1) {
355 if (d < Base::Iterations::kD - 1) {
365 template <
typename TileTraits_,
typename Index_ =
int>
367 typename TileTraits_::Scalar,
369 MemorySpace::kGlobal,
375 typename TileTraits_::Scalar,
385 typedef typename TileTraits_::Scalar
Scalar;
387 typedef typename TileTraits_::Pointer
Pointer;
389 typedef typename TileTraits_::Threads
Threads;
417 Index epilogue_stride_w,
418 Index epilogue_delta_w) {
424 stride_h = TileTraits_::ThreadsDelta::kH * ldm;
427 inc_h = ldm * TileTraits_::kStrideH;
429 (ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
434 -((TileTraits_::kStrideH * (Base::Iterations::kH - 1) - 1) + epilogue_delta_w);
441 Index _predicate_offset) {
479 for (
int i = 0; i < Base::Iterations::kW; ++i) {
504 LongIndex _offset = offset.template dot<LongIndex>(
548 template <
typename Fragment>
551 for (
int d = 0; d < Base::Iterations::kD; ++d) {
552 for (
int h = 0; h < Base::Iterations::kH; ++h) {
553 for (
int w = 0; w < Base::Iterations::kW; ++w) {
554 for (
int c = 0; c < Base::Iterations::kC; ++c) {
555 if (
valid(d, h, w, c)) {
557 reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
564 if (w < Base::Iterations::kW - 1) {
568 if (h < Base::Iterations::kH - 1) {
572 if (d < Base::Iterations::kD - 1) {
579 template <
typename Fragment>
582 for (
int d = 0; d < Base::Iterations::kD; ++d) {
583 for (
int h = 0; h < Base::Iterations::kH; ++h) {
584 for (
int w = 0; w < Base::Iterations::kW; ++w) {
585 for (
int c = 0; c < Base::Iterations::kC; ++c) {
586 if (
valid(d, h, w, c)) {
588 reinterpret_cast<typename Base::AccessType&>(frag_iterator.at(d, h, w, c)),
595 if (w < Base::Iterations::kW - 1) {
599 if (h < Base::Iterations::kH - 1) {
603 if (d < Base::Iterations::kD - 1) {
Definition: gemm_global_tile.h:120
Shape< 0, Threads::kH, Threads::kW *kAccessSize > Delta
The strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:92
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:503
CUTLASS_HOST_DEVICE void add_pointer_offset(LongIndex offset)
add pointer offset
Definition: gemm_global_tile.h:545
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long _stride_d, Index _stride_h, Index _inc_advance, Index _inc_h, Index _predicate_inc_advance, Index _predicate_inc_h, Index _predicate_offset)
Definition: gemm_global_tile.h:439
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:332
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:406
cutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount > PredicateVector
Definition: gemm_global_tile.h:198
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:180
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:467
Base::Params BaseParams
Iterator parameters type.
Definition: gemm_global_tile.h:201
CUTLASS_HOST_DEVICE void inc_c()
Increment the pointer in the C dimension.
Definition: gemm_global_tile.h:486
Index_ Index
The index.
Definition: gemm_global_tile.h:391
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
ReshapeTile< Tile_, kAccessSize_ >::Tile VectorizedTile
The vectorized tile shape.
Definition: gemm_global_tile.h:86
GemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:372
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:382
Definition: gemm_global_tile.h:70
Scalar_ * Pointer
The pointer.
Definition: gemm_global_tile.h:78
FragmentIterator< Fragment, Iterations, AccessType > FragmentIterator
The fragment iterator.
Definition: tile_iterator.h:202
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Kind
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE int initialize(Scalar const *ptr, Index stride_d, Index stride_h)
Initializes params to load a strip-mined tile, given pointer and stride_h.
Definition: gemm_global_tile.h:205
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:357
Definition: load_store.h:42
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Shape< 1, 1, VectorizedTile::kC > ThreadsDelta
The relative offset between two elements in the H/W dimension in adjacent threads.
Definition: gemm_global_tile.h:90
GemmMultiplicandTraits< Tile, kOperand, kLayout > MultiplicandTraits
Definition: gemm_global_tile.h:103
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:428
static MemorySpace::Kind const kMemorySpace
The memory space.
Definition: gemm_global_tile.h:82
long long LongIndex
Long index.
Definition: gemm_global_tile.h:192
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:267
TileIteratorBase< TileTraits_, typename TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:379
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: gemm_global_tile.h:138
Index predicate_inc_h
Definition: gemm_global_tile.h:408
CUTLASS_HOST_DEVICE int initialize()
Definition: tile_iterator.h:590
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:437
Tile_ Tile
The tile shape.
Definition: gemm_global_tile.h:84
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:470
Definition: tile_iterator.h:65
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads and increments iterator.
Definition: gemm_global_tile.h:549
TileLoadIterator< TileTraits_, typename TileTraits_::Scalar, TileTraits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index_ > Base
The base class.
Definition: gemm_global_tile.h:178
Definition: gemm_global_tile.h:203
Index inc_d
Definition: tile_iterator.h:226
CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &block, int offset=0, int pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:463
Definition: matrix_traits.h:357
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:181
Definition: gemm_global_tile.h:163
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: gemm_global_tile.h:133
Kind
Definition: load_store.h:39
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:262
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: gemm_global_tile.h:196
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:188
static int const kStrideH
The stride in the H dimension.
Definition: gemm_global_tile.h:136
static int const kH
The height of the cube.
Definition: shape.h:68
Definition: load_store.h:178
Shape< Threads_::kD, Threads_::kH *Threads_::kW/Tile_::kW, Tile_::kW, 1 > Threads
Definition: gemm_global_tile.h:59
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:408
static GemmOperand::Kind const kOperand
Identity of the operand.
Definition: gemm_global_tile.h:72
Index stride_h
Definition: tile_iterator.h:223
Defines container classes and iterators for managing a statically sized vector of boolean predicates...
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Test the validity of the.
Definition: gemm_global_tile.h:540
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
CUTLASS_HOST_DEVICE void inc_d()
Increment the pointer in the D dimension.
Definition: gemm_global_tile.h:495
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Definition: gemm_global_tile.h:58
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const &_params, const Coord< 3 > &threadblock_offset, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: gemm_global_tile.h:246
PredicateVector predicates
The predicates.
Definition: gemm_global_tile.h:217
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_global_tile.h:76
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:260
CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the iterator.
Definition: gemm_global_tile.h:311
Base::Fragment Fragment
Fragment type loaded by the iterator.
Definition: gemm_global_tile.h:184
TileTraits_::Threads Threads
The threads.
Definition: gemm_global_tile.h:389
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:151
Definition: gemm_operand.h:67
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:106
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:258
CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment without advancing the iterator..
Definition: tile_iterator.h:776
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:344
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:580
Base::Threads Threads
Definition: gemm_global_tile.h:142
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:404
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:264
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: gemm_global_tile.h:108
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:512
Index inc_h
Definition: tile_iterator.h:227
Index stride_d
Definition: tile_iterator.h:222
Shape< 0, 0, Threads::kW *ThreadsDelta::kW, kAccessSize > ImmediateOffsetStrides
Strides for immediate offset computation.
Definition: gemm_global_tile.h:95
Statically sized array of bits implementing.
Definition: predicate_vector.h:105
Definition: load_store.h:60
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:395
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: matrix_traits.h:159
Base::ImmediateOffsetStrides ImmediateOffsetStrides
Definition: gemm_global_tile.h:146
long long LongIndex
The index.
Definition: gemm_global_tile.h:393
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:385
Index inc_h
Definition: gemm_global_tile.h:406
cutlass::PredicateVector< Base::Iterations::kW > predicates
The predicates for the row.
Definition: gemm_global_tile.h:460
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:400
GemmGlobalIteratorAb< TileTraits_, Index_ > This_
This class.
Definition: gemm_global_tile.h:171
TileTraits_::Tile Tile
The tile.
Definition: gemm_global_tile.h:182
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
Shape< 1, VectorizedTile::kH/Threads::kH, VectorizedTile::kW/Threads::kW, VectorizedTile::kC/kAccessSize > Iterations
The number of iterations needed to load/store the tile.
Definition: gemm_global_tile.h:101
CUTLASS_HOST_DEVICE void inc_advance()
Increment the pointer to move to the next iteration.
Definition: gemm_global_tile.h:497
CUTLASS_HOST_DEVICE void inc_h()
Increment the pointer in the H dimension.
Definition: gemm_global_tile.h:490
CUTLASS_HOST_DEVICE void inc_w()
Increment the pointer in the W dimension.
Definition: gemm_global_tile.h:488
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
Scalar const * pointer
Pointer to memory.
Definition: tile_iterator.h:499
Params params
Parameters.
Definition: gemm_global_tile.h:456
Definition: gemm_global_tile.h:366
static IteratorAdvance::Kind const kAdvance
Specifies in which dimension post-increment accesses advance.
Definition: tile_iterator.h:431
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:458
FragmentElement_ FragmentElement
Fragment element.
Definition: tile_iterator.h:152
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:194
static int const kW
The width of the cube.
Definition: shape.h:70
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:365
CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset)
Definition: gemm_global_tile.h:321
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:188
Parameters.
Definition: tile_iterator.h:497
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
Computes the thread offset in (H, W) based on thread ID.
Definition: gemm_global_tile.h:149
long long stride_d
The stride in the D dimension.
Definition: gemm_global_tile.h:402
CUTLASS_HOST_DEVICE Index stride_advance(void)
Definition: gemm_global_tile.h:323
CUTLASS_HOST_DEVICE void inc_w()
Increment in the W dimension.
Definition: tile_iterator.h:686
static int const kAccessSize
The number of scalars per LDG/STG.
Definition: gemm_global_tile.h:80
Tile_ Tile
Definition: reshape_tile.h:43
Definition: tile_iterator.h:65
Base::Iterations Iterations
Definition: gemm_global_tile.h:140
Index_ Index
The index.
Definition: gemm_global_tile.h:190
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:387
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the valid?
Definition: gemm_global_tile.h:305
static FragmentElementType::Kind const kFragmentElementType
Specifies type of iterator fragment storage (Salar or WmmaMatrix)
Definition: tile_iterator.h:434
Kind
Definition: matrix_traits.h:357
TileTraits_::Scalar Scalar
The scalar.
Definition: gemm_global_tile.h:186
Index inc_advance
Definition: tile_iterator.h:230
Threads_ Threads
Definition: gemm_global_tile.h:54
Params params
The parameters.
Definition: gemm_global_tile.h:215
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, int stride_d_, Index ldm, Index bound, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: gemm_global_tile.h:413
CUTLASS_HOST_DEVICE void initialize_predicates(const Coord< 3 > &bounds, const Coord< 3 > &block_offset)
Definition: gemm_global_tile.h:219
The params.
Definition: gemm_global_tile.h:398
Base::ThreadsDelta ThreadsDelta
Definition: gemm_global_tile.h:144
Coord< 4 > thread_offset
Offset of an individual lane from the start of the tile.
Definition: gemm_global_tile.h:213
Base::FragmentIterator FragmentIterator
Fragment iterator definition.
Definition: tile_iterator.h:473
CUTLASS_HOST_DEVICE void residue(Index k)
That's the residue! Update the predicates.
Definition: gemm_global_tile.h:281
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:410
static MatrixLayout::Kind const kLayout
The layout.
Definition: gemm_global_tile.h:74
CUTLASS_HOST_DEVICE void store_element(typename Base::AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: gemm_global_tile.h:526
Index stride_w
Definition: tile_iterator.h:224