Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
wmma_matrix.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700)
31 #define CUTLASS_USE_WMMA_API
32 
33 #if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750)
34 #define CUTLASS_USE_SUBBYTE_WMMA
35 #endif
36 
37 #include "stdio.h"
38 
39 #if __CUDACC_VER_MAJOR__ >= 10
40 #include <mma.h>
41 #else
42 #include <crt/mma.h>
43 #endif
44 #include "cutlass/fragment.h"
45 #include "cutlass/matrix_traits.h"
46 #include "cutlass/shape.h"
47 #include "cutlass/vector.h"
48 
49 namespace cutlass {
50 
52 
54 template <MatrixLayout::Kind kLayout_>
55 struct WmmaLayout {
56  typedef nvcuda::wmma::col_major Layout;
57 };
58 
60 template <>
61 struct WmmaLayout<MatrixLayout::kRowMajor> {
62  typedef nvcuda::wmma::row_major Layout;
63 };
64 
66 
68 template <typename Type_>
69 struct WmmaDataType{
70  typedef Type_ Type;
71 };
72 
73 #ifdef CUTLASS_USE_SUBBYTE_WMMA
74 template<>
76 struct WmmaDataType<Vector<bin1_t, 32> > {
77  typedef nvcuda::wmma::experimental::precision::b1 Type;
78 };
79 
81 template<>
82 struct WmmaDataType<Vector<int4_t, 8> > {
83  typedef nvcuda::wmma::experimental::precision::s4 Type;
84 };
85 
87 template<>
88 struct WmmaDataType<Vector<uint4_t, 8> > {
89  typedef nvcuda::wmma::experimental::precision::u4 Type;
90 };
91 #endif
92 
94 
96 template <GemmOperand::Kind kOperand_,
97  MatrixLayout::Kind kLayout_,
98  typename Scalar_,
99  typename WmmaShape_>
100 struct WmmaMatrix {};
101 
103 
105 template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
106 struct WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_>
107  : public nvcuda::wmma::fragment<
109  nvcuda::wmma::matrix_a,
111  WmmaShape_::kW,
112  WmmaShape_::kH,
113  WmmaShape_::kD,
115  typename WmmaDataType<Scalar_>::Type,
117  typename WmmaLayout<kLayout_>::Layout> {
119  typedef WmmaMatrix<GemmOperand::kA, kLayout_, Scalar_, WmmaShape_> This_;
120 
122  CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
123  nvcuda::wmma::fill_fragment(*this, x);
124  return *this;
125  }
126 
128  CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
129  nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
130  }
131 
133  CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
134  nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
135  }
136 };
137 
139 
141 template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
142 struct WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_>
143  : public nvcuda::wmma::fragment<
145  nvcuda::wmma::matrix_b,
147  WmmaShape_::kW,
148  WmmaShape_::kH,
149  WmmaShape_::kD,
151  typename WmmaDataType<Scalar_>::Type,
153  typename WmmaLayout<kLayout_>::Layout> {
155  typedef WmmaMatrix<GemmOperand::kB, kLayout_, Scalar_, WmmaShape_> This_;
156 
158  CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
159  nvcuda::wmma::fill_fragment(*this, x);
160  return *this;
161  }
162 
164  CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
165  nvcuda::wmma::load_matrix_sync(*this, pointer, stride);
166  }
167 
169  CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
170  nvcuda::wmma::store_matrix_sync(pointer, *this, stride);
171  }
172 };
173 
175 
177 template <MatrixLayout::Kind kLayout_, typename Scalar_, typename WmmaShape_>
178 struct WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_>
179  : public nvcuda::wmma::fragment<
181  nvcuda::wmma::accumulator,
183  WmmaShape_::kW,
184  WmmaShape_::kH,
185  WmmaShape_::kD,
187  Scalar_> {
189  typedef WmmaMatrix<GemmOperand::kC, kLayout_, Scalar_, WmmaShape_> This_;
191  static MatrixLayout::Kind const kLayout = kLayout_;
192 
194  CUTLASS_DEVICE This_& operator=(Scalar_ const& x) {
195  nvcuda::wmma::fill_fragment(*this, x);
196  return *this;
197  }
198 
200  CUTLASS_DEVICE void load(Scalar_ const* pointer, int const stride) {
201  bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
202  nvcuda::wmma::load_matrix_sync(
203  *this,
204  pointer,
205  stride,
206  kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
207  }
208 
210  CUTLASS_DEVICE void store(Scalar_* pointer, int const stride) const {
211  bool const kIsRowMajor = kLayout == MatrixLayout::kRowMajor;
212  nvcuda::wmma::store_matrix_sync(
213  pointer,
214  *this,
215  stride,
216  kIsRowMajor ? nvcuda::wmma::mem_row_major : nvcuda::wmma::mem_col_major);
217  }
218 };
219 
221 
222 // WmmaMatrix cannot be used in a Union and thus in cannot be used in our Vector implementation.
223 // The only use of WmmaMatrix in in combination with Vectorize has kLanes == 1. Due to this it is
224 // safe to keep the Vector->Scalar conversion for WmmaMatrix.
225 template <GemmOperand::Kind kOperand_,
226  MatrixLayout::Kind kLayout_,
227  typename Scalar_,
228  typename WmmaShape_>
229 struct Vectorize<WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_>, 1> {
230  typedef WmmaMatrix<kOperand_, kLayout_, Scalar_, WmmaShape_> Type;
231 };
232 
234 }
235 
236 #endif // defined CUTLASS_USE_WMMA_API
Definition: convert.h:33
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
Definition: matrix_traits.h:159
Vector< Element_, kLanes_ > Type
Definition: vector.h:271
Defines a 1D vector of elements held in the registers of each thread.
Kind
Definition: matrix_traits.h:357
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...