Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
gemm_traits.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/convert.h"
31 #include "cutlass/matrix_traits.h"
32 #include "cutlass/reshape_tile.h"
34 #include "cutlass/tile_iterator.h"
35 #include "cutlass/kernel_launch.h"
36 
39 #include "cutlass/gemm/gemm_desc.h"
45 #include "cutlass/gemm/gemm.h"
46 namespace cutlass {
47 namespace gemm {
48 
50 
51 template <enum MatrixLayout::Kind, typename GemmConfig_>
53 
55 
56 template <typename GemmConfig_>
57 struct GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
60 
62  typedef typename GemmConfig_::ScalarA Scalar;
64  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
65 
67  typedef GemmGlobalTileTraits<
68  // That's A.
70  // A is column-major.
72  // The pointer is float const.
73  Scalar const,
74  // The tile has size KxM in GEMM's terminology.
76  // The threads are distributed as warps x 32 (the traits may reorganize).
78  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
79  GemmConfig_::kScalarsPerLdgA>
81 
84  // The pointer is float.
86  // The tile has size KxM in GEMM's terminology.
87  Shape<GemmConfig_::kStages,
88  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
89  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
90  // The threads are distributed as warps x 32 (the traits may reorganize).
92  // The number of scalars per STS (STS.32 or STS.128, etc).
93  GemmConfig_::kScalarsPerStsA>
95 
98  // The pointer is float const.
99  MultiplyAddScalar const,
100  // The output tile size.
101  typename GemmConfig_::OutputTile,
102  // The number of warps.
103  typename GemmConfig_::Warps,
104  // The number of threads per warp.
105  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
106  // The shape of the FMA instruction.
107  typename GemmConfig_::InstructionShape,
108  // The number of stages.
109  GemmConfig_::kStages,
110  // The number of scalars per LDS.
111  GemmConfig_::kScalarsPerLdsA,
112  // The skew.
113  0>
115 };
116 
118 
119 template <typename GemmConfig_>
120 struct GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
123 
125  typedef typename GemmConfig_::ScalarA Scalar;
127  typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar;
128 
130  typedef GemmGlobalTileTraits<
131  // That's A.
133  // A is row-major.
135  // The pointer is float const.
136  Scalar const,
137  // The tile has size MxK in GEMM's terminology.
139  // The threads are distributed as (threads / K) x K (the traits may reorganize).
140  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
141  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
142  GemmConfig_::kScalarsPerLdgA>
144 
146  static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
148  static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA /
149  GlobalTileTraits::Threads::kW * kScalarsIn4B;
150 
153  // The pointer is float.
155  // The tile has size KxM in GEMM's terminology.
156  Shape<GemmConfig_::kStages,
157  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
158  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
159  // The threads are distributed as (threads / K) x K (the traits may reorganize).
160  typename GlobalTileTraits::Threads,
161  // The number of scalars per STS.
162  GemmConfig_::kScalarsPerStsA,
163  // The skew to avoid bank conflicts added in the tile W dimension.
164  kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
165  SharedStoreTileTraits;
166 
169  // The pointer is float const.
170  MultiplyAddScalar const,
171  // The output tile size.
172  typename GemmConfig_::OutputTile,
173  // The number of warps.
174  typename GemmConfig_::Warps,
175  // The number of threads per warp.
176  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
177  // The shape of the FMA instruction.
178  typename GemmConfig_::InstructionShape,
179  // The number of stages.
180  GemmConfig_::kStages,
181  // The number of scalars per LDS.
182  GemmConfig_::kScalarsPerLdsA,
183  // The skew.
184  SharedStoreTileTraits::kSkew>
185  SharedLoadTileTraits;
186 };
187 
189 
190 template <enum MatrixLayout::Kind, typename GemmConfig_>
192 
194 
195 template <typename GemmConfig_>
196 struct GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
199 
201  typedef typename GemmConfig_::ScalarB Scalar;
203  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
204 
206  typedef GemmGlobalTileTraits<
207  // That's B.
209  // B is column-major.
211  // The pointer is float const.
212  Scalar const,
213  // The tile has size MxK in GEMM's terminology.
215  // The threads are distributed as (threads / K) x K (the traits may reorganize).
216  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
217  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
218  GemmConfig_::kScalarsPerLdgB>
220 
222  static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar);
224  static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB /
225  GlobalTileTraits::Threads::kW * kScalarsIn4B;
226 
229  // The pointer is float.
231  // The tile has size KxN in GEMM's terminology.
232  Shape<GemmConfig_::kStages,
233  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
234  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
235  // The threads are distributed as (threads / K) x K (the traits may reorganize).
236  typename GlobalTileTraits::Threads,
237  // The number of scalars per STS.
238  GemmConfig_::kScalarsPerStsB,
239  // The skew to avoid bank conflicts added in the tile W dimension.
240  kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
241  SharedStoreTileTraits;
242 
245  // The pointer is float const.
246  MultiplyAddScalar const,
247  // The output tile size.
248  typename GemmConfig_::OutputTile,
249  // The number of warps.
250  typename GemmConfig_::Warps,
251  // The number of threads per warp.
252  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
253  // The shape of the FMA instruction.
254  typename GemmConfig_::InstructionShape,
255  // The number of stages.
256  GemmConfig_::kStages,
257  // The number of scalars per LDS.
258  GemmConfig_::kScalarsPerLdsB,
259  // The skew.
260  SharedStoreTileTraits::kSkew>
261  SharedLoadTileTraits;
262 };
263 
265 
266 template <typename GemmConfig_>
267 struct GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
270 
272  typedef typename GemmConfig_::ScalarB Scalar;
274  typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar;
275 
277  typedef GemmGlobalTileTraits<
278  // That's B.
280  // B is row-major.
282  // The pointer is float const.
283  Scalar const,
284  // The tile has size KxN in GEMM's terminology.
286  // The threads are distributed as warps x 32 (the traits may reorganize).
288  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
289  GemmConfig_::kScalarsPerLdgB>
291 
294  // The pointer is float.
296  // The tile has size KxN in GEMM's terminology.
297  Shape<GemmConfig_::kStages,
298  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
299  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
300  // The threads are distributed as warps x 32 (the traits may reorganize).
301  typename GlobalTileTraits::Threads,
302  // The number of scalars per STS (STS.32 or STS.128, etc).
303  GemmConfig_::kScalarsPerStsB>
305 
308  // The pointer is float const.
309  MultiplyAddScalar const,
310  // The output tile size.
311  typename GemmConfig_::OutputTile,
312  // The number of warps.
313  typename GemmConfig_::Warps,
314  // The number of threads per warp.
315  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
316  // The shape of the FMA instruction.
317  typename GemmConfig_::InstructionShape,
318  // The number of stages.
319  GemmConfig_::kStages,
320  // The number of scalars per LDS.
321  GemmConfig_::kScalarsPerLdsB,
322  // The skew.
323  0>
325 };
326 
328 
329 template <
331  typename GemmConfig_,
333  typename GlobalLoadStreamA_,
335  typename GlobalLoadStreamB_,
337  typename SharedLoadStreamA_,
339  typename SharedLoadStreamB_,
341  typename Epilogue_,
343  typename BlockSwizzle_ = IdentityBlockSwizzle,
345  typename Index_ = int,
348 
349 struct GemmTraits {
351  typedef GemmTraits<GemmConfig_,
352  GlobalLoadStreamA_,
353  GlobalLoadStreamB_,
354  SharedLoadStreamA_,
355  SharedLoadStreamB_,
356  Epilogue_,
357  BlockSwizzle_,
358  Index_,
359  ClearAccumulators_> This_;
360 
363 
365  typedef GemmConfig_ GemmConfig;
368 
370  typedef GlobalLoadStreamA_ GlobalLoadStreamA;
372  static MatrixLayout::Kind const kLayoutA = GlobalLoadStreamA::kLayout;
374  typedef typename GlobalLoadStreamA_::Scalar ScalarA;
375 
377  typedef GlobalLoadStreamB_ GlobalLoadStreamB;
379  static MatrixLayout::Kind const kLayoutB = GlobalLoadStreamB::kLayout;
381  typedef typename GlobalLoadStreamB_::Scalar ScalarB;
382 
384  typedef SharedLoadStreamA_ SharedLoadStreamA;
386  typedef SharedLoadStreamB_ SharedLoadStreamB;
387 
391  typedef Epilogue_ Epilogue;
393  typedef typename Epilogue::ScalarC ScalarC;
394  typedef typename Epilogue::ScalarD ScalarD;
395 
397  typedef BlockSwizzle_ BlockSwizzle;
399  typedef Index_ Index;
401  typedef ClearAccumulators_ ClearAccumulators;
402 
408 
411 
414 
417 
420 
423 
426 
429 
431  typename Epilogue::Params epilogue;
432 
434  template <typename GemmDesc_>
435  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
436  // Set the problem size.
437  problem_size = desc.problem_size;
438 
439  // there is no partitionK in the default case
441  // Compute grid dimensions
442  BlockSwizzle block_swizzle;
443  this->block = dim3(GemmConfig::kThreads);
444  this->grid = block_swizzle.get_grid_layout(
445  problem_size,
446  make_Coord_from_shape<OutputTile>());
447 
448  // Compute offset to residue.
449  // partitionK_range <= problem_size[0]
450  Index gemm_k = problem_size[0];
451  Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
452  Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
453 
454  // Initialize parameters objects for
455  int error_code = global_to_shared_stream.stream_a.initialize(
456  desc.A.data(),
457  desc.batch_stride_A,
458  desc.A.leading_dim(),
459  offset_to_residue,
460  offset_to_residue_last_partition
461  );
462  if (error_code) {
463  return error_code;
464  }
465 
466  error_code = global_to_shared_stream.stream_b.initialize(
467  desc.B.data(),
468  desc.batch_stride_B,
469  desc.B.leading_dim(),
470  offset_to_residue,
471  offset_to_residue_last_partition
472  );
473 
474  if (error_code) {
475  return error_code;
476  }
477 
478  // The epilogue.
479  return epilogue.initialize(desc);
480  }
481 
484  Index n,
485  Index k,
486  typename Epilogue::Scalar alpha,
487  ScalarA const* d_a,
488  Index lda,
489  ScalarB const* d_b,
490  Index ldb,
491  typename Epilogue::Scalar beta,
492  ScalarC const* d_c,
493  Index ldc,
494  ScalarD* d_d,
495  Index ldd) {
497  GemmCoord(k, n, m, 1),
498  alpha,
499  TensorRef<ScalarA const, 2>(d_a, lda),
500  TensorRef<ScalarB const, 2>(d_b, ldb),
501  beta,
502  TensorRef<ScalarC const, 2>(d_c, ldc),
503  TensorRef<ScalarD, 2>(d_d, ldd)
504  );
505 
506  return this->initialize(desc);
507  }
508 
511  Index n,
512  Index k,
513  typename Epilogue::Scalar alpha,
514  ScalarA const* d_a,
515  Index lda,
516  long long int batch_stride_A,
517  ScalarB const* d_b,
518  Index ldb,
519  long long int batch_stride_B,
520  typename Epilogue::Scalar beta,
521  ScalarC const* d_c,
522  Index ldc,
523  long long int batch_stride_C,
524  ScalarD* d_d,
525  Index ldd,
526  long long int batch_stride_D,
527  Index batch_count) {
529  GemmCoord(k, n, m, batch_count),
530  alpha,
531  TensorRef<ScalarA const, 2>(d_a, lda),
532  batch_stride_A,
533  TensorRef<ScalarB const, 2>(d_b, ldb),
534  batch_stride_B,
535  beta,
536  TensorRef<ScalarC const, 2>(d_c, ldc),
537  batch_stride_C,
538  TensorRef<ScalarD, 2>(d_d, ldd),
539  batch_stride_D
540  );
541 
542  return this->initialize(desc);
543  }
544 
546  template <typename GemmDesc_>
547  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& partitonK_desc, Index partitionK_count_) {
548  // partitionK GEMM is a specialized batched stried gemm with different K ranges per batch
549  // the problem_size of each batch is (lastK_size, n, m)
550  // add more comments here
551  // the k range for every batch excpet the last one
552  //assert(partitionK_count_ > 0);
553  partitionK_range = partitonK_desc.problem_size.k() / partitionK_count_;
554  // the k range of the last batch
555  // int lastK_range = (partitonK_desc.problem_size.k() % partitionK_range) + partitionK_range;
556  int lastK_range = partitonK_desc.problem_size.k() - partitionK_range * (partitionK_count_ - 1);
557  int k_size = lastK_range;
558  int lda = partitonK_desc.A.stride(0);
559  int ldb = partitonK_desc.B.stride(0);
560  int ldc = partitonK_desc.C.stride(0);
561  int ldd = partitonK_desc.D.stride(0);
562  int n = partitonK_desc.problem_size.n();
563 
564 
565  long long int batch_stride_A = (kLayoutA == cutlass::MatrixLayout::kColumnMajor) ? lda * partitionK_range : partitionK_range;
566  long long int batch_stride_B = (kLayoutB == cutlass::MatrixLayout::kColumnMajor) ? partitionK_range : partitionK_range * ldb;
567  long long int batch_stride_C = ldc * n;
568  long long int batch_stride_D = ldd * n;
569 
571  //we pass lastK_size as per batch K. there is also a range that will match partitionK_size
572  GemmCoord(k_size, partitonK_desc.problem_size.n(), partitonK_desc.problem_size.m(), partitionK_count_),
573  partitonK_desc.alpha,
574  partitonK_desc.A,
575  batch_stride_A,
576  partitonK_desc.B,
577  batch_stride_B,
578  partitonK_desc.beta,
579  partitonK_desc.C,
580  batch_stride_C,
581  partitonK_desc.D,
582  batch_stride_D
583  );
584 
585  // Set the problem size.
586  problem_size = desc.problem_size;
587 
588  // Compute grid dimensions
589  BlockSwizzle block_swizzle;
590  this->block = dim3(GemmConfig::kThreads);
591  this->grid = block_swizzle.get_grid_layout(
592  problem_size,
593  make_Coord_from_shape<OutputTile>());
594 
595  // Compute offset to residue.
596  // partitionK_range <= problem_size[0]
597  Index gemm_k = problem_size[0];
598  Index offset_to_residue_last_partition = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0;
599  Index offset_to_residue = (partitionK_range % OutputTile::kD) ? partitionK_range - (partitionK_range % OutputTile::kD) : 0;
600 
601  // Initialize parameters objects for
602  int error_code = global_to_shared_stream.stream_a.initialize(
603  desc.A.data(),
604  desc.batch_stride_A,
605  desc.A.leading_dim(),
606  offset_to_residue,
607  offset_to_residue_last_partition
608  );
609  if (error_code) {
610  return error_code;
611  }
612 
613  error_code = global_to_shared_stream.stream_b.initialize(
614  desc.B.data(),
615  desc.batch_stride_B,
616  desc.B.leading_dim(),
617  offset_to_residue,
618  offset_to_residue_last_partition
619  );
620 
621  if (error_code) {
622  return error_code;
623  }
624 
625  // The epilogue.
626  return epilogue.initialize(desc);
627  }
628 
629 
632  Index n,
633  Index k,
634  typename Epilogue::Scalar alpha,
635  ScalarA const* d_a,
636  Index lda,
637  ScalarB const* d_b,
638  Index ldb,
639  typename Epilogue::Scalar beta,
640  ScalarC const* d_c,
641  Index ldc,
642  ScalarD* d_d,
643  Index ldd,
644  Index partitionK_count_) {
645 
647  GemmCoord(k, n, m, 1),
648  alpha,
649  TensorRef<ScalarA const, 2>(d_a, lda),
650  TensorRef<ScalarB const, 2>(d_b, ldb),
651  beta,
652  TensorRef<ScalarC const, 2>(d_c, ldc),
653  TensorRef<ScalarD, 2>(d_d, ldd)
654  );
655 
656 
657  return this->initialize(desc, partitionK_count_);
658  }
659  };
660 
661  // The storage for the main loop + prologue.
665 
668 
671  };
672 
675  // The storage for the main loop.
677  // The storage for the epilogue.
678  typename Epilogue::SharedStorage epilogue;
679  };
680 
682  static CUTLASS_DEVICE void shared_load_fence(bool in_loop) {
683  if (SharedLoadStreamA::Iterator::kRequiresLoadFence ||
684  SharedLoadStreamB::Iterator::kRequiresLoadFence) {
685  __syncthreads();
686  }
687  }
688 
690  static CUTLASS_DEVICE void shared_store_fence(bool in_loop) {
691  __syncthreads();
692  }
693 };
694 
696 
697 template <typename GemmTileTraitsHelperA_, typename GemmTileTraitsHelperB_, typename Index_>
705  typedef TileStoreIterator<typename GemmTileTraitsHelperA_::SharedStoreTileTraits,
706  typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar,
716 
723  typedef TileStoreIterator<typename GemmTileTraitsHelperB_::SharedStoreTileTraits,
724  typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar,
734 
736  typedef TileLoadIterator<typename GemmTileTraitsHelperA_::SharedLoadTileTraits,
737  typename GemmTileTraitsHelperA_::Scalar,
744  typedef TileLoadIterator<typename GemmTileTraitsHelperB_::SharedLoadTileTraits,
745  typename GemmTileTraitsHelperB_::Scalar,
751 };
752 
754 
755 template <
757  MatrixLayout::Kind kLayoutA_,
759  MatrixLayout::Kind kLayoutB_,
761  typename GemmConfig_,
763  typename Epilogue_,
765  typename Index_ = int,
766  // The configuration for the A matrix.
767  typename GemmTileTraitsHelperA_ = GemmTileTraitsHelperA<kLayoutA_, GemmConfig_>,
768  // The configuration for the B matrix.
769  typename GemmTileTraitsHelperB_ = GemmTileTraitsHelperB<kLayoutB_, GemmConfig_>,
770  // The helper class to create the streams and iterators.
771  typename Helper_ =
774  // The config.
775  GemmConfig_,
776  // The stream to load A from global memory to shared memory.
777  typename Helper_::GlobalLoadStreamA,
778  // The stream to load B from global memory to shared memory.
779  typename Helper_::GlobalLoadStreamB,
780  // The stream to load A from shared memory.
781  typename Helper_::SharedLoadStreamA,
782  // The stream to load B from shared memory.
783  typename Helper_::SharedLoadStreamB,
784  // The epilogue.
785  Epilogue_,
786  // The block swizzle to reorganize the grid.
787  IdentityBlockSwizzle,
788  // The index.
789  Index_,
790  // The tool used to clear accumulators.
791  ClearAccumulators<typename GemmConfig_::Accumulators::Element> > {
792 };
793 
795 
796 } // namespace gemm
797 } // namespace cutlass
Epilogue::SharedStorage epilogue
Definition: gemm_traits.h:678
GEMM problem description.
Definition: gemm_desc.h:50
GlobalLoadStreamA_ GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:370
GlobalStoreIteratorD::Scalar ScalarD
The scalar for D.
Definition: gemm_epilogue.h:85
GlobalLoadStream< GemmOperand::kA, GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
The stream to load A from global memory to shared memory.
Definition: gemm_traits.h:715
Definition: load_store.h:41
SharedLoadStreamA_ SharedLoadStreamA
The iterator for A to load from shared memory.
Definition: gemm_traits.h:384
Definition: convert.h:33
Definition: gemm_shared_tile.h:128
GlobalLoadStreamB_ GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:377
static int const kThreads
The numnber of threads.
Definition: gemm_config.h:103
TileStoreIterator< typename GemmTileTraitsHelperA_::SharedStoreTileTraits, typename GemmTileTraitsHelperA_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
The iterator to store A to shared memory.
Definition: gemm_traits.h:709
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Epilogue::ScalarD ScalarD
Definition: gemm_traits.h:394
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, ScalarB const *d_b, Index ldb, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, ScalarD *d_d, Index ldd)
Helper to construct a GEMM params using a BLAS-like API.
Definition: gemm_traits.h:483
The storage in shared memory.
Definition: gemm_traits.h:674
SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
The stream to load B from shared memory.
Definition: gemm_traits.h:750
Definition: gemm_global_tile.h:70
TensorRefA A
The source matrix A.
Definition: gemm_desc.h:96
Defies functors for mapping blockIdx to partitions of the GEMM computation.
GemmCoord problem_size
The dimensions of the GEMM.
Definition: gemm_desc.h:90
Defines a structure containing shared storage for each pair.
Definition: gemm_stream_pair.h:91
GlobalLoadStream< GemmOperand::kB, GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
The stream to load B from global memory to shared memory.
Definition: gemm_traits.h:733
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:201
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kH *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsB > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for B^T.
Definition: gemm_traits.h:304
Definition: gemm_coord.h:43
GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ > This_
This traits.
Definition: gemm_traits.h:359
SharedLoadStreamB_ SharedLoadStreamB
The iterator for B to load from shared memory.
Definition: gemm_traits.h:386
int partitionK_range
The K range for every partition except the last one.
Definition: gemm_traits.h:422
Defines structures and helpers to launch CUDA kernels within CUTLASS.
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
GlobalLoadStreamPair< GlobalLoadStreamA, GlobalLoadStreamB, GemmConfig::kResidueInProlog > GlobalLoadStream
Assemble the global load streams for A/B.
Definition: gemm_traits.h:407
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^N.
Definition: gemm_traits.h:219
Definition: convert.h:69
ThreadblockTileStorage threadblock_tile
Stores the threadblock tile.
Definition: gemm_traits.h:664
SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
The stream to load A from shared memory.
Definition: gemm_traits.h:742
Definition: gemm_shared_tile.h:38
TensorRefB B
The source matrix B.
Definition: gemm_desc.h:102
GemmSharedLoadTileATraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsA, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for A^N.
Definition: gemm_traits.h:114
Epilogue_ Epilogue
The epilogue.
Definition: gemm_traits.h:391
GlobalLoadStreamA_::Scalar ScalarA
The scalar for A.
Definition: gemm_traits.h:374
Definition: tile_iterator.h:65
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^N.
Definition: gemm_traits.h:80
GlobalLoadStream::Params global_to_shared_stream
Parameters object for the global load stream.
Definition: gemm_traits.h:425
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:62
Definition: gemm_shared_tile.h:200
Definition: gemm_global_tile.h:163
Epilogue::ScalarC ScalarC
The scalars in the epilogue.
Definition: gemm_traits.h:393
GemmConfig::MultiplyAdd MultiplyAdd
The multiply-add functor.
Definition: gemm_traits.h:389
static CUTLASS_DEVICE void shared_load_fence(bool in_loop)
The memory fence for shared loads.
Definition: gemm_traits.h:682
GemmConfig_ GemmConfig
The configuration.
Definition: gemm_traits.h:365
Definition: gemm_global_stream.h:52
Definition: gemm_traits.h:191
Definition: clear_accumulators.h:38
Parameters object constructable on the host.
Definition: gemm_traits.h:416
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:180
Copy< typename GlobalLoadIteratorB::Fragment > GlobalTransformerB
The data converter for B before storing to shared memory.
Definition: gemm_traits.h:721
GemmConfig_::ScalarB Scalar
The input scalar.
Definition: gemm_traits.h:272
StreamB::Params stream_b
Parameters object for StreamB.
Definition: gemm_stream_pair.h:67
long long batch_stride_A
batch stride for A operand
Definition: gemm_desc.h:99
Definition: gemm.h:92
GemmGlobalIteratorAb< typename GemmTileTraitsHelperB_::GlobalTileTraits, Index_ > GlobalLoadIteratorB
The global iterator to load B from global memory.
Definition: gemm_traits.h:719
static bool const kResidueInProlog
If true, residue is computed in the prologue.
Definition: gemm_config.h:136
An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
Definition: tile_iterator.h:402
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &partitonK_desc, Index partitionK_count_)
Helper to construct a partitionedK GEMM params.
Definition: gemm_traits.h:547
Collect the global load streams for multiplicands.
Definition: gemm_stream_pair.h:50
MultiplyAdd_ MultiplyAdd
The functor to do D = A*B + C.
Definition: gemm_config.h:90
Defines a fragment based on a Shape<> template.
Structure containing the basic launch configuration of a CUDA kernel.
Definition: kernel_launch.h:38
ClearAccumulators_ ClearAccumulators
Clear the accumulators.
Definition: gemm_traits.h:401
Definition: gemm_shared_stream.h:45
GemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
The traits class to build the iterator to load data from global memory for A^T.
Definition: gemm_traits.h:143
Parameters object.
Definition: gemm_stream_pair.h:62
Defines a type for restructuring a tile.
Defines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory...
GemmCoord problem_size
GEMM problem size.
Definition: gemm_traits.h:419
Implements a software-pipelined efficient GEMM.
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, long long int batch_stride_A, ScalarB const *d_b, Index ldb, long long int batch_stride_B, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, long long int batch_stride_C, ScalarD *d_d, Index ldd, long long int batch_stride_D, Index batch_count)
Helper to construct a batched GEMM params.
Definition: gemm_traits.h:510
Defines abstractions for efficiently clearing accumulator tiles.
Definition: tensor_ref.h:131
SharedStreamPair< SharedLoadStreamA, SharedLoadStreamB > SharedStream
Assemble the shared load streams for A/B.
Definition: gemm_traits.h:413
static CUTLASS_DEVICE void shared_store_fence(bool in_loop)
The memory fence for shared stores.
Definition: gemm_traits.h:690
GemmConfig_::ScalarA Scalar
The input scalar.
Definition: gemm_traits.h:125
CUTLASS_HOST_DEVICE int initialize(Index m, Index n, Index k, typename Epilogue::Scalar alpha, ScalarA const *d_a, Index lda, ScalarB const *d_b, Index ldb, typename Epilogue::Scalar beta, ScalarC const *d_c, Index ldc, ScalarD *d_d, Index ldd, Index partitionK_count_)
Helper to construct a partitionedK GEMM params.
Definition: gemm_traits.h:631
Manages a pair of tile allocations as if they are one allocation.
Definition: tile_allocation.h:125
Definition: gemm_traits.h:52
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: gemm_traits.h:435
Definition: matrix_traits.h:357
Definition: gemm/threadblock_swizzle.h:65
GemmSharedStoreTileAbTraits< MultiplyAddScalar, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/GemmConfig_::InstructionShape::kD, GemmConfig_::OutputTile::kW *GemmConfig_::InstructionShape::kD >, typename GlobalTileTraits::Threads, GemmConfig_::kScalarsPerStsA > SharedStoreTileTraits
The traits class to build the iterator to store data to shared memory for A^N.
Definition: gemm_traits.h:94
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:274
GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:203
GlobalLoadStreamB_::Scalar ScalarB
The scalar for B.
Definition: gemm_traits.h:381
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Defines properties of GEMM computation that impose some constraints on caller.
Definition: gemm_traits.h:349
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
Definition: matrix_traits.h:159
cutlass::gemm::Gemm< This_ > KernelClass
The struct that consumes this Traits.
Definition: gemm_traits.h:362
Definition: matrix_traits.h:159
SharedStream::Params shared_stream
Parameters object for the shared load stream.
Definition: gemm_traits.h:428
ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
The threads shape.
Definition: gemm_global_tile.h:88
long long batch_stride_B
batch stride for B operand
Definition: gemm_desc.h:105
BlockSwizzle_ BlockSwizzle
The block swizzle to reorganize the grid.
Definition: gemm_traits.h:397
TileLoadIterator< typename GemmTileTraitsHelperA_::SharedLoadTileTraits, typename GemmTileTraitsHelperA_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
The iterator to load A from shared memory.
Definition: gemm_traits.h:740
TileLoadIterator< typename GemmTileTraitsHelperB_::SharedLoadTileTraits, typename GemmTileTraitsHelperB_::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
The iterator to load B from shared memory.
Definition: gemm_traits.h:748
GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage
Memory needed to store the threadblock-scoped GEMM tile.
Definition: gemm_traits.h:410
dim3 block
CUDA threablock dimensions.
Definition: kernel_launch.h:44
GlobalLoadIteratorC::Scalar ScalarC
The scalar for C.
Definition: gemm_epilogue.h:83
Index_ Index
The index.
Definition: gemm_traits.h:399
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:64
TileStoreIterator< typename GemmTileTraitsHelperB_::SharedStoreTileTraits, typename GemmTileTraitsHelperB_::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
The iterator to store B to shared memory.
Definition: gemm_traits.h:727
Epilogue::Params epilogue
The params for the epilogue.
Definition: gemm_traits.h:431
Defines a pair of GEMM tile streams.
The shared storage.
Definition: clear_accumulators.h:40
Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
MainLoopSharedStorage main_loop
Definition: gemm_traits.h:676
static MatrixLayout::Kind const kLayoutA
The layout of A.
Definition: gemm_traits.h:372
dim3 grid
CUDA grid dimensions.
Definition: kernel_launch.h:41
Definition: matrix_traits.h:357
GlobalLoadStream::SharedStorage global_to_shared_stream
Storage for GEMM global stream.
Definition: gemm_traits.h:667
Parameters object passed to load iterators.
Definition: gemm_stream_pair.h:192
Definition: gemm_traits.h:698
Implements a software-pipelined efficient GEMM.
GemmGlobalIteratorAb< typename GemmTileTraitsHelperA_::GlobalTileTraits, Index_ > GlobalLoadIteratorA
The global iterator to load A from global memory.
Definition: gemm_traits.h:701
GemmConfig::OutputTile OutputTile
The output tile.
Definition: gemm_traits.h:367
Defines properties of matrices used to denote layout and operands to GEMM kernels.
Copy< typename GlobalLoadIteratorA::Fragment > GlobalTransformerA
The data converter for A before storing to shared memory.
Definition: gemm_traits.h:703
GemmSharedLoadTileBTraits< MultiplyAddScalar const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, GemmConfig_::kScalarsPerLdsB, 0 > SharedLoadTileTraits
The traits class to build the iterator to load from shared memory for B^T.
Definition: gemm_traits.h:324
ClearAccumulators::SharedStorage clear
Storage for clearing accumulators.
Definition: gemm_traits.h:670
StreamA::Params stream_a
Parameters object for StreamA.
Definition: gemm_stream_pair.h:64
GemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, Scalar const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
The traits class to build the iterator to load data from global memory for B^T.
Definition: gemm_traits.h:290
Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
Defines conversion operations among Fragments of different base type.
Definition: gemm_traits.h:773
OutputTile_ OutputTile
The tile.
Definition: gemm_config.h:88
static MatrixLayout::Kind const kLayoutB
The layout of B.
Definition: gemm_traits.h:379
An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
Definition: tile_iterator.h:841
GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar
The scalar stored in shared memory.
Definition: gemm_traits.h:127