Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_operand.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/matrix_traits.h"
32 #include "cutlass/reshape_tile.h"
33 #include "cutlass/util/platform.h"
34 
35 namespace cutlass {
36 namespace gemm {
37 
39 
41 template <GemmOperand::Kind kOperand_, MatrixLayout::Kind kLayout_>
43  static const bool Congruous =
44  (kOperand_ == GemmOperand::kA ^ kLayout_ == MatrixLayout::kRowMajor);
45 };
46 
48 
49 template <typename GemmOperand::Kind kOperand_, typename Tile_>
50 struct GetExtent;
51 
52 template <typename Tile_>
53 struct GetExtent<GemmOperand::kA, Tile_> {
54  static const int kExtent = Tile_::kW;
55 };
56 
57 template <typename Tile_>
58 struct GetExtent<GemmOperand::kB, Tile_> {
59  static const int kExtent = Tile_::kH;
60 };
61 
63 
66 template <typename ThreadBlockTile_, GemmOperand::Kind Usage, MatrixLayout::Kind Layout>
68  // Only defined for A or B
69  static_assert(Usage == GemmOperand::kA || Usage == GemmOperand::kB,
70  "MultiplicandTileShape defined only for A or B operands.");
71 
73  typedef ThreadBlockTile_ ThreadBlockTile;
74 
76  static GemmOperand::Kind const kUsage = Usage;
77 
79  static MatrixLayout::Kind const kLayout = Layout;
80 
81  // True if K is the strided dimension
83 
85  typedef typename platform::conditional<
86  kKstrided,
89 };
90 
92 
95 template <GemmOperand::Kind operand, bool Kstrided = true>
97 
99 template <bool Kstrided>
100 struct ProjectOperand<GemmOperand::kA, Kstrided> {
102  static Coord<3> project(Coord<3> const &coord) {
103  if (Kstrided) {
104  return make_Coord(0, coord[0], coord[2]);
105  } else {
106  return make_Coord(0, coord[2], coord[0]);
107  }
108  }
109 };
110 
112 template <bool Kstrided>
113 struct ProjectOperand<GemmOperand::kB, Kstrided> {
115  static Coord<3> project(Coord<3> const &coord) {
116  if (Kstrided) {
117  return make_Coord(0, coord[0], coord[1]);
118  } else {
119  return make_Coord(0, coord[1], coord[0]);
120  }
121  }
122 };
123 
125 template <>
126 struct ProjectOperand<GemmOperand::kC, true> {
128  static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
129 };
130 
132 template <>
133 struct ProjectOperand<GemmOperand::kD, true> {
135  static Coord<3> project(Coord<3> const &coord) { return make_Coord(0, coord[1], coord[2]); }
136 };
137 
139 
140 } // namespace gemm
141 } // namespace cutlass
static bool const kKstrided
Definition: gemm_operand.h:82
static CUTLASS_HOST_DEVICE Coord< 3 > project(Coord< 3 > const &coord)
Definition: gemm_operand.h:115
Definition: convert.h:33
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
static CUTLASS_HOST_DEVICE Coord< 3 > project(Coord< 3 > const &coord)
Definition: gemm_operand.h:128
C++ features that may be otherwise unimplemented for CUDA device functions.
ThreadBlockTile_ ThreadBlockTile
Shape of GEMM thread block tile (K, N, M)
Definition: gemm_operand.h:70
platform::conditional< kKstrided, Shape< 1, ThreadBlockTile::kD, GetExtent< Usage, ThreadBlockTile >::kExtent >, Shape< 1, GetExtent< Usage, ThreadBlockTile >::kExtent, ThreadBlockTile::kD > >::type Shape
Map the ThreadBlockShape onto (kH, kW) dimensions for A and B operand.
Definition: gemm_operand.h:88
Defines a type for restructuring a tile.
Definition: gemm_operand.h:67
static const bool Congruous
Definition: gemm_operand.h:43
Definition: matrix_traits.h:357
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
std::conditional (true specialization)
Definition: platform.h:351
#define static_assert(__e, __m)
Definition: platform.h:153
static MatrixLayout::Kind const kLayout
Layout of tile.
Definition: gemm_operand.h:79
static CUTLASS_HOST_DEVICE Coord< 3 > project(Coord< 3 > const &coord)
Definition: gemm_operand.h:102
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: gemm_operand.h:96
Definition: gemm_operand.h:50
Definition: matrix_traits.h:159
Gemm operand - D = A * B + C.
Definition: matrix_traits.h:356
static CUTLASS_HOST_DEVICE Coord< 3 > project(Coord< 3 > const &coord)
Definition: gemm_operand.h:135
Kind
Definition: matrix_traits.h:357
Definition: matrix_traits.h:357
static GemmOperand::Kind const kUsage
Identifies multiplicand.
Definition: gemm_operand.h:76
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Helper to describe attributes of GEMM matrix operands.
Definition: gemm_operand.h:42