diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 52c9013362..e88efa14e6 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -281,10 +281,10 @@ def test_validate_supported_if_minified_keep_count(): assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR_WITH_COUNTS assert model2.minified_data_type is None - assert np.allclose(model2.get_elbo(), model.get_elbo(), rtol=5e-2) + assert np.allclose(model2.get_elbo().cpu(), model.get_elbo().cpu(), rtol=5e-2) assert np.allclose( - model2.get_reconstruction_error()["reconstruction_loss"], - model.get_reconstruction_error()["reconstruction_loss"], + model2.get_reconstruction_error()["reconstruction_loss"].cpu(), + model.get_reconstruction_error()["reconstruction_loss"].cpu(), rtol=5e-2, ) assert np.allclose(model2.get_marginal_ll(), model.get_marginal_ll(), rtol=5e-2)