Skip to content

Commit

Permalink
sympy: drop older versions
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Feb 13, 2023
1 parent 0ad1787 commit fb1197b
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 51 deletions.
28 changes: 1 addition & 27 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast',
'DefFunction', 'InlineIf', 'Keyword', 'String', 'Macro', 'MacroArgument',
'CustomType', 'Deref', 'INT', 'FLOAT', 'DOUBLE', 'VOID',
'Null', 'SizeOf', 'rfunc', 'cast_mapper',
'BasicWrapperMixin']
'Null', 'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -705,28 +704,3 @@ def rfunc(func, item, *args):
min: Min,
max: Max,
}


def integer_args(*args):
"""
Check if expression is Integer.
Used to choose the function printed in the c-code
"""
if len(args) == 0:
return False

if len(args) == 1:
try:
return np.issubdtype(args[0].dtype, np.integer)
except AttributeError:
return args[0].is_integer
res = True
for a in args:
try:
if len(a.args) > 0:
res = res and integer_args(*a.args)
else:
res = res and integer_args(a)
except AttributeError:
res = res and integer_args(a)
return res
34 changes: 30 additions & 4 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from sympy.printing.c import C99CodePrinter

from devito.arch.compiler import AOMPCompiler
from devito.symbolics.extended_sympy import integer_args

__all__ = ['ccode']

Expand Down Expand Up @@ -104,12 +103,12 @@ def _print_Mod(self, expr):

def _print_Min(self, expr):
"""Print Min using devito defined header Min"""
func = 'MIN' if integer_args(*expr.args) else 'fmin'
func = 'MIN' if has_integer_args(*expr.args) else 'fmin'
return "%s(%s)" % (func, self._print(expr.args)[1:-1])

def _print_Max(self, expr):
"""Print Max using devito defined header Max"""
func = 'MAX' if integer_args(*expr.args) else 'fmax'
func = 'MAX' if has_integer_args(*expr.args) else 'fmax'
return "%s(%s)" % (func, self._print(expr.args)[1:-1])

def _print_Abs(self, expr):
Expand All @@ -118,7 +117,7 @@ def _print_Abs(self, expr):
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
func = "abs" if integer_args(*expr.args[0].args) else "fabs"
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs"
return "%s(%s)" % (func, self._print(expr.args[0]))

def _print_Add(self, expr, order=None):
Expand Down Expand Up @@ -259,3 +258,30 @@ def ccode(expr, **settings):
# to always use the correct one from our printer
if Version(sympy.__version__) >= Version("1.11"):
setattr(sympy.printing.str.StrPrinter, '_print_Add', CodePrinter._print_Add)


# Check arguements type
def has_integer_args(*args):
"""
Check if expression is Integer.
Used to choose the function printed in the c-code
"""
if len(args) == 0:
return False

if len(args) == 1:
try:
return np.issubdtype(args[0].dtype, np.integer)
except AttributeError:
return args[0].is_integer

