45 template <
typename Scalar_,
typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
113 template <
typename GemmDesc_>
CUTLASS_HOST_DEVICE int initialize(Scalar const *alpha, Scalar const *beta)
Initialize the parameters.
Definition: linear_scaling_device_ptr.h:102
The parameters.
Definition: linear_scaling_device_ptr.h:55
CUTLASS_HOST_DEVICE Params(Scalar const *alpha_ptr, Scalar const *beta_ptr)
Definition: linear_scaling_device_ptr.h:83
Implements the BLAS linear scaling function alpha*AB + beta*C.
Implements the BLAS linear scaling function alpha*AB + beta*C.
CUTLASS_HOST_DEVICE int initialize(Scalar alpha, Scalar beta)
Initialize the parameters.
Definition: linear_scaling_device_ptr.h:91
Params params
Definition: linear_scaling.h:92
LinearScaling< Scalar_, FragmentMultiplyAdd_ > Base
Linear Scaling class used.
Definition: linear_scaling_device_ptr.h:49
CUTLASS_HOST_DEVICE Params()
Definition: linear_scaling_device_ptr.h:70
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE Params(Scalar alpha, Scalar beta)
Definition: linear_scaling_device_ptr.h:74
CUTLASS_HOST_DEVICE Scalar beta() const
Gets the beta scalar.
Definition: linear_scaling_device_ptr.h:130
CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const &_params)
Ctor.
Definition: linear_scaling_device_ptr.h:140
CUTLASS_HOST_DEVICE Scalar alpha() const
Gets the alpha scalar.
Definition: linear_scaling_device_ptr.h:124
Definition: linear_scaling_device_ptr.h:46
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
Scalar_ Scalar
Definition: linear_scaling.h:53
Base::Scalar Scalar
Definition: linear_scaling_device_ptr.h:52
Basic include for CUTLASS macros.
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
Initialize the parameters.
Definition: linear_scaling_device_ptr.h:114