31 #ifdef CUTLASS_USE_WMMA_API 48 template <
typename GemmConfig_,
typename Accumulator_,
typename EpilogueFunctor_,
typename Index_ =
int>
49 struct WmmaGemmEpilogueTraitsHelper {
51 typedef typename EpilogueFunctor_::Scalar Scalar;
53 typedef typename GemmConfig_::OutputTile OutputTile;
56 static int const kWmmasPerH =
57 GemmConfig_::AccumulatorsPerWarp::kH / GemmConfig_::InstructionShape::kH;
59 typedef Shape<1, 1, kWmmasPerH> Iterations;
61 typedef Shape<0, 0, 0> Delta;
63 typedef EpilogueFunctor_ Functor;
66 typedef WmmaGemmSharedStoreTileDTraits<
70 typename Functor::Scalar,
72 typename GemmConfig_::OutputTile,
74 typename GemmConfig_::Warps,
76 typename GemmConfig_::InstructionShape>
77 SharedStoreTileTraits;
82 typename GemmConfig_::InstructionShape>
86 typedef TileStoreIterator<SharedStoreTileTraits,
87 typename SharedStoreTileTraits::Scalar,
96 typedef Copy<typename SharedStoreIteratorD::Fragment> SharedStoreTransformerD;
99 typedef WmmaGemmSharedLoadTileDTraits<
101 typename Functor::Scalar,
103 typename SharedStoreIteratorD::Tile,
105 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
107 GemmConfig_::kScalarsPerLdsD,
109 sizeof(Accumulator_) /
sizeof(
typename GemmConfig_::ScalarD)
111 SharedLoadTileTraits;
114 typedef TileLoadIterator<SharedLoadTileTraits,
115 typename SharedLoadTileTraits::Scalar,
121 typedef SharedLoadStream<SharedLoadIteratorD> SharedLoadStreamD;
124 typedef WmmaGemmGlobalIteratorCdTraits<
126 typename GemmConfig_::ScalarC
const,
130 GemmConfig_::OutputTile::kW>,
132 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
134 GemmConfig_::kScalarsPerLdgC>
135 GlobalLoadTileTraits;
138 typedef WmmaGemmGlobalIteratorCd<GlobalLoadTileTraits, Index_> GlobalLoadIteratorC;
140 typedef Copy<typename GlobalLoadIteratorC::Fragment> GlobalTransformerC;
143 typedef WmmaGemmGlobalIteratorCdTraits<
145 typename GemmConfig_::ScalarD,
149 GemmConfig_::OutputTile::kW>,
151 Shape<1, ShapeCount<typename GemmConfig_::Warps>::kCount, GemmConfig_::kWarpSize>,
153 GemmConfig_::kScalarsPerStgD>
154 GlobalStoreTileTraits;
157 typedef WmmaGemmGlobalIteratorCd<GlobalStoreTileTraits, Index_> GlobalStoreIteratorD;
159 typedef Copy<typename GlobalStoreIteratorD::Fragment> GlobalTransformerD;
167 #endif // defined CUTLASS_USE_WMMA_API Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Implements the BLAS linear scaling function alpha*AB + beta*C.
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Definition: load_store.h:48
Definition: tile_iterator.h:65
Definition: matrix_traits.h:357
Defines a type for restructuring a tile.
Definition: matrix_traits.h:159
Defines tile iterator traits for loading thread block-level tile from global memory.
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Defines iterator traits for efficiently loading and storing fragment to and from shared memory...