From 76a181621a400528238d0704e4cc1879bf653c5f Mon Sep 17 00:00:00 2001 From: Zheming Jin Date: Tue, 27 Jun 2023 08:40:34 -0700 Subject: [PATCH 01/12] [MKL][ROCBLAS] add the support of omatcopy in row-major layout --- .../backends/rocblas/rocblas_extensions.cpp | 140 +++++++++++++++++- 1 file changed, 132 insertions(+), 8 deletions(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 315f9ce30..d4efb9013 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -363,24 +363,82 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.get_access(cgh); + auto b_acc = b.get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + const float beta = 0.f; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_sgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, &alpha, + a_, lda, &beta, a_, ldb, b_, ldb); + }); + }); } void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.get_access(cgh); + auto b_acc = b.get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + const double beta = 0.0; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_dgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, &alpha, + a_, lda, &beta, a_, ldb, b_, ldb); + }); + }); } void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + using rocDataType = typename RocEquivalentType>::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.get_access(cgh); + auto b_acc = b.get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + const rocDataType beta {0.f, 0.f}; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_cgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, (rocDataType *)&alpha, + a_, lda, &beta, a_, ldb, b_, ldb); + }); + }); } void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + using rocDataType = typename RocEquivalentType>::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.get_access(cgh); + auto b_acc = b.get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + const rocDataType beta {0.0, 0.0}; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_zgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, (rocDataType *)&alpha, + a_, lda, &beta, a_, ldb, b_, ldb); + }); + }); } void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, @@ -496,27 +554,93 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, float *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + const float beta = 0.f; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_sgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, &alpha, + a, lda, &beta, a, ldb, b, ldb); + }); + }); + return done; } sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, const double *a, int64_t lda, double *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + const double beta = 0.0; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_dgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, &alpha, + a, lda, &beta, a, ldb, b, ldb); + }); + }); + return done; } sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + using rocDataType = typename RocEquivalentType>::Type; + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + const rocDataType beta {0.f, 0.f}; + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_cgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, (rocDataType *)&alpha, + a_, lda, &beta, a_, ldb, b_, ldb); + }); + }); + return done; } sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); + using rocDataType = typename RocEquivalentType>::Type; + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + int64_t num_events = dependencies.size(); + for (int64_t i = 0; i < num_events; i++) { + cgh.depends_on(dependencies[i]); + } + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + const rocDataType beta {0.0, 0.0}; + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_zgeam, err, handle, get_rocblas_operation(trans), + rocblas_operation_none, m, n, (rocDataType *)&alpha, + a_, lda, &beta, a_, ldb, b_, ldb); + }); + }); + return done; } sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, From 9fb01eebce2f6e4ba5d8f15de64429f2c205e08d Mon Sep 17 00:00:00 2001 From: Zheming Jin Date: Tue, 22 Aug 2023 14:44:14 -0700 Subject: [PATCH 02/12] Update the omatcopy and omatadd functions in row and column majors --- .../backends/rocblas/rocblas_extensions.cpp | 523 +++++++++--------- src/blas/backends/rocblas/rocblas_helper.hpp | 11 + 2 files changed, 258 insertions(+), 276 deletions(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index d4efb9013..c485216a6 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -88,27 +88,41 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for column_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); +template +void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? n : m); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + }); + }); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +#define OMATCOPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +OMATCOPY_LAUNCHER(float, rocblas_sgeam) +OMATCOPY_LAUNCHER(double, rocblas_dgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_zgeam) -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +#undef OMATCOPY_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -130,31 +144,44 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for column_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); +template +void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, + T beta, sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, + lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, \ + b, ldb, c, ldc); \ + } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +OMATADD_LAUNCHER(float, rocblas_sgeam) +OMATADD_LAUNCHER(double, rocblas_dgeam) +OMATADD_LAUNCHER(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} // USM APIs @@ -220,31 +247,43 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for column_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, +template +sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? n : m); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + }); + }); + return done; } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +#define OMATCOPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, \ + ldb, dependencies); \ + } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +OMATCOPY_LAUNCHER_USM(float, rocblas_sgeam) +OMATCOPY_LAUNCHER_USM(double, rocblas_dgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_zgeam) -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +#undef OMATCOPY_LAUNCHER_USM sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, @@ -270,35 +309,44 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for column_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); +template +inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, + T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, + lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + return done; } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +#define OMATADD_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, TYPE alpha, const TYPE *a, int64_t lda, TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, \ + lda, beta, b, ldb, c, ldc, dependencies); \ + } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +OMATADD_LAUNCHER_USM(float, rocblas_sgeam) +OMATADD_LAUNCHER_USM(double, rocblas_dgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_zgeam) -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +#undef OMATADD_LAUNCHER_USM } // namespace column_major namespace row_major { @@ -361,85 +409,41 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for row_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { +template +void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); queue.submit([&](sycl::handler &cgh) { - auto a_acc = a.get_access(cgh); - auto b_acc = b.get_access(cgh); + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? m : n); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - const float beta = 0.f; + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_sgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, &alpha, - a_, lda, &beta, a_, ldb, b_, ldb); + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); }); }); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - overflow_check(m, n, lda, ldb); - queue.submit([&](sycl::handler &cgh) { - auto a_acc = a.get_access(cgh); - auto b_acc = b.get_access(cgh); - onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - const double beta = 0.0; - rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_dgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, &alpha, - a_, lda, &beta, a_, ldb, b_, ldb); - }); - }); -} +#define OMATCOPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - using rocDataType = typename RocEquivalentType>::Type; - overflow_check(m, n, lda, ldb); - queue.submit([&](sycl::handler &cgh) { - auto a_acc = a.get_access(cgh); - auto b_acc = b.get_access(cgh); - onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - const rocDataType beta {0.f, 0.f}; - rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_cgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, (rocDataType *)&alpha, - a_, lda, &beta, a_, ldb, b_, ldb); - }); - }); -} +OMATCOPY_LAUNCHER(float, rocblas_sgeam) +OMATCOPY_LAUNCHER(double, rocblas_dgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_zgeam) -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - using rocDataType = typename RocEquivalentType>::Type; - overflow_check(m, n, lda, ldb); - queue.submit([&](sycl::handler &cgh) { - auto a_acc = a.get_access(cgh); - auto b_acc = b.get_access(cgh); - onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - const rocDataType beta {0.0, 0.0}; - rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_zgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, (rocDataType *)&alpha, - a_, lda, &beta, a_, ldb, b_, ldb); - }); - }); -} +#undef OMATCOPY_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -461,31 +465,43 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for row_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); +template +void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, + T beta, sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, a_, + lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, \ + b, ldb, c, ldc); \ + } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +OMATADD_LAUNCHER(float, rocblas_sgeam) +OMATADD_LAUNCHER(double, rocblas_dgeam) +OMATADD_LAUNCHER(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER(std::complex, rocblas_zgeam) -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#undef OMATADD_LAUNCHER // USM APIs @@ -551,97 +567,43 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for row_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, - const std::vector &dependencies) { - overflow_check(m, n, lda, ldb); - auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - const float beta = 0.f; - rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_sgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, &alpha, - a, lda, &beta, a, ldb, b, ldb); - }); - }); - return done; -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, +template +sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - const double beta = 0.0; - rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_dgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, &alpha, - a, lda, &beta, a, ldb, b, ldb); - }); - }); - return done; -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - using rocDataType = typename RocEquivalentType>::Type; - overflow_check(m, n, lda, ldb); - auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? m : n); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); - const rocDataType beta {0.f, 0.f}; rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_cgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, (rocDataType *)&alpha, - a_, lda, &beta, a_, ldb, b_, ldb); - }); + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, nullptr, nullptr, ldb, b_, ldb); + }); }); return done; } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - using rocDataType = typename RocEquivalentType>::Type; - overflow_check(m, n, lda, ldb); - auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { - auto handle = sc.get_handle(queue); - const rocDataType beta {0.0, 0.0}; - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(rocblas_zgeam, err, handle, get_rocblas_operation(trans), - rocblas_operation_none, m, n, (rocDataType *)&alpha, - a_, lda, &beta, a_, ldb, b_, ldb); - }); - }); - return done; -} +#define OMATCOPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, \ + ldb, dependencies); \ + } + +OMATCOPY_LAUNCHER_USM(float, rocblas_sgeam) +OMATCOPY_LAUNCHER_USM(double, rocblas_dgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER_USM sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, @@ -667,35 +629,44 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for row_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); +template +inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, + T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, a_, + lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + return done; } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#define OMATADD_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, TYPE alpha, const TYPE *a, int64_t lda, TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, \ + lda, beta, b, ldb, c, ldc, dependencies); \ + } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +OMATADD_LAUNCHER_USM(float, rocblas_sgeam) +OMATADD_LAUNCHER_USM(double, rocblas_dgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_zgeam) -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#undef OMATADD_LAUNCHER_USM } // namespace row_major } // namespace rocblas diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index 75490e333..38fd20088 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -172,6 +172,17 @@ class hip_error : virtual public std::runtime_error { hipError_t hip_err; \ HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId); +#define ROCBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ + err = func(handle, __VA_ARGS__); \ + if (err != rocblas_status_success) { \ + throw rocblas_error(std::string(name) + std::string(" : "), err); \ + } \ + hipStream_t currentStreamId; \ + ROCBLAS_ERROR_FUNC(rocblas_get_stream, err, handle, ¤tStreamId); \ + hipError_t hip_err; \ + HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId); + + inline rocblas_operation get_rocblas_operation(oneapi::mkl::transpose trn) { switch (trn) { case oneapi::mkl::transpose::nontrans: return rocblas_operation_none; From ec5940b03550cd90c3427110d7918c6fa5a3ebe1 Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:44:08 -0400 Subject: [PATCH 03/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index c485216a6..b52c77899 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -106,7 +106,7 @@ void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose tr rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, lda, b_, ldb); }); }); } From 98e0e89b27d8de355f46e6faea2baa996856ca69 Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:44:22 -0400 Subject: [PATCH 04/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index b52c77899..3b659f343 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -253,6 +253,7 @@ sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, trans const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + const T beta = 0; auto done = queue.submit([&](sycl::handler &cgh) { cgh.depends_on(dependencies); const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); From c3d663002cdfef06a1791343ca11579d072fc00c Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:44:37 -0400 Subject: [PATCH 05/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 3b659f343..0717bcf2e 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -265,7 +265,7 @@ sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, trans rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, lda, b_, ldb); }); }); return done; From 11865e13e5c4c6cbaa8f90682c507469e04e7601 Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:44:47 -0400 Subject: [PATCH 06/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 0717bcf2e..cc9853682 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -416,6 +416,7 @@ void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose tr int64_t ldb) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + const T beta = 0; queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); From bcbad1eb333cfb640643f5e17353a81dd5ad4b94 Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:44:53 -0400 Subject: [PATCH 07/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index cc9853682..ff4d10d98 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -429,7 +429,7 @@ void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose tr rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, lda, b_, ldb); }); }); } From 1e2d1069d5e3f4f9e4f10708713ac4acc9d62a8c Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:44:58 -0400 Subject: [PATCH 08/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index ff4d10d98..888cdd583 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -575,6 +575,7 @@ sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, trans const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + const T beta = 0; auto done = queue.submit([&](sycl::handler &cgh) { cgh.depends_on(dependencies); const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); From b81bbc79bc0a7ea5c8a02d34f0fc63e44eeedd7e Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Sun, 24 Sep 2023 10:45:06 -0400 Subject: [PATCH 09/12] Update src/blas/backends/rocblas/rocblas_extensions.cpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_extensions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 888cdd583..c4a513cd4 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -587,7 +587,7 @@ sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, trans rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, nullptr, nullptr, ldb, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, ldb, b_, ldb); }); }); return done; From 611b36eef3b2f6b58803e208a3611ba1c278ec0e Mon Sep 17 00:00:00 2001 From: Jin Z <5zj@cousteau.ftpn.ornl.gov> Date: Sun, 24 Sep 2023 14:09:33 -0400 Subject: [PATCH 10/12] update the code to fix compile error --- src/blas/backends/rocblas/rocblas_extensions.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index c4a513cd4..1c6579424 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -94,6 +94,7 @@ void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose tr int64_t ldb) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + const T beta = 0; queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); From b5804b80ef12b950c8b0422a71a554f099d709fe Mon Sep 17 00:00:00 2001 From: Jin Z <5zj@cousteau.ftpn.ornl.gov> Date: Fri, 29 Sep 2023 21:13:09 -0400 Subject: [PATCH 11/12] clang format the source --- .../backends/rocblas/rocblas_extensions.cpp | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 1c6579424..59c20b03f 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -107,7 +107,8 @@ void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose tr rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, lda, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); }); }); } @@ -162,18 +163,18 @@ void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose tra auto c_ = sc.get_mem(c_acc); rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, - lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); }); }); } -#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ - void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ - omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, \ - b, ldb, c, ldc); \ +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, \ + beta, b, ldb, c, ldc); \ } OMATADD_LAUNCHER(float, rocblas_sgeam) @@ -183,7 +184,6 @@ OMATADD_LAUNCHER(std::complex, rocblas_zgeam) #undef OMATADD_LAUNCHER - // USM APIs sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, @@ -266,7 +266,8 @@ sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, trans rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, lda, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); }); }); return done; @@ -327,8 +328,8 @@ inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, - lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); }); }); return done; @@ -430,7 +431,8 @@ void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose tr rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, lda, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); }); }); } @@ -485,18 +487,18 @@ void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose tra auto c_ = sc.get_mem(c_acc); rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, a_, - lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); }); }); } -#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ - void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ - omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, \ - b, ldb, c, ldc); \ +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, \ + beta, b, ldb, c, ldc); \ } OMATADD_LAUNCHER(float, rocblas_sgeam) @@ -588,7 +590,8 @@ sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, trans rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), get_rocblas_operation(trans), logical_m, logical_n, - (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, ldb, b_, ldb); + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + ldb, b_, ldb); }); }); return done; @@ -649,8 +652,8 @@ inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, auto c_ = reinterpret_cast(c); rocblas_status err; ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, a_, - lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); }); }); return done; From 91b0c147d28b0cbc3f72a9ebd5776f85b197a7d9 Mon Sep 17 00:00:00 2001 From: jinz2014 <7799920+jinz2014@users.noreply.github.com> Date: Thu, 9 Nov 2023 14:40:29 -0500 Subject: [PATCH 12/12] Update src/blas/backends/rocblas/rocblas_helper.hpp Co-authored-by: Muhammad Tanvir <84532306+muhammad-tanvir-1211@users.noreply.github.com> --- src/blas/backends/rocblas/rocblas_helper.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index 38fd20088..601a02a14 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -182,7 +182,6 @@ class hip_error : virtual public std::runtime_error { hipError_t hip_err; \ HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId); - inline rocblas_operation get_rocblas_operation(oneapi::mkl::transpose trn) { switch (trn) { case oneapi::mkl::transpose::nontrans: return rocblas_operation_none;