Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
cutlass_math.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
26 #pragma once
27 
33 #include "cutlass/util/platform.h"
34 
35 namespace cutlass {
36 
37 /******************************************************************************
38  * Static math utilities
39  ******************************************************************************/
40 
44 template <int N>
45 struct is_pow2 : platform::integral_constant<bool, (N & (N - 1)) == 0> {};
46 
50 template <int N, int CurrentVal = N, int Count = 0>
51 struct log2_down {
53  enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
54 };
55 
56 // Base case
57 template <int N, int Count>
58 struct log2_down<N, 1, Count> {
59  enum { value = Count };
60 };
61 
65 template <int N, int CurrentVal = N, int Count = 0>
66 struct log2_up {
68  enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
69 };
70 
71 // Base case
72 template <int N, int Count>
73 struct log2_up<N, 1, Count> {
74  enum { value = ((1 << Count) < N) ? Count + 1 : Count };
75 };
76 
80 template <int N>
81 struct sqrt_est {
82  enum { value = 1 << (log2_up<N>::value / 2) };
83 };
84 
89 template <int Dividend, int Divisor>
90 struct divide_assert {
91  enum { value = Dividend / Divisor };
92 
93  static_assert((Dividend % Divisor == 0), "Not an even multiple");
94 };
95 
96 /******************************************************************************
97  * Rounding
98  ******************************************************************************/
99 
103 template <typename dividend_t, typename divisor_t>
104 CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
105  return ((dividend + divisor - 1) / divisor) * divisor;
106 }
107 
111 template <typename value_t>
112 CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
113  for (;;) {
114  if (a == 0) return b;
115  b %= a;
116  if (b == 0) return a;
117  a %= b;
118  }
119 }
120 
124 template <typename value_t>
125 CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
126  value_t temp = gcd(a, b);
127 
128  return temp ? (a / temp * b) : 0;
129 }
130 
136 template <typename value_t>
137 CUTLASS_HOST_DEVICE value_t clz(value_t x) {
138  for (int i = 31; i >= 0; --i) {
139  if ((1 << i) & x) return 31 - i;
140  }
141  return 32;
142 }
143 
144 template <typename value_t>
145 CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {
146  int a = 31 - clz(x);
147  a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
148  return a;
149 }
150 
151 /******************************************************************************
152  * Min/Max
153  ******************************************************************************/
154 
155 template <int A, int B>
156 struct Min {
157  static int const kValue = (A < B) ? A : B;
158 };
159 
160 template <int A, int B>
161 struct Max {
162  static int const kValue = (A > B) ? A : B;
163 };
164 
165 } // namespace cutlass
Definition: cutlass_math.h:91
Definition: convert.h:33
static int const kValue
Definition: cutlass_math.h:157
CUTLASS_HOST_DEVICE value_t find_log2(value_t x)
Definition: cutlass_math.h:145
Definition: cutlass_math.h:51
C++ features that may be otherwise unimplemented for CUDA device functions.
Definition: cutlass_math.h:156
Definition: cutlass_math.h:53
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b)
Definition: cutlass_math.h:125
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor)
Definition: cutlass_math.h:104
Definition: cutlass_math.h:68
std::integral_constant
Definition: platform.h:282
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
#define static_assert(__e, __m)
Definition: platform.h:153
Definition: cutlass_math.h:161
Definition: cutlass_math.h:82
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b)
Definition: cutlass_math.h:112
Definition: cutlass_math.h:90
Definition: cutlass_math.h:66
CUTLASS_HOST_DEVICE value_t clz(value_t x)
Definition: cutlass_math.h:137
Definition: cutlass_math.h:45
static int const kValue
Definition: cutlass_math.h:162
Definition: cutlass_math.h:81