Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_fragment_stream.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
31 #include <cutlass/matrix_traits.h>
32 
35 
36 namespace cutlass {
37 namespace gemm {
38 
40 
42 template <GemmOperand::Kind Usage, typename Scalar_, MatrixLayout::Kind Layout,
43  typename ThreadBlockTile_, int Threads, int ScalarsPerInst, typename Index_ = int,
44  typename DestinationSkew_ = Shape<0, 0, 0, 0> >
47  static GemmOperand::Kind const kUsage = Usage;
48 
50  typedef Scalar_ Scalar;
51 
53  static MatrixLayout::Kind const kLayout = Layout;
54 
56  typedef ThreadBlockTile_ ThreadBlockTile;
57 
59  static int const kThreads = Threads;
60 
62  static int const kAccessSize = ScalarsPerInst;
63 
65  typedef Index_ Index;
66 
68  typedef typename ShapeDiv<DestinationSkew_, Shape<ScalarsPerInst, ScalarsPerInst, ScalarsPerInst,
70 
73 
76 
79 
81  typedef TileTraitsDefault<VectorizedTile, kThreads> TileTraits;
82 
84  typedef FragmentStream<
85  TileTraits,
94 };
95 
97 template <typename Traits_>
99  : public FragmentStream<
100  typename Traits_::TileTraits,
101  TileLoadIterator<typename Traits_::TileTraits, typename Traits_::Scalar,
102  Traits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
103  : IteratorAdvance::kW,
104  MemorySpace::kGlobal, typename Traits_::Index>,
105  TileStoreIterator<typename Traits_::TileTraits, typename Traits_::Scalar,
106  Traits_::MultiplicandTraits::kKstrided ? IteratorAdvance::kH
107  : IteratorAdvance::kW,
108  MemorySpace::kShared, typename Traits_::Index, typename Traits_::Scalar,
109  IteratorFragment::kScalar, typename Traits_::DestinationSkew> > {
111  typedef Traits_ Traits;
112 
114  typedef typename Traits::FragmentStream Base;
115 
116  //
117  // FragmentStream concept
118  //
119 
121  typedef typename Traits::Scalar Scalar;
122 
124  typedef typename Base::LoadIterator LoadIterator;
125 
127  typedef typename Base::StoreIterator StoreIterator;
128 
130  typedef typename Base::Convert Convert;
131 
133  typedef typename Base::Fragment Fragment;
134 
136  typedef typename Base::StoreFragment StoreFragment;
137 
139  typedef typename Base::Storage Storage;
140 
141  // Parameters type
142  // typedef typename Base::Params BaseParams;
143 
145  typedef typename Traits::Index Index;
146 
147  //
148  // Nested class definitions
149  //
150 
152  typedef typename Traits::Scalar const *Pointer;
153 
155  struct Params : public Base::Params {
156  //
157  // Methods
158  //
159 
161  template <typename GemmDesc_>
162  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc,
163  typename Traits::Scalar const *pointer, Index ldm) {
164  return this->load_params.initialize(pointer, ldm * Traits::MultiplicandTraits::Shape::kH, ldm,
165  Traits::kAccessSize);
166  }
167  };
168 
169  //
170  // Static member functions
171  //
172 
174  static CUTLASS_DEVICE void shared_store_fence() { Base::shared_store_fence(); }
175 
176  //
177  // Methods
178  //
179 
180  CUTLASS_DEVICE
182 
184  CUTLASS_DEVICE
185  GemmFragmentStream(Params const &params, Coord<3> const &bounds,
186  Coord<3> const &block_offset = make_Coord(0, 0, 0))
187  : Base(params, ProjectOperand<Traits::kUsage, Traits::MultiplicandTraits::kKstrided>::project(
188  bounds) +
189  make_Coord(1, 0, 0),
190  ProjectOperand<Traits::kUsage, Traits::MultiplicandTraits::kKstrided>::project(
191  block_offset)) {}
192 
194  CUTLASS_DEVICE
195  void load() { Base::load(); }
196 
198  CUTLASS_DEVICE
199  void commit() { Base::commit(); }
200 
202  CUTLASS_DEVICE
203  void residue(Coord<3> const &bounds, Coord<3> const &block_offset) {
204  this->initialize_predicates(bounds, block_offset);
205 
206  this->fetch.clear();
207  }
208 
210  CUTLASS_DEVICE
211  void initialize_predicates(Coord<3> const &bounds, Coord<3> const &block_offset) {
212  Base::initialize_predicates(
214  make_Coord(1, 0, 0),
216  block_offset));
217  }
218 };
219 
221 }
222 }
nv_std::conditional< kKstrided, Shape< 1, ThreadBlockTile::kD, GetExtent< Usage, ThreadBlockTile >::kExtent >, Shape< 1, GetExtent< Usage, ThreadBlockTile >::kExtent, ThreadBlockTile::kD > >::type Shape
Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand.
Definition: gemm_operand.h:86
static bool const kKstrided
Definition: gemm_operand.h:81
Scalar_ Scalar
Scalar data type.
Definition: gemm_fragment_stream.h:50
GemmMultiplicandTraits< ThreadBlockTile, kUsage, kLayout > MultiplicandTraits
Traits of multiplicand.
Definition: gemm_fragment_stream.h:72
static int const kAccessSize
Scalars per instruction.
Definition: gemm_fragment_stream.h:62
Definition: load_store.h:42
Definition: convert.h:34
Base::StoreIterator StoreIterator
Defines the store iterator.
Definition: gemm_fragment_stream.h:127
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc, typename Traits::Scalar const *pointer, Index ldm)
Initializes parameters.
Definition: gemm_fragment_stream.h:162
Defines structural properties of complete GEMM computation.
Traits::FragmentStream Base
Base class.
Definition: gemm_fragment_stream.h:114
An abstraction for implementing a stream loading a tile and storing a tile using a pair of tile itera...
Definition: load_store.h:43
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
ReshapeTile< ScalarTile, kAccessSize >::Tile VectorizedTile
Reshape for vectorized access.
Definition: gemm_fragment_stream.h:78
Traits::Index Index
Index type.
Definition: gemm_fragment_stream.h:145
FragmentStream< TileTraits, TileLoadIterator< TileTraits, Scalar, MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kGlobal, Index >, TileStoreIterator< TileTraits, Scalar, MultiplicandTraits::kKstrided ? IteratorAdvance::kH :IteratorAdvance::kW, MemorySpace::kShared, Index, Scalar, IteratorFragment::kScalar, DestinationSkew > > FragmentStream
Define the tile stream.
Definition: gemm_fragment_stream.h:93
Traits_ Traits
Traits.
Definition: gemm_fragment_stream.h:111
Definition: tile_iterator.h:97
CUTLASS_DEVICE GemmFragmentStream()
Definition: gemm_fragment_stream.h:181
TileTraitsDefault< VectorizedTile, kThreads > TileTraits
Define structure of stripmined tile.
Definition: gemm_fragment_stream.h:81
MultiplicandTraits::Shape ScalarTile
Scalar tile shape.
Definition: gemm_fragment_stream.h:75
static CUTLASS_DEVICE void shared_store_fence()
The memory fence for shared stores.
Definition: gemm_fragment_stream.h:174
Defines a FragmentStream by mapping GEMM dimensions onto contiguous and strided dimensions.
Definition: gemm_fragment_stream.h:45
CUTLASS_DEVICE GemmFragmentStream(Params const &params, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Constructor - bounds and block offset are aligned to GEMM coordinates (K, N, M)
Definition: gemm_fragment_stream.h:185
Base::Fragment Fragment
Loaded fragment type.
Definition: gemm_fragment_stream.h:133
GEMM Fragment Stream.
Definition: gemm_fragment_stream.h:98
Traits::Scalar const * Pointer
The pointer.
Definition: gemm_fragment_stream.h:152
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:308
CUTLASS_DEVICE void commit()
Commits the fragment.
Definition: gemm_fragment_stream.h:199
Base::Storage Storage
Destination storage.
Definition: gemm_fragment_stream.h:139
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
Manages a pair of iterators to stream data from global memory to shared.
Definition: fragment_stream.h:50
Definition: gemm_operand.h:66
static MatrixLayout::Kind const kLayout
Layout of the operand.
Definition: gemm_fragment_stream.h:53
Traits::Scalar Scalar
Scalar type.
Definition: gemm_fragment_stream.h:121
Parameters object.
Definition: gemm_fragment_stream.h:155
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:37
Index_ Index
Index type.
Definition: gemm_fragment_stream.h:65
Definition: shape.h:124
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:63
Base::Convert Convert
Converts between tiles.
Definition: gemm_fragment_stream.h:130
Definition: gemm_operand.h:94
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
static int const kThreads
Number of threads.
Definition: gemm_fragment_stream.h:59
Kind
Definition: matrix_traits.h:36
Base::StoreFragment StoreFragment
Stored fragment type.
Definition: gemm_fragment_stream.h:136
Base::LoadIterator LoadIterator
Defines the load iterator.
Definition: gemm_fragment_stream.h:124
static GemmOperand::Kind const kUsage
Indicates identity of multiplicand.
Definition: gemm_fragment_stream.h:47
Tile_ Tile
Definition: tile.h:43
Definition: tile_iterator.h:97
ShapeDiv< DestinationSkew_, Shape< ScalarsPerInst, ScalarsPerInst, ScalarsPerInst, 1 > >::Shape DestinationSkew
Skew added to shared memory tile.
Definition: gemm_fragment_stream.h:69
Kind
Definition: matrix_traits.h:43
CUTLASS_DEVICE void residue(Coord< 3 > const &bounds, Coord< 3 > const &block_offset)
TODO - Recomputes predicates and clears fetch registers.
Definition: gemm_fragment_stream.h:203
ThreadBlockTile_ ThreadBlockTile
Shape of the thread block tile (K, N, M)
Definition: gemm_fragment_stream.h:56
Defines properties of matrices used to denote layout and operands to GEMM kernels.
CUTLASS_DEVICE void initialize_predicates(Coord< 3 > const &bounds, Coord< 3 > const &block_offset)
Recomputes predicates aligned to GEMM coordinates (K, N, M)
Definition: gemm_fragment_stream.h:211
Definition: tile_iterator.h:102
CUTLASS_DEVICE void load()
Loads the fragment.
Definition: gemm_fragment_stream.h:195
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:556