Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tile_stream.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
30 #pragma once
31 
32 // clang-format off
33 
34 #include "cutlass/convert.h"
35 #include "cutlass/tile_iterator.h"
36 
38 
39 namespace cutlass {
40 
42 
44 template <typename Iterator_, typename Transformer_ = Copy<typename Iterator_::Fragment> >
46  //
47  // Type definitions
48  //
49 
51  typedef Iterator_ Iterator;
52 
54  typedef Transformer_ Transformer;
55 
57  typedef typename Iterator::Fragment Fragment;
58 
60  typedef typename Transformer::OutputFragment TransformedFragment;
61 
63  typedef typename Iterator::TensorRef TensorRef;
64 
66  struct PredicateVector {};
67 
69  typedef typename Iterator::Index Index;
70 
72  struct Params {
74  typename Iterator::Params iterator;
75 
76  //
77  // Methods
78  //
79 
82  Params() {}
83 
86  Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {}
87  };
88 
89  //
90  // Data members
91  //
92 
95 
98 
101 
104 
105  //
106  // Methods
107  //
108 
110  CUTLASS_DEVICE
111  TileLoadStream(Params const &_params, TensorRef const &_ref)
112  : iterator(_params.iterator, _ref) {}
113 
115  CUTLASS_DEVICE
116  TileLoadStream(Params const &_params,
117  Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)
118  ): iterator(_params.iterator, threadblock_offset) { }
119 
121  CUTLASS_DEVICE
122  void copy() { iterator.load_post_increment(fetched_fragment); }
123 
125  CUTLASS_DEVICE
127 
129  CUTLASS_DEVICE
131 
133  CUTLASS_DEVICE
135 };
136 
138 
140 template <typename Iterator_, typename Transformer_ = Copy<typename Iterator_::Fragment> >
142  //
143  // Type definitions
144  //
145 
147  typedef Iterator_ Iterator;
148 
150  typedef Transformer_ Transformer;
151 
153  typedef typename Transformer::InputFragment Fragment;
154 
156  typedef typename Transformer::OutputFragment TransformedFragment;
157 
159  typedef typename Iterator::TensorRef TensorRef;
160 
162  struct PredicateVector {};
163 
165  typedef typename Iterator::Index Index;
166 
168  struct Params {
170  typename Iterator::Params iterator;
171 
172  //
173  // Methods
174  //
175 
178  Params() {}
179 
182  Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {}
183  };
184 
185  //
186  // Data members
187  //
188 
191 
194 
197 
200 
201  //
202  // Methods
203  //
204 
206  CUTLASS_DEVICE
207  TileStoreStream(Params const &_params, TensorRef const &_ref)
208  : iterator(_params.iterator, _ref) {}
209 
211  CUTLASS_DEVICE
212  TileStoreStream(Params const &_params,
213  Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)
214  ): iterator(_params.iterator, threadblock_offset) { }
215 
217  CUTLASS_DEVICE
218  void copy() {
219 
221  iterator.store_post_increment(transformed_fragment);
222  }
223 
225  CUTLASS_DEVICE
226  void copy(Fragment const &frag) {
227  source_fragment = frag;
228  copy();
229  }
230 
232  CUTLASS_DEVICE
233  void commit() {}
234 
236  CUTLASS_DEVICE
238 
240  CUTLASS_DEVICE
242 };
243 
245 
247 template <typename Iterator_,
248  typename PredicateFunctor_ =
249  RegularTilePredicateFunctor<typename Iterator_::Traits::Delta>,
250  typename Transformer_ = Copy<typename Iterator_::Fragment> >
251 struct PredicatedTileLoadStream : public TileLoadStream<Iterator_, Transformer_> {
252  //
253  // Type definitions
254  //
255 
257 
259  typedef Iterator_ Iterator;
260 
262  typedef PredicateFunctor_ PredicateFunctor;
263 
265  typedef Transformer_ Transformer;
266 
268  typedef typename Base::Fragment Fragment;
269 
272 
274  typedef typename Base::Params Params;
275 
276  //
277  // Data members
278  //
279 
281  typename Iterator::PredicateVector predicates;
282 
283  //
284  // Methods
285  //
286 
288  CUTLASS_DEVICE
290  Coord<3> const &bounds,
291  Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
292  : Base(_params, threadblock_offset) {
293  this->iterator.initialize_predicates(
294  predicates.begin(), PredicateFunctor(bounds), threadblock_offset);
295  }
296 
298  CUTLASS_DEVICE
299  void copy() { this->iterator.load_post_increment(this->fetched_fragment, predicates.begin()); }
300 };
301 
303 
305 template <typename Iterator_,
306  typename PredicateFunctor_ =
307  RegularTilePredicateFunctor<typename Iterator_::Traits::Delta>,
308  typename Transformer_ = Copy<typename Iterator_::Fragment> >
309 struct PredicatedTileStoreStream : public TileStoreStream<Iterator_, Transformer_> {
310  //
311  // Type definitions
312  //
313 
315 
317  typedef Iterator_ Iterator;
318 
320  typedef PredicateFunctor_ PredicateFunctor;
321 
323  typedef Transformer_ Transformer;
324 
326  typedef typename Base::Fragment Fragment;
327 
330 
332  typedef typename Base::Params Params;
333 
334  //
335  // Data members
336  //
337 
339  typename Iterator::PredicateVector predicates;
340 
341  //
342  // Methods
343  //
344 
346  CUTLASS_DEVICE
348  Coord<3> const &bounds,
349  Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
350  : Base(_params, threadblock_offset) {
351  this->iterator.initialize_predicates(
352  predicates.begin(), PredicateFunctor(bounds), threadblock_offset);
353  }
354 
356  CUTLASS_DEVICE
357  void copy() {
358  this->transformer.transform(this->source_fragment, this->transformed_fragment);
359  this->iterator.store_post_increment(this->transformed_fragment, predicates.begin());
360  }
361 
363  CUTLASS_DEVICE
364  void copy(Fragment const &frag) {
365  this->source_fragment = frag;
366  copy();
367  }
368 
370  CUTLASS_DEVICE
371  void commit() {}
372 };
373 
375 
376 } // namespace cutlass
377 
378 // clang-format on
Base::TransformedFragment TransformedFragment
Output fragment from transformer.
Definition: tile_stream.h:329
CUTLASS_DEVICE TransformedFragment & fragment()
Accesses the loaded, transformed fragment.
Definition: tile_stream.h:134
CUTLASS_DEVICE void copy(Fragment const &frag)
Stores the fragment and increments the iterator.
Definition: tile_stream.h:364
CUTLASS_DEVICE Fragment & fragment()
Accesses the transformed fragment.
Definition: tile_stream.h:237
CUTLASS_DEVICE void copy()
Loads a tile and increments the iterator.
Definition: tile_stream.h:122
CUTLASS_DEVICE void copy()
Stores the fragment and increments the iterator.
Definition: tile_stream.h:357
Definition: convert.h:33
Transformer::InputFragment Fragment
Source fragment.
Definition: tile_stream.h:153
Iterator::PredicateVector predicates
Predicates.
Definition: tile_stream.h:281
CUTLASS_DEVICE void copy(Fragment const &frag)
Stores a fragment and increments the iterator.
Definition: tile_stream.h:226
Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
Iterator_ Iterator
TileLoadIterator.
Definition: tile_stream.h:317
Iterator::TensorRef TensorRef
Tensor reference expected by the underlying iterator.
Definition: tile_stream.h:159
Iterator_ Iterator
TileLoadIterator.
Definition: tile_stream.h:51
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
Generic stream for transforming and storing fragments.
Definition: tile_stream.h:141
PredicateFunctor_ PredicateFunctor
Predicate functor.
Definition: tile_stream.h:262
Generic stream for loading and transforming fragments.
Definition: tile_stream.h:251
Base::TransformedFragment TransformedFragment
Output fragment from transformer.
Definition: tile_stream.h:271
Base::Fragment Fragment
Fragment fetched from source memory.
Definition: tile_stream.h:326
Empty predicate vector struct.
Definition: tile_stream.h:162
Base::Params Params
Parameters object used to construct generic load stream.
Definition: tile_stream.h:332
Parameters used to construct the stream.
Definition: tile_stream.h:168
CUTLASS_DEVICE void copy()
Loads a tile and increments the iterator.
Definition: tile_stream.h:299
TransformedFragment transformed_fragment
Transformed fragment from transformer.
Definition: tile_stream.h:103
CUTLASS_DEVICE void commit()
Commits the fetched fragment and applies a transformation.
Definition: tile_stream.h:126
Transformer_ Transformer
Transformer.
Definition: tile_stream.h:323
Iterator iterator
Iterator to load tiles.
Definition: tile_stream.h:94
CUTLASS_HOST_DEVICE Params(typename Iterator::Params const &_iterator)
Constructor with iterator params.
Definition: tile_stream.h:182
CUTLASS_DEVICE Fragment & intermediate_fragment()
Accesses the loaded, transformed fragment.
Definition: tile_stream.h:130
Transformer_ Transformer
Transformer.
Definition: tile_stream.h:265
Base::Params Params
Parameters object used to construct generic load stream.
Definition: tile_stream.h:274
Transformer_ Transformer
Transformer.
Definition: tile_stream.h:150
Iterator::Index Index
Index type.
Definition: tile_stream.h:69
Iterator::Params iterator
Parameters to the iterator.
Definition: tile_stream.h:170
Iterator::PredicateVector predicates
Predicates.
Definition: tile_stream.h:339
Transformer transformer
Transformation applied to fragments.
Definition: tile_stream.h:100
CUTLASS_DEVICE PredicatedTileStoreStream(Params const &_params, Coord< 3 > const &bounds, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))
Ctor.
Definition: tile_stream.h:347
Fragment fetched_fragment
Fragment loaded via iterator.
Definition: tile_stream.h:97
Transformer_ Transformer
Transformer.
Definition: tile_stream.h:54
CUTLASS_DEVICE TileStoreStream(Params const &_params, TensorRef const &_ref)
Ctor.
Definition: tile_stream.h:207
CUTLASS_HOST_DEVICE Params()
Default constructor.
Definition: tile_stream.h:82
CUTLASS_DEVICE TileStoreStream(Params const &_params, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))
Ctor.
Definition: tile_stream.h:212
CUTLASS_DEVICE TileLoadStream(Params const &_params, TensorRef const &_ref)
Ctor.
Definition: tile_stream.h:111
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
Iterator::TensorRef TensorRef
Tensor reference expected by the stream.
Definition: tile_stream.h:63
CUTLASS_HOST_DEVICE Params()
Default constructor.
Definition: tile_stream.h:178
Transformer transformer
Transformation applied to inputs.
Definition: tile_stream.h:193
Transformer::OutputFragment TransformedFragment
Output fragment from transformer.
Definition: tile_stream.h:60
CUTLASS_DEVICE PredicatedTileLoadStream(Params const &_params, Coord< 3 > const &bounds, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))
Ctor.
Definition: tile_stream.h:289
PredicateFunctor_ PredicateFunctor
Predicate functor.
Definition: tile_stream.h:320
Generic stream for loading and transforming fragments.
Definition: tile_stream.h:45
Empty predicate vector struct.
Definition: tile_stream.h:66
CUTLASS_DEVICE void commit()
Commits the store operation.
Definition: tile_stream.h:233
Iterator::Fragment Fragment
Fragment fetched from source memory.
Definition: tile_stream.h:57
Parameters object used to construct generic load stream.
Definition: tile_stream.h:72
Iterator::Index Index
Index type.
Definition: tile_stream.h:165
Base::Fragment Fragment
Fragment fetched from source memory.
Definition: tile_stream.h:268
CUTLASS_HOST_DEVICE Params(typename Iterator::Params const &_iterator)
Constructor with iterator params.
Definition: tile_stream.h:86
Transformer::OutputFragment TransformedFragment
Transformed fragment, compatible with Iterator::Fragment.
Definition: tile_stream.h:156
Fragment source_fragment
Source fragment.
Definition: tile_stream.h:196
CUTLASS_DEVICE TransformedFragment & intermediate_fragment()
Accesses the fragment after trasnforming.
Definition: tile_stream.h:241
TileStoreStream< Iterator_, Transformer_ > Base
Definition: tile_stream.h:314
CUTLASS_DEVICE TileLoadStream(Params const &_params, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))
Ctor.
Definition: tile_stream.h:116
Iterator_ Iterator
TileLoadIterator.
Definition: tile_stream.h:147
CUTLASS_DEVICE void commit()
Commits the store operation.
Definition: tile_stream.h:371
Generic stream for transforming and storing fragments.
Definition: tile_stream.h:309
TransformedFragment transformed_fragment
Transformed fragment from transformer.
Definition: tile_stream.h:199
Defines conversion operations among Fragments of different base type.
Iterator_ Iterator
TileLoadIterator.
Definition: tile_stream.h:259
TileLoadStream< Iterator_, Transformer_ > Base
Definition: tile_stream.h:256
CUTLASS_DEVICE void copy()
Stores a fragment and increments the iterator.
Definition: tile_stream.h:218
Iterator::Params iterator
Parameters to the iterator.
Definition: tile_stream.h:74
Iterator iterator
Iterator to store tiles.
Definition: tile_stream.h:190