Skip to content

Commit

Permalink
Made dace::math::pow forward to std::pow more generic (#1583)
Browse files Browse the repository at this point in the history
This overload solution will forward the argument types separately and
use the same return type.

---------

Co-authored-by: Berke Ates <[email protected]>
Co-authored-by: Berke Ates <[email protected]>
Co-authored-by: Philipp Schaad <[email protected]>
Co-authored-by: Philip Mueller <[email protected]>
  • Loading branch information
5 people authored Jun 12, 2024
1 parent e8aebc0 commit 0e3e098
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
26 changes: 4 additions & 22 deletions dace/runtime/include/dace/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,36 +473,29 @@ namespace dace
return (thrust::complex<T>)thrust::pow(a, b);
}
#endif
template<typename T>
DACE_CONSTEXPR DACE_HDFI T pow(const T& a, const T& b)
template<typename T, typename U>
DACE_CONSTEXPR DACE_HDFI auto pow(const T& a, const U& b)
{
return (T)std::pow(a, b);
return std::pow(a, b);
}

#ifndef DACE_XILINX
static DACE_CONSTEXPR DACE_HDFI int pow(const int& a, const int& b)
{
/*#ifndef __CUDA_ARCH__
return std::pow(a, b);
#else*/
if (b < 0) return 0;
int result = 1;
for (int i = 0; i < b; ++i)
result *= a;
return result;
//#endif
}

static DACE_CONSTEXPR DACE_HDFI unsigned int pow(const unsigned int& a,
const unsigned int& b)
{
/*#ifndef __CUDA_ARCH__
return std::pow(a, b);
#else*/
unsigned int result = 1;
for (unsigned int i = 0; i < b; ++i)
result *= a;
return result;
//#endif
}
#endif

Expand All @@ -514,17 +507,6 @@ namespace dace
return result;
}

template<typename T>
DACE_CONSTEXPR DACE_HDFI T pow(const T& a, const int& b)
{
return (T)std::pow(a, (T)b);
}
template<typename T>
DACE_CONSTEXPR DACE_HDFI T pow(const T& a, const unsigned int& b)
{
return (T)std::pow(a, (T)b);
}

template<typename T, typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
DACE_CONSTEXPR DACE_HDFI T ifloor(const T& a)
{
Expand Down
29 changes: 29 additions & 0 deletions tests/python_frontend/power_operator_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.

import pytest
import dace
import numpy as np

Expand Down Expand Up @@ -35,8 +37,35 @@ def test_pow_neg_exp():
pow_neg_exp(a=res)
assert (res[0] == 0.01)

in_types = [dace.float32, dace.float64, dace.int8, dace.int16, dace.int32, dace.int64]

@pytest.mark.parametrize("a_type", in_types)
@pytest.mark.parametrize("b_type", in_types)
def test_pow_types(a_type, b_type):
@dace.program
def pow_types(A: a_type[1], B: b_type[1], R: dace.float64[1]):
with dace.tasklet(dace.Language.Python):
scalar_a << A[0]
scalar_b << B[0]
scalar_r >> R[0]
scalar_r = scalar_a ** scalar_b

# a ** b needs to fit into the smallest type (int8)
a = np.random.rand(1) * 4
b = np.random.rand(1) * 4
r = np.random.rand(1).astype(np.float64)

a = a.astype(a_type.as_numpy_dtype())
b = b.astype(b_type.as_numpy_dtype())

pow_types(A=a, B=b, R=r)
assert np.allclose(r, a ** b)

if __name__ == "__main__":
test_pow_num_literals()
test_pow_op_preced()
test_pow_neg_exp()

for a_type in in_types:
for b_type in in_types:
test_pow_types(a_type, b_type)

0 comments on commit 0e3e098

Please sign in to comment.