From d510f0d9df3c8c99eee5760c78e0c07a165ba44c Mon Sep 17 00:00:00 2001 From: "Lumberbot (aka Jack)" <39504233+meeseeksmachine@users.noreply.github.com> Date: Wed, 4 Dec 2024 01:27:41 -0800 Subject: [PATCH] Backport PR #3067 on branch 1.2.x (fix: CUDA Minified Tests following refactoring) (#3068) Backport PR #3067: fix: CUDA Minified Tests following refactoring Co-authored-by: Ori Kronfeld --- tests/model/test_models_with_minified_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 52c9013362..e88efa14e6 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -281,10 +281,10 @@ def test_validate_supported_if_minified_keep_count(): assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS assert model2.minified_data_type is None - assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=5e-2) + assert np.allclose(model2.get_elbo().cpu(), model.get_elbo().cpu(), rtol=5e-2) assert np.allclose( - model2.get_reconstruction_error()["reconstruction_loss"], - model.get_reconstruction_error()["reconstruction_loss"], + model2.get_reconstruction_error()["reconstruction_loss"].cpu(), + model.get_reconstruction_error()["reconstruction_loss"].cpu(), rtol=5e-2, ) assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=5e-2)