47 typename Accumulators_,
49 typename GlobalLoadIteratorC_,
51 typename GlobalTransformerC_,
53 typename GlobalTransformerD_,
55 typename GlobalStoreIteratorD_,
57 typename SharedStoreIteratorD_,
59 typename SharedStoreTransformerD_,
61 typename SharedLoadStreamD_,
69 typename Index_ =
int>
104 static_assert(Iterations::kD == 1 && Iterations::kC == 1,
"Unsupported 3D/4D shapes");
107 typedef typename Functor::Scalar
Scalar;
109 typedef typename GlobalLoadIteratorC::Scalar
ScalarC;
111 typedef typename GlobalStoreIteratorD::Scalar
ScalarD;
137 template <
typename GemmDesc_>
141 int error_code =
functor.initialize(desc);
147 this->stride_h = desc.D.leading_dim() * Delta::kH;
151 error_code =
iterator_c.initialize(desc.C.data(),
152 desc.C.leading_dim(),
153 desc.C.leading_dim(),
154 desc.problem_size[1],
165 error_code =
iterator_d.initialize(desc.D.data(),
166 desc.D.leading_dim(),
167 desc.D.leading_dim(),
168 desc.problem_size[1],
181 typename SharedStoreIteratorD::SharedStorage
store;
183 typename SharedLoadStreamD::SharedStorage
load;
202 template <
typename GemmConfig_,
typename EpilogueFunctor_,
typename Index_ =
int>
205 typedef typename EpilogueFunctor_::Scalar
Scalar;
211 GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH /
212 GemmConfig_::kAccumulatorsPerLdsB,
213 GemmConfig_::kAccumulatorsPerLdsB>
217 GemmConfig_::kAccumulatorsPerLdsB*(
218 GemmConfig_::Warps::kH* GemmConfig_::MultiplyAdd::ThreadsPerWarp::kH - 1),
230 typename Functor::ScalarAccum,
232 typename GemmConfig_::OutputTile,
234 typename GemmConfig_::Warps,
236 typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
238 GemmConfig_::kScalarsPerStsD,
242 128 /
sizeof(
typename GemmConfig_::ScalarD) / GemmConfig_::kScalarsPerStsD / 2 *
243 GemmConfig_::kScalarsPerStsD>
262 typename Functor::ScalarAccum,
264 typename GemmConfig_::OutputTile,
266 typename GemmConfig_::Warps,
268 typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
272 GemmConfig_::kScalarsPerLdsD,
289 typename GemmConfig_::ScalarC
const,
293 GemmConfig_::OutputTile::kW>,
299 GemmConfig_::kScalarsPerLdgC>
310 typename GemmConfig_::ScalarD,
314 GemmConfig_::OutputTile::kW>,
320 GemmConfig_::kScalarsPerStgD>
333 typename GemmConfig_,
335 typename EpilogueFunctor_,
337 typename Index_ = int,
342 typename GemmConfig_::OutputTile,
344 typename GemmConfig_::Accumulators,
346 typename Helper_::GlobalLoadIteratorC,
348 typename Helper_::GlobalTransformerC,
350 typename Helper_::GlobalTransformerD,
352 typename Helper_::GlobalStoreIteratorD,
354 typename Helper_::SharedStoreIteratorD,
356 typename Helper_::SharedStoreTransformerD,
358 typename Helper_::SharedLoadStreamD,
360 typename Helper_::Iterations,
362 typename Helper_::Delta,
Definition: gemm_global_tile.h:120
SharedStoreTransformerD_ SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue_traits.h:88
Iterations_ Iterations
typedef typename GemmConfig::EpilogueIterations Iterations;
Definition: gemm_epilogue_traits.h:92
CUTLASS_DEVICE ScalarD * data()
Definition: gemm_epilogue_traits.h:196
Definition: load_store.h:41
GemmGlobalTileCdTraits< typename GemmConfig_::ScalarC const, Shape< 1, GemmConfig_::OutputTile::kH/ShapeCount< Iterations >::kCount, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, Iterations::kW, GemmConfig_::kScalarsPerLdgC > GlobalLoadTileTraits
The traits class to build the iterator to load data from global memory for C^N.
Definition: gemm_epilogue_traits.h:300
GlobalLoadIteratorC_ GlobalLoadIteratorC
The iterator for C in global memory.
Definition: gemm_epilogue_traits.h:78
Definition: gemm_epilogue_traits.h:203
Functor::Params functor
The functor params.
Definition: gemm_epilogue_traits.h:134
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
SharedLoadStreamD::SharedStorage load
Definition: gemm_epilogue_traits.h:183
long long LongIndex
The long index.
Definition: gemm_epilogue_traits.h:101
Implements the BLAS linear scaling function alpha*AB + beta*C.
The shared memory storage to exchange data.
Definition: gemm_epilogue_traits.h:179
EpilogueFunctor_::Scalar Scalar
The scalar.
Definition: gemm_epilogue_traits.h:205
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
GlobalTransformerC_ GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue_traits.h:80
GemmGlobalIteratorCd< GlobalLoadTileTraits, Index_ > GlobalLoadIteratorC
The iterator to load C.
Definition: gemm_epilogue_traits.h:303
GlobalTransformerD_ GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue_traits.h:82
Definition: tile_iterator.h:65
TileStoreIterator< SharedStoreTileTraits, typename SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorD
The iterator to store D to shared memory.
Definition: gemm_epilogue_traits.h:251
SharedLoadStream< SharedLoadIteratorD > SharedLoadStreamD
The stream to load D.
Definition: gemm_epilogue_traits.h:284
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue_traits.h:111
LongIndex batch_stride_C
Batch stride for C matrix.
Definition: gemm_epilogue_traits.h:121
GemmGlobalIteratorCd< GlobalStoreTileTraits, Index_ > GlobalStoreIteratorD
The iterator to store D.
Definition: gemm_epilogue_traits.h:324
GlobalStoreIteratorD::Params iterator_d
The params for the D global iterator.
Definition: gemm_epilogue_traits.h:124
SharedLoadStreamD::Params shared_load_stream_d
The params for the D shared load stream.
Definition: gemm_epilogue_traits.h:132
Copy< typename SharedStoreIteratorD::Fragment > SharedStoreTransformerD
The shared store transformer for D.
Definition: gemm_epilogue_traits.h:254
Shape< 1, GemmConfig_::MultiplyAdd::AccumulatorsPerThread::kH/GemmConfig_::kAccumulatorsPerLdsB, GemmConfig_::kAccumulatorsPerLdsB > Iterations
The number of iterations in the epilogue.
Definition: gemm_epilogue_traits.h:214
GemmGlobalTileCdTraits< typename GemmConfig_::ScalarD, Shape< 1, GemmConfig_::OutputTile::kH/ShapeCount< Iterations >::kCount, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, Iterations::kW, GemmConfig_::kScalarsPerStgD > GlobalStoreTileTraits
The traits class to build the iterator to store data to global memory for D^N.
Definition: gemm_epilogue_traits.h:321
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Index stride_h
The strides for H and W in the different iterations of the epilogue.
Definition: gemm_epilogue_traits.h:116
Functor_ Functor
The functor in charge of the math.
Definition: gemm_epilogue_traits.h:97
Definition: gemm_shared_stream.h:45
Accumulators_ Accumulators
Definition: gemm_epilogue_traits.h:76
static int const kSkew
The skew.
Definition: gemm_shared_tile.h:284
Defines a type for restructuring a tile.
GlobalLoadIteratorC::Params iterator_c
The params for the C iterator.
Definition: gemm_epilogue_traits.h:118
LongIndex batch_stride_D
Batch stride for C matrix.
Definition: gemm_epilogue_traits.h:127
SharedStoreIteratorD::SharedStorage store
Definition: gemm_epilogue_traits.h:181
Index stride_w
Definition: gemm_epilogue_traits.h:116
SharedLoadStreamD_ SharedLoadStreamD
The stream to store D in shared memory.
Definition: gemm_epilogue_traits.h:90
OutputTile_ OutputTile
The output tile.
Definition: gemm_epilogue_traits.h:73
Definition: gemm_shared_tile.h:339
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: gemm_epilogue_traits.h:340
EpilogueFunctor_ Functor
The functor to do the math in the epilogue.
Definition: gemm_epilogue_traits.h:222
TileLoadIterator< SharedLoadTileTraits, typename SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorD
The iterator to load D from shared memory.
Definition: gemm_epilogue_traits.h:282
GemmConfig_::OutputTile OutputTile
The output tile.
Definition: gemm_epilogue_traits.h:207
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Setup the params.
Definition: gemm_epilogue_traits.h:138
StreamSharedStorage shared_stream
Definition: gemm_epilogue_traits.h:189
Index_ Index
The index.
Definition: gemm_epilogue_traits.h:99
Definition: gemm_epilogue_traits.h:70
GemmSharedLoadTileDTraits< typename Functor::ScalarAccum, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, GemmConfig_::OutputTile::kH/ShapeCount< Iterations >::kCount, GemmConfig_::kScalarsPerLdsD, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for D.
Definition: gemm_epilogue_traits.h:275
Definition: gemm_global_tile.h:366
static int const kW
The width of the cube.
Definition: shape.h:70
Delta_ Delta
The iterations strides.
Definition: gemm_epilogue_traits.h:94
GlobalStoreIteratorD_ GlobalStoreIteratorD
The iterator for D in global memory.
Definition: gemm_epilogue_traits.h:84
Copy< typename GlobalStoreIteratorD::Fragment > GlobalTransformerD
The transformer for D.
Definition: gemm_epilogue_traits.h:326
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue_traits.h:109
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
The params.
Definition: gemm_epilogue_traits.h:114
The shared memory to swizzle the data in the epilogue.
Definition: gemm_epilogue_traits.h:187
Copy< typename GlobalLoadIteratorC::Fragment > GlobalTransformerC
The transformer for C.
Definition: gemm_epilogue_traits.h:305
SharedStoreIteratorD::Params shared_store_iterator_d
The params for the D shared store iterator.
Definition: gemm_epilogue_traits.h:130
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:272
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
Defines conversion operations among Fragments of different base type.
Functor::Scalar Scalar
We do not support 3D or 4D shapes.
Definition: gemm_epilogue_traits.h:104
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:341
SharedStoreIteratorD_ SharedStoreIteratorD
The iterator to store D in shared memory.
Definition: gemm_epilogue_traits.h:86
GemmSharedStoreTileDTraits< typename Functor::ScalarAccum, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, GemmConfig_::kScalarsPerStsD, 128/sizeof(typename GemmConfig_::ScalarD)/GemmConfig_::kScalarsPerStsD/2 *GemmConfig_::kScalarsPerStsD > SharedStoreTileTraits
The traits class to build the iterator to store to shared memory for D.
Definition: gemm_epilogue_traits.h:244
Definition: gemm_shared_tile.h:270
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841