Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_config.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/shape.h"
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <
39  typename ScalarA_,
41  typename ScalarB_,
43  typename ScalarC_,
45  typename ScalarD_,
47  typename OutputTile_,
49  typename MultiplyAdd_,
51  int kScalarsPerLdgA_,
53  int kScalarsPerStsA_,
55  int kScalarsPerLdsA_,
57  int kScalarsPerLdgB_,
59  int kScalarsPerStsB_,
61  int kScalarsPerLdsB_,
63  int kScalarsPerLdgCAndStgD_,
65  int kScalarsPerStsD_,
67  int kScalarsPerLdsD_,
69  int kStages_,
71  bool kResidueSeparate_ = false,
73  bool kResidueInProlog_ = false,
75  bool kLaunchBounds_ = true>
76 struct GemmConfig {
77  //
79  typedef ScalarA_ ScalarA;
81  typedef ScalarB_ ScalarB;
83  typedef ScalarC_ ScalarC;
85  typedef ScalarD_ ScalarD;
86 
88  typedef OutputTile_ OutputTile;
90  typedef MultiplyAdd_ MultiplyAdd;
97 
101  static int const kWarpSize = cutlass::kWarpSize;
104 
106  static int const kScalarsPerLdgA = kScalarsPerLdgA_;
107  static int const kScalarsPerStsA = kScalarsPerStsA_;
108  static int const kScalarsPerLdsA = kScalarsPerLdsA_;
109 
111  static int const kScalarsPerLdgB = kScalarsPerLdgB_;
112  static int const kScalarsPerStsB = kScalarsPerStsB_;
113  static int const kScalarsPerLdsB = kScalarsPerLdsB_;
114 
116  static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
117 
119  static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
120  static int const kScalarsPerStsD = kScalarsPerStsD_;
121  static int const kScalarsPerLdsD = kScalarsPerLdsD_;
122 
124  static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
125  static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
126 
128  static int const kStages = kStages_;
129 
131  // updates and is more efficient for some kernels. If false, only a single mainloop is
132  // instantaited.
133  static bool const kResidueSeparate = kResidueSeparate_;
134 
136  static bool const kResidueInProlog = kResidueInProlog_;
137 
139  static bool const kLaunchBounds = kLaunchBounds_;
140 };
141 
143 
144 } // namespace gemm
145 } // namespace cutlass
Definition: convert.h:33
static int const kThreads
The numnber of threads.
Definition: gemm_config.h:103
ShapeDiv< OutputTile, AccumulatorsPerWarp >::Shape Warps
The number of warps.
Definition: gemm_config.h:99
MultiplyAdd::InstructionShape InstructionShape
The shape of the instruction.
Definition: gemm_config.h:92
static int const kWarpSize
The default warp size (32 threads per warp).
Definition: gemm_config.h:101
static int const kScalarsPerLdsD
Definition: gemm_config.h:121
static int const kScalarsPerStgD
The number of scalars per STS/LDS/STG for D.
Definition: gemm_config.h:119
A template defining Fragment Concept.
Definition: fragment.h:99
static int const kScalarsPerLdgB
The number of scalars per LDG/STS/LDS for B.
Definition: gemm_config.h:111
ScalarC_ ScalarC
The scalar for C.
Definition: gemm_config.h:83
MultiplyAdd::Accumulators Accumulators
The accumulators.
Definition: gemm_config.h:96
static int const kStages
The number of stages in shared memory to implement double, triple, more-buffering.
Definition: gemm_config.h:128
ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: thread_multiply_add.h:54
static bool const kResidueInProlog
If true, residue is computed in the prologue.
Definition: gemm_config.h:136
static bool const kLaunchBounds
If true, kernel is launched with launch bounds specified.
Definition: gemm_config.h:139
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
static int const kAccumulatorsPerLdsB
Definition: gemm_config.h:125
Shape< A_::kD/B_::kD, A_::kH/B_::kH, A_::kW/B_::kW, A_::kC/B_::kC > Shape
Definition: shape.h:126
ScalarA_ ScalarA
The scalar for A.
Definition: gemm_config.h:79
static bool const kResidueSeparate
If true, mainloop is instantiated twice. The first instantiation contains no predicate.
Definition: gemm_config.h:133
Definition: gemm_config.h:76
MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp
The shape of warp-level GEMM.
Definition: gemm_config.h:94
static int const kScalarsPerLdsB
Definition: gemm_config.h:113
static int const kScalarsPerLdgC
The number of scalars per LDG for C.
Definition: gemm_config.h:116
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
static int const kScalarsPerLdsA
Definition: gemm_config.h:108
static int const kAccumulatorsPerLdsA
The number of accumulators that are going to be fed from one LDS A/B.
Definition: gemm_config.h:124
static int const kScalarsPerStsA
Definition: gemm_config.h:107
static int const kScalarsPerStsB
Definition: gemm_config.h:112
static int const kScalarsPerLdgA
The number of scalars per LDG/STS/LDS for A.
Definition: gemm_config.h:106
static int const kScalarsPerStsD
Definition: gemm_config.h:120
ScalarD_ ScalarD
The scalar for D.
Definition: gemm_config.h:85
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
Compute derived counted of a Layout Concept based class.
Definition: shape.h:79
ScalarB_ ScalarB
The scalar for B.
Definition: gemm_config.h:81
OutputTile_ OutputTile
The tile.
Definition: gemm_config.h:88