Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add raw pointer cases to math builtin tests #909

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions tests/math_builtin_api/math_builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ void check_function(sycl_cts::util::logger& log, funT fun,
}

template <int N, typename returnT, typename funT, typename argT>
void check_function_multi_ptr_private(
sycl_cts::util::logger& log, funT fun, sycl_cts::resultRef<returnT> ref,
argT ptrRef, float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
void check_function_ptr_private(sycl_cts::util::logger& log, funT fun,
sycl_cts::resultRef<returnT> ref, argT ptrRef,
float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
sycl::range<1> ndRng(1);
returnT kernelResult;
argT kernelResultArg;
Expand Down Expand Up @@ -274,11 +274,11 @@ void check_function_multi_ptr_private(
}

template <int N, typename returnT, typename funT, typename argT>
void check_function_multi_ptr_global(
sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef, float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
void check_function_ptr_global(sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef,
float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
sycl::range<1> ndRng(1);
returnT kernelResult;
auto&& testQueue = once_per_unit::get_queue();
Expand Down Expand Up @@ -308,11 +308,11 @@ void check_function_multi_ptr_global(
}

template <int N, typename returnT, typename funT, typename argT>
void check_function_multi_ptr_local(
sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef, float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
void check_function_ptr_local(sycl_cts::util::logger& log, funT fun, argT arg,
sycl_cts::resultRef<returnT> ref, argT ptrRef,
float accuracy = 0.0f,
AccuracyMode accuracy_mode = AccuracyMode::ULP,
const std::string& comment = {}) {
sycl::range<1> ndRng(1);
returnT kernelResult;
auto&& testQueue = once_per_unit::get_queue();
Expand Down Expand Up @@ -364,7 +364,7 @@ void test_function(funT fun) {
}

template <int T, typename returnT, typename funT, typename argT>
void test_function_multi_ptr_global(funT fun, argT arg) {
void test_function_ptr_global(funT fun, argT arg) {
sycl::range<1> ndRng(1);
returnT* kernelResult = new returnT[1];
auto&& testQueue = once_per_unit::get_queue();
Expand All @@ -384,7 +384,7 @@ void test_function_multi_ptr_global(funT fun, argT arg) {
}

template <int T, typename returnT, typename funT, typename argT>
void test_function_multi_ptr_local(funT fun, argT arg) {
void test_function_ptr_local(funT fun, argT arg) {
sycl::range<1> ndRng(1);
returnT* kernelResult = new returnT[1];
auto&& testQueue = once_per_unit::get_queue();
Expand Down
73 changes: 46 additions & 27 deletions tests/math_builtin_api/modules/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"local" : ("""
{
$DECL
test_function_multi_ptr_local<$TEST_ID, $RETURN_TYPE>(
test_function_ptr_local<$TEST_ID, $RETURN_TYPE>(
[=]($ACCESSOR acc){
$FUNCTION_CALL"
}, $DATA);
Expand All @@ -27,7 +27,7 @@
"global" : ("""
{
$DECL
test_function_multi_ptr_global<$TEST_ID, $RETURN_TYPE>(
test_function_ptr_global<$TEST_ID, $RETURN_TYPE>(
[=]($ACCESSOR acc){
$FUNCTION_CALL
}, $DATA);
Expand All @@ -49,7 +49,7 @@
"private" : ("""
{
$PTR_REF
check_function_multi_ptr_private<$TEST_ID, $RETURN_TYPE>(log,
check_function_ptr_private<$TEST_ID, $RETURN_TYPE>(log,
[=]{
$FUNCTION_PRIVATE_CALL
}, ref, refPtr$ACCURACY$COMMENT);
Expand All @@ -59,7 +59,7 @@
"local" : ("""
{
$PTR_REF
check_function_multi_ptr_local<$TEST_ID, $RETURN_TYPE>(log,
check_function_ptr_local<$TEST_ID, $RETURN_TYPE>(log,
[=]($ACCESSOR acc){
$FUNCTION_CALL
}, $DATA, ref, refPtr$ACCURACY$COMMENT);
Expand All @@ -69,7 +69,7 @@
"global" : ("""
{
$PTR_REF
check_function_multi_ptr_global<$TEST_ID, $RETURN_TYPE>(log,
check_function_ptr_global<$TEST_ID, $RETURN_TYPE>(log,
[=]($ACCESSOR acc){
$FUNCTION_CALL
}, $DATA, ref, refPtr$ACCURACY$COMMENT);
Expand Down Expand Up @@ -117,25 +117,38 @@ def generate_value(base_type, dim):
values = [str(generate_literal_value(base_type)) + get_literal_suffix(base_type) for _ in range(dim)]
return ','.join(values)

def generate_multi_ptr(var_name, var_type, memory, decorated):
raw_ptr_arg_decl_template = Template("""
std::add_pointer_t<${var_type}> ${var_name} = sycl::multi_ptr<${var_type}, sycl::access::address_space::${addr_space}_space, sycl::access::decorated::no>(${accessor}).get_raw();
""")
multi_ptr_arg_decl_template = Template("""
sycl::multi_ptr<${var_type}, sycl::access::address_space::${addr_space}_space, ${decorated}> ${var_name}(${accessor});
""")
def generate_ptr(var_name, var_type, memory, decorated_or_raw):
decl = ""
value = generate_value(var_type.base_type, var_type.dim)
if memory == "global":
source_name = "multiPtrSourceData"
if memory == "global" or memory == "local":
source_name = "ptrSourceData"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::global_space," + decorated + "> "
decl += var_name + "(acc);\n"
if memory == "local":
source_name = "multiPtrSourceData"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::local_space," + decorated + "> "
decl += var_name + "(acc);\n"
if decorated_or_raw == "raw":
decl += raw_ptr_arg_decl_template.substitute(var_type=var_type.name,
var_name=var_name,
addr_space=memory,
accessor="acc")
else:
decl += multi_ptr_arg_decl_template.substitute(var_type=var_type.name,
var_name=var_name,
addr_space=memory,
decorated=decorated_or_raw,
accessor="acc")
if memory == "private":
source_name = "multiPtrSourceData"
source_name = "ptrSourceData"
decl = var_type.name + " " + source_name + "(" + value + ");\n"
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::private_space," + decorated + "> "
decl += var_name + " = sycl::address_space_cast<sycl::access::address_space::private_space," + decorated + ">(&"
decl += source_name + ");\n"
if decorated_or_raw == "raw":
decl += "std::add_pointer_t<" + var_type.name + "> " + var_name + " = &" + source_name + ";\n"
else:
decl += "sycl::multi_ptr<" + var_type.name + ", sycl::access::address_space::private_space, " + decorated_or_raw + "> "
decl += var_name + " = sycl::address_space_cast<sycl::access::address_space::private_space, " + decorated_or_raw + ">(&"
decl += source_name + ");\n"
return decl

def generate_variable(var_name, var_type, var_index):
Expand All @@ -161,7 +174,7 @@ def generate_arguments_clamp(sig):
return (arg_names, " ".join(args))


def generate_arguments(sig, memory, decorated):
def generate_arguments(sig, memory, decorated_or_raw):
arg_src = ""
arg_names = []
arg_index = 0
Expand All @@ -178,7 +191,7 @@ def generate_arguments(sig, memory, decorated):

current_arg = ""
if is_pointer:
current_arg = generate_multi_ptr(arg_name, arg, memory, decorated )
current_arg = generate_ptr(arg_name, arg, memory, decorated_or_raw )
else:
current_arg = generate_variable(arg_name, arg, arg_index)
arg_src += current_arg + " "
Expand Down Expand Up @@ -266,7 +279,7 @@ def generate_function_call(sig, arg_names, arg_src):
function_private_call_template = Template("""
${arg_src}
${ret_type} res = ${namespace}::${func_name}(${arg_names});
return privatePtrCheck<${ret_type}, ${arg_type}>(res, multiPtrSourceData);
return privatePtrCheck<${ret_type}, ${arg_type}>(res, ptrSourceData);
""")
def generate_function_private_call(sig, arg_names, arg_src, types):
fc = function_private_call_template.substitute(
Expand All @@ -292,24 +305,24 @@ def generate_reference(sig, arg_names, arg_src):

reference_ptr_template = Template("""
${arg_src}
${arg_type} refPtr = multiPtrSourceData;
${arg_type} refPtr = ptrSourceData;
sycl_cts::resultRef<${ret_type}> ref = reference::${func_name}(${arg_names}, &refPtr);
""")
def generate_reference_ptr(types, sig, arg_names, arg_src):
fc = reference_ptr_template.substitute(
arg_src=re.sub(r'^sycl::multi_ptr.*\n?', '', arg_src, flags=re.MULTILINE),
arg_src=re.sub(r'^.*sycl::multi_ptr.*\n?', '', arg_src, flags=re.MULTILINE),
func_name=sig.name,
ret_type=sig.ret_type.name,
arg_names=",".join(arg_names[:-1]),
arg_type=sig.arg_types[-1].name)
return fc

def generate_test_case(test_id, types, sig, memory, check, decorated = ""):
def generate_test_case(test_id, types, sig, memory, check, decorated_or_raw = ""):
testCaseSource = test_case_templates_check[memory] if check else test_case_templates[memory]
testCaseId = str(test_id)
# for the clamp function we use a separate argument generator to make sure that its preconditions are met,
# otherwise argument generation for clamp would be completely random.
(arg_names, arg_src) = generate_arguments(sig, memory, decorated) if sig.name != "clamp" else generate_arguments_clamp(sig)
(arg_names, arg_src) = generate_arguments(sig, memory, decorated_or_raw) if sig.name != "clamp" else generate_arguments_clamp(sig)
testCaseSource = testCaseSource.replace("$REFERENCE", generate_reference(sig, arg_names, arg_src))
testCaseSource = testCaseSource.replace("$PTR_REF", generate_reference_ptr(types, sig, arg_names, arg_src))
testCaseSource = testCaseSource.replace("$TEST_ID", testCaseId)
Expand All @@ -333,7 +346,7 @@ def generate_test_case(test_id, types, sig, memory, check, decorated = ""):
if memory != "private" and memory !="no_ptr":
# We rely on the fact that all SYCL math builtins have at most one arguments as pointer.
pointerType = sig.arg_types[sig.pntr_indx[0] - 1]
sourcePtrDataName = "multiPtrSourceData"
sourcePtrDataName = "ptrSourceData"
sourcePtrData = generate_variable(sourcePtrDataName, pointerType, 0)
testCaseSource = testCaseSource.replace("$DECL", sourcePtrData)
testCaseSource = testCaseSource.replace("$DATA", sourcePtrDataName)
Expand All @@ -357,14 +370,20 @@ def generate_test_cases(test_id, types, sig_list, check):
test_id += 1
test_source += generate_test_case(test_id, types, sig, "private", check, decorated_yes)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "private", check, "raw")
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, decorated_no)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, decorated_yes)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "local", check, "raw")
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, decorated_no)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, decorated_yes)
test_id += 1
test_source += generate_test_case(test_id, types, sig, "global", check, "raw")
test_id += 1
else:
if check:
test_source += generate_test_case(test_id, types, sig, "no_ptr", check)
Expand Down
Loading