From 08c86a2802cd92322e46adebd668131464a3c7ea Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Mon, 5 Feb 2024 16:27:51 -0600 Subject: [PATCH] add test and fix bugs --- tests/test_predictions.py | 29 ++++++++++++++++++++++++++++ whereistheplanet/whereistheplanet.py | 18 ++++++++++++----- 2 files changed, 42 insertions(+), 5 deletions(-) create mode 100644 tests/test_predictions.py diff --git a/tests/test_predictions.py b/tests/test_predictions.py new file mode 100644 index 0000000..992b86d --- /dev/null +++ b/tests/test_predictions.py @@ -0,0 +1,29 @@ +""" +Testing the prediciton tool +""" +import subprocess +import numpy as np +import whereistheplanet.whereistheplanet as witp + +def test_all_predictions(): + """ + Test all predictions to make sure the posteriors exist. + Checks the values produced are not nan + """ + labels = witp.post_dict.keys() + + for name in labels: + print(name) + ra_args, dec_args, sep_args, pa_args = witp.predict_planet(name) + assert np.all(~np.isnan(ra_args)) + assert np.all(~np.isnan(pa_args)) + +def test_cmd_line(): + """ + Test command line script + """ + subprocess.run(['whereistheplanet', 'betpicb', '-t', '2022-01-01']) + +if __name__ == "__main__": + test_all_predictions() + diff --git a/whereistheplanet/whereistheplanet.py b/whereistheplanet/whereistheplanet.py index e608373..ba9b4a6 100644 --- a/whereistheplanet/whereistheplanet.py +++ b/whereistheplanet/whereistheplanet.py @@ -8,6 +8,7 @@ from astropy.time import Time import orbitize.kepler as kepler +import orbitize.results as results moduledir = os.path.dirname(__file__) basedir = os.path.dirname(moduledir) # up one leve @@ -68,7 +69,7 @@ "hr5362b" : ("binary_HR5362B.hdf5", "GRAVITY Binary"), "kap01sclb" : ("binary_kap01SclB.hdf5", "GRAVITY Binary"), "hd30003b" : ("binary_HD30003B.hdf5", "GRAVITY Binary"), - "hd1663b" : ("binary_wds00209+1059.hdf5", "Nowak et al. 2024"), + "hd1663b" : ("binary_wdsj00209+1059.hdf5", "Nowak et al. 2024"), "lam01sclb" : ("binary_wdsj00427-3828.hdf5", "Nowak et al. 2024"), "hd25535b" : ("binary_wdsj04021-3429.hdf5", "Nowak et al. 2024"), "hd32642b" : ("binary_wdsj05055+1948.hdf5", "Nowak et al. 2024"), @@ -258,10 +259,17 @@ def get_chains(planet_name): filename, reference = post_dict[planet_name] filepath = os.path.join(datadir, filename) - with h5py.File(filepath,'r') as hf: # Opens file for reading - post = np.array(hf.get('post')) - tau_ref_epoch = float(hf.attrs['tau_ref_epoch']) - + try: + res = results.Results() + res.load_results(filepath) + post = res.post + tau_ref_epoch = res.tau_ref_epoch + except KeyError: + with h5py.File(filepath,'r') as hf: # Opens file for reading + post = np.array(hf.get('post')) + tau_ref_epoch = float(hf.attrs['tau_ref_epoch']) + post = np.array(post, dtype=float) + return post, tau_ref_epoch def get_reference(planet_name):