-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: settle minmod limiters here
- Loading branch information
Showing
5 changed files
with
137 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .num_api_weno3 import * | ||
|
||
from .minmod_limiter import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import numpy as np | ||
from numba import njit | ||
|
||
|
||
@njit(cache=True) | ||
def minmod(a, b, c): | ||
if isinstance(a, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
a = np.array([a]) | ||
a = a.reshape(-1) | ||
|
||
if isinstance(b, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
b = np.array([b]) | ||
b = b.reshape(-1) | ||
|
||
if isinstance(c, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
c = np.array([c]) | ||
c = c.reshape(-1) | ||
|
||
# check the consistency of input length | ||
if not (len(a) == len(b) == len(c)): | ||
raise ValueError("Input length must be the same!") | ||
|
||
res = np.zeros_like(a) | ||
|
||
for i in range(len(a)): | ||
if a[i] * b[i] > 0 and a[i] * c[i] > 0: | ||
res[i] = np.min(np.abs(np.array([a[i], b[i], c[i]]))) * np.sign(a[i]) | ||
|
||
return res | ||
|
||
|
||
@njit(cache=True) | ||
def minmod_flag(a, b, c): | ||
""" | ||
Return the index of minmod_flag | ||
""" | ||
|
||
if isinstance(a, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
a = np.array([a]) | ||
a = a.reshape(-1) | ||
|
||
if isinstance(b, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
b = np.array([b]) | ||
b = b.reshape(-1) | ||
|
||
if isinstance(c, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
c = np.array([c]) | ||
c = c.reshape(-1) | ||
|
||
# check the consistency of input length | ||
if not (len(a) == len(b) == len(c)): | ||
raise ValueError("Input length must be the same!") | ||
|
||
res = np.zeros_like(a).astype(np.int32) | ||
|
||
for i in range(len(a)): | ||
if a[i] * b[i] > 0 and a[i] * c[i] > 0: | ||
res[i] = np.abs(np.array([a[i], b[i], c[i]])).argmin() + 1 | ||
|
||
return res | ||
|
||
|
||
@njit(cache=True) | ||
def switch_minmod(a, b, c, flag): | ||
""" | ||
Conditionally output the derivatives of minmod according to the flag | ||
""" | ||
|
||
if isinstance(a, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
a = np.array([a]) | ||
a = a.reshape(-1) | ||
|
||
if isinstance(b, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
b = np.array([b]) | ||
b = b.reshape(-1) | ||
|
||
if isinstance(c, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
c = np.array([c]) | ||
c = c.reshape(-1) | ||
|
||
if isinstance(flag, (np.int32, np.int64, np.float64, np.float32, int, float)): | ||
flag = np.array([flag]) | ||
flag = flag.reshape(-1) | ||
|
||
# if not (len(a) == len(b) == len(c) == len(flag)): | ||
# raise ValueError("Input length must be the same!") | ||
|
||
res = np.zeros_like(flag) | ||
|
||
for i in range(len(flag)): | ||
if flag[i] == 1: | ||
if len(a) > 1: | ||
res[i] = a[i] | ||
else: | ||
res[i] = a[0] | ||
elif flag[i] == 2: | ||
if len(b) > 1: | ||
res[i] = b[i] | ||
else: | ||
res[i] = b[0] | ||
elif flag[i] == 3: | ||
if len(c) > 1: | ||
res[i] = c[i] | ||
else: | ||
res[i] = c[0] | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,33 @@ | ||
from Solverz.sym_algebra.functions import MulVarFunc | ||
from sympy import Integer | ||
|
||
|
||
class SolPde(MulVarFunc): | ||
def _numpycode(self, printer, **kwargs): | ||
return (f'SolMF.pde.{self.__class__.__name__}' + r'(' + | ||
', '.join([printer._print(arg, **kwargs) for arg in self.args]) + r')') | ||
|
||
|
||
class minmod(SolPde): | ||
arglength = 3 | ||
|
||
def _eval_derivative(self, s): | ||
return switch_minmod(*[arg.diff(s) for arg in self.args], | ||
minmod_flag(*self.args)) | ||
|
||
|
||
class minmod_flag(SolPde): | ||
""" | ||
Different from `minmod`, minmod function outputs the position of args instead of the values of args. | ||
""" | ||
arglength = 3 | ||
|
||
def _eval_derivative(self, s): | ||
return Integer(0) | ||
|
||
|
||
class switch_minmod(SolPde): | ||
arglength = 4 | ||
|
||
def _eval_derivative(self, s): | ||
return switch_minmod(*[arg.diff(s) for arg in self.args[0:len(self.args) - 1]], self.args[-1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters