Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
Signed-off-by: nscipione <[email protected]>
  • Loading branch information
s-Nick committed Dec 11, 2024
1 parent cdd5e1e commit 8e59102
Show file tree
Hide file tree
Showing 8 changed files with 637 additions and 631 deletions.
1,076 changes: 537 additions & 539 deletions include/oneapi/math/blas/detail/generic/blas_ct.hxx

Large diffs are not rendered by default.

64 changes: 32 additions & 32 deletions src/blas/backends/generic/generic_batch.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ void axpy_batch(sycl::queue& queue, std::int64_t n, float alpha, sycl::buffer<fl
std::int64_t incx, std::int64_t stridex, sycl::buffer<float, 1>& y,
std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey,
batch_size);
batch_size);
}

void axpy_batch(sycl::queue& queue, std::int64_t n, double alpha, sycl::buffer<double, 1>& x,
std::int64_t incx, std::int64_t stridex, sycl::buffer<double, 1>& y,
std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey,
batch_size);
batch_size);
}

void axpy_batch(sycl::queue& queue, std::int64_t n, std::complex<float> alpha,
Expand Down Expand Up @@ -172,8 +172,8 @@ void gemm_batch(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math
sycl::buffer<float, 1>& b, std::int64_t ldb, std::int64_t stride_b, float beta,
sycl::buffer<float, 1>& c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
CALL_GENERIC_BLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
}

void gemm_batch(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::transpose transb,
Expand All @@ -182,8 +182,8 @@ void gemm_batch(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math
sycl::buffer<double, 1>& b, std::int64_t ldb, std::int64_t stride_b, double beta,
sycl::buffer<double, 1>& c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda,
stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
CALL_GENERIC_BLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size);
}

void gemm_batch(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::transpose transb,
Expand Down Expand Up @@ -277,16 +277,16 @@ void omatcopy_batch(sycl::queue& queue, oneapi::math::transpose trans, std::int6
std::int64_t n, float alpha, sycl::buffer<float, 1>& a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<float, 1>& b, std::int64_t ldb,
std::int64_t stride_b, std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, ldb,
stride_b, batch_size);
CALL_GENERIC_BLAS_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b,
ldb, stride_b, batch_size);
}

void omatcopy_batch(sycl::queue& queue, oneapi::math::transpose trans, std::int64_t m,
std::int64_t n, double alpha, sycl::buffer<double, 1>& a, std::int64_t lda,
std::int64_t stride_a, sycl::buffer<double, 1>& b, std::int64_t ldb,
std::int64_t stride_b, std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, ldb,
stride_b, batch_size);
CALL_GENERIC_BLAS_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b,
ldb, stride_b, batch_size);
}

void omatcopy_batch(sycl::queue& queue, oneapi::math::transpose trans, std::int64_t m,
Expand Down Expand Up @@ -337,8 +337,8 @@ void omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
sycl::buffer<float, 1>& b, std::int64_t ldb, std::int64_t stride_b,
sycl::buffer<float, 1>& c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, stride_a,
beta, b, ldb, stride_b, c, ldc, stride_c, batch_size);
CALL_GENERIC_BLAS_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda,
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size);
}

void omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
Expand All @@ -347,8 +347,8 @@ void omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
sycl::buffer<double, 1>& b, std::int64_t ldb, std::int64_t stride_b,
sycl::buffer<double, 1>& c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size) {
CALL_GENERIC_BLAS_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, stride_a,
beta, b, ldb, stride_b, c, ldc, stride_c, batch_size);
CALL_GENERIC_BLAS_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda,
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size);
}

void omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
Expand Down Expand Up @@ -605,16 +605,16 @@ sycl::event axpy_batch(sycl::queue& queue, std::int64_t n, float alpha, const fl
std::int64_t incx, std::int64_t stridex, float* y, std::int64_t incy,
std::int64_t stridey, std::int64_t batch_size,
const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey,
batch_size, dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy,
stridey, batch_size, dependencies);
}

sycl::event axpy_batch(sycl::queue& queue, std::int64_t n, double alpha, const double* x,
std::int64_t incx, std::int64_t stridex, double* y, std::int64_t incy,
std::int64_t stridey, std::int64_t batch_size,
const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey,
batch_size, dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy,
stridey, batch_size, dependencies);
}

