Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
shape.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/cutlass.h"
31 
32 namespace cutlass {
33 
35 
63 template <int kD_ = 1, int kH_ = 1, int kW_ = 1, int kC_ = 1>
64 struct Shape {
66  static int const kD = kD_;
68  static int const kH = kH_;
70  static int const kW = kW_;
72  static int const kC = kC_;
73 };
74 
78 template <typename Shape>
79 struct ShapeCount {
81  static int const kWc = Shape::kW * Shape::kC;
83  static int const kHw = Shape::kH * Shape::kW;
85  static int const kHwc = Shape::kH * kWc;
87  static int const kDhw = Shape::kD * kHw;
89  static int const kDhwc = Shape::kD * kHwc;
91  static int const kCount = kDhwc;
92 };
93 
95 
96 template <typename A_, int kScale_>
97 struct ShapeScale {
99 };
100 
102 
103 template <typename A_, typename B_>
104 struct ShapeAdd {
106 };
107 
109 
110 template <typename A_, typename B_>
111 struct ShapeSub {
112  typedef Shape<A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC> Shape;
113 };
114 
116 
117 template <typename A_, typename B_>
118 struct ShapeMul {
120 };
121 
123 
124 template <typename A_, typename B_>
125 struct ShapeDiv {
126  typedef Shape<A_::kD / B_::kD, A_::kH / B_::kH, A_::kW / B_::kW, A_::kC / B_::kC> Shape;
127 };
128 
130 
131 template <typename A_, typename B_>
133  typedef Shape<(A_::kD + B_::kD - 1) / B_::kD,
134  (A_::kH + B_::kH - 1) / B_::kH,
135  (A_::kW + B_::kW - 1) / B_::kW,
136  (A_::kC + B_::kC - 1) / B_::kC>
138 };
139 
141 
142 template <typename A_, typename B_>
143 struct ShapeMax {
144  typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD),
145  (A_::kH > B_::kH ? A_::kH : B_::kH),
146  (A_::kW > B_::kW ? A_::kW : B_::kW),
147  (A_::kC > B_::kC ? A_::kC : B_::kC)>
149 };
150 
152 
153 template <typename A_, typename B_>
154 struct ShapeMin {
155  typedef Shape<(A_::kD < B_::kD ? A_::kD : B_::kD),
156  (A_::kH < B_::kH ? A_::kH : B_::kH),
157  (A_::kW < B_::kW ? A_::kW : B_::kW),
158  (A_::kC < B_::kC ? A_::kC : B_::kC)>
160 };
161 
163 
164 template <typename Shape_, int elementsPerAccess>
165 struct ShapeStrides {
166  typedef Shape<Shape_::kH * Shape_::kW * Shape_::kC,
167  Shape_::kW * Shape_::kC,
168  Shape_::kC,
169  elementsPerAccess>
171 };
172 
174 
179 template <typename Shape_>
181  static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) {
182  // clang-format off
183  return d * Shape_::kH * Shape_::kW * Shape_::kC +
184  h * Shape_::kW * Shape_::kC +
185  w * Shape_::kC +
186  c;
187  // clang-format on
188  }
189 };
190 
192 
197 template <typename Strides_>
199  static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) {
200  return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
201  }
202 };
203 
205 
212 template <typename Threads_, typename Strides_>
214  static CUTLASS_DEVICE int get() {
215  // Decompose the thread index.
216  int c = threadIdx.x % Threads_::kC;
217  int w = threadIdx.x / Threads_::kC % Threads_::kW;
218  int h = threadIdx.x / Threads_::kC / Threads_::kW % Threads_::kH;
219  int d = threadIdx.x / Threads_::kC / Threads_::kW / Threads_::kH;
220 
221  // Compute the offset.
222  return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC;
223  }
224 };
225 
227 
230 template <int T_h_, int T_w_, int T_c_, int S_h_, int S_w_, int S_c_>
231 struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, T_c_>, Shape<1, S_h_, S_w_, S_c_> > {
232  static CUTLASS_DEVICE int get() {
233  // Decompose the thread index.
234  int c = threadIdx.x % T_c_;
235  int w = threadIdx.x / T_c_ % T_w_;
236  int h = threadIdx.x / T_c_ / T_w_ % T_h_;
237 
238  // Compute the offset.
239  return h * S_h_ + w * S_w_ + c * S_c_;
240  }
241 };
242 
244 
248 template <int T_h_, int T_w_, int S_h_, int S_w_>
249 struct ComputeThreadOffsetFromStrides<Shape<1, T_h_, T_w_, 1>, Shape<1, S_h_, S_w_, 1> > {
250  static CUTLASS_DEVICE int get() {
251  // Decompose the thread index.
252  int w = threadIdx.x % T_w_;
253  int h = threadIdx.x / T_w_;
254 
255  // Compute the offset.
256  return h * S_h_ + w * S_w_;
257  }
258 };
259 
261 
262 } // namespace cutlass
Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_.
Definition: shape.h:213
static int const kWc
The number of elements per row.
Definition: shape.h:81
Definition: convert.h:33
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, elementsPerAccess > Shape
Definition: shape.h:170
Shape< A_::kD+B_::kD, A_::kH+B_::kH, A_::kW+B_::kW, A_::kC+B_::kC > Shape
Definition: shape.h:105
Shape<(A_::kD+B_::kD - 1)/B_::kD,(A_::kH+B_::kH - 1)/B_::kH,(A_::kW+B_::kW - 1)/B_::kW,(A_::kC+B_::kC - 1)/B_::kC > Shape
Definition: shape.h:137
Shape< A_::kD *kScale_, A_::kH *kScale_, A_::kW *kScale_, A_::kC *kScale_ > Shape
Definition: shape.h:98
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
Shape< A_::kD - B_::kD, A_::kH - B_::kH, A_::kW - B_::kW, A_::kC - B_::kC > Shape
Definition: shape.h:112
Definition: shape.h:111
static int const kH
The height of the cube.
Definition: shape.h:68
static int const kC
The number of scalars per element.
Definition: shape.h:72
Definition: shape.h:97
Compute the offset for the given coordinates in a cube.
Definition: shape.h:180
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
static int const kDhw
The number of pixels per cube.
Definition: shape.h:87
Definition: shape.h:118
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Definition: shape.h:125
Compute the offset for the given coordinates in a cube.
Definition: shape.h:198
Definition: shape.h:132
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: shape.h:143
Definition: shape.h:104
static int const kCount
The number of elements in the 4D space.
Definition: shape.h:91
static int const kDhwc
The number of elements in the 4D space.
Definition: shape.h:89
static int const kW
The width of the cube.
Definition: shape.h:70
Definition: shape.h:154
static int const kHw
The number of pixels per image.
Definition: shape.h:83
static int const kD
The depth of the cube.
Definition: shape.h:66
Definition: shape.h:165
Shape<(A_::kD > B_::kD ? A_::kD :B_::kD),(A_::kH > B_::kH ? A_::kH :B_::kH),(A_::kW > B_::kW ? A_::kW :B_::kW),(A_::kC > B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:148
Basic include for CUTLASS macros.
Shape<(A_::kD< B_::kD ? A_::kD :B_::kD),(A_::kH< B_::kH ? A_::kH :B_::kH),(A_::kW< B_::kW ? A_::kW :B_::kW),(A_::kC< B_::kC ? A_::kC :B_::kC)> Shape
Definition: shape.h:159
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
static int const kHwc
The number of elements per image.
Definition: shape.h:85