Skip to content

Commit

Permalink
Fix FP8 test case (#3243)
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored and TedThemistokleous committed Aug 21, 2024
1 parent 0d549b4 commit 2161d29
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
42 changes: 28 additions & 14 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ struct gemm_impl
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
auto common_args =
create_strided_batched_args_common(ctx, compute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex3,
common_args,
rocblas_gemm_algo_standard,
Expand All @@ -272,7 +273,7 @@ struct gemm_impl
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
auto common_args = create_gemm_ex_args_common(ctx, compute_type, input_args);
rocblas_invoke(&rocblas_gemm_ex3,
common_args,
rocblas_gemm_algo_standard,
Expand All @@ -285,7 +286,8 @@ struct gemm_impl
{
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
auto common_args =
create_strided_batched_args_common(ctx, compute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
Expand All @@ -294,7 +296,7 @@ struct gemm_impl
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
auto common_args = create_gemm_ex_args_common(ctx, compute_type, input_args);
rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
Expand Down Expand Up @@ -333,7 +335,7 @@ struct gemm_impl

if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
auto common_args = create_strided_batched_args_common(ctx, compute_type, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_strided_batched_ex,
common_args,
rocblas_gemm_algo_solution_index,
Expand All @@ -342,7 +344,7 @@ struct gemm_impl
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
auto common_args = create_gemm_ex_args_common(ctx, compute_type, input_args);
check_valid = rocblas_invoke(&rocblas_gemm_ex,
common_args,
rocblas_gemm_algo_solution_index,
Expand All @@ -369,7 +371,9 @@ struct gemm_impl
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto create_strided_batched_args_common(context& ctx, const std::vector<argument>& args) const
auto create_strided_batched_args_common(context& ctx,
rb_compute_type rbcompute_type,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
Expand All @@ -396,7 +400,7 @@ struct gemm_impl
ldd,
d_stride,
num_matrices,
compute_type);
rbcompute_type);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
Expand All @@ -408,7 +412,9 @@ struct gemm_impl
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto create_gemm_ex_args_common(context& ctx, const std::vector<argument>& args) const
auto create_gemm_ex_args_common(context& ctx,
rb_compute_type rbcompute_type,
const std::vector<argument>& args) const
{
return pack(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
Expand All @@ -430,7 +436,7 @@ struct gemm_impl
is_3inputs ? args[3].data() : args[2].data(),
output_type,
ldd,
compute_type);
rbcompute_type);
}

#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
Expand All @@ -455,9 +461,16 @@ struct gemm_impl
//
rocblas_int list_size = 0;
std::vector<rocblas_int> solution_indices;
rb_compute_type rbcompute_type = compute_type;
// rocblas_gemm_get_solutions() API requires compute_type as rocblas_datatype. Convert
// manually for FP8
if(arg_type == rocblas_datatype_f8_r)
{
rbcompute_type = rocblas_datatype_f32_r;
}
if(strided_batched)
{
auto common_args = create_strided_batched_args_common(ctx, input_args);
auto common_args = create_strided_batched_args_common(ctx, rbcompute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
Expand All @@ -466,7 +479,8 @@ struct gemm_impl
&list_size);
solution_indices.resize(list_size);

auto common_sol_args = create_strided_batched_args_common(ctx, input_args);
auto common_sol_args =
create_strided_batched_args_common(ctx, rbcompute_type, input_args);
rocblas_invoke(&rocblas_gemm_strided_batched_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
Expand All @@ -476,7 +490,7 @@ struct gemm_impl
}
else
{
auto common_args = create_gemm_ex_args_common(ctx, input_args);
auto common_args = create_gemm_ex_args_common(ctx, rbcompute_type, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_args,
rocblas_gemm_algo_solution_index,
Expand All @@ -485,7 +499,7 @@ struct gemm_impl
&list_size);
solution_indices.resize(list_size);

auto common_sol_args = create_gemm_ex_args_common(ctx, input_args);
auto common_sol_args = create_gemm_ex_args_common(ctx, rbcompute_type, input_args);
rocblas_invoke(&rocblas_gemm_ex_get_solutions,
common_sol_args,
rocblas_gemm_algo_solution_index,
Expand Down
4 changes: 1 addition & 3 deletions test/verify/test_gemm_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,4 @@ struct test_gemm_add : verify_program<test_gemm_add<DType>>

template struct test_gemm_add<migraphx::shape::float_type>;
template struct test_gemm_add<migraphx::shape::half_type>;
// TODO: Investigate failure: rocblas_invoke: rocBLAS call failed with status 2
// Github Issue: #3199
// template struct test_gemm_add<migraphx::shape::fp8e4m3fnuz_type>;
template struct test_gemm_add<migraphx::shape::fp8e4m3fnuz_type>;

0 comments on commit 2161d29

Please sign in to comment.