Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
fragment_load_store.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/load_store.h>
31 #include <cutlass/vector.h>
32 
33 namespace cutlass {
34 
36 
37 template <IteratorFragment::Kind kIteratorFragment,
38  int kAccessSize,
39  typename Scalar_,
40  MemorySpace::Kind Memory_,
41  typename FragmentElement_,
42  int kStride>
43 struct FragmentLoad {};
44 
45 template <int kAccessSize,
46  typename Scalar_,
47  MemorySpace::Kind Memory_,
48  typename FragmentElement_,
49  int kStride>
50 struct FragmentLoad<IteratorFragment::kWmmaMatrix,
51  kAccessSize,
52  Scalar_,
53  Memory_,
54  FragmentElement_,
55  kStride> {
57  typedef FragmentElement_ AccessType;
58 
60  static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
61  value.load(&pointer[offset], kStride);
62  }
63 };
64 
65 template <int kAccessSize,
66  typename Scalar_,
67  MemorySpace::Kind Memory_,
68  typename FragmentElement_,
69  int kStride>
71  kAccessSize,
72  Scalar_,
73  Memory_,
74  FragmentElement_,
75  kStride> {
78 
80  static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
81  Load<Scalar_, kAccessSize, Memory_>::load(value, pointer, offset);
82  }
83 };
84 
85 template <IteratorFragment::Kind kIteratorFragment,
86  int kAccessSize,
87  typename Scalar_,
88  MemorySpace::Kind Memory_,
89  typename FragmentElement_,
90  int kStride>
91 struct FragmentStore {};
92 
93 template <int kAccessSize,
94  typename Scalar_,
95  MemorySpace::Kind Memory_,
96  typename FragmentElement_,
97  int kStride>
98 struct FragmentStore<IteratorFragment::kWmmaMatrix,
99  kAccessSize,
100  Scalar_,
101  Memory_,
102  FragmentElement_,
103  kStride> {
105  typedef FragmentElement_ AccessType;
106 
108  static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
109  value.store(&pointer[offset], kStride);
110  }
111 };
112 
113 template <int kAccessSize,
114  typename Scalar_,
115  MemorySpace::Kind Memory_,
116  typename FragmentElement_,
117  int kStride>
119  kAccessSize,
120  Scalar_,
121  Memory_,
122  FragmentElement_,
123  kStride> {
126 
128  static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
129  Store<Scalar_, kAccessSize, Memory_>::store(value, pointer, offset);
130  }
131 };
132 
134 
135 }
Definition: fragment_load_store.h:43
Vectorize< Scalar_, kAccessSize >::Type AccessType
The input type.
Definition: fragment_load_store.h:125
Definition: convert.h:33
Vectorize< Scalar_, kAccessSize >::Type AccessType
The output type.
Definition: fragment_load_store.h:77
static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
The load function.
Definition: load_store.h:59
static CUTLASS_DEVICE void store(AccessType const &value, Scalar_ *pointer, int offset)
The store function.
Definition: fragment_load_store.h:108
static CUTLASS_DEVICE void store(AccessType const &value, Scalar_ *pointer, int offset)
The store function.
Definition: fragment_load_store.h:128
Kind
Definition: load_store.h:40
static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
The store function.
Definition: load_store.h:136
Kind
Definition: tile_iterator.h:67
static CUTLASS_DEVICE void load(AccessType &value, Scalar_ const *pointer, int offset)
The load function.
Definition: fragment_load_store.h:80
Defines abstractions for efficiently loading and storing vectors to memory.
Definition: vector.h:61
Defines a 1D vector of elements held in the registers of each thread.
Definition: fragment_load_store.h:91
static CUTLASS_DEVICE void load(AccessType &value, Scalar_ const *pointer, int offset)
The load function.
Definition: fragment_load_store.h:60
Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix.
Definition: tile_iterator.h:66