Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
zip_tile_iterator.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
30 #pragma once
31 
32 #include "cutlass/coord.h"
33 #include "cutlass/zip_tensor_ref.h"
34 #include "cutlass/zip_fragment.h"
35 #include "cutlass/util/pair.h"
36 
37 namespace cutlass {
38 
40 
42 template <typename First_, typename Second_>
44  public:
46  typedef First_ First;
47 
49  typedef Second_ Second;
50 
52  struct Params {
54  typename First::Params first;
55 
57  typename Second::Params second;
58 
61  Params() {}
62 
65  Params(typename First::Params const &_first, typename Second::Params const &_second)
66  : first(_first), second(_second) {}
67  };
68 
71 
73  typedef typename First::PredicateVector PredicateVector;
74 
77 
80 
82  typedef ZipTensorRef<
83  typename First::TensorRef,
84  typename Second::TensorRef> TensorRef;
85 
86  //
87  // Data members
88  //
89 
92 
95 
96  //
97  // Methods
98  //
99 
101  CUTLASS_DEVICE
103 
105  CUTLASS_DEVICE
106  ZipTileIterator(Params const &_params, Coord<3> const &threadblock_offset = make_Coord(0, 0, 0))
107  : first(_params.first, threadblock_offset), second(_params.second, threadblock_offset) {}
108 
110  CUTLASS_DEVICE
111  ZipTileIterator(First const &_first, Second const &_second) : first(_first), second(_second) {}
112 
114  CUTLASS_DEVICE
115  ZipTileIterator(TensorRef const &ref) : first(ref.first), second(ref.second) {}
116 
118  CUTLASS_DEVICE
119  ZipTileIterator(Params const &_params, TensorRef const &ref):
120  first(_params.first, ref.first), second(_params.second, ref.second) {}
121 
122  //
123  // Predicate initialization
124  //
125 
127  template <
129  typename PredicateIterator>
130  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
131  Coord<3> const &bounds,
132  Coord<3> const &block_offset = make_Coord(0,
133  0,
134  0)) {
135  first.initialize_predicates(predicate_it, bounds, block_offset);
136  }
137 
139  template <
141  typename PredicateIterator,
143  typename PredicateFunctor>
144  CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it,
145  PredicateFunctor const &functor,
146  Coord<3> const &block_offset) {
147  first.initialize_predicates(predicate_it, functor, block_offset);
148  }
149 
150  //
151  // No predicates
152  //
153 
155  template <typename Fragment>
156  CUTLASS_DEVICE void load_post_increment(Fragment &fragment) {
157  first.load_post_increment(fragment.first);
158  second.load_post_increment(fragment.second);
159  }
160 
162  template <typename Fragment>
163  CUTLASS_DEVICE void load_post_increment(Fragment &fragment,
164  Coord<4> const &offset) {
165  first.load_post_increment(fragment.first, offset);
166  second.load_post_increment(fragment.second, offset);
167  }
168 
170  template <typename Fragment>
171  CUTLASS_DEVICE void load(Fragment &fragment) const {
172  first.load(fragment.first);
173  second.load(fragment.second);
174  }
175 
177  template <typename Fragment>
178  CUTLASS_DEVICE void load(Fragment &fragment,
179  Coord<4> const &offset) const {
180  first.load(fragment.first, offset);
181  second.load(fragment.second, offset);
182  }
183 
185  template <typename Fragment>
186  CUTLASS_DEVICE void store_post_increment(Fragment const &fragment) {
187  first.store_post_increment(fragment.first);
188  second.store_post_increment(fragment.second);
189  }
190 
192  template <typename Fragment>
193  CUTLASS_DEVICE void store_post_increment(Fragment const &fragment,
194  Coord<4> const &offset) {
195  first.store_post_increment(fragment.first, offset);
196  second.store_post_increment(fragment.second, offset);
197  }
198 
200  template <typename Fragment>
201  CUTLASS_DEVICE void store(Fragment const &fragment) const {
202  first.store(fragment.first);
203  second.store(fragment.second);
204  }
205 
207  template <typename Fragment>
208  CUTLASS_DEVICE void store(Fragment const &fragment,
209  Coord<4> const &offset) const {
210  first.store(fragment.first, offset);
211  second.store(fragment.second, offset);
212  }
213 
214  //
215  // With predication
216  //
217 
219  template <typename Fragment, typename PredicateIterator>
220  CUTLASS_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) {
221  first.load_post_increment(fragment.first, pred_it);
222  second.load_post_increment(fragment.second, pred_it);
223  }
224 
226  template <typename Fragment, typename PredicateIterator>
227  CUTLASS_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const {
228  first.load(fragment.first, pred_it);
229  second.load(fragment.second, pred_it);
230  }
231 
233  template <typename Fragment, typename PredicateIterator>
234  CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) {
235  first.store_post_increment(fragment.first, pred_it);
236  second.store_post_increment(fragment.second, pred_it);
237  }
238 
240  template <typename Fragment, typename PredicateIterator>
241  CUTLASS_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const {
242  first.store(fragment.first, pred_it);
243  second.store(fragment.second, pred_it);
244  }
245 
246  //
247  // Advances the iterators
248  //
249 
251  CUTLASS_DEVICE ZipTileIterator &increment(int count = 1) {
252  first.increment(count);
253  second.increment(count);
254  return *this;
255  }
256 
258  CUTLASS_DEVICE ZipTileIterator &operator++() { return increment(); }
259 
260  CUTLASS_DEVICE ZipTileIterator &operator+=(int count) { return increment(count); }
261 
263  CUTLASS_DEVICE ZipTileIterator &operator+=(Coord<3> const &offset) {
264  first += offset;
265  second += offset;
266  return *this;
267  }
268 
270  CUTLASS_DEVICE ZipTileIterator &decrement(int count = 1) {
271  first.decrement(count);
272  second.decrement(count);
273  return *this;
274  }
275 
277  CUTLASS_DEVICE ZipTileIterator &operator--() { return decrement(); }
278 
280  CUTLASS_DEVICE ZipTileIterator &operator-=(int count) { return decrement(count); }
281 
283  CUTLASS_DEVICE void add_pointer_offset(LongIndex offset) {
284  first.add_pointer_offset(offset.first);
285  second.add_pointer_offset(offset.second);
286  }
287 };
288 
290 
291 } // namspace cutlass
ZipFragment< typename First::Fragment, typename Second::Fragment > Fragment
Fragment type.
Definition: zip_tile_iterator.h:70
Second::Params second
Parameters of second iterator.
Definition: zip_tile_iterator.h:57
First_ First
First iterator type.
Definition: zip_tile_iterator.h:46
Definition: convert.h:33
Definition: zip_tensor_ref.h:38
CUTLASS_DEVICE ZipTileIterator(Params const &_params, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))
Constructs a zip iterator from params.
Definition: zip_tile_iterator.h:106
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, Coord< 4 > const &offset)
Stores a fragment and increments without predicates.
Definition: zip_tile_iterator.h:193
CUTLASS_DEVICE void load_post_increment(Fragment &fragment, Coord< 4 > const &offset)
Loads a fragment and increments without predicates.
Definition: zip_tile_iterator.h:163
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
CUTLASS_DEVICE void store(Fragment const &fragment) const
Stores a fragment without predicates.
Definition: zip_tile_iterator.h:201
platform::Pair< typename First::Index, typename Second::Index > Index
Index type.
Definition: zip_tile_iterator.h:76
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:368
CUTLASS_HOST_DEVICE Params(typename First::Params const &_first, typename Second::Params const &_second)
Constructs a parameters object.
Definition: zip_tile_iterator.h:65
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
Initializes a predicate vector using a RegularTilePredicateFunctor.
Definition: zip_tile_iterator.h:130
CUTLASS_HOST_DEVICE Params()
Constructs a parameters object.
Definition: zip_tile_iterator.h:61
A template defining Fragment Concept.
Definition: zip_fragment.h:46
CUTLASS_DEVICE ZipTileIterator & operator+=(Coord< 3 > const &offset)
Adds a vector offset to the underlying iterators.
Definition: zip_tile_iterator.h:263
CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
Initializes a predicate vector using an arbitrary predicate functor.
Definition: zip_tile_iterator.h:144
First::Params first
Parameters of first iterator.
Definition: zip_tile_iterator.h:54
CUTLASS_DEVICE void load_post_increment(Fragment &fragment)
Loads a fragment and increments without predicates.
Definition: zip_tile_iterator.h:156
CUTLASS_DEVICE ZipTileIterator(Params const &_params, TensorRef const &ref)
Constructs a zip iterator from iterator instances.
Definition: zip_tile_iterator.h:119
CUTLASS_DEVICE ZipTileIterator & operator+=(int count)
Definition: zip_tile_iterator.h:260
CUTLASS_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const
Loads a fragment with predicates.
Definition: zip_tile_iterator.h:241
ZipTensorRef< typename First::TensorRef, typename Second::TensorRef > TensorRef
Tensor reference.
Definition: zip_tile_iterator.h:84
CUTLASS_DEVICE ZipTileIterator & operator-=(int count)
Decrements to previous tile.
Definition: zip_tile_iterator.h:280
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment)
Stores a fragment and increments without predicates.
Definition: zip_tile_iterator.h:186
Models a pair of fragments.
First first
First fragment object.
Definition: zip_fragment.h:61
CUTLASS_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it)
Loads a fragment and increments, using predicates.
Definition: zip_tile_iterator.h:220
Defines a structure containing a pair of TensorRef-like objects.
Second_ Second
Second iterator type.
Definition: zip_tile_iterator.h:49
CUTLASS_DEVICE ZipTileIterator(First const &_first, Second const &_second)
Constructs a zip iterator from iterator instances.
Definition: zip_tile_iterator.h:111
CUTLASS_DEVICE ZipTileIterator & operator++()
Increments to next tile.
Definition: zip_tile_iterator.h:258
Constructs an iterator from a pair of iterators.
Definition: zip_tile_iterator.h:43
Second second
Second fragment object.
Definition: zip_fragment.h:64
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE ZipTileIterator()
Default constructor.
Definition: zip_tile_iterator.h:102
First first
First iterator.
Definition: zip_tile_iterator.h:91
Params object.
Definition: zip_tile_iterator.h:52
Second second
Second iterator.
Definition: zip_tile_iterator.h:94
T1 second
Definition: pair.h:49
Defines a pair<>
CUTLASS_DEVICE void store(Fragment const &fragment, Coord< 4 > const &offset) const
Stores a fragment without predicates.
Definition: zip_tile_iterator.h:208
Constructs an iterator from a pair of iterators.
Definition: pair.h:39
T1 first
Definition: pair.h:48
platform::Pair< typename First::LongIndex, typename Second::LongIndex > LongIndex
Long index type.
Definition: zip_tile_iterator.h:79
CUTLASS_DEVICE void load(Fragment &fragment) const
Loads a fragment without predicates.
Definition: zip_tile_iterator.h:171
First::PredicateVector PredicateVector
Predicate vector.
Definition: zip_tile_iterator.h:73
CUTLASS_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const
Loads a fragment with predicates.
Definition: zip_tile_iterator.h:227
CUTLASS_DEVICE ZipTileIterator & decrement(int count=1)
Increments store iterator to previous tile.
Definition: zip_tile_iterator.h:270
CUTLASS_DEVICE ZipTileIterator & operator--()
Increments to subsequent tile.
Definition: zip_tile_iterator.h:277
CUTLASS_DEVICE void add_pointer_offset(LongIndex offset)
Adds an offset to both iterators.
Definition: zip_tile_iterator.h:283
CUTLASS_DEVICE ZipTileIterator(TensorRef const &ref)
Constructs a zip iterator from iterator instances.
Definition: zip_tile_iterator.h:115
CUTLASS_DEVICE ZipTileIterator & increment(int count=1)
Increments store iterator to next tile.
Definition: zip_tile_iterator.h:251
CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it)
Loads a fragment and increments, using predicates.
Definition: zip_tile_iterator.h:234
CUTLASS_DEVICE void load(Fragment &fragment, Coord< 4 > const &offset) const
Loads a fragment without predicates.
Definition: zip_tile_iterator.h:178