Skip to content

Commit

Permalink
Merge pull request #909 from steffenlarsen/steffen/raw_pointer_math
Browse files Browse the repository at this point in the history
Add raw pointer cases to math builtin tests
  • Loading branch information
steffenlarsen authored Jul 19, 2024
2 parents b4e3600 + 4476575 commit da9e27c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 44 deletions.
34 changes: 17 additions & 17 deletions tests/math_builtin_api/math_builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ void check_function(sycl_cts::util::logger& log, funT fun,
}

template <int N, typename returnT, typename funT, typename argT>
void check_function_multi_ptr_private(
sycl_cts::util::logger& log, funT fun, sycl_cts::resultRef<returnT> ref,
argT ptrRef, float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
void check_function_ptr_private(sycl_cts::util::logger& log, funT fun,
sycl_cts::resultRef<returnT> ref, argT ptrRef,
float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
sycl::range<1> ndRng(1);
returnT kernelResult;
argT kernelResultArg;
Expand Down Expand Up @@ -274,11 +274,11 @@ void check_function_multi_ptr_private(
}

template <int N, typename returnT, typename funT, typename argT>
void check_function_multi_ptr_global(
sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef, float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
void check_function_ptr_global(sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef,
float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
sycl::range<1> ndRng(1);
returnT kernelResult;
auto&& testQueue = once_per_unit::get_queue();
Expand Down Expand Up @@ -308,11 +308,11 @@ void check_function_multi_ptr_global(
}

template <int N, typename returnT, typename funT, typename argT>
void check_function_multi_ptr_local(
sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef, float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
void check_function_ptr_local(sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef,
float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
sycl::range<1> ndRng(1);
returnT kernelResult;
auto&& testQueue = once_per_unit::get_queue();
Expand Down Expand Up @@ -364,7 +364,7 @@ void test_function(funT fun) {
}

template <int T, typename returnT, typename funT, typename argT>
void test_function_multi_ptr_global(funT fun, argT arg) {
void test_function_ptr_global(funT fun, argT arg) {
sycl::range<1> ndRng(1);
returnT* kernelResult = new returnT[1];
auto&& testQueue = once_per_unit::get_queue();
Expand All @@ -384,7 +384,7 @@ void test_function_multi_ptr_global(funT fun, argT arg) {
}

template <int T, typename returnT, typename funT, typename argT>
void test_function_multi_ptr_local(funT fun, argT arg) {
void test_function_ptr_local(funT fun, argT arg) {
sycl::range<1> ndRng(1);
returnT* kernelResult = new returnT[1];
auto&& testQueue = once_per_unit::get_queue();
Expand Down
73 changes: 46 additions & 27 deletions tests/math_builtin_api/modules/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"local" : ("""
{
$DECL
test_function_multi_ptr_local<$TEST_ID, $RETURN_TYPE>(
test_function_ptr_local<$TEST_ID, $RETURN_TYPE>(
[=]($ACCESSOR acc){
$FUNCTION_CALL"
}, $DATA);
Expand All @@ -27,7 +27,7 @@
"global" : ("""
{
$DECL
test_function_multi_ptr_global<$TEST_ID, $RETURN_TYPE>(
test_function_ptr_global<$TEST_ID, $RETURN_TYPE>(
[=]($ACCESSOR acc){
$FUNCTION_CALL
}, $DATA);
Expand All @@ -49,7 +49,7 @@
"private" : ("""
{
$PTR_REF
check_function_multi_ptr_private<$TEST_ID, $RETURN_TYPE>(log,
check_function_ptr_private<$TEST_ID, $RETURN_TYPE>(log,
[=]{
$FUNCTION_PRIVATE_CALL
}, ref, refPtr$ACCURACY$COMMENT);
Expand All @@ -59,7 +59,7 @@
"local" : ("""
{
$PTR_REF
check_function_multi_ptr_local<$TEST_ID, $RETURN_TYPE>(log,
check_function_ptr_local<$TEST_ID, $RETURN_TYPE>(log,
[=]($ACCESSOR acc){
$FUNCTION_CALL
}, $DATA, ref, refPtr$ACCURACY$COMMENT);
Expand All @@ -69,7 +69,7 @@
"global" : ("""
{
$PTR_REF
check_function_multi_ptr_global<$TEST_ID, $RETURN_TYPE>(log,
check_function_ptr_global<$TEST_ID, $RETURN_TYPE>(log,
[=]($ACCESSOR acc){
$FUNCTION_CALL
}, $DATA, ref, refPtr$ACCURACY$COMMENT);
Expand Down Expand Up @@ -117,25 +117,38 @@ def generate_value(base_type, dim):
values = [str(generate_literal_value(base_type)) + get_literal_suffix(base_type) for _ in range(dim)]
return ','.join(values)

def generate_multi_ptr(var_name, var_type, memory, decorated):
raw_ptr_arg_decl_template = Template("""
std::add_pointer_t<${var_type}> ${var_name} = sycl::multi_ptr<${var_type}, sycl::access::address_space::${addr_space}_space, sycl::access::decorated::no>(${accessor}).get_raw();
""")
multi_ptr_arg_decl_template = Template("""
sycl::multi_ptr<${var_type}, sycl::access::address_space::${addr_space}_space, ${decorated}> ${var_name}(${accessor});
""")
def generate_ptr(var_name, var_type, memory, decorated_or_raw):
decl = ""
value = generate_value(var_type.base_type, var_type.dim)
if memory == "global":
source_name = "multiPtrSourceData"
if memory == "global" or memory == "local":
source_name = "ptrSourceData"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::global_space," + decorated + "> "
decl += var_name + "(acc);\n"
if memory == "local":
source_name = "multiPtrSourceData"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::local_space," + decorated + "> "
decl += var_name + "(acc);\n"
if decorated_or_raw == "raw":
decl += raw_ptr_arg_decl_template.substitute(var_type=var_type.name,
var_name=var_name,
addr_space=memory,
accessor="acc")
else:
decl += multi_ptr_arg_decl_template.substitute(var_type=var_type.name,
var_name=var_name,
addr_space=memory,
decorated=decorated_or_raw,
accessor="acc")
if memory == "private":
source_name = "multiPtrSourceData"
source_name = "ptrSourceData"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::private_space," + decorated + "> "
decl += var_name + " = sycl::address_space_cast<sycl::access::address_space::private_space," + decorated + ">(&"
decl += source_name + ");\n"
if decorated_or_raw == "raw":
decl += "std::add_pointer_t<" + var_type.name + "> " + var_name + " = &" + source_name + ";\n"
else:
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::private_space, " + decorated_or_raw + "> "
decl += var_name + " = sycl::address_space_cast<sycl::access::address_space::private_space, " + decorated_or_raw + ">(&"
decl += source_name + ");\n"
return decl

def generate_variable(var_name, var_type, var_index):
Expand All @@ -161,7 +174,7 @@ def generate_arguments_clamp(sig):
return (arg_names, " ".join(args))


def generate_arguments(sig, memory, decorated):
def generate_arguments(sig, memory, decorated_or_raw):
arg_src = ""
arg_names = []
arg_index = 0
Expand All @@ -178,7 +191,7 @@ def generate_arguments(sig, memory, decorated):

current_arg = ""
if is_pointer:
current_arg = generate_multi_ptr(arg_name, arg, memory, decorated )
current_arg = generate_ptr(arg_name, arg, memory, decorated_or_raw )
else:
current_arg = generate_variable(arg_name, arg, arg_index)
arg_src += current_arg + " "
Expand Down Expand Up @@ -266,7 +279,7 @@ def generate_function_call(sig, arg_names, arg_src):
function_private_call_template = Template("""
${arg_src}
${ret_type} res = ${namespace}::${func_name}(${arg_names});
return privatePtrCheck<${ret_type}, ${arg_type}>(res, multiPtrSourceData);
return privatePtrCheck<${ret_type}, ${arg_type}>(res, ptrSourceData);
""")
def generate_function_private_call(sig, arg_names, arg_src, types):
fc = function_private_call_template.substitute(
Expand All @@ -292,24 +305,24 @@ def generate_reference(sig, arg_names, arg_src):

reference_ptr_template = Template("""
${arg_src}
${arg_type} refPtr = multiPtrSourceData;
${arg_type} refPtr = ptrSourceData;
sycl_cts::resultRef<${ret_type}> ref = reference::${func_name}(${arg_names}, &refPtr);
""")
def generate_reference_ptr(types, sig, arg_names, arg_src):
fc = reference_ptr_template.substitute(
arg_src=re.sub(r'^sycl::multi_ptr.*\n?', '', arg_src, flags=re.MULTILINE),
arg_src=re.sub(r'^.*sycl::multi_ptr.*\n?', '', arg_src, flags=re.MULTILINE),
func_name=sig.name,
ret_type=sig.ret_type.name,
arg_names=",".join(arg_names[:-1]),
arg_type=sig.arg_types[-1].name)
return fc

def generate_test_case(test_id, types, sig, memory, check, decorated = ""):
def generate_test_case(test_id, types, sig, memory, check, decorated_or_raw = ""):
testCaseSource = test_case_templates_check[memory] if check else test_case_templates[memory]
testCaseId = str(test_id)
# for the clamp function we use a separate argument generator to make sure that its preconditions are met,
# otherwise argument generation for clamp would be completely random.
(arg_names, arg_src) = generate_arguments(sig, memory, decorated) if sig.name != "clamp" else generate_arguments_clamp(sig)
(arg_names, arg_src) = generate_arguments(sig, memory, decorated_or_raw) if sig.name != "clamp" else generate_arguments_clamp(sig)
testCaseSource = testCaseSource.replace("$REFERENCE", generate_reference(sig, arg_names, arg_src))
testCaseSource = testCaseSource.replace("$PTR_REF", generate_reference_ptr(types, sig, arg_names, arg_src))
testCaseSource = testCaseSource.replace("$TEST_ID", testCaseId)
Expand All @@ -333,7 +346,7 @@ def generate_test_case(test_id, types, sig, memory, check, decorated = ""):
if memory != "private" and memory !="no_ptr":
# We rely on the fact that all SYCL math builtins have at most one arguments as pointer.
pointerType = sig.arg_types[sig.pntr_indx[0] - 1]
sourcePtrDataName = "multiPtrSourceData"
sourcePtrDataName = "ptrSourceData"
sourcePtrData = generate_variable(sourcePtrDataName, pointerType, 0)
testCaseSource = testCaseSource.replace("$DECL", sourcePtrData)
testCaseSource = testCaseSource.replace("$DATA", sourcePtrDataName)
Expand All @@ -357,14 +370,20 @@ def generate_test_cases(test_id, types, sig_list, check):
test_id += 1
test_source += generate_test_case(test_id, types, sig, "private", check, decorated_yes)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "private", check, "raw")
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, decorated_no)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, decorated_yes)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, "raw")
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, decorated_no)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, decorated_yes)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, "raw")
test_id += 1
else:
if check:
test_source += generate_test_case(test_id, types, sig, "no_ptr", check)
Expand Down

0 comments on commit da9e27c

Please sign in to comment.