Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_shared_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/wmma_matrix.h"
32 #ifdef CUTLASS_USE_WMMA_API
33 
35 #include "cutlass/reshape_tile.h"
36 
37 namespace cutlass {
38 namespace gemm {
39 
41 
42 template <MatrixLayout::Kind kLayout_,
43  typename Scalar_,
44  typename Tile_,
45  typename Warps_,
46  int kWarpStride_,
47  typename Iterations_,
48  typename Delta_,
49  typename WmmaShape_>
50 struct WmmaGemmSharedLoadTileATraits {
52  static GemmOperand::Kind const kOperand = GemmOperand::kA;
54  static MatrixLayout::Kind const kLayout = kLayout_;
56  typedef Scalar_ Scalar;
58  typedef Scalar const* Pointer;
60  static int const kAccessSize = 1;
62  typedef Tile_ Tile;
64  typedef Warps_ Warps;
66  static int const kWarpStride = kWarpStride_;
68  typedef Iterations_ Iterations;
70  typedef Delta_ Delta;
72  typedef Delta_ ImmediateOffsetStrides;
74  typedef WmmaShape_ WmmaShape;
76  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
78  struct ThreadOffset {
80  Coord<4> operator()() const {
81  // The warp id.
82  int const warp = threadIdx.x / kWarpSize;
83  // The offset.
84  int const offset = warp % Warps::kW * kWarpStride;
85  return make_Coord(0, 0, offset, 0);
86  }
87  };
88 };
89 
91 
92 template <MatrixLayout::Kind kLayout_,
93  typename Scalar_,
94  typename Tile_,
95  typename Warps_,
96  int kWarpStride_,
97  typename Iterations_,
98  typename Delta_,
99  typename WmmaShape_>
100 struct WmmaGemmSharedLoadTileBTraits {
102  static GemmOperand::Kind const kOperand = GemmOperand::kB;
104  static MatrixLayout::Kind const kLayout = kLayout_;
106  typedef Scalar_ Scalar;
108  typedef Scalar const* Pointer;
110  static int const kAccessSize = 1;
112  typedef Tile_ Tile;
114  typedef Warps_ Warps;
116  static int const kWarpStride = kWarpStride_;
118  typedef Iterations_ Iterations;
120  typedef Delta_ Delta;
122  typedef Delta_ ImmediateOffsetStrides;
124  typedef WmmaShape_ WmmaShape;
126  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
128  struct ThreadOffset {
130  Coord<4> operator()() const {
131  // The warp id.
132  int const warp = threadIdx.x / kWarpSize;
133  // The offset.
134  int const offset = warp / Warps::kW * kWarpStride;
135  return make_Coord(0, 0, offset, 0);
136  }
137  };
138 };
139 
141 
142 template <MatrixLayout::Kind kLayout_,
143  typename Scalar_,
144  typename OutputTile_,
145  typename Warps_,
146  typename WmmaShape_,
147  int kSkew_ = 0>
148 struct WmmaGemmSharedStoreTileDTraits {
150  static GemmOperand::Kind const kOperand = GemmOperand::kC;
152  static MatrixLayout::Kind const kLayout = kLayout_;
154  typedef Scalar_ Scalar;
155  // The access size
156  static int const kAccessSize = 1;
158  typedef Scalar* Pointer;
160  typedef Warps_ Warps;
162  typedef WmmaShape_ WmmaShape;
164  static int const kSkew = kSkew_;
166  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
168  typedef Shape<1, Warps_::kH * WmmaShape_::kH, OutputTile_::kW + kSkew_> Tile;
170  typedef Shape<1, 1, OutputTile_::kW / Warps::kW / WmmaShape_::kW> Iterations;
172  typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> Delta;
174  typedef Shape<0, 0, Warps::kW * WmmaShape_::kW, 0> ImmediateOffsetStrides;
175 
176 
178  struct ThreadOffset {
180  Coord<4> operator()() const {
181  // The warp id.
182  int const warp = threadIdx.x / kWarpSize;
183  // The starting column.
184  int const h = warp / Warps::kW * WmmaShape::kH;
185  // The w.
186  int const w = warp % Warps::kW * WmmaShape::kW;
187  // The offset.
188  int const offset = h * Tile::kW + w;
189  return make_Coord(0, 0, offset, 0);
190  }
191  };
192 };
193 
195 
196 template <typename Scalar_, typename Tile_, typename Threads_, int kScalarsPerLds_, int kLdsPerAccess_ = 1>
197 struct WmmaGemmSharedLoadTileDTraits {
199  typedef Scalar_ Scalar;
201  typedef Scalar const* Pointer;
203  static int const kAccessSize = kScalarsPerLds_;
207  typedef typename ReshapeThreads<Tile, Threads_>::Threads Threads;
209  typedef Shape<1, Tile::kW * Tile::kC, Tile::kC> ThreadsStrides;
211  static MemorySpace::Kind const kMemorySpace = MemorySpace::kShared;
212 
214  typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_> Delta;
216  typedef Shape<0, Threads::kH * ShapeCount<Tile>::kWc, Threads::kW * kScalarsPerLds_, kScalarsPerLds_>
217  ImmediateOffsetStrides;
219  typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kScalarsPerLds_>
220  Iterations;
221 
222 
224  struct ThreadOffset {
226  Coord<4> operator()() const {
227  // The offset.
229  return make_Coord(0, 0, offset, 0);
230  }
231  };
232 };
233 
235 
236 } // namespace gemm
237 } // namespace cutlass
238 
239 #endif // defined CUTLASS_USE_WMMA_API
static CUTLASS_DEVICE int get()
Definition: shape.h:214
Abstractions for loading and storing matrices using the CUDA WMMA API.
Definition: load_store.h:41
Definition: convert.h:33
Tile_ Tile
Definition: reshape_tile.h:59
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: matrix_traits.h:357
Kind
Definition: load_store.h:39
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Definition: matrix_traits.h:357
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
Threads_ Threads
Definition: gemm_global_tile.h:54