Skip to content

Commit

Permalink
Fix format try 2.
Browse files Browse the repository at this point in the history
Signed-off-by: JackAKirk <[email protected]>
  • Loading branch information
JackAKirk committed Oct 7, 2024
1 parent 27b251f commit 94dcc7e
Showing 1 changed file with 43 additions and 54 deletions.
97 changes: 43 additions & 54 deletions src/blas/backends/cublas/cublas_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,23 +168,19 @@ inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose tran
auto c_ = sc.get_mem<cuTypeC *>(c_acc);
cublasStatus_t err;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx,
err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c_,
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
CUBLAS_ERROR_FUNC_T(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx",
cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta,
c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c,
batch_size, get_cublas_datatype<cuTypeS>(),
cublas_gemm_algo);
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b_, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c_, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
});
});
Expand Down Expand Up @@ -622,23 +618,19 @@ inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose tra
auto handle = sc.get_handle(queue);
cublasStatus_t err;
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx,
err, handle, get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta, c,
get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
CUBLAS_ERROR_FUNC_T(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC("cublasGemmStridedBatchedEx",
cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa),
get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b,
get_cublas_datatype<cuTypeB>(), ldb, stride_b, &beta,
c, get_cublas_datatype<cuTypeC>(), ldc, stride_c,
batch_size, get_cublas_datatype<cuTypeS>(),
cublas_gemm_algo);
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle,
get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a,
get_cublas_datatype<cuTypeA>(), lda, stride_a, b, get_cublas_datatype<cuTypeB>(),
ldb, stride_b, &beta, c, get_cublas_datatype<cuTypeC>(), ldc, stride_c, batch_size,
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
});
});
Expand Down Expand Up @@ -714,26 +706,23 @@ inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, tr
cublasStatus_t err;
for (int64_t i = 0; i < group_count; i++) {
#ifdef SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND
CUBLAS_ERROR_FUNC_T("cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]),
get_cublas_operation(transb[i]), (int)m[i], (int)n[i],
(int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i],
(const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(),
(int)ldc[i], (int)group_size[i],
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
CUBLAS_ERROR_FUNC_T(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#else
CUBLAS_ERROR_FUNC_T_SYNC(
"cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle,
get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i],
(int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset),
get_cublas_datatype<cuTypeA>(), (int)lda[i],
(const void *const *)(b + offset), get_cublas_datatype<cuTypeB>(),
(int)ldb[i], &beta[i], (void *const *)(c + offset),
get_cublas_datatype<cuTypeC>(), (int)ldc[i], (int)group_size[i],
get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
get_cublas_datatype<cuTypeA>(), (int)lda[i], (const void *const *)(b + offset),
get_cublas_datatype<cuTypeB>(), (int)ldb[i], &beta[i],
(void *const *)(c + offset), get_cublas_datatype<cuTypeC>(), (int)ldc[i],
(int)group_size[i], get_cublas_datatype<cuTypeS>(), cublas_gemm_algo);
#endif
offset += group_size[i];
}
Expand Down Expand Up @@ -832,13 +821,13 @@ inline sycl::event trsm_batch(const char *func_name, Func func, sycl::queue &que
for (int64_t i = 0; i < group_count; i++) {
auto **a_ = reinterpret_cast<const cuDataType **>(a);
auto **b_ = reinterpret_cast<cuDataType **>(b);
cublas_native_named_func(func_name, func, err, handle,
get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]),
get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i],
b_ + offset, (int)ldb[i], (int)group_size[i]);
cublas_native_named_func(
func_name, func, err, handle, get_cublas_side_mode(left_right[i]),
get_cublas_fill_mode(upper_lower[i]), get_cublas_operation(trans[i]),
get_cublas_diag_type(unit_diag[i]), (int)m[i], (int)n[i],
(cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i],
(int)group_size[i]);

offset += group_size[i];
}
});
Expand Down

0 comments on commit 94dcc7e

Please sign in to comment.