44 #ifndef KOKKOS_MV_GEMM_HPP 45 #define KOKKOS_MV_GEMM_HPP 52 #include <Teuchos_BLAS.hpp> 53 #include <Kokkos_Blas2_MV.hpp> 55 #ifdef KOKKOS_HAVE_CUDA 69 class BLAS<int, ::Kokkos::complex<float> > {
71 typedef float mag_type;
72 typedef ::Kokkos::complex<float> val_type;
73 typedef std::complex<float> impl_type;
76 BLAS (
const BLAS<int, val_type>&) {}
90 GEMV (ETransp trans,
const int m,
const int n,
const val_type alpha,
91 const val_type* A,
const int lda,
const val_type* x,
const int incx,
92 const val_type beta, val_type* y,
const int incy)
const 94 BLAS<int, impl_type> blas;
95 blas.GEMV (trans, m, n, static_cast<impl_type> (alpha),
96 reinterpret_cast<const impl_type*> (A), lda,
97 reinterpret_cast<const impl_type*> (x), incx,
98 static_cast<impl_type> (beta),
99 reinterpret_cast<impl_type*> (y), incy);
106 GEMM (ETransp transa, ETransp transb,
const int m,
const int n,
const int k,
107 const val_type alpha,
const val_type* A,
const int lda,
108 const val_type* B,
const int ldb,
const val_type beta, val_type* C,
111 BLAS<int, impl_type> blas;
112 blas.GEMM (transa, transb, m, n, k,
113 static_cast<impl_type> (alpha),
114 reinterpret_cast<const impl_type*> (A), lda,
115 reinterpret_cast<const impl_type*> (B), ldb,
116 static_cast<impl_type> (beta),
117 reinterpret_cast<impl_type*> (C), ldc);
127 class BLAS<int, ::Kokkos::complex<double> > {
129 typedef double mag_type;
130 typedef ::Kokkos::complex<double> val_type;
131 typedef std::complex<double> impl_type;
134 BLAS (
const BLAS<int, val_type>&) {}
148 GEMV (ETransp trans,
const int m,
const int n,
const val_type alpha,
149 const val_type* A,
const int lda,
const val_type* x,
const int incx,
150 const val_type beta, val_type* y,
const int incy)
const 152 BLAS<int, impl_type> blas;
153 blas.GEMV (trans, m, n, static_cast<impl_type> (alpha),
154 reinterpret_cast<const impl_type*> (A), lda,
155 reinterpret_cast<const impl_type*> (x), incx,
156 static_cast<impl_type> (beta),
157 reinterpret_cast<impl_type*> (y), incy);
164 GEMM (ETransp transa, ETransp transb,
const int m,
const int n,
const int k,
165 const val_type alpha,
const val_type* A,
const int lda,
166 const val_type* B,
const int ldb,
const val_type beta, val_type* C,
169 BLAS<int, impl_type> blas;
170 blas.GEMM (transa, transb, m, n, k,
171 static_cast<impl_type> (alpha),
172 reinterpret_cast<const impl_type*> (A), lda,
173 reinterpret_cast<const impl_type*> (B), ldb,
174 static_cast<impl_type> (beta),
175 reinterpret_cast<impl_type*> (C), ldc);
190 template<
class ViewType>
191 size_t getStride2DView (ViewType A) {
194 return A.dimension_1 () > 1 ? stride[1] : A.dimension_0 ();
204 template <
typename Scalar,
typename DeviceType>
208 GEMM (
const Teuchos::ETransp transA,
209 const Teuchos::ETransp transB,
211 const View<const Scalar**, LayoutLeft, DeviceType>& A,
212 const View<const Scalar**, LayoutLeft, DeviceType>& B,
214 const View<Scalar**, LayoutLeft, DeviceType>& C)
216 const int n =
static_cast<int> (C.dimension_1 ());
217 const int lda =
static_cast<int> (Impl::getStride2DView (A));
218 Teuchos::BLAS<int,Scalar> blas;
222 if (n == 1 && transB == Teuchos::NO_TRANS) {
223 blas.GEMV (transA, A.dimension_0 (), A.dimension_1 (),
224 alpha, A.ptr_on_device (), lda,
225 B.ptr_on_device (),
static_cast<int> (1),
226 beta, C.ptr_on_device (),
static_cast<int> (1));
229 const int m =
static_cast<int> (C.dimension_0 ());
230 const int k =
static_cast<int> (transA == Teuchos::NO_TRANS ?
231 A.dimension_1 () : A.dimension_0 ());
232 const int ldb =
static_cast<int> (Impl::getStride2DView (B));
233 const int ldc =
static_cast<int> (Impl::getStride2DView (C));
235 blas.GEMM (transA, transB, m, n, k, alpha,
236 A.ptr_on_device(), lda,
237 B.ptr_on_device(), ldb,
238 beta, C.ptr_on_device(), ldc);
245 #ifdef HAVE_TPETRAKERNELS_MKL 246 template <
typename DeviceType>
250 GEMM (
const Teuchos::ETransp transA,
251 const Teuchos::ETransp transB,
253 const View<const double**, LayoutLeft, DeviceType>& A,
254 const View<const double**, LayoutLeft, DeviceType>& B,
256 const View<double**, LayoutLeft, DeviceType>& C)
258 const int n =
static_cast<int> (C.dimension_1 ());
262 if (n == 1 && transB == Teuchos::NO_TRANS) {
264 if (transA == Teuchos::TRANS) {
267 else if (transA == Teuchos::CONJ_TRANS) {
270 auto B_0 = Kokkos::subview (B, Kokkos::ALL (), 0);
271 auto C_0 = Kokkos::subview (C, Kokkos::ALL (), 0);
272 KokkosBlas::gemv (&trans, alpha, A, B_0, beta, C_0);
275 const int m =
static_cast<int> (C.dimension_0 ());
276 const int k =
static_cast<int> (transA == Teuchos::NO_TRANS ? A.dimension_1 () : A.dimension_0 ());
277 const int lda =
static_cast<int> (Impl::getStride2DView (A));
278 const int ldb =
static_cast<int> (Impl::getStride2DView (B));
279 const int ldc =
static_cast<int> (Impl::getStride2DView (C));
281 Teuchos::BLAS<int,double> blas;
282 blas.GEMM (transA, transB, m, n, k, alpha,
283 A.ptr_on_device(), lda,
284 B.ptr_on_device(), ldb,
285 beta, C.ptr_on_device(), ldc);
289 #endif // HAVE_TPETRAKERNELS_MKL 291 #ifdef KOKKOS_HAVE_CUDA 292 template <
typename Scalar>
293 struct DeviceGEMM<Scalar,Cuda> {
296 GEMM (
const Teuchos::ETransp transA,
297 const Teuchos::ETransp transB,
299 const View<const Scalar**, LayoutLeft, Cuda>& A,
300 const View<const Scalar**,LayoutLeft,Cuda>& B,
302 const View<Scalar**,LayoutLeft,Cuda>& C)
304 TEUCHOS_TEST_FOR_EXCEPTION
305 (
true, std::logic_error,
"DeviceGEMM: Kokkos::Cuda has no support " 306 "for GEMM operations over Scalar=" << Teuchos::typeName(alpha) <<
".");
311 struct DeviceGEMM<float,Cuda> {
314 GEMM (
const Teuchos::ETransp transA,
315 const Teuchos::ETransp transB,
317 const View<const float**,LayoutLeft,Cuda>& A,
318 const View<const float**,LayoutLeft,Cuda>& B,
320 const View<float**,LayoutLeft,Cuda>& C)
322 const int m =
static_cast<int>(C.dimension_0()),
323 n = static_cast<int>(C.dimension_1()),
324 k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
325 lda = static_cast<int>(Impl::getStride2DView(A)),
326 ldb = static_cast<int>(Impl::getStride2DView(B)),
327 ldc = static_cast<int>(Impl::getStride2DView(C));
328 const char char_transA = (transA == Teuchos::NO_TRANS ?
'N' :
'T'),
329 char_transB = (transB == Teuchos::NO_TRANS ?
'N' :
'T');
330 cublasSgemm (char_transA, char_transB, m, n, k, alpha,
331 A.ptr_on_device(), lda, B.ptr_on_device(),
332 ldb, beta, C.ptr_on_device(), ldc);
334 #ifdef HAVE_KOKKOS_DEBUG 335 const cublasStatus info = cublasGetError ();
336 TEUCHOS_TEST_FOR_EXCEPTION
337 (info != CUBLAS_STATUS_SUCCESS, std::runtime_error,
338 "cublasSgemm failed with status " << info <<
"." );
339 #endif // HAVE_KOKKOS_DEBUG 344 struct DeviceGEMM<double,Cuda> {
346 static void GEMM(Teuchos::ETransp transA, Teuchos::ETransp transB,
double alpha,
347 View<const double**,LayoutLeft,Cuda> A, View<const double**,LayoutLeft,Cuda> B,
348 double beta, View<double**,LayoutLeft,Cuda> C) {
349 const int m =
static_cast<int>(C.dimension_0()),
350 n = static_cast<int>(C.dimension_1()),
351 k = (transA == Teuchos::NO_TRANS ? A.dimension_1() : A.dimension_0()),
352 lda = static_cast<int>(Impl::getStride2DView(A)),
353 ldb = static_cast<int>(Impl::getStride2DView(B)),
354 ldc = static_cast<int>(Impl::getStride2DView(C));
355 const char char_transA = (transA == Teuchos::NO_TRANS ?
'N' :
'T'),
356 char_transB = (transB == Teuchos::NO_TRANS ?
'N' :
'T');
357 cublasDgemm(char_transA, char_transB, m, n, k, alpha, A.ptr_on_device(), lda, B.ptr_on_device(), ldb, beta, C.ptr_on_device(), ldc);
358 #ifdef HAVE_KOKKOS_DEBUG 359 cublasStatus info = cublasGetError();
360 TEUCHOS_TEST_FOR_EXCEPTION( info != CUBLAS_STATUS_SUCCESS, std::runtime_error,
"cublasDgemm failed with status " << info <<
"." );
368 #endif // KOKKOS_MV_GEMM_HPP KOKKOS_INLINE_FUNCTION void GEMV(const CoeffType &alpha, const BlkType &A, const VecType1 &x, const VecType2 &y)
y := y + alpha * A * x (dense matrix-vector multiply)
KOKKOS_INLINE_FUNCTION void GEMM(const char transA[], const char transB[], const CoefficientType &alpha, const ViewType1 &A, const ViewType2 &B, const CoefficientType &beta, const ViewType3 &C)
Small dense matrix-matrix multiply: C := alpha*A*B + beta*C
Class that provides GEMM for a particular Kokkos Device.