46 typename LoadIterator_,
48 typename StoreIterator_,
50 typename Transformer_>
78 typedef typename LoadIterator::Scalar
Scalar;
80 typedef typename LoadIterator::Pointer
Pointer;
82 typedef typename LoadIterator::Index
Index;
86 typedef typename LoadIterator::Tile
Tile;
116 Index offset_to_residue_,
117 Index offset_to_residue_last_partition_) {
119 int error_code =
load_iterator.initialize(pointer, ldm, ldm);
132 if (blockIdx.z == gridDim.z - 1) {
152 bool const kKstrided =
156 tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
165 Coord<3> const& _threadblock_offset)
199 Index kResidue = k % kTileK;
215 load_iterator.add_pointer_offset(-(this_offset_residue + kBlock) *
ThreadblockTileStorage::TensorRef ThreadblockTileRef
Tensor reference to threadblock tile.
Definition: gemm_global_stream.h:93
LoadIterator::Pointer Pointer
The pointer.
Definition: gemm_global_stream.h:80
LoadIterator load_iterator
The iterator.
Definition: gemm_global_stream.h:242
static CUTLASS_HOST_DEVICE Coord< 3 > project_coordinate(Coord< 3 > const &coord, Index d_offset=0)
Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory.
Definition: gemm_global_stream.h:151
StoreIterator store_iterator
The store iterator.
Definition: gemm_global_stream.h:250
Params params
Parameters.
Definition: gemm_global_stream.h:236
Defines iterators for efficiently loading and storing to global memory.
TensorRef< Scalar, 4 > TensorRef
Defines the tensor reference for this allocation.
Definition: tile_allocation.h:63
static GemmOperand::Kind const kOperand
Indicates the type of GEMM operand.
Definition: gemm_global_stream.h:54
CUTLASS_DEVICE GlobalLoadStream & operator+=(Coord< 3 > const &offset)
Adds a Coord<3> to the underlying global load iterator.
Definition: gemm_global_stream.h:220
CUTLASS_DEVICE void copy()
Load the data from shared memory to the fetch fragment.
Definition: gemm_global_stream.h:178
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Coord< 3 > multiplicand_bounds
Multiplicand bounds.
Definition: gemm_global_stream.h:240
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
static MatrixLayout::Kind const kLayout
Make sure the transformed fragment is the same as the store fragment.
Definition: gemm_global_stream.h:76
StoreIterator::Params store_iterator
Definition: gemm_global_stream.h:104
LoadIterator::LongIndex LongIndex
The index.
Definition: gemm_global_stream.h:84
FetchedFragment fetched_fragment
The fragment to fetch from shared memory.
Definition: gemm_global_stream.h:244
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, LongIndex batch_stride_, Index ldm, Index offset_to_residue_, Index offset_to_residue_last_partition_)
Setup the params.
Definition: gemm_global_stream.h:113
Definition: gemm_global_stream.h:52
Definition: gemm_global_stream.h:144
LoadIterator::Scalar Scalar
The scalar type of the iterator.
Definition: gemm_global_stream.h:78
CUTLASS_DEVICE void residue(Index k, bool skip_clear=false)
Execute the residue code.
Definition: gemm_global_stream.h:190
TransformedFragment transformed_fragment
The fragment to convert the data after it has been fetched from shared memory.
Definition: gemm_global_stream.h:248
Defines a fragment based on a Shape<> template.
Index offset_to_residue
Definition: gemm_global_stream.h:107
TransformedFragment Fragment
Make sure the fragments match.
Definition: gemm_global_stream.h:68
CUTLASS_DEVICE GlobalLoadStream & add_batch_offset(int batch_id)
Adds an offset based on batch stride.
Definition: gemm_global_stream.h:226
LoadIterator_ LoadIterator
The load iterator.
Definition: gemm_global_stream.h:56
Definition: gemm_operand.h:67
Index offset_to_residue_last_partition
Definition: gemm_global_stream.h:110
CUTLASS_DEVICE void commit()
Commit the data.
Definition: gemm_global_stream.h:183
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Class for storing a tile in memory and accessing it through a tensor ref.
Definition: tile_allocation.h:42
Transformer transformer
The transformer.
Definition: gemm_global_stream.h:246
LongIndex batch_stride
Batch stride in global memory.
Definition: gemm_global_stream.h:101
Definition: matrix_traits.h:159
Definition: gemm_operand.h:96
Definition: matrix_traits.h:159
StoreIterator_ StoreIterator
The store iterator to write to shared memory.
Definition: gemm_global_stream.h:60
TileAllocation< typename StoreIterator::Scalar, typename StoreIterator::Tile > ThreadblockTileStorage
Shared memory allocation for the tile.
Definition: gemm_global_stream.h:90
LoadIterator::Params load_iterator
Definition: gemm_global_stream.h:98
The params.
Definition: gemm_global_stream.h:96
Transformer_ Transformer
The transformer.
Definition: gemm_global_stream.h:58
Coord< 3 > threadblock_offset
Threadblock offset.
Definition: gemm_global_stream.h:238
LoadIterator::Index Index
The index.
Definition: gemm_global_stream.h:82
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK)
Move to the residue portion.
Definition: gemm_global_stream.h:198
LoadIterator::Fragment FetchedFragment
The fragment that is copied from shared memory.
Definition: gemm_global_stream.h:63
Transformer::OutputFragment TransformedFragment
The fragment that is obtained after the transformation by the transformer.
Definition: gemm_global_stream.h:65
CUTLASS_DEVICE GlobalLoadStream(Params const &_params, SharedStorage &shared_storage, ThreadblockTileRef const &threadblock_tile_ref, Coord< 3 > const bounds, Coord< 3 > const &_threadblock_offset)
Ctor.
Definition: gemm_global_stream.h:160
Defines conversion operations among Fragments of different base type.
CUTLASS_DEVICE Index get_offset_to_residue()
Definition: gemm_global_stream.h:131
LoadIterator::Tile Tile
The tile.
Definition: gemm_global_stream.h:86
CUTLASS_DEVICE void rollback(void)
Rollback to the beginning of the first tile.
Definition: gemm_global_stream.h:208