sycl::event axpy_batch(sycl::queue& queue, std::int64_t n, std::complex<float> alpha,
Expand Down Expand Up @@ -764,9 +764,9 @@ sycl::event gemm_batch(sycl::queue& queue, oneapi::math::transpose transa,
std::int64_t stride_b, float beta, float* c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha,
a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event gemm_batch(sycl::queue& queue, oneapi::math::transpose transa,
Expand All @@ -776,9 +776,9 @@ sycl::event gemm_batch(sycl::queue& queue, oneapi::math::transpose transa,
std::int64_t stride_b, double beta, double* c, std::int64_t ldc,
std::int64_t stride_c, std::int64_t batch_size,
const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a,
lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha,
a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event gemm_batch(sycl::queue& queue, oneapi::math::transpose transa,
Expand Down Expand Up @@ -920,17 +920,17 @@ sycl::event omatcopy_batch(sycl::queue& queue, oneapi::math::transpose trans, st
std::int64_t n, float alpha, const float* a, std::int64_t lda,
std::int64_t stride_a, float* b, std::int64_t ldb, std::int64_t stride_b,
std::int64_t batch_size, const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b,
ldb, stride_b, batch_size, dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a,
b, ldb, stride_b, batch_size, dependencies);
}

sycl::event omatcopy_batch(sycl::queue& queue, oneapi::math::transpose trans, std::int64_t m,
std::int64_t n, double alpha, const double* a, std::int64_t lda,
std::int64_t stride_a, double* b, std::int64_t ldb,
std::int64_t stride_b, std::int64_t batch_size,
const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b,
ldb, stride_b, batch_size, dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a,
b, ldb, stride_b, batch_size, dependencies);
}

sycl::event omatcopy_batch(sycl::queue& queue, oneapi::math::transpose trans, std::int64_t m,
Expand Down Expand Up @@ -984,8 +984,8 @@ sycl::event omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
float* c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda,
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size,
dependencies);
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
Expand All @@ -995,8 +995,8 @@ sycl::event omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
double* c, std::int64_t ldc, std::int64_t stride_c,
std::int64_t batch_size, const std::vector<sycl::event>& dependencies) {
CALL_GENERIC_BLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda,
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size,
dependencies);
stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size,
dependencies);
}

sycl::event omatadd_batch(sycl::queue& queue, oneapi::math::transpose transa,
Expand Down
10 changes: 5 additions & 5 deletions src/blas/backends/generic/generic_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ template <typename InputT>
struct generic_type;

#define DEF_GENERIC_BLAS_TYPE(onemath_t, generic_t) \
template <> \
struct generic_type<onemath_t> { \
using type = generic_t; \
template <> \
struct generic_type<onemath_t> { \
using type = generic_t; \
};

DEF_GENERIC_BLAS_TYPE(sycl::queue, handle_t)
Expand Down Expand Up @@ -210,15 +210,15 @@ struct throw_if_unsupported_by_device {
throw unimplemented("blas", "onemath_sycl_blas function"); \
}

#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
if constexpr (is_column_major()) { \
detail::throw_if_unsupported_by_device<double, sycl::aspect::fp64>{}( \
" generic BLAS function requiring fp64 support", __VA_ARGS__); \
detail::throw_if_unsupported_by_device<sycl::half, sycl::aspect::fp16>{}( \
" generic BLAS function requiring fp16 support", __VA_ARGS__); \
auto args = detail::convert_to_generic_type(__VA_ARGS__); \
auto fn = [](auto&&... targs) { \
return genericFunc(std::forward<decltype(targs)>(targs)...).back(); \
return genericFunc(std::forward<decltype(targs)>(targs)...).back(); \
}; \
try { \
return std::apply(fn, args); \
Expand Down
3 changes: 2 additions & 1 deletion src/blas/backends/generic/generic_level1.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ sycl::event sdsdot(sycl::queue& queue, std::int64_t n, real_t sb, const real_t*
[&](sycl::handler& cgh) { cgh.single_task([=]() { result[0] = real_t(0); }); });
std::vector<sycl::event> new_dependencies = dependencies;
new_dependencies.emplace_back(init_res_val);
CALL_GENERIC_BLAS_USM_FN(::blas::_sdsdot, queue, n, sb, x, incx, y, incy, result, new_dependencies);
CALL_GENERIC_BLAS_USM_FN(::blas::_sdsdot, queue, n, sb, x, incx, y, incy, result,
new_dependencies);
}

sycl::event nrm2(sycl::queue& queue, std::int64_t n, const std::complex<real_t>* x,
Expand Down
Loading

0 comments on commit 8e59102

Please sign in to comment.