32 #ifdef CUTLASS_USE_WMMA_API 50 struct WmmaGemmSharedLoadTileATraits {
56 typedef Scalar_ Scalar;
58 typedef Scalar
const* Pointer;
60 static int const kAccessSize = 1;
66 static int const kWarpStride = kWarpStride_;
68 typedef Iterations_ Iterations;
72 typedef Delta_ ImmediateOffsetStrides;
74 typedef WmmaShape_ WmmaShape;
80 Coord<4> operator()()
const {
82 int const warp = threadIdx.x / kWarpSize;
84 int const offset = warp % Warps::kW * kWarpStride;
100 struct WmmaGemmSharedLoadTileBTraits {
106 typedef Scalar_ Scalar;
108 typedef Scalar
const* Pointer;
110 static int const kAccessSize = 1;
114 typedef Warps_ Warps;
116 static int const kWarpStride = kWarpStride_;
118 typedef Iterations_ Iterations;
120 typedef Delta_ Delta;
122 typedef Delta_ ImmediateOffsetStrides;
124 typedef WmmaShape_ WmmaShape;
128 struct ThreadOffset {
130 Coord<4> operator()()
const {
132 int const warp = threadIdx.x / kWarpSize;
134 int const offset = warp / Warps::kW * kWarpStride;
144 typename OutputTile_,
148 struct WmmaGemmSharedStoreTileDTraits {
154 typedef Scalar_ Scalar;
156 static int const kAccessSize = 1;
158 typedef Scalar* Pointer;
160 typedef Warps_ Warps;
162 typedef WmmaShape_ WmmaShape;
164 static int const kSkew = kSkew_;
168 typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
170 typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
172 typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
174 typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
178 struct ThreadOffset {
180 Coord<4> operator()()
const {
182 int const warp = threadIdx.x / kWarpSize;
184 int const h = warp / Warps::kW * WmmaShape::kH;
186 int const w = warp % Warps::kW * WmmaShape::kW;
188 int const offset = h * Tile::kW + w;
196 template <
typename Scalar_,
typename Tile_,
typename Threads_,
int kScalarsPerLds_,
int kLdsPerAccess_ = 1>
197 struct WmmaGemmSharedLoadTileDTraits {
199 typedef Scalar_ Scalar;
201 typedef Scalar
const* Pointer;
203 static int const kAccessSize = kScalarsPerLds_;
209 typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
214 typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
216 typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_, kScalarsPerLds_>
217 ImmediateOffsetStrides;
219 typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
224 struct ThreadOffset {
226 Coord<4> operator()()
const {
239 #endif // defined CUTLASS_USE_WMMA_API static CUTLASS_DEVICE int get()
Definition: shape.h:214
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Tile_ Tile
Definition: reshape_tile.h:59
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Kind
Definition: load_store.h:39
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Definition: matrix_traits.h:357
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
Threads_ Threads
Definition: gemm_global_tile.h:54