Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tile_coord.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/coord.h"
32 
33 namespace cutlass {
34 
36 
39 template <typename Index_ = int>
40 struct TileCoord : public Coord<4, Index_> {
41 
43  typedef Index_ Index;
44 
47 
49  static int kD = 0;
50 
52  static int kH = 1;
53 
55  static int kW = 2;
56 
58  static int kC = 3;
59 
60  //
61  // Methods
62  //
63 
66  TileCoord() { }
67 
70  TileCoord(Coord<3, Index> const &coord):
71  Base(make_Coord(coord[0], coord[1], coord[2], 0)) { }
72 
75  TileCoord(Coord<4, Index> const &coord): Base(coord) { }
76 
79  TileCoord(Index coord[4]): Base(coord) { }
80 
84 
87  Index const & d() const { return this->at(kD); }
88 
91  Index & d() { return this->at(kD); }
92 
95  Index const & h() const { return this->at(kH); }
96 
99  Index & h() { return this->at(kH); }
100 
103  Index const & w() const { return this->at(kW); }
104 
107  Index & w() { return this->at(kW); }
108 
111  Index const & c() const { return this->at(kC); }
112 
115  Index & c() { return this->at(kC); }
116 
119  Coord<2> hw() const {
120  return make_Coord(h(), w());
121  }
122 
125  Coord<3> hwc() const {
126  return make_Coord(h(), w(), c());
127  }
128 
131  Coord<3> dhw() const {
132  return make_Coord(d(), h(), w());
133  }
134 
135  //
136  // Coord operators
137  //
138 
141  TileCoord operator+(Base const& b) const {
142  return TileCoord(Base::operator+(b));
143  }
144 
147  TileCoord operator-(Base const& b) const {
148  return TileCoord(Base::operator-(b));
149  }
150 
153  TileCoord operator*(Base const& b) const {
154  return TileCoord(Base::operator*(b));
155  }
156 
159  TileCoord operator/(Base const& b) const {
160  return TileCoord(Base::operator/(b));
161  }
162 
165  TileCoord& operator+=(Base const& b) {
166  Base::operator+=(b);
167  return *this;
168  }
169 
172  TileCoord& operator-=(Base const& b) {
173  Base::operator-=(b);
174  return *this;
175  }
176 
179  TileCoord& operator*=(Base const& b) {
180  Base::operator*=(b);
181  return *this;
182  }
183 
186  TileCoord& operator/=(Base const& b) {
187  Base::operator/=(b);
188  return *this;
189  }
190 };
191 
193 
194 } // namespace cutlass
static int kC
C dimension.
Definition: tile_coord.h:58
static int kD
D dimension.
Definition: tile_coord.h:49
Definition: convert.h:33
static int kH
H dimension.
Definition: tile_coord.h:52
CUTLASS_HOST_DEVICE Coord< 3 > dhw() const
Gets D, H, and W dimensions as a Coord<3>
Definition: tile_coord.h:131
CUTLASS_HOST_DEVICE Index const & c() const
Returns the Celement of the coordinate.
Definition: tile_coord.h:111
CUTLASS_HOST_DEVICE Index const & h() const
Returns the H element of the coordinate.
Definition: tile_coord.h:95
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Index_ Index
Index type.
Definition: tile_coord.h:43
CUTLASS_HOST_DEVICE Index const & w() const
Returns the W element of the coordinate.
Definition: tile_coord.h:103
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_HOST_DEVICE Index const & d() const
Returns the D element of the coordinate.
Definition: tile_coord.h:87
CUTLASS_HOST_DEVICE TileCoord & operator-=(Base const &b)
In-place subtraction.
Definition: tile_coord.h:172
CUTLASS_HOST_DEVICE TileCoord(Index d, Index h, Index w, Index c)
Helper to construct from a row and column.
Definition: tile_coord.h:83
CUTLASS_HOST_DEVICE TileCoord(Coord< 4, Index > const &coord)
Constructs from Coord<4>
Definition: tile_coord.h:75
CUTLASS_HOST_DEVICE TileCoord operator*(Base const &b) const
Element-wise multiplication.
Definition: tile_coord.h:153
CUTLASS_HOST_DEVICE Index & c()
Returns the C element of the coordinate.
Definition: tile_coord.h:115
CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
In-place multiplication.
Definition: coord.h:197
CUTLASS_HOST_DEVICE TileCoord(Coord< 3, Index > const &coord)
Constructs from Coord<3> and infers coord[kC] = 0.
Definition: tile_coord.h:70
CUTLASS_HOST_DEVICE TileCoord operator-(Base const &b) const
Element-wise subtraction.
Definition: tile_coord.h:147
CUTLASS_HOST_DEVICE TileCoord & operator*=(Base const &b)
In-place multiplication.
Definition: tile_coord.h:179
Definition: tile_coord.h:40
CUTLASS_HOST_DEVICE Coord< 3 > hwc() const
Gets H, W, and C dimensions as a Coord<3>
Definition: tile_coord.h:125
CUTLASS_HOST_DEVICE TileCoord()
Default ctor.
Definition: tile_coord.h:66
CUTLASS_HOST_DEVICE Index & w()
Returns the W element of the coordinate.
Definition: tile_coord.h:107
CUTLASS_HOST_DEVICE Coord< 2 > hw() const
Gets H and W dimensions as a Coord<2>
Definition: tile_coord.h:119
CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
In-place subtraction.
Definition: coord.h:188
CUTLASS_HOST_DEVICE TileCoord operator/(Base const &b) const
Element-wise division.
Definition: tile_coord.h:159
CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
In-place addition.
Definition: coord.h:179
Coord< 4, Index > Base
Underlying Coord<4>
Definition: tile_coord.h:46
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Index & at()
Gets the index of a given Coord element.
Definition: coord.h:240
CUTLASS_HOST_DEVICE TileCoord & operator/=(Base const &b)
In-place division.
Definition: tile_coord.h:186
CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
In-place division.
Definition: coord.h:206
CUTLASS_HOST_DEVICE TileCoord(Index coord[4])
Constructs from an array of coordinate elements.
Definition: tile_coord.h:79
static int kW
W dimension.
Definition: tile_coord.h:55
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
CUTLASS_HOST_DEVICE Index & h()
Returns the H element of the coordinate.
Definition: tile_coord.h:99
CUTLASS_HOST_DEVICE TileCoord & operator+=(Base const &b)
In-place addition.
Definition: tile_coord.h:165
CUTLASS_HOST_DEVICE TileCoord operator+(Base const &b) const
Element-wise addition.
Definition: tile_coord.h:141
CUTLASS_HOST_DEVICE Index & d()
Returns the D element of the coordinate.
Definition: tile_coord.h:91