Skip to content

Commit

Permalink
Merge pull request #15 from spcl/fix-tolerances
Browse files Browse the repository at this point in the history
Error tolerances
  • Loading branch information
tbennun authored Oct 21, 2022
2 parents 7e90d52 + bab8bc2 commit f18e3c7
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 8 deletions.
4 changes: 3 additions & 1 deletion bench_info/durbin.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
},
"input_args": ["r"],
"array_args": ["r"],
"output_args": []
"output_args": [],
"rtol": 1e-3,
"atol": 1e-3
}
}
3 changes: 2 additions & 1 deletion bench_info/mandelbrot1.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
},
"input_args": ["xmin", "xmax", "ymin", "ymax", "XN", "YN", "maxiter", "horizon"],
"array_args": [],
"output_args": []
"output_args": [],
"norm_error": 1e-3
}
}
3 changes: 2 additions & 1 deletion bench_info/mandelbrot2.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
},
"input_args": ["xmin", "xmax", "ymin", "ymax", "XN", "YN", "maxiter", "horizon"],
"array_args": [],
"output_args": []
"output_args": [],
"norm_error": 1e-3
}
}
3 changes: 2 additions & 1 deletion bench_info/nbody.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
},
"input_args": ["mass", "pos", "vel", "N", "Nt", "dt", "G", "softening"],
"array_args": ["mass", "pos", "vel"],
"output_args": []
"output_args": [],
"norm_error": 1e-1
}
}
6 changes: 5 additions & 1 deletion npbench/infrastructure/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ def first_execution(impl, impl_name):
if validate and np_out is not None:
try:
frmwrk_name = self.frmwrk.info["full_name"]
valid = util.validate(np_out, frmwrk_out, frmwrk_name)

rtol = 1e-5 if not 'rtol' in self.bench.info else self.bench.info['rtol']
atol = 1e-8 if not 'atol' in self.bench.info else self.bench.info['atol']
norm_error = 1e-5 if not 'norm_error' in self.bench.info else self.bench.info['norm_error']
valid = util.validate(np_out, frmwrk_out, frmwrk_name, rtol=rtol, atol=atol, norm_error=norm_error)
if valid:
print("{} - {} - validation: SUCCESS".format(frmwrk_name, impl_name))
elif not ignore_errors:
Expand Down
6 changes: 3 additions & 3 deletions npbench/infrastructure/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ def benchmark(stmt, setup="pass", out_text="", repeat=1, context={}, output=None
return res, raw_time_list


def validate(ref, val, framework="Unknown"):
def validate(ref, val, framework="Unknown", rtol=1e-5, atol=1e-8, norm_error=1e-5):
if not isinstance(ref, (tuple, list)):
ref = [ref]
if not isinstance(val, (tuple, list)):
val = [val]
valid = True
for r, v in zip(ref, val):
if not np.allclose(r, v):
if not np.allclose(r, v, rtol=rtol, atol=atol):
try:
import cupy
if isinstance(v, cupy.ndarray):
Expand All @@ -163,7 +163,7 @@ def validate(ref, val, framework="Unknown"):
relerror = relative_error(r, v)
except Exception:
relerror = relative_error(r, v)
if relerror < 1e-05:
if relerror < norm_error:
continue
valid = False
print("Relative error: {}".format(relerror))
Expand Down

0 comments on commit f18e3c7

Please sign in to comment.