Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
igemm_epilogue.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/convert.h"
32 #include "cutlass/fragment.h"
36 #include "cutlass/reshape_tile.h"
37 #include "cutlass/tile_iterator.h"
38 
39 namespace cutlass {
40 namespace gemm {
41 
43 
44 template <int kElements_>
50 
51  // We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
52  static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
53 
55  CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
56 
58  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
59  transform(src, 0, dst);
60  }
61 
63  template <typename Fragment_>
64  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
65  // The inputs.
66  float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
67  // The outputs.
68  int* dst_int = reinterpret_cast<int*>(&dst[0]);
69 
70  // Iterate over the floats and pack them together to produce ints.
71  for (int i = 0; i < kElements_ / 4; ++i) {
72  // Read the float4.
73  float4 f4 = src_f4[i];
74 
75  // Clamp the 4 elements of the floats to the [-128, +127] range.
76  float x = fmaxf(-128.f, fminf(127.f, f4.x));
77  float y = fmaxf(-128.f, fminf(127.f, f4.y));
78  float z = fmaxf(-128.f, fminf(127.f, f4.z));
79  float w = fmaxf(-128.f, fminf(127.f, f4.w));
80 
81  // Convert to integers.
82  int ix = (int)x;
83  int iy = (int)y;
84  int iz = (int)z;
85  int iw = (int)w;
86 
87  // Extract the lower bytes to build an int32 with 4 int8.
88  asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
89  asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
90  asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
91 
92  // Store the int.
93  dst_int[i] = ix;
94  }
95  }
96 };
97 
99 
100 template <typename InputScalar_, typename OutputFragment_>
103 };
104 
105 template <int kElements_>
106 struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
108 };
109 
111 
112 template <int kElements_>
118 
119  // We are unpacking 4 int8s from int32.
120  static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
121 
123  CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
124 
126  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
127  transform(src, 0, dst);
128  }
129 
131  template <typename Fragment_>
132  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
133  // The inputs.
134  int const* src_int = reinterpret_cast<int const*>(&src[0]);
135  // The outputs.
136  float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
137 
138  // Iterate over the int8 and unpack them together to produce floats.
139  for (int i = 0; i < kElements_ / 4; ++i) {
140  // Read the int.
141  int ix, iy, iz, iw = src_int[i];
142 
143  // Extract the 4 bytes.
144  asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
145  asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
146  asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
147  asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
148 
149  // The floats.
150  float fx, fy, fz, fw;
151 
152  // Convert to floats (make sure we generate I2F.F32.S8).
153  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
154  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
155  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
156  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
157 
158  // Store the float4.
159  dst_f4[i] = make_float4(fx, fy, fz, fw);
160  }
161  }
162 };
163 
165 
166 template <typename InputFragment_, typename OutputScalar_>
169 };
170 
171 template <int kElements_>
172 struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
174 };
175 
177 
178 template <typename InputScalar_, typename OutputFragment_>
181 };
182 
184 
185 template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
187  : public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
191  typedef IgemmConfig_ IgemmConfig;
192 
194  typedef typename Base::Scalar Scalar;
196  typedef typename Base::Iterations Iterations;
198  typedef typename Base::Delta Delta;
199 
207  typedef
209 
217  typedef
219 
232  SharedStoreFragmentD>::Transformer
242 };
243 
245 
246 template <
248  typename IgemmConfig_,
250  typename EpilogueFunctor_,
252  typename Index_ = int,
256  // The output tile.
257  typename IgemmConfig_::OutputTile,
258  // The accumulators.
259  typename IgemmConfig_::Accumulators,
260  // The global iterator for C.
261  typename Helper_::GlobalLoadIteratorC,
262  // The transformer for C.
263  typename Helper_::GlobalTransformerC,
264  // The transformer for D.
265  typename Helper_::GlobalTransformerD,
266  // The global iterator for D.
267  typename Helper_::GlobalStoreIteratorD,
268  // The iterator to store D to shared memory.
269  typename Helper_::SharedStoreIteratorD,
270  // The shared store transformer for D.
271  typename Helper_::SharedStoreTransformerD,
272  // The stream to load D from shared memory.
273  typename Helper_::SharedLoadStreamD,
274  // The iterations.
275  typename Helper_::Iterations,
276  // The strides between iterations.
277  typename Helper_::Delta,
278  // The functor to be used in the epilogue.
279  EpilogueFunctor_,
280  // The index.
281  Index_> {
283  static bool const kInt8Output =
285 };
286 
288 
289 template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
290 struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
293 
295  CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
296  typename Base::SharedStorage& shared_storage_,
297  Coord<3> const& _problem_size)
298  : Base(params_, shared_storage_, _problem_size) {}
299 };
300 
302 
303 template <typename GemmEpilogueTraits_>
304 struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
307 
309  CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
310  typename Base::SharedStorage& shared_storage_,
311  Coord<3> const& _problem_size)
312  : Base(params_, shared_storage_, _problem_size) {}
313 };
314 
316 
317 } // namespace gemm
318 } // namespace cutlass
Definition: gemm_global_tile.h:120
Definition: igemm_epilogue.h:255
Definition: load_store.h:41
Base::Delta Delta
The iterations strides.
Definition: igemm_epilogue.h:198
Base::SharedStoreTileTraits SharedStoreTileTraits
The traits class for the shared iterator to store D to shared memory.
Definition: igemm_epilogue.h:221
IgemmGlobalStoreTransformer< Scalar, GlobalFragmentD >::Transformer GlobalTransformerD
The transformer from accumulators to shared memory fragments.
Definition: igemm_epilogue.h:218
Definition: convert.h:33
Base::SharedLoadTileTraits SharedLoadTileTraits
The traits class for the shared iterator to load D from shared memory.
Definition: igemm_epilogue.h:235
TileLoadIterator< SharedLoadTileTraits, typename SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorD
The shared iterator to load D from shared memory.
Definition: igemm_epilogue.h:241
Definition: gemm_epilogue_traits.h:203
GemmEpilogue< GemmEpilogueTraits_ > Base
The base class.
Definition: igemm_epilogue.h:292
Traits::Params Params
The params.
Definition: gemm_epilogue.h:46
Definition: gemm_epilogue.h:42
Definition: igemm_epilogue.h:167
std::is_same (false specialization)
Definition: platform.h:420
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
CUTLASS_DEVICE IgemmInt8ToFloatConverter()
Ctor.
Definition: igemm_epilogue.h:123
SharedStoreIteratorD::Fragment SharedStoreFragmentD
The fragment that needs to be passed to that store iterator.
Definition: igemm_epilogue.h:229
EpilogueFunctor_::Scalar Scalar
The scalar.
Definition: gemm_epilogue_traits.h:205
Definition: igemm_epilogue.h:186
Definition: load_store.h:42
Fragment< int8_t, kElements_ > InputFragment
The input fragment.
Definition: igemm_epilogue.h:115
Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
The fragment.
Definition: tile_iterator.h:199
Definition: igemm_epilogue.h:290
Definition: igemm_epilogue.h:45
CUTLASS_DEVICE void transform(Fragment_ const &src, int offset, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:64
Traits::SharedStorage SharedStorage
The shared storage.
Definition: gemm_epilogue.h:48
A template defining Fragment Concept.
Definition: fragment.h:99
Definition: tile_iterator.h:65
CUTLASS_DEVICE void transform(InputFragment const &src, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:126
Base::Scalar Scalar
The scalar type of the epilogue.
Definition: igemm_epilogue.h:194
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const &params_, typename Base::SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
Ctor.
Definition: igemm_epilogue.h:295
GlobalLoadIteratorC::Fragment GlobalFragmentC
The fragment that needs to be produced by the load iterator.
Definition: igemm_epilogue.h:205
Base::Fragment Fragment
Fragment definition.
Definition: tile_iterator.h:909
CUTLASS_DEVICE void transform(InputFragment const &src, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:58
Fragment< int8_t, kElements_ > OutputFragment
The output fragment.
Definition: igemm_epilogue.h:49
GemmGlobalIteratorCd< GlobalStoreTileTraits > GlobalStoreIteratorD
The iterator to store to shared memory.
Definition: igemm_epilogue.h:213
IgemmSharedStoreTransformer< typename IgemmConfig::Accumulators::Element, SharedStoreFragmentD >::Transformer SharedStoreTransformerD
The transformer from accumulators to shared memory fragments.
Definition: igemm_epilogue.h:233
static bool const kInt8Output
Do we output in int8?
Definition: igemm_epilogue.h:283
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
Convert< Fragment< InputScalar_, OutputFragment_::kElements >, OutputFragment_ > Transformer
Definition: igemm_epilogue.h:180
GemmEpilogue< GemmEpilogueTraits_ > Base
The base class.
Definition: igemm_epilogue.h:306
Defines a type for restructuring a tile.
Base::GlobalLoadTileTraits GlobalLoadTileTraits
The traits class for the iterator.
Definition: igemm_epilogue.h:201
Fragment< float, kElements_ > OutputFragment
The output fragment.
Definition: igemm_epilogue.h:117
GemmEpilogueTraitsHelper< IgemmConfig_, EpilogueFunctor_, Index_ > Base
The base class.
Definition: igemm_epilogue.h:189
Definition: gemm_shared_tile.h:339
GlobalStoreIteratorD::Fragment GlobalFragmentD
The fragment that needs to be passed to that store iterator.
Definition: igemm_epilogue.h:215
GemmGlobalIteratorCd< GlobalLoadTileTraits > GlobalLoadIteratorC
The iterator to store to shared memory.
Definition: igemm_epilogue.h:203
#define static_assert(__e, __m)
Definition: platform.h:153
IgemmConfig_ IgemmConfig
The config.
Definition: igemm_epilogue.h:191
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
CUTLASS_DEVICE IgemmFloatToInt8Converter()
Ctor.
Definition: igemm_epilogue.h:55
Element_ Element
The element.
Definition: fragment.h:108
Fragment< float, kElements_ > InputFragment
The input fragment.
Definition: igemm_epilogue.h:47
Definition: gemm_epilogue_traits.h:70
Definition: gemm_global_tile.h:366
Definition: igemm_epilogue.h:179
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
Definition: convert.h:38
IgemmFloatToInt8Converter< kElements_ > Transformer
Definition: igemm_epilogue.h:107
Base::Iterations Iterations
The iterations.
Definition: igemm_epilogue.h:196
IgemmGlobalLoadTransformer< GlobalFragmentC, Scalar >::Transformer GlobalTransformerC
The transformer from loaded data to math fragment.
Definition: igemm_epilogue.h:208
Base::GlobalStoreTileTraits GlobalStoreTileTraits
The traits class for the iterator.
Definition: igemm_epilogue.h:211
Convert< InputFragment_, Fragment< OutputScalar_, InputFragment_::kElements > > Transformer
Definition: igemm_epilogue.h:168
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:272
CUTLASS_DEVICE void transform(Fragment_ const &src, int offset, OutputFragment &dst)
Transform a fragment.
Definition: igemm_epilogue.h:132
Convert< Fragment< InputScalar_, OutputFragment_::kElements >, OutputFragment_ > Transformer
Definition: igemm_epilogue.h:102
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Definition: igemm_epilogue.h:101
TileStoreIterator< SharedStoreTileTraits, typename SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal > SharedStoreIteratorD
The shared iterator to store D to shared memory.
Definition: igemm_epilogue.h:227
IgemmInt8ToFloatConverter< kElements_ > Transformer
Definition: igemm_epilogue.h:173
Defines conversion operations among Fragments of different base type.
Definition: igemm_epilogue.h:113
platform::remove_const< Scalar_ >::type Scalar
The scalar.
Definition: gemm_shared_tile.h:341
CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const &params_, typename Base::SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
Ctor.
Definition: igemm_epilogue.h:309
Implements tile iterators to partition the thread block tile into 2D subtiles and efficiently load ea...
Definition: gemm_shared_tile.h:270
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841