diff --git a/.gitignore b/.gitignore index 2a3a157c9..631826c77 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,4 @@ .git/ # Build -build/ +build*/ diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index f5bee1232..6bc4f0d0d 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -261,11 +261,10 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); detail::gemm(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -273,15 +272,14 @@ static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose tran } static inline void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); detail::gemm(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); gemm_postcondition(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - static inline void gemm_batch(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp b/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp index aad01181e..2eeef6fae 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp @@ -25,6 +25,7 @@ #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" namespace oneapi { diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index dea57bb71..62143da20 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -413,19 +413,17 @@ static inline void gemm(backend_selector selector, transpose t std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); - static inline void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - half alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + cl::sycl::half alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); static inline void gemm(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - float alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); - static inline void herk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, float beta, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hpp b/include/oneapi/mkl/blas/detail/blas_loader.hpp index fe4bd892e..17abdbc1d 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hpp +++ b/include/oneapi/mkl/blas/detail/blas_loader.hpp @@ -24,6 +24,7 @@ #include #include +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/detail/export.hpp" diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index 60183d179..2ef6f7069 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -398,13 +398,13 @@ ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, tran cl::sycl::buffer, 1> &c, std::int64_t ldc); ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - half alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + cl::sycl::half alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); ONEMKL_EXPORT void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, - float alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); ONEMKL_EXPORT void syr2(oneapi::mkl::device libkey, cl::sycl::queue &queue, uplo upper_lower, diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp index be19fe05d..6644cdc98 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hpp @@ -26,6 +26,7 @@ #include "oneapi/mkl/types.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp" #include "oneapi/mkl/blas/detail/blas_ct_backends.hpp" diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index bd74e9715..8a4f7ff00 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::cublas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp index 0b3fc2fb6..3948625a1 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hpp @@ -23,6 +23,7 @@ #include #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" namespace oneapi { namespace mkl { diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index bf2f502b0..f81325f2d 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -496,17 +496,15 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); - void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); - void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp index c35069646..de032e8f9 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hpp @@ -25,6 +25,7 @@ #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" #include "oneapi/mkl/blas/detail/blas_ct_backends.hpp" diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index 8956f15fc..33217b8ff 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::mklcpu::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index 1dfee1647..f6abd97aa 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::mklgpu::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index 935ecc59d..9d58cbfea 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -743,11 +743,10 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); oneapi::mkl::blas::netlib::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, a, @@ -757,8 +756,8 @@ void gemm(backend_selector selector, transpose transa, transpos } void gemm(backend_selector selector, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { gemm_precondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); @@ -767,7 +766,6 @@ void gemm(backend_selector selector, transpose transa, transpos gemm_postcondition(selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, cl::sycl::buffer &a, std::int64_t lda) { diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index 866aab302..cd58aa9f9 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -46,19 +46,17 @@ ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, cl::sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, std::int64_t ldc); - ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, half alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc); + std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc); ONEMKL_EXPORT void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t k, float alpha, cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); - ONEMKL_EXPORT void symm(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, diff --git a/include/oneapi/mkl/blas/predicates.hpp b/include/oneapi/mkl/blas/predicates.hpp index 5a668be89..ebaddc02f 100644 --- a/include/oneapi/mkl/blas/predicates.hpp +++ b/include/oneapi/mkl/blas/predicates.hpp @@ -26,6 +26,7 @@ #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" namespace oneapi { namespace mkl { diff --git a/include/oneapi/mkl/blas/predicates.hxx b/include/oneapi/mkl/blas/predicates.hxx index 19f7a3b76..88cbb63ab 100644 --- a/include/oneapi/mkl/blas/predicates.hxx +++ b/include/oneapi/mkl/blas/predicates.hxx @@ -1519,20 +1519,20 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo } inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add prechecks to queue here for input args. */ #endif } inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add postchecks to queue here for input args. */ #endif @@ -1540,8 +1540,8 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add prechecks to queue here for input args. */ @@ -1550,14 +1550,13 @@ inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpos inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, float beta, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { #ifndef ONEMKL_DISABLE_PREDICATES /* add postchecks to queue here for input args. */ #endif } - inline void syr2_precondition(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, cl::sycl::buffer &x, std::int64_t incx, cl::sycl::buffer &y, std::int64_t incy, @@ -4749,11 +4748,10 @@ inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpo /* add postchecks to queue here for input args. */ #endif } - inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - const half *a, std::int64_t lda, const half *b, std::int64_t ldb, - half beta, half *c, std::int64_t ldc, + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + const cl::sycl::half *a, std::int64_t lda, const cl::sycl::half *b, std::int64_t ldb, + cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { #ifndef ONEMKL_DISABLE_PREDICATES /* add prechecks to queue here for input args. */ @@ -4761,15 +4759,14 @@ inline void gemm_precondition(cl::sycl::queue &queue, transpose transa, transpos } inline void gemm_postcondition(cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, - const half *a, std::int64_t lda, const half *b, std::int64_t ldb, - half beta, half *c, std::int64_t ldc, + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, + const cl::sycl::half *a, std::int64_t lda, const cl::sycl::half *b, std::int64_t ldb, + cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { #ifndef ONEMKL_DISABLE_PREDICATES /* add postchecks to queue here for input args. */ #endif } - inline void syr2_precondition(cl::sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, std::int64_t incx, const float *y, std::int64_t incy, float *a, std::int64_t lda, diff --git a/include/oneapi/mkl/detail/backend_selector.hpp b/include/oneapi/mkl/detail/backend_selector.hpp index b0c763ae0..9b5aef3c4 100644 --- a/include/oneapi/mkl/detail/backend_selector.hpp +++ b/include/oneapi/mkl/detail/backend_selector.hpp @@ -27,6 +27,8 @@ namespace oneapi { namespace mkl { +using namespace cl; + template class backend_selector { public: diff --git a/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp b/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp index e63c6ab56..e3d856958 100644 --- a/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp +++ b/include/oneapi/mkl/rng/detail/curand/onemkl_rng_curand.hpp @@ -50,7 +50,7 @@ * NOTICE. This Software was developed under funding from the U.S. Department * of Energy and the U.S. Government consequently retains certain rights. As * such, the U.S. Government has been granted for itself and others acting on - * its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the + * its becl::sycl::half a paid-up, nonexclusive, irrevocable, worldwide license in the * Software to reproduce, distribute copies to the public, prepare derivative * works, and perform publicly and display publicly, and to permit others to do * so. diff --git a/src/blas/backends/cublas/cublas_level3.cpp b/src/blas/backends/cublas/cublas_level3.cpp index 671a15ea7..341f35dee 100644 --- a/src/blas/backends/cublas/cublas_level3.cpp +++ b/src/blas/backends/cublas/cublas_level3.cpp @@ -99,6 +99,7 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, }); }); } +#ifdef ENABLE_HALF_ROUTINES #define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, \ CUDADATATYPE_C) \ @@ -109,9 +110,18 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } - -GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) -GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) +#else +#define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, \ + CUDADATATYPE_C) \ + void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_C alpha, cl::sycl::buffer &a, int64_t lda, \ + cl::sycl::buffer &b, int64_t ldb, TYPE_C beta, \ + cl::sycl::buffer &c, int64_t ldc) { \ + throw unimplemented("blas", "gemm", "half is disabled"); \ + } +#endif +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, cl::sycl::half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) #undef GEMM_EX_LAUNCHER @@ -465,14 +475,12 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM - cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, - const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, cl::sycl::half alpha, const cl::sycl::half *a, std::int64_t lda, + const cl::sycl::half *b, std::int64_t ldb, cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { throw unimplemented("blas", "gemm", "for column_major layout"); } - template inline cl::sycl::event symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, @@ -860,10 +868,8 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, gemm(CUBLAS_ROUTINE, CUDADATATYPE_A, CUDADATATYPE_B, CUDADATATYPE_C, queue, transa, \ transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } - -GEMM_EX_LAUNCHER(half, half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) -GEMM_EX_LAUNCHER(half, half, half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) - +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, float, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_32F) +GEMM_EX_LAUNCHER(cl::sycl::half, cl::sycl::half, cl::sycl::half, cublasGemmEx, CUDA_R_16F, CUDA_R_16F, CUDA_R_16F) #undef GEMM_EX_LAUNCHER template @@ -1065,14 +1071,12 @@ GEMM_LAUNCHER_USM(std::complex, cublasCgemm) GEMM_LAUNCHER_USM(std::complex, cublasZgemm) #undef GEMM_LAUNCHER_USM - cl::sycl::event gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, const half *a, std::int64_t lda, - const half *b, std::int64_t ldb, half beta, half *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, cl::sycl::half alpha, const cl::sycl::half *a, std::int64_t lda, + const cl::sycl::half *b, std::int64_t ldb, cl::sycl::half beta, cl::sycl::half *c, std::int64_t ldc, const cl::sycl::vector_class &dependencies) { throw unimplemented("blas", "gemm", "for row_major layout"); } - template inline cl::sycl::event symm(Func func, cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, diff --git a/src/blas/backends/mklcpu/mklcpu_level3.cpp b/src/blas/backends/mklcpu/mklcpu_level3.cpp index 694a3eb60..128465e98 100644 --- a/src/blas/backends/mklcpu/mklcpu_level3.cpp +++ b/src/blas/backends/mklcpu/mklcpu_level3.cpp @@ -19,6 +19,7 @@ #include +#include "oneapi/mkl/exceptions.hpp" #include "mklcpu_common.hpp" #include "fp16.hpp" #include "oneapi/mkl/blas/detail/mklcpu/onemkl_blas_mklcpu.hpp" diff --git a/src/blas/backends/mklcpu/mklcpu_level3.cxx b/src/blas/backends/mklcpu/mklcpu_level3.cxx index 7917be1ca..46a966cbc 100644 --- a/src/blas/backends/mklcpu/mklcpu_level3.cxx +++ b/src/blas/backends/mklcpu/mklcpu_level3.cxx @@ -98,9 +98,10 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, } void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, + int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, cl::sycl::half beta, cl::sycl::buffer &c, int64_t ldc) { +#ifdef ENABLE_HALF_ROUTINES auto a_fp16 = a.reinterpret(a.get_range()); auto b_fp16 = b.reinterpret(b.get_range()); auto c_fp16 = c.reinterpret(c.get_range()); @@ -131,7 +132,7 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, copy_mat(accessor_c, MKLMAJOR, transpose::N, m, n, ldc, 0.0f, f32_c); ::cblas_sgemm(CBLASMAJOR, transa_, transb_, m, n, k, f32_alpha, f32_a, lda, f32_b, ldb, f32_beta, f32_c, ldc); - // copy C back to half + // copy C back to cl::sycl::half fp16 co = 0.0f; copy_mat(f32_c, MKLMAJOR, m, n, ldc, offset::F, &co, accessor_c); ::free(f32_a); @@ -139,12 +140,15 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, ::free(f32_c); }); }); +#else + throw oneapi::mkl::unimplemented("blas", "gemm", "half is disabled"); +#endif } - void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, + int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc) { +#ifdef ENABLE_HALF_ROUTINES auto a_fp16 = a.reinterpret(a.get_range()); auto b_fp16 = b.reinterpret(b.get_range()); queue.submit([&](cl::sycl::handler &cgh) { @@ -172,8 +176,10 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, ::free(f32_b); }); }); +#else + throw oneapi::mkl::unimplemented("blas", "cl::sycl::half", "when using hipSYCL"); +#endif } - void hemm(cl::sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, diff --git a/src/blas/backends/mklgpu/mklgpu_common.hpp b/src/blas/backends/mklgpu/mklgpu_common.hpp index 86c4512e3..f23e432a1 100644 --- a/src/blas/backends/mklgpu/mklgpu_common.hpp +++ b/src/blas/backends/mklgpu/mklgpu_common.hpp @@ -779,15 +779,14 @@ void cgemmt(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_UPLO upper_lower, MKL cl::sycl::buffer, 1> &a, int64_t lda, cl::sycl::buffer, 1> &b, int64_t ldb, std::complex beta, cl::sycl::buffer, 1> &c, int64_t ldc); - void hgemm(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, - int64_t m, int64_t n, int64_t k, half alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, half beta, cl::sycl::buffer &c, + int64_t m, int64_t n, int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, cl::sycl::half beta, cl::sycl::buffer &c, int64_t ldc); void gemm_f16f16f32(cl::sycl::queue &queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, MKL_TRANSPOSE transb, int64_t m, int64_t n, int64_t k, float alpha, - cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, + cl::sycl::buffer &a, int64_t lda, cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc); cl::sycl::event gemm_s8u8s32_sycl(cl::sycl::queue *queue, MKL_LAYOUT layout, MKL_TRANSPOSE transa, diff --git a/src/blas/backends/mklgpu/mklgpu_level3.cxx b/src/blas/backends/mklgpu/mklgpu_level3.cxx index 594b33b5d..6d7a04b90 100644 --- a/src/blas/backends/mklgpu/mklgpu_level3.cxx +++ b/src/blas/backends/mklgpu/mklgpu_level3.cxx @@ -56,25 +56,23 @@ void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::tr ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { ::oneapi::mkl::gpu::hgemm(queue, MAJOR, ::mkl::cblas_convert(transa), ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { ::oneapi::mkl::gpu::gemm_f16f16f32(queue, MAJOR, ::mkl::cblas_convert(transa), ::mkl::cblas_convert(transb), m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void symm(cl::sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, std::int64_t m, std::int64_t n, float alpha, cl::sycl::buffer &a, std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, diff --git a/src/blas/backends/netlib/netlib_level3.cxx b/src/blas/backends/netlib/netlib_level3.cxx index 54ce343da..693905fc3 100644 --- a/src/blas/backends/netlib/netlib_level3.cxx +++ b/src/blas/backends/netlib/netlib_level3.cxx @@ -92,9 +92,9 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, } void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -104,8 +104,8 @@ void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, std::int64 } void gemm(cl::sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, - cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, + int64_t k, float alpha, cl::sycl::buffer &a, int64_t lda, + cl::sycl::buffer &b, int64_t ldb, float beta, cl::sycl::buffer &c, int64_t ldc) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index 206cb5846..b864bafe9 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -877,23 +877,21 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].column_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].column_major_hgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].column_major_gemm_f16f16f32_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void hemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, @@ -3495,23 +3493,21 @@ void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, function_tables[libkey].row_major_zgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, half beta, - cl::sycl::buffer &c, std::int64_t ldc) { + std::int64_t m, std::int64_t n, std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, cl::sycl::half beta, + cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].row_major_hgemm_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void gemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, transpose transa, transpose transb, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc) { function_tables[libkey].row_major_gemm_f16f16f32_sycl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } - void hemm(oneapi::mkl::device libkey, cl::sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, std::int64_t n, std::complex alpha, cl::sycl::buffer, 1> &a, std::int64_t lda, diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index 48c04f29c..6f05e57ab 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -24,6 +24,7 @@ #include #include #include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/config.hpp" typedef struct { int version; @@ -569,15 +570,16 @@ typedef struct { cl::sycl::buffer, 1> &c, std::int64_t ldc); void (*column_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, - std::int64_t ldb, half beta, cl::sycl::buffer &c, + std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, + std::int64_t ldb, cl::sycl::half beta, cl::sycl::buffer &c, std::int64_t ldc); + void (*column_major_gemm_f16f16f32_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); void (*column_major_chemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::side left_right, @@ -2088,14 +2090,14 @@ typedef struct { cl::sycl::buffer, 1> &c, std::int64_t ldc); void (*row_major_hgemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, - std::int64_t k, half alpha, cl::sycl::buffer &a, - std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, - half beta, cl::sycl::buffer &c, std::int64_t ldc); + std::int64_t k, cl::sycl::half alpha, cl::sycl::buffer &a, + std::int64_t lda, cl::sycl::buffer &b, std::int64_t ldb, + cl::sycl::half beta, cl::sycl::buffer &c, std::int64_t ldc); void (*row_major_gemm_f16f16f32_sycl)(cl::sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, - cl::sycl::buffer &a, std::int64_t lda, - cl::sycl::buffer &b, std::int64_t ldb, + cl::sycl::buffer &a, std::int64_t lda, + cl::sycl::buffer &b, std::int64_t ldb, float beta, cl::sycl::buffer &c, std::int64_t ldc); void (*row_major_chemm_sycl)(cl::sycl::queue &queue, oneapi::mkl::side left_right, diff --git a/src/config.hpp.in b/src/config.hpp.in index 957820cc9..e7784c489 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -25,6 +25,7 @@ #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine ENABLE_NETLIB_BACKEND +#cmakedefine ENABLE_HALF_ROUTINES #cmakedefine BUILD_SHARED_LIBS #endif diff --git a/src/rng/backends/mklcpu/mrg32k3a.cpp b/src/rng/backends/mklcpu/mrg32k3a.cpp index 6d38d78d6..163767121 100755 --- a/src/rng/backends/mklcpu/mrg32k3a.cpp +++ b/src/rng/backends/mklcpu/mrg32k3a.cpp @@ -33,6 +33,8 @@ namespace mkl { namespace rng { namespace mklcpu { +using namespace cl; + class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { public: mrg32k3a_impl(cl::sycl::queue queue, std::uint32_t seed) diff --git a/src/rng/backends/mklcpu/philox4x32x10.cpp b/src/rng/backends/mklcpu/philox4x32x10.cpp index f204912f4..b65d375d6 100644 --- a/src/rng/backends/mklcpu/philox4x32x10.cpp +++ b/src/rng/backends/mklcpu/philox4x32x10.cpp @@ -33,6 +33,8 @@ namespace mkl { namespace rng { namespace mklcpu { +using namespace cl; + class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { public: philox4x32x10_impl(cl::sycl::queue queue, std::uint64_t seed)