Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
YuriPlyakhin committed Jan 31, 2025
1 parent c5f0f39 commit 0ffd40a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
7 changes: 4 additions & 3 deletions sycl/test-e2e/Matrix/Inputs/joint_matrix_out_bounds_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
#include <iostream>
#include <sycl/usm.hpp>

template <size_t K, layout B_layout, unsigned int vnniFactor> class mult;
template <typename Tab, size_t K, layout B_layout, unsigned int vnniFactor>
class mult;

template <typename T1, typename T2, size_t M, size_t N, size_t K, size_t TM,
size_t TN, size_t TK, layout A_layout, layout B_layout,
Expand All @@ -19,11 +20,11 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
// Add one iteration for the out of bounds dpas instruction
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
size_t NDRangeN = N / TN;
size_t sg_size = get_sg_size<mult<K, B_layout, vnniFactor>>(q);
size_t sg_size = get_sg_size<mult<T2, K, B_layout, vnniFactor>>(q);
std::cout << "SG size: " << sg_size << " ";

q.submit([&](handler &cgh) {
cgh.parallel_for<mult<K, B_layout, vnniFactor>>(
cgh.parallel_for<mult<T2, K, B_layout, vnniFactor>>(
nd_range<2>({NDRangeM, NDRangeN * sg_size}, {1, 1 * sg_size}),
[=](nd_item<2> spmd_item)
#ifdef SG_SZ
Expand Down
14 changes: 7 additions & 7 deletions sycl/test-e2e/Matrix/joint_matrix_out_bounds_colmajor.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//==-------- joint_matrix_out_bounds.cpp - DPC++ joint_matrix--------------==//
//==----joint_matrix_out_bounds_colmajor.cpp - DPC++ joint_matrix---------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -25,20 +25,20 @@ int main() {
test<bfloat16, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "half A col major, B col major: ";
test<half, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
test<half, float, 1024 + 14, 1024, 1024 + 24, 8, 16, 16, layout::col_major,
layout::col_major, 1>();
std::cout << "int8 A col major, B col major: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 24, 8, 16, 32,
layout::col_major, layout::col_major, 2>();
layout::col_major, layout::col_major, 1>();

// unaligned k:
std::cout << "bf16 A col major, B col major: ";
test<bfloat16, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
std::cout << "half A col major, B col major: ";
test<half, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16,
layout::col_major, layout::col_major, 1>();
test<half, float, 1024 + 14, 1024, 1024 + 14, 8, 16, 16, layout::col_major,
layout::col_major, 1>();
std::cout << "int8 A col major, B col major: ";
test<int8_t, int32_t, 1024 + 14, 1024, 1024 + 14, 8, 16, 32,
layout::col_major, layout::col_major, 2>();
layout::col_major, layout::col_major, 1>();
}

0 comments on commit 0ffd40a

Please sign in to comment.