res = True
for a in args:
try:
if len(a.args) > 0:
res = res and has_integer_args(*a.args)
else:
res = res and has_integer_args(a)
except AttributeError:
res = res and has_integer_args(a)
return res
36 changes: 18 additions & 18 deletions examples/performance/00_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -340,9 +340,9 @@
" {\n",
" for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1)\n",
" for (int y = y0_blk0; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1)\n",
" {\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
" {\n",
Expand Down Expand Up @@ -426,17 +426,17 @@
" {\n",
" for (int z0_blk0 = z_m; z0_blk0 <= z_M; z0_blk0 += z0_blk0_size)\n",
" {\n",
" for (int x0_blk1 = x0_blk0; x0_blk1 <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x0_blk1 += x0_blk1_size)\n",
" for (int x0_blk1 = x0_blk0; x0_blk1 <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x0_blk1 += x0_blk1_size)\n",
" {\n",
" for (int y0_blk1 = y0_blk0; y0_blk1 <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y0_blk1 += y0_blk1_size)\n",
" for (int y0_blk1 = y0_blk0; y0_blk1 <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y0_blk1 += y0_blk1_size)\n",
" {\n",
" for (int z0_blk1 = z0_blk0; z0_blk1 <= MIN(z0_blk0 + z0_blk0_size - 1, z_M); z0_blk1 += z0_blk1_size)\n",
" for (int z0_blk1 = z0_blk0; z0_blk1 <= MIN(z_M, z0_blk0 + z0_blk0_size - 1); z0_blk1 += z0_blk1_size)\n",
" {\n",
" for (int x = x0_blk1; x <= MIN(x0_blk1 + x0_blk1_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk1; x <= MIN(x_M, x0_blk1 + x0_blk1_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk1; y <= MIN(y0_blk1 + y0_blk1_size - 1, y_M); y += 1)\n",
" for (int y = y0_blk1; y <= MIN(y_M, y0_blk1 + y0_blk1_size - 1); y += 1)\n",
" {\n",
" for (int z = z0_blk1; z <= MIN(z0_blk1 + z0_blk1_size - 1, z_M); z += 1)\n",
" for (int z = z0_blk1; z <= MIN(z_M, z0_blk1 + z0_blk1_size - 1); z += 1)\n",
" {\n",
" u[t1][x + 4][y + 4][z + 4] = ((-6.66666667e-1F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 1][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 2][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 4][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 5][z + 4]/h_y) + (-8.33333333e-2F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 4][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 5][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 7][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 8][z + 4]/h_y) + (8.33333333e-2F/h_y)*(8.33333333e-2F*u[t0][x + 4][y][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 1][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 3][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 4][z + 4]/h_y) + (6.66666667e-1F/h_y)*(8.33333333e-2F*u[t0][x + 4][y + 3][z + 4]/h_y - 6.66666667e-1F*u[t0][x + 4][y + 4][z + 4]/h_y + 6.66666667e-1F*u[t0][x + 4][y + 6][z + 4]/h_y - 8.33333333e-2F*u[t0][x + 4][y + 7][z + 4]/h_y))*sin(f[x + 1][y + 1][z + 1])*pow(f[x + 1][y + 1][z + 1], 2);\n",
" }\n",
Expand Down Expand Up @@ -1272,17 +1272,17 @@
" {\n",
" for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0 - 2, ys = 0; y <= MIN(y0_blk0 + y0_blk0_size + 1, y_M + 2); y += 1, ys += 1)\n",
" for (int y = y0_blk0 - 2, ys = 0; y <= MIN(y_M + 2, y0_blk0 + y0_blk0_size + 1); y += 1, ys += 1)\n",
" {\n",
" #pragma omp simd aligned(u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
" {\n",
" r2[ys][z] = r1*(8.33333333e-2F*(u[t0][x + 4][y + 2][z + 4] - u[t0][x + 4][y + 6][z + 4]) + 6.66666667e-1F*(-u[t0][x + 4][y + 3][z + 4] + u[t0][x + 4][y + 5][z + 4]));\n",
" }\n",
" }\n",
" for (int y = y0_blk0, ys = 0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1, ys += 1)\n",
" for (int y = y0_blk0, ys = 0; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1)\n",
" {\n",
" #pragma omp simd aligned(f,u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
Expand Down Expand Up @@ -1390,9 +1390,9 @@
" {\n",
" for (int y0_blk0 = y_m; y0_blk0 <= y_M; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n",
" for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n",
" {\n",
" for (int yy = yii, ysi = (yy + ys + 2)%(5); yy <= 2; yy += 1, ysi = (yy + ys + 2)%(5))\n",
" {\n",
Expand Down Expand Up @@ -1706,9 +1706,9 @@
" {\n",
" for (int y0_blk0 = y_m - 2; y0_blk0 <= y_M + 2; y0_blk0 += y0_blk0_size)\n",
" {\n",
" for (int x = x0_blk0; x <= MIN(x0_blk0 + x0_blk0_size - 1, x_M + 2); x += 1)\n",
" for (int x = x0_blk0; x <= MIN(x_M + 2, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0; y <= MIN(y0_blk0 + y0_blk0_size - 1, y_M + 2); y += 1)\n",
" for (int y = y0_blk0; y <= MIN(y_M + 2, y0_blk0 + y0_blk0_size - 1); y += 1)\n",
" {\n",
" #pragma omp simd aligned(u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
Expand All @@ -1728,9 +1728,9 @@
" {\n",
" for (int y1_blk0 = y_m; y1_blk0 <= y_M; y1_blk0 += y1_blk0_size)\n",
" {\n",
" for (int x = x1_blk0; x <= MIN(x1_blk0 + x1_blk0_size - 1, x_M); x += 1)\n",
" for (int x = x1_blk0; x <= MIN(x_M, x1_blk0 + x1_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y1_blk0; y <= MIN(y1_blk0 + y1_blk0_size - 1, y_M); y += 1)\n",
" for (int y = y1_blk0; y <= MIN(y_M, y1_blk0 + y1_blk0_size - 1); y += 1)\n",
" {\n",
" #pragma omp simd aligned(f,u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
Expand Down Expand Up @@ -1788,7 +1788,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.9.16"
},
"varInspector": {
"cols": {
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
pip>=9.0.1
numpy>1.16
sympy>=1.7,<1.12
sympy>=1.9,<1.12
scipy
flake8>=2.1.0
nbval
Expand Down
2 changes: 1 addition & 1 deletion tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_codegen_quality1():
assert all('const long' not in str(i) for i in exprs[-3:])

# Only two access macros necessary, namely `uL0` and `r1L0` (the other five
# obviously are _POSIX_C_SOURCE, Min, Max, START_TIMER, STOP_TIMER)
# obviously are _POSIX_C_SOURCE, MIN, MAX, START_TIMER, STOP_TIMER)
assert len(op._headers) == 7


Expand Down

0 comments on commit fb1197b

Please sign in to comment.