Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
matrix_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/coord.h"
31 
32 namespace cutlass {
33 
35 
38 struct MatrixCoord : public Coord<2, int> {
39 
41  typedef int Index;
42 
45 
47  static int const kRow = 0;
48 
50  static int const kColumn = 1;
51 
52  //
53  // Methods
54  //
55 
59 
62  MatrixCoord(Coord<2, Index> const &coord): Base(coord) { }
63 
67 
70  Index const & row() const { return this->at(kRow); }
71 
74  Index & row() { return this->at(kRow); }
75 
78  Index const & column() const { return this->at(kColumn); }
79 
82  Index & column() { return this->at(kColumn); }
83 
84  //
85  // Coord operators
86  //
87 
90  MatrixCoord operator+(Base const& b) const {
91  return MatrixCoord(Base::operator+(b));
92  }
93 
96  MatrixCoord operator-(Base const& b) const {
97  return MatrixCoord(Base::operator-(b));
98  }
99 
102  MatrixCoord operator*(Base const& b) const {
103  return MatrixCoord(Base::operator*(b));
104  }
105 
108  MatrixCoord operator/(Base const& b) const {
109  return MatrixCoord(Base::operator/(b));
110  }
111 
115  Base::operator+=(b);
116  return *this;
117  }
118 
122  Base::operator-=(b);
123  return *this;
124  }
125 
129  Base::operator*=(b);
130  return *this;
131  }
132 
136  Base::operator/=(b);
137  return *this;
138  }
139 };
140 
142 
144 //
145 // The following define classes satisfying the TensorRefMapFunc concept. These must support the
146 // following operations, where func is an instance of type TensorRefMapFunc.
147 //
148 // Coord<TensorRefMapFunc::kStorageRank> = func(Coord<kRank>);
149 //
150 // Though not required to be usable by TensorRef, each of the following also define a helper
151 // function to map the "leading dimension" to an appropriate stride vector. Implementations
152 // following this convention should also implement the following static method:
153 //
154 // Coord<TensorRefMapFunc::kStorageRank> stride = TensorRefMapFunc::stride(leading_dim);
155 //
156 namespace MatrixLayout {
157 
160 
161  //
162  // TensorRefMapFunc definitions for common layouts
163  //
164 
166  struct RowMajor {
167  static int const kStorageRank = 2;
171  return coord;
172  }
173  };
174 
176  struct ColumnMajor {
177  static int const kStorageRank = 2;
181  return make_Coord(coord.column(), coord.row());
182  }
183  };
184 
187  template <int Interleave>
189 
191  static int const kStorageRank = 3;
192 
194  static int const kInterleave = Interleave;
195 
199  return make_Coord(
200  coord.row() / kInterleave,
201  coord.column(),
202  coord.row() % kInterleave
203  );
204  }
205 
208  static Coord<kStorageRank> stride(int ldm) {
209  return make_Coord(
210  ldm * kInterleave,
211  kInterleave,
212  1
213  );
214  }
215  };
216 
219  template <int Interleave>
221 
223  static int const kStorageRank = 3;
224 
226  static int const kInterleave = Interleave;
227 
231  return make_Coord(
232  coord.column() / kInterleave,
233  coord.row(),
234  coord.column() % kInterleave
235  );
236  }
237 
240  static Coord<kStorageRank> stride(int ldm) {
241  return make_Coord(
242  ldm * kInterleave,
243  kInterleave,
244  1
245  );
246  }
247  };
248 
253  static int const kStorageRank = 3;
254 
256  static int const kRow = 0;
257 
259  static int const kColumn = 1;
260 
265  return make_Coord(coord.row(), coord.column(), 0);
266  }
267 
271  if (layout == MatrixLayout::kRowMajor) {
272  return make_Coord(ldm, 1, 1);
273  }
274  return make_Coord(1, ldm, 1);
275  }
276  };
277 
280  template <int BlockRows, int BlockColumns>
282 
284  static int const kStorageRank = 4;
285 
287  static int const kBlockRows = BlockRows;
288 
290  static int const kBlockColumns = BlockColumns;
291 
295  return make_Coord(
296  coord.column() / kBlockColumns,
297  coord.row() / kBlockRows,
298  coord.column() % kBlockColumns,
299  coord.row() % kBlockRows
300  );
301  }
302 
305  static Coord<kStorageRank> stride(int ldm) {
306  return make_Coord(
307  ldm * kBlockRows * kBlockColumns,
309  kBlockRows,
310  1
311  );
312  }
313  };
314 
317  template <int BlockRows, int BlockColumns>
319 
321  static int const kStorageRank = 4;
322 
324  static int const kBlockRows = BlockRows;
325 
327  static int const kBlockColumns = BlockColumns;
328 
332  return make_Coord(
333  coord.row() / kBlockRows,
334  coord.column() / kBlockColumns,
335  coord.row() % kBlockRows,
336  coord.column() % kBlockColumns
337  );
338  }
339 
342  static Coord<kStorageRank> stride(int ldm) {
343  return make_Coord(
344  ldm * kBlockRows * kBlockColumns,
347  1
348  );
349  }
350  };
351 };
352 
354 
356 struct GemmOperand {
357  enum Kind { kA, kB, kC, kD };
358 };
359 
361 
364  enum Kind {
367  };
368 };
369 
371 
372 } // namespace cutlass
int Index
Integer-valued index.
Definition: matrix_traits.h:41
Mapping function for column-major matrices.
Definition: matrix_traits.h:176
static int const kBlockColumns
Interleaving size in columns dimension.
Definition: matrix_traits.h:327
Definition: convert.h:33
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Maps (row, col) to (col, row, col)
Definition: matrix_traits.h:230
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Maps (i, j) to (i, j)
Definition: matrix_traits.h:170
Transformation applied to matrix operands.
Definition: matrix_traits.h:363
Definition: matrix_traits.h:188
static int const kBlockColumns
Interleaving size in columns dimension.
Definition: matrix_traits.h:290
Definition: matrix_traits.h:365
Definition: matrix_traits.h:281
Definition: matrix_traits.h:220
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
no operation
Definition: matrix_traits.h:366
CUTLASS_HOST_DEVICE MatrixCoord & operator/=(Base const &b)
In-place division.
Definition: matrix_traits.h:135
static int const kStorageRank
Definition: matrix_traits.h:167
Definition: matrix_traits.h:251
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Maps (i, j) to (j, i)
Definition: matrix_traits.h:180
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Kind
Definition: matrix_traits.h:364
Coord< 2, Index > Base
Base type is a Coord of rank=2.
Definition: matrix_traits.h:44
CUTLASS_HOST_DEVICE MatrixCoord operator+(Base const &b) const
Element-wise addition.
Definition: matrix_traits.h:90
CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
In-place multiplication.
Definition: coord.h:197
Definition: matrix_traits.h:357
static int const kRow
Dimension of rows.
Definition: matrix_traits.h:256
static int const kStorageRank
Definition: matrix_traits.h:177
static int const kBlockRows
Interleaving size in rows dimension.
Definition: matrix_traits.h:287
static int const kInterleave
Interleaving size.
Definition: matrix_traits.h:194
CUTLASS_HOST_DEVICE Index const & column() const
Returns the column of the coordinate.
Definition: matrix_traits.h:78
CUTLASS_HOST_DEVICE MatrixCoord(Index row, Index column)
Helper to construct from a row and column.
Definition: matrix_traits.h:66
static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
Helper to compute stride vector from leading dimension.
Definition: matrix_traits.h:208
static int const kColumn
Dimension of columns.
Definition: matrix_traits.h:259
static int const kStorageRank
Rank of storage n-D array.
Definition: matrix_traits.h:191
CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
In-place subtraction.
Definition: coord.h:188
static int const kStorageRank
Arbitrary storage rank.
Definition: matrix_traits.h:253
Definition: matrix_traits.h:357
CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
In-place addition.
Definition: coord.h:179
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Maps (row, col) to (row, col, row)
Definition: matrix_traits.h:198
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
static int const kBlockRows
Interleaving size in rows dimension.
Definition: matrix_traits.h:324
CUTLASS_HOST_DEVICE Index const & row() const
Returns the row of the coordinate.
Definition: matrix_traits.h:70
CUTLASS_HOST_DEVICE Index & at()
Gets the index of a given Coord element.
Definition: coord.h:240
CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
In-place division.
Definition: coord.h:206
Definition: matrix_traits.h:159
CUTLASS_HOST_DEVICE MatrixCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: matrix_traits.h:96
CUTLASS_HOST_DEVICE MatrixCoord(Coord< 2, Index > const &coord)
Constructs from Coord<2>
Definition: matrix_traits.h:62
static int const kStorageRank
Rank of storage n-D array.
Definition: matrix_traits.h:321
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
Definition: matrix_traits.h:159
Gemm operand - D = A * B + C.
Definition: matrix_traits.h:356
static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
Helper to compute stride vector from leading dimension.
Definition: matrix_traits.h:342
static int const kRow
Rows dimension.
Definition: matrix_traits.h:47
CUTLASS_HOST_DEVICE MatrixCoord & operator-=(Base const &b)
In-place subtraction.
Definition: matrix_traits.h:121
CUTLASS_HOST_DEVICE MatrixCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: matrix_traits.h:102
static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
Helper to compute stride vector from leading dimension.
Definition: matrix_traits.h:240
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Definition: matrix_traits.h:264
CUTLASS_HOST_DEVICE Index & row()
Returns the row of the coordinate.
Definition: matrix_traits.h:74
static int const kStorageRank
Rank of storage n-D array.
Definition: matrix_traits.h:284
static int const kInterleave
Interleaving size.
Definition: matrix_traits.h:226
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Maps (row, col) to (row, col, row, col)
Definition: matrix_traits.h:331
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
CUTLASS_HOST_DEVICE Index & column()
Returns the column of the coordinate.
Definition: matrix_traits.h:82
CUTLASS_HOST_DEVICE MatrixCoord & operator*=(Base const &b)
In-place multiplication.
Definition: matrix_traits.h:128
static int const kStorageRank
Rank of storage n-D array.
Definition: matrix_traits.h:223
Definition: matrix_traits.h:318
CUTLASS_HOST_DEVICE MatrixCoord & operator+=(Base const &b)
In-place addition.
Definition: matrix_traits.h:114
static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
Helper to compute stride vector from leading dimension.
Definition: matrix_traits.h:305
CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
Maps (row, col) to (col, row, col, row)
Definition: matrix_traits.h:294
static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(MatrixLayout::Kind layout, int ldm)
Helper to construct a stride vector based on contiguous matrix layout and leading dimension...
Definition: matrix_traits.h:270
Definition: matrix_traits.h:38
CUTLASS_HOST_DEVICE MatrixCoord operator/(Base const &b) const
Element-wise division.
Definition: matrix_traits.h:108
static int const kColumn
Columns dimension.
Definition: matrix_traits.h:50
CUTLASS_HOST_DEVICE MatrixCoord()
Default ctor.
Definition: matrix_traits.h:58
Definition: matrix_traits.h:357
Mapping function for row-major matrices.
Definition: matrix_traits.h:166