49 template <
typename StreamA_,
typename StreamB_,
bool kRes
idueInProlog_>
75 Params(
typename StreamA::Params
const &_params_A,
typename StreamB::Params
const &_params_B)
80 typedef typename StreamA::Index
Index;
84 typename StreamB::ThreadblockTileStorage>
118 threadblock_tile_ref.first,
123 threadblock_tile_ref.second,
136 stream_a.add_batch_offset(batch_id);
137 stream_b.add_batch_offset(batch_id);
161 if (kResidueInProlog_) {
162 stream_a.move_to_residue(k, kTileK);
163 stream_b.move_to_residue(k, kTileK);
164 }
else if (k < kTileK) {
171 if (kResidueInProlog_ && kRollback) {
179 template <
typename StreamA_,
typename StreamB_>
202 typename StreamB::TensorRef >
225 CUTLASS_DEVICE
void copy(
int step) {
238 typename StreamA::TransformedFragment
const &
fragment_a(
int step)
const {
244 typename StreamB::TransformedFragment
const &
fragment_b(
int step)
const {
CUTLASS_HOST_DEVICE Params(typename StreamA::Params const &_params_A, typename StreamB::Params const &_params_B)
Constructs a global load stream pair Params object.
Definition: gemm_stream_pair.h:75
CUTLASS_DEVICE GlobalLoadStreamPair & operator+=(Coord< 3 > const offset)
Definition: gemm_stream_pair.h:128
StreamA_ StreamA
Stream for A multiplicand.
Definition: gemm_stream_pair.h:186
CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK)
Move to residue.
Definition: gemm_stream_pair.h:160
StreamB::SharedStorage stream_b
Definition: gemm_stream_pair.h:93
CUTLASS_DEVICE GlobalLoadStreamPair & add_batch_offset(int batch_id)
Definition: gemm_stream_pair.h:135
Definition: zip_tensor_ref.h:38
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
StreamA::Params stream_a
Definition: gemm_stream_pair.h:194
CUTLASS_DEVICE GlobalLoadStreamPair(Params const ¶ms, SharedStorage &shared_storage, ThreadblockTileRef const &threadblock_tile_ref, Coord< 3 > const bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Ctor.
Definition: gemm_stream_pair.h:111
Defies functors for mapping blockIdx to partitions of the GEMM computation.
Defines a structure containing shared storage for each pair.
Definition: gemm_stream_pair.h:91
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
ZipTileAllocation< typename StreamA::ThreadblockTileStorage, typename StreamB::ThreadblockTileStorage > ThreadblockTileStorage
Shared memory allocation for threadblock-scoped GEMM tile.
Definition: gemm_stream_pair.h:85
CUTLASS_DEVICE void residue(Index k, bool skip_clear=false)
Execute the residue code.
Definition: gemm_stream_pair.h:154
ThreadblockTileStorage::TensorRef ThreadblockTileRef
ZipTensorRef to threadblock tiles.
Definition: gemm_stream_pair.h:88
CUTLASS_DEVICE SharedStreamPair(Params const ¶ms, ThreadblockTileRef const &threadblock_tile_ref)
Construct with the composable structure.
Definition: gemm_stream_pair.h:220
StreamB_ StreamB
Stream for B multiplicand.
Definition: gemm_stream_pair.h:59
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:180
CUTLASS_DEVICE StreamB::TransformedFragment const & fragment_b(int step) const
The fragment B.
Definition: gemm_stream_pair.h:244
StreamB::Params stream_b
Parameters object for StreamB.
Definition: gemm_stream_pair.h:67
CUTLASS_DEVICE void rollback(bool kRollback)
Rollback to beginning of first tile.
Definition: gemm_stream_pair.h:170
StreamA stream_a
The stream for A.
Definition: gemm_stream_pair.h:210
CUTLASS_DEVICE StreamA::TransformedFragment const & fragment_a(int step) const
The fragment A.
Definition: gemm_stream_pair.h:238
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:50
CUTLASS_DEVICE void copy(int step)
Trigger the copies from shared memory to registers.
Definition: gemm_stream_pair.h:225
Defines a fragment based on a Shape<> template.
Parameters object.
Definition: gemm_stream_pair.h:62
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Defines abstractions for efficiently clearing accumulator tiles.
StreamA_ StreamA
Stream for A multiplicand.
Definition: gemm_stream_pair.h:56
StreamA stream_a
Stream for A multiplicand.
Definition: gemm_stream_pair.h:101
ZipTensorRef< typename StreamA::TensorRef, typename StreamB::TensorRef > ThreadblockTileRef
Shared memory allocation for threadblock-scoped GEMM tile.
Definition: gemm_stream_pair.h:203
StreamA::Index Index
Assumes the A stream defines the index type.
Definition: gemm_stream_pair.h:80
Manages a pair of tile allocations as if they are one allocation.
Definition: tile_allocation.h:125
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Defines properties of GEMM computation that impose some constraints on caller.
CUTLASS_HOST_DEVICE Params()
Default constructor.
Definition: gemm_stream_pair.h:71
StreamB_ StreamB
Stream for B multiplicand.
Definition: gemm_stream_pair.h:189
StreamA::SharedStorage stream_a
Definition: gemm_stream_pair.h:92
CUTLASS_DEVICE void commit(int step)
Commit the data.
Definition: gemm_stream_pair.h:231
CUTLASS_DEVICE void commit()
Commit the data.
Definition: gemm_stream_pair.h:148
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
StreamB stream_b
The stream for B.
Definition: gemm_stream_pair.h:213
CUTLASS_DEVICE void inc_stage()
Increment the stage.
Definition: gemm_stream_pair.h:249
Parameters object passed to load iterators.
Definition: gemm_stream_pair.h:192
StreamB::Params stream_b
Definition: gemm_stream_pair.h:197
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_DEVICE void copy()
Trigger the copies from shared memory to registers.
Definition: gemm_stream_pair.h:142
StreamA::Params stream_a
Parameters object for StreamA.
Definition: gemm_stream_pair.h:64
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
StreamB stream_b
Stream for B multiplicand.
Definition: gemm_stream_pair.h:104
Defines conversion operations among Fragments of different base type.