Skip to content

Commit

Permalink
refactor: settle minmod limiters here
Browse files Browse the repository at this point in the history
  • Loading branch information
rzyu45 committed Nov 20, 2024
1 parent 2e0becf commit 47408ec
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 8 deletions.
2 changes: 1 addition & 1 deletion SolMuseum/num_api/pde/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .num_api_weno3 import *

from .minmod_limiter import *
107 changes: 107 additions & 0 deletions SolMuseum/num_api/pde/minmod_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
from numba import njit


@njit(cache=True)
def minmod(a, b, c):
if isinstance(a, (np.int32, np.int64, np.float64, np.float32, int, float)):
a = np.array([a])
a = a.reshape(-1)

if isinstance(b, (np.int32, np.int64, np.float64, np.float32, int, float)):
b = np.array([b])
b = b.reshape(-1)

if isinstance(c, (np.int32, np.int64, np.float64, np.float32, int, float)):
c = np.array([c])
c = c.reshape(-1)

# check the consistency of input length
if not (len(a) == len(b) == len(c)):
raise ValueError("Input length must be the same!")

res = np.zeros_like(a)

for i in range(len(a)):
if a[i] * b[i] > 0 and a[i] * c[i] > 0:
res[i] = np.min(np.abs(np.array([a[i], b[i], c[i]]))) * np.sign(a[i])

return res


@njit(cache=True)
def minmod_flag(a, b, c):
"""
Return the index of minmod_flag
"""

if isinstance(a, (np.int32, np.int64, np.float64, np.float32, int, float)):
a = np.array([a])
a = a.reshape(-1)

if isinstance(b, (np.int32, np.int64, np.float64, np.float32, int, float)):
b = np.array([b])
b = b.reshape(-1)

if isinstance(c, (np.int32, np.int64, np.float64, np.float32, int, float)):
c = np.array([c])
c = c.reshape(-1)

# check the consistency of input length
if not (len(a) == len(b) == len(c)):
raise ValueError("Input length must be the same!")

res = np.zeros_like(a).astype(np.int32)

for i in range(len(a)):
if a[i] * b[i] > 0 and a[i] * c[i] > 0:
res[i] = np.abs(np.array([a[i], b[i], c[i]])).argmin() + 1

return res


@njit(cache=True)
def switch_minmod(a, b, c, flag):
"""
Conditionally output the derivatives of minmod according to the flag
"""

if isinstance(a, (np.int32, np.int64, np.float64, np.float32, int, float)):
a = np.array([a])
a = a.reshape(-1)

if isinstance(b, (np.int32, np.int64, np.float64, np.float32, int, float)):
b = np.array([b])
b = b.reshape(-1)

if isinstance(c, (np.int32, np.int64, np.float64, np.float32, int, float)):
c = np.array([c])
c = c.reshape(-1)

if isinstance(flag, (np.int32, np.int64, np.float64, np.float32, int, float)):
flag = np.array([flag])
flag = flag.reshape(-1)

# if not (len(a) == len(b) == len(c) == len(flag)):
# raise ValueError("Input length must be the same!")

res = np.zeros_like(flag)

for i in range(len(flag)):
if flag[i] == 1:
if len(a) > 1:
res[i] = a[i]
else:
res[i] = a[0]
elif flag[i] == 2:
if len(b) > 1:
res[i] = b[i]
else:
res[i] = b[0]
elif flag[i] == 3:
if len(c) > 1:
res[i] = c[i]
else:
res[i] = c[0]

return res
26 changes: 26 additions & 0 deletions SolMuseum/pde/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
from Solverz.sym_algebra.functions import MulVarFunc
from sympy import Integer


class SolPde(MulVarFunc):
def _numpycode(self, printer, **kwargs):
return (f'SolMF.pde.{self.__class__.__name__}' + r'(' +
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')')


class minmod(SolPde):
arglength = 3

def _eval_derivative(self, s):
return switch_minmod(*[arg.diff(s) for arg in self.args],
minmod_flag(*self.args))


class minmod_flag(SolPde):
"""
Different from `minmod`, minmod function outputs the position of args instead of the values of args.
"""
arglength = 3

def _eval_derivative(self, s):
return Integer(0)


class switch_minmod(SolPde):
arglength = 4

def _eval_derivative(self, s):
return switch_minmod(*[arg.diff(s) for arg in self.args[0:len(self.args) - 1]], self.args[-1])
7 changes: 2 additions & 5 deletions SolMuseum/pde/gas/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from sympy import Add, Mul, Rational, Integer
import numpy as np

from ..basic import minmod
from .weno3.weno_pipe import weno_odep, weno_odeq


def cdm(p11, p10, p01, p00,
q11, q10, q01, q00,
var, lam, va, S, D, dx, dt):
Expand Down Expand Up @@ -543,15 +545,11 @@ def mol_tvd1_q_eqn_rhs0(p_list, q_list, S, va, lam, D, dx):
2 * D * S * p0)


from Solverz import minmod


def ux(theta, um1, u, up1, dx):
return minmod(theta * (u - um1) / dx, (up1 - um1) / (2 * dx), theta * (up1 - u) / dx)


def mol_tvd2_p_eqn_rhs(p_list, q_list, S, va, dx):

pm2, pm1, p0, pp1, pp2 = p_list
qm2, qm1, q0, qp1, qp2 = q_list

Expand Down Expand Up @@ -586,7 +584,6 @@ def Source(p, q):


def mol_tvd2_q_eqn_rhs(p_list, q_list, S, va, lam, D, dx):

pm2, pm1, p0, pp1, pp2 = p_list
qm2, qm1, q0, qp1, qp2 = q_list

Expand Down
3 changes: 1 addition & 2 deletions SolMuseum/pde/heat/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from Solverz import Eqn, Ode, AliasVar, TimeSeriesParam, Param
from Solverz import iVar, idx, Var, Abs
from Solverz.utilities.type_checker import is_integer, is_number

from Solverz import minmod
from ..basic import minmod


def ux(theta, um1, u, up1, dx):
Expand Down

0 comments on commit 47408ec

Please sign in to comment.