43 typename ThreadBlockTile_,
int Threads,
int ScalarsPerInst,
typename Index_ = int,
44 typename DestinationSkew_ = Shape<0, 0, 0, 0> >
68 typedef typename ShapeDiv<DestinationSkew_,
Shape<ScalarsPerInst, ScalarsPerInst, ScalarsPerInst,
81 typedef TileTraitsDefault<VectorizedTile, kThreads>
TileTraits;
97 template <
typename Traits_>
100 typename Traits_::TileTraits,
101 TileLoadIterator<typename Traits_::TileTraits, typename Traits_::Scalar,
102 Traits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
103 : IteratorAdvance::kW,
104 MemorySpace::kGlobal, typename Traits_::Index>,
105 TileStoreIterator<typename Traits_::TileTraits, typename Traits_::Scalar,
106 Traits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
107 : IteratorAdvance::kW,
108 MemorySpace::kShared, typename Traits_::Index, typename Traits_::Scalar,
109 IteratorFragment::kScalar, typename Traits_::DestinationSkew> > {
114 typedef typename Traits::FragmentStream
Base;
145 typedef typename Traits::Index
Index;
152 typedef typename Traits::Scalar
const *
Pointer;
161 template <
typename GemmDesc_>
163 typename Traits::Scalar
const *pointer,
Index ldm) {
164 return this->load_params.initialize(pointer, ldm * Traits::MultiplicandTraits::Shape::kH, ldm,
165 Traits::kAccessSize);
212 Base::initialize_predicates(
nv_std::conditional< kKstrided, Shape< 1, ThreadBlockTile::kD, GetExtent< Usage, ThreadBlockTile >::kExtent >, Shape< 1, GetExtent< Usage, ThreadBlockTile >::kExtent, ThreadBlockTile::kD > >::type Shape
Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand.
Definition: gemm_operand.h:86
static bool const kKstrided
Definition: gemm_operand.h:81
Scalar_ Scalar
Scalar data type.
Definition: gemm_fragment_stream.h:50
GemmMultiplicandTraits< ThreadBlockTile, kUsage, kLayout > MultiplicandTraits
Traits of multiplicand.
Definition: gemm_fragment_stream.h:72
static int const kAccessSize
Scalars per instruction.
Definition: gemm_fragment_stream.h:62
Definition: load_store.h:42
Base::StoreIterator StoreIterator
Defines the store iterator.
Definition: gemm_fragment_stream.h:127
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc, typename Traits::Scalar const *pointer, Index ldm)
Initializes parameters.
Definition: gemm_fragment_stream.h:162
Defines structural properties of complete GEMM computation.
Traits::FragmentStream Base
Base class.
Definition: gemm_fragment_stream.h:114
An abstraction for implementing a stream loading a tile and storing a tile using a pair of tile itera...
Definition: load_store.h:43
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
ReshapeTile< ScalarTile, kAccessSize >::Tile VectorizedTile
Reshape for vectorized access.
Definition: gemm_fragment_stream.h:78
Traits::Index Index
Index type.
Definition: gemm_fragment_stream.h:145
FragmentStream< TileTraits, TileLoadIterator< TileTraits, Scalar, MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index >, TileStoreIterator< TileTraits, Scalar, MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kShared, Index, Scalar, IteratorFragment::kScalar, DestinationSkew > > FragmentStream
Define the tile stream.
Definition: gemm_fragment_stream.h:93
Traits_ Traits
Traits.
Definition: gemm_fragment_stream.h:111
cutlass::FragmentStream< Traits_::TileTraits, TileLoadIterator< Traits_::TileTraits, Traits_::Scalar, Traits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Traits_::Index >, TileStoreIterator< Traits_::TileTraits, Traits_::Scalar, Traits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kShared, Traits_::Index, Traits_::Scalar, IteratorFragment::kScalar, Traits_::DestinationSkew > >::fetch Fragment fetch
Fragment fetched by load iterator.
Definition: fragment_stream.h:135
Definition: tile_iterator.h:97
CUTLASS_DEVICE GemmFragmentStream()
Definition: gemm_fragment_stream.h:181
TileTraitsDefault< VectorizedTile, kThreads > TileTraits
Define structure of stripmined tile.
Definition: gemm_fragment_stream.h:81
MultiplicandTraits::Shape ScalarTile
Scalar tile shape.
Definition: gemm_fragment_stream.h:75
static CUTLASS_DEVICE void shared_store_fence()
The memory fence for shared stores.
Definition: gemm_fragment_stream.h:174
Defines a FragmentStream by mapping GEMM dimensions onto contiguous and strided dimensions.
Definition: gemm_fragment_stream.h:45
CUTLASS_DEVICE GemmFragmentStream(Params const ¶ms, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Constructor - bounds and block offset are aligned to GEMM coordinates (K, N, M)
Definition: gemm_fragment_stream.h:185
Base::Fragment Fragment
Loaded fragment type.
Definition: gemm_fragment_stream.h:133
GEMM Fragment Stream.
Definition: gemm_fragment_stream.h:98
Traits::Scalar const * Pointer
The pointer.
Definition: gemm_fragment_stream.h:152
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:308
CUTLASS_DEVICE void commit()
Commits the fragment.
Definition: gemm_fragment_stream.h:199
Base::Storage Storage
Destination storage.
Definition: gemm_fragment_stream.h:139
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Manages a pair of iterators to stream data from global memory to shared.
Definition: fragment_stream.h:50
Definition: gemm_operand.h:66
static MatrixLayout::Kind const kLayout
Layout of the operand.
Definition: gemm_fragment_stream.h:53
Traits::Scalar Scalar
Scalar type.
Definition: gemm_fragment_stream.h:121
Parameters object.
Definition: gemm_fragment_stream.h:155
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:37
Index_ Index
Index type.
Definition: gemm_fragment_stream.h:65
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:63
Base::Convert Convert
Converts between tiles.
Definition: gemm_fragment_stream.h:130
Definition: gemm_operand.h:94
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
static int const kThreads
Number of threads.
Definition: gemm_fragment_stream.h:59
Kind
Definition: matrix_traits.h:36
Base::StoreFragment StoreFragment
Stored fragment type.
Definition: gemm_fragment_stream.h:136
Base::LoadIterator LoadIterator
Defines the load iterator.
Definition: gemm_fragment_stream.h:124
static GemmOperand::Kind const kUsage
Indicates identity of multiplicand.
Definition: gemm_fragment_stream.h:47
Tile_ Tile
Definition: tile.h:43
Definition: tile_iterator.h:97
ShapeDiv< DestinationSkew_, Shape< ScalarsPerInst, ScalarsPerInst, ScalarsPerInst, 1 > >::Shape DestinationSkew
Skew added to shared memory tile.
Definition: gemm_fragment_stream.h:69
Kind
Definition: matrix_traits.h:43
CUTLASS_DEVICE void residue(Coord< 3 > const &bounds, Coord< 3 > const &block_offset)
TODO - Recomputes predicates and clears fetch registers.
Definition: gemm_fragment_stream.h:203
ThreadBlockTile_ ThreadBlockTile
Shape of the thread block tile (K, N, M)
Definition: gemm_fragment_stream.h:56
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_DEVICE void initialize_predicates(Coord< 3 > const &bounds, Coord< 3 > const &block_offset)
Recomputes predicates aligned to GEMM coordinates (K, N, M)
Definition: gemm_fragment_stream.h:211
Definition: tile_iterator.h:102
CUTLASS_DEVICE void load()
Loads the fragment.
Definition: gemm_fragment_stream.h:195
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:556