37 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kAccessSize_>
39 MatrixLayout::kColumnMajor,
60 int thread_offset_h = threadIdx.x / Base::Threads::kW;
63 return make_Coord(0, thread_offset_h, thread_offset_w, 0);
70 template <
typename TileTraits_,
typename Index_ =
int>
84 typedef typename TileTraits_::Scalar
Scalar;
86 typedef typename TileTraits_::Pointer
Pointer;
88 typedef typename TileTraits_::Threads
Threads;
100 long long batch_stride,
103 Index epilogue_stride_w,
104 Index epilogue_delta_w) {
112 this->
inc_h = ldm * TileTraits_::Threads::kH;
127 int const pointer_offset = 0,
128 int const pred_offset = 0,
131 :
Base(
params, bounds, block, pointer_offset, pred_offset, thread_offset_func) {}
149 Base::Tile::kW>::store(value,
Base::params.pointer, offset);
153 template <
typename Fragment>
158 template <
typename Fragment>
TileTraits_::Threads Threads
The threads.
Definition: wmma_gemm_global_tile.h:88
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:406
Defines iterators for efficiently loading and storing to global memory.
Index_ Index
The index.
Definition: gemm_global_tile.h:391
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long batch_stride, Index ldm, Index n, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: wmma_gemm_global_tile.h:99
Definition: gemm_global_tile.h:70
GemmGlobalIteratorCd< Traits, Index_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:77
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: wmma_gemm_global_tile.h:134
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Base::Params BaseParams
Base parameters.
Definition: wmma_gemm_global_tile.h:94
Index predicate_inc_h
Definition: gemm_global_tile.h:408
Index_ Index
The index.
Definition: wmma_gemm_global_tile.h:90
TileTraits_::Scalar Scalar
The scalar.
Definition: wmma_gemm_global_tile.h:84
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads and increments iterator.
Definition: gemm_global_tile.h:549
CUTLASS_DEVICE void load_post_increment(Fragment &fragment)
Definition: wmma_gemm_global_tile.h:154
Definition: matrix_traits.h:357
Definition: load_store.h:178
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:408
The params.
Definition: wmma_gemm_global_tile.h:97
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: wmma_gemm_global_tile.h:59
Definition: wmma_gemm_global_tile.h:71
TileTraits_::Pointer Pointer
The pointer.
Definition: wmma_gemm_global_tile.h:86
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:580
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:404
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:512
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:54
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:395
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: matrix_traits.h:159
CUTLASS_DEVICE void store_post_increment(Fragment &fragment)
Definition: wmma_gemm_global_tile.h:159
Index inc_h
Definition: gemm_global_tile.h:406
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
static MatrixLayout::Kind const kLayout
The layout.
Definition: wmma_gemm_global_tile.h:81
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:400
Definition: wmma_gemm_global_tile.h:38
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
TileTraits_::ThreadOffset ThreadOffset
The thread offset functor.
Definition: wmma_gemm_global_tile.h:92
Params params
Parameters.
Definition: gemm_global_tile.h:456
Definition: gemm_global_tile.h:366
TileTraits_::Scalar FragmentElement
Fragment element.
Definition: tile_iterator.h:152
static int const kW
The width of the cube.
Definition: shape.h:70
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:188
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:51
long long stride_d
The stride in the D dimension.
Definition: gemm_global_tile.h:402
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const ¶ms, const Coord< 3 > &bounds, const Coord< 3 > &block, int const pointer_offset=0, int const pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: wmma_gemm_global_tile.h:124
WmmaGemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: wmma_gemm_global_tile.h:73
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:387
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > ImmediateOffsetStrides
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:79
Computes the thread offset in (H, W) based on thread ID.
Definition: wmma_gemm_global_tile.h:57
The params.
Definition: gemm_global_tile.h:398
TileTraits_ Traits
The traits.
Definition: wmma_gemm_global_tile.h:75
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:410
CUTLASS_DEVICE void store_element(typename Base::AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: wmma_gemm_global_tile.h:140