43 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) 44 CUTLASS_DEVICE
bool is_zero(half x) {
return reinterpret_cast<int16_t&
>(x) == int16_t(0); }
50 template <
typename Scalar_,
typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
55 typedef typename FragmentMultiplyAdd_::ScalarAccum
ScalarAccum;
80 template <
typename GemmDesc_>
113 template <
typename FragmentA_,
typename FragmentB_>
114 CUTLASS_DEVICE
void evaluate(FragmentA_
const& accum, FragmentB_& output) {
121 template <
typename ScalarAccum,
typename ScalarOutput,
int size>
126 for (
int i = 0; i < size; i++) {
127 FragAccum[i] = accum[i];
128 FragOutput[i] = output[i];
132 for (
int i = 0; i < size; i++) {
133 output[i] = FragOutput[i];
138 template <
typename FragmentA_,
typename FragmentB_>
139 CUTLASS_DEVICE
void evaluate(FragmentA_
const& accum, FragmentB_
const& old, FragmentB_& output) {
143 mad.multiply_add(
params.
alpha, accum, tmp, output);
147 template <
typename ScalarAccum,
typename ScalarOutput,
int size>
153 for (
int i = 0; i < size; i++) {
154 FragAccum[i] = accum[i];
155 FragOutput[i] = output[i];
158 evaluate(FragAccum, FragOld, FragOutput);
160 for (
int i = 0; i < size; i++) {
161 output[i] = FragOutput[i];
CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta)
Initialize the parameters.
Definition: linear_scaling.h:73
Scalar alpha
The alpha/beta scaling params.
Definition: linear_scaling.h:62
CUTLASS_DEVICE bool source_required() const
Definition: linear_scaling.h:108
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output)
Evaluate the functor, without using fragment in the API.
Definition: linear_scaling.h:122
CUTLASS_DEVICE void evaluate(FragmentA_ const &accum, FragmentB_ const &old, FragmentB_ &output)
Evaluate the functor.
Definition: linear_scaling.h:139
CUTLASS_DEVICE void evaluate(FragmentA_ const &accum, FragmentB_ &output)
Evaluate the functor.
Definition: linear_scaling.h:114
Scalar beta
Definition: linear_scaling.h:62
A template defining Fragment Concept.
Definition: fragment.h:99
Params params
Definition: linear_scaling.h:92
FragmentMultiplyAdd_::ScalarAccum ScalarAccum
Definition: linear_scaling.h:55
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: linear_scaling.h:81
Defines multiply-add operations on fragments within a thread.
FragmentMultiplyAdd_ FragmentMultiplyAdd
Definition: linear_scaling.h:57
CUTLASS_DEVICE LinearScaling()
Ctor.
Definition: linear_scaling.h:99
CUTLASS_DEVICE bool is_zero(T x)
Definition: linear_scaling.h:39
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_DEVICE LinearScaling(Params const &_params)
Ctor.
Definition: linear_scaling.h:102
The parameters.
Definition: linear_scaling.h:60
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
Scalar_ Scalar
Definition: linear_scaling.h:53
CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output)
Evaluate the functor, without using fragment in the API.
Definition: linear_scaling.h:148
CUTLASS_HOST_DEVICE Params(Scalar _alpha=0, Scalar _beta=0)
Definition: linear_scaling.h:70