Skip to content

Commit

Permalink
refactor solve_fixed_iter test to compiled and first to prevent huge …
Browse files Browse the repository at this point in the history
…standard deviation
  • Loading branch information
YigitElma committed Nov 19, 2024
1 parent 70bdc81 commit 0c0fddb
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
19 changes: 18 additions & 1 deletion tests/benchmarks/benchmark_cpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,22 @@ def run(x):
benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_solve_fixed_iter_compiled(benchmark):
"""Benchmark running eq.solve for fixed iteration count."""
jax.clear_caches()
eq = desc.examples.get("ESTELL")
with pytest.warns(UserWarning, match="Reducing radial"):
eq.change_resolution(6, 6, 6, 12, 12, 12)
eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0)

def run(eq):
eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0)

benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_solve_fixed_iter(benchmark):
Expand All @@ -421,9 +437,10 @@ def test_solve_fixed_iter(benchmark):
eq.change_resolution(6, 6, 6, 12, 12, 12)

def run(eq):
jax.clear_caches()
eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0)

benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1)
benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1)


@pytest.mark.slow
Expand Down
19 changes: 18 additions & 1 deletion tests/benchmarks/benchmark_gpu_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,22 @@ def run(x):
benchmark.pedantic(run, args=(x,), rounds=15, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_solve_fixed_iter_compiled(benchmark):
"""Benchmark running eq.solve for fixed iteration count after compilation."""
jax.clear_caches()
eq = desc.examples.get("ESTELL")
with pytest.warns(UserWarning, match="Reducing radial"):
eq.change_resolution(6, 6, 6, 12, 12, 12)
eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0)

def run(eq):
eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0)

benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1)


@pytest.mark.slow
@pytest.mark.benchmark
def test_solve_fixed_iter(benchmark):
Expand All @@ -421,9 +437,10 @@ def test_solve_fixed_iter(benchmark):
eq.change_resolution(6, 6, 6, 12, 12, 12)

def run(eq):
jax.clear_caches()
eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0)

benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1)
benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1)


@pytest.mark.slow
Expand Down

0 comments on commit 0c0fddb

Please sign in to comment.