Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_gemm_global_tile.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename Scalar_, typename Tile_, typename Threads_, int kAccessSize_>
38 struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits<GemmOperand::kC,
39  MatrixLayout::kColumnMajor,
40  Scalar_,
41  Tile_,
42  Threads_,
43  kAccessSize_> {
47  Scalar_,
48  Tile_,
49  Threads_,
50  kAccessSize_>
52 
55 
57  struct ThreadOffset {
59  Coord<4> operator()() const {
60  int thread_offset_h = threadIdx.x / Base::Threads::kW;
61  int thread_offset_w = threadIdx.x % Base::Threads::kW * Base::ThreadsDelta::kW;
62 
63  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
64  }
65  };
66 };
67 
69 
70 template <typename TileTraits_, typename Index_ = int>
71 struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd<TileTraits_, Index_> {
75  typedef TileTraits_ Traits;
81  static MatrixLayout::Kind const kLayout = TileTraits_::kLayout;
82 
84  typedef typename TileTraits_::Scalar Scalar;
86  typedef typename TileTraits_::Pointer Pointer;
88  typedef typename TileTraits_::Threads Threads;
90  typedef Index_ Index;
92  typedef typename TileTraits_::ThreadOffset ThreadOffset;
94  typedef typename Base::Params BaseParams;
95 
97  struct Params : public BaseParams {
100  long long batch_stride,
101  Index ldm,
102  Index n,
103  Index epilogue_stride_w,
104  Index epilogue_delta_w) {
105  // The pointer.
106  this->pointer = pointer;
107  // Stride between GEMMs
108  this->stride_d = batch_stride;
109  // Setup the base stride. One "group of threads" per column.
110  this->stride_h = ldm;
111  // Each thread output 1 column per iteration. .
112  this->inc_h = ldm * TileTraits_::Threads::kH;
113  this->inc_advance = this->inc_h + epilogue_stride_w;
114 
115  this->predicate_offset = n;
116  this->predicate_inc_h = TileTraits_::Threads::kH;
117  this->predicate_inc_advance = this->predicate_inc_h + epilogue_delta_w;
118 
119  return 0;
120  }
121  };
122 
124  CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params,
125  const Coord<3>& bounds,
126  const Coord<3>& block,
127  int const pointer_offset = 0,
128  int const pred_offset = 0,
129  ThreadOffset thread_offset_func = ThreadOffset())
130 
131  : Base(params, bounds, block, pointer_offset, pred_offset, thread_offset_func) {}
132 
134  CUTLASS_DEVICE void load_element(
135  typename Base::AccessType& value, int d, int h, int w, int c) const {
136  Base::load_element(value, d, h, w, c);
137  }
138 
140  CUTLASS_DEVICE void store_element(
141  typename Base::AccessType const& value, int d, int h, int w, int c) {
142  int const offset =
144  Store<Scalar,
148  typename Base::FragmentElement,
149  Base::Tile::kW>::store(value, Base::params.pointer, offset);
150  }
151 
152  public:
153  template <typename Fragment>
154  CUTLASS_DEVICE void load_post_increment(Fragment& fragment) {
155  Base::load_post_increment(fragment);
156  }
157 
158  template <typename Fragment>
159  CUTLASS_DEVICE void store_post_increment(Fragment& fragment) {
160  Base::store_post_increment(fragment);
161  }
162 };
163 
165 
166 } // namespace gemm
167 } // namespace cutlass
TileTraits_::Threads Threads
The threads.
Definition: wmma_gemm_global_tile.h:88
Index inc_advance
The strides to increment the pointer.
Definition: gemm_global_tile.h:406
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Index_ Index
The index.
Definition: gemm_global_tile.h:391
CUTLASS_HOST_DEVICE int initialize(Pointer pointer, long long batch_stride, Index ldm, Index n, Index epilogue_stride_w, Index epilogue_delta_w)
Setup the params.
Definition: wmma_gemm_global_tile.h:99
Definition: gemm_global_tile.h:70
GemmGlobalIteratorCd< Traits, Index_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:77
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: wmma_gemm_global_tile.h:134
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Base::Params BaseParams
Base parameters.
Definition: wmma_gemm_global_tile.h:94
Index predicate_inc_h
Definition: gemm_global_tile.h:408
Index_ Index
The index.
Definition: wmma_gemm_global_tile.h:90
TileTraits_::Scalar Scalar
The scalar.
Definition: wmma_gemm_global_tile.h:84
CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment)
Loads and increments iterator.
Definition: gemm_global_tile.h:549
CUTLASS_DEVICE void load_post_increment(Fragment &fragment)
Definition: wmma_gemm_global_tile.h:154
Definition: matrix_traits.h:357
Definition: load_store.h:178
Index predicate_inc_advance
The strides to increment the predicate offset.
Definition: gemm_global_tile.h:408
The params.
Definition: wmma_gemm_global_tile.h:97
static FragmentElementType::Kind const kFragmentElementType
Specifies iterator storage fragment type (Scalar or WmmaMatrix)
Definition: tile_iterator.h:158
CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
Definition: wmma_gemm_global_tile.h:59
Definition: wmma_gemm_global_tile.h:71
TileTraits_::Pointer Pointer
The pointer.
Definition: wmma_gemm_global_tile.h:86
CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment)
Definition: gemm_global_tile.h:580
Index stride_h
The stride in the H dimension to setup the thread in the block.
Definition: gemm_global_tile.h:404
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
Loads a single fragment element from memory.
Definition: gemm_global_tile.h:512
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > Delta
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:54
Definition: vector.h:62
TileTraits_::ThreadOffset ThreadOffset
The thread offset.
Definition: gemm_global_tile.h:395
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: matrix_traits.h:159
CUTLASS_DEVICE void store_post_increment(Fragment &fragment)
Definition: wmma_gemm_global_tile.h:159
Index inc_h
Definition: gemm_global_tile.h:406
static MemorySpace::Kind const kMemorySpace
Source or destination memory space.
Definition: tile_iterator.h:161
static MatrixLayout::Kind const kLayout
The layout.
Definition: wmma_gemm_global_tile.h:81
Pointer pointer
The pointer.
Definition: gemm_global_tile.h:400
Definition: wmma_gemm_global_tile.h:38
Iterator for accessing a stripmined tile in memory.
Definition: tile_iterator.h:144
TileTraits_::ThreadOffset ThreadOffset
The thread offset functor.
Definition: wmma_gemm_global_tile.h:92
Params params
Parameters.
Definition: gemm_global_tile.h:456
Definition: gemm_global_tile.h:366
static int const kW
The width of the cube.
Definition: shape.h:70
static int const kAccessSize
The number of scalars accessed per load/store.
Definition: tile_iterator.h:188
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ > Base
The base class.
Definition: wmma_gemm_global_tile.h:51
long long stride_d
The stride in the D dimension.
Definition: gemm_global_tile.h:402
CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const &params, const Coord< 3 > &bounds, const Coord< 3 > &block, int const pointer_offset=0, int const pred_offset=0, ThreadOffset thread_offset_func=ThreadOffset())
Ctor.
Definition: wmma_gemm_global_tile.h:124
WmmaGemmGlobalIteratorCd< TileTraits_, Index_ > This_
This class.
Definition: wmma_gemm_global_tile.h:73
TileTraits_::Pointer Pointer
The pointer.
Definition: gemm_global_tile.h:387
Shape< 0, 0, Base::Delta::kW, Base::Delta::kC > ImmediateOffsetStrides
Override the strides in each dimension between different loads/stores.
Definition: wmma_gemm_global_tile.h:79
Computes the thread offset in (H, W) based on thread ID.
Definition: wmma_gemm_global_tile.h:57
The params.
Definition: gemm_global_tile.h:398
TileTraits_ Traits
The traits.
Definition: wmma_gemm_global_tile.h:75
Index predicate_offset
The column offset to compute the predicate for the columns.
Definition: gemm_global_tile.h:410
CUTLASS_DEVICE void store_element(typename Base::AccessType const &value, int d, int h, int w, int c)
Stores a single fragment element into memory.
Definition: wmma_gemm_global_tile.h:140