From 0c0fddb0c93e7c0ecd027f6adf1bf45e93a88f4f Mon Sep 17 00:00:00 2001 From: YigitElma Date: Mon, 18 Nov 2024 22:54:38 -0500 Subject: [PATCH] refactor solve_fixed_iter test to compiled and first to prevent huge standard deviation --- tests/benchmarks/benchmark_cpu_small.py | 19 ++++++++++++++++++- tests/benchmarks/benchmark_gpu_small.py | 19 ++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/tests/benchmarks/benchmark_cpu_small.py b/tests/benchmarks/benchmark_cpu_small.py index 2d9b8c1ac..9285a9685 100644 --- a/tests/benchmarks/benchmark_cpu_small.py +++ b/tests/benchmarks/benchmark_cpu_small.py @@ -411,6 +411,22 @@ def run(x): benchmark.pedantic(run, args=(x,), rounds=15, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_solve_fixed_iter_compiled(benchmark): + """Benchmark running eq.solve for fixed iteration count.""" + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + def run(eq): + eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) + + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): @@ -421,9 +437,10 @@ def test_solve_fixed_iter(benchmark): eq.change_resolution(6, 6, 6, 12, 12, 12) def run(eq): + jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) @pytest.mark.slow diff --git a/tests/benchmarks/benchmark_gpu_small.py b/tests/benchmarks/benchmark_gpu_small.py index ef1f69c68..bc8a6b7af 100644 --- a/tests/benchmarks/benchmark_gpu_small.py +++ b/tests/benchmarks/benchmark_gpu_small.py @@ -411,6 +411,22 @@ def run(x): benchmark.pedantic(run, args=(x,), rounds=15, iterations=1) +@pytest.mark.slow +@pytest.mark.benchmark +def test_solve_fixed_iter_compiled(benchmark): + """Benchmark running eq.solve for fixed iteration count after compilation.""" + jax.clear_caches() + eq = desc.examples.get("ESTELL") + with pytest.warns(UserWarning, match="Reducing radial"): + eq.change_resolution(6, 6, 6, 12, 12, 12) + eq.solve(maxiter=1, ftol=0, xtol=0, gtol=0) + + def run(eq): + eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) + + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) + + @pytest.mark.slow @pytest.mark.benchmark def test_solve_fixed_iter(benchmark): @@ -421,9 +437,10 @@ def test_solve_fixed_iter(benchmark): eq.change_resolution(6, 6, 6, 12, 12, 12) def run(eq): + jax.clear_caches() eq.solve(maxiter=20, ftol=0, xtol=0, gtol=0) - benchmark.pedantic(run, args=(eq,), rounds=10, iterations=1) + benchmark.pedantic(run, args=(eq,), rounds=5, iterations=1) @pytest.mark.slow