diff --git a/forest/__init__.py b/forest/__init__.py index eecc9fe38..b97dbc1e1 100644 --- a/forest/__init__.py +++ b/forest/__init__.py @@ -23,7 +23,7 @@ .. automodule:: forest.presets """ -__version__ = '0.17.1' +__version__ = '0.17.2' from .config import * from . import ( diff --git a/forest/db/control.py b/forest/db/control.py index 2036be6aa..c6c533b52 100644 --- a/forest/db/control.py +++ b/forest/db/control.py @@ -118,11 +118,20 @@ def time_array_equal(x, y): return False elif (len(x) == 0) or (len(y) == 0): return x == y + else: + if len(x) != len(y): + return False + left = _as_datetime_array(x) + right = _as_datetime_array(y) + return np.all(left == right) + +def _as_datetime_array(x): + """Either vectorized _to_datetime or pd.to_datetime""" try: - return np.all(_vto_datetime(x) == _vto_datetime(y)) + return _vto_datetime(x) except TypeError: # NOTE: Needed for EarthNetworks DatetimeIndex - return np.all(pd.to_datetime(x) == pd.to_datetime(y)) + return pd.to_datetime(x) def equal_value(a, b): if (a is None) and (b is None): diff --git a/test/test_state.py b/test/test_state.py index 40ce71c10..ea75e5b33 100644 --- a/test/test_state.py +++ b/test/test_state.py @@ -1,7 +1,10 @@ import pytest import datetime as dt import numpy as np +import pandas as pd +import cftime from forest import db +from forest.db.control import time_array_equal @pytest.mark.parametrize("left,right,expect", [ @@ -35,8 +38,50 @@ (db.State(), db.State(variables=["a", "b"]), False), (db.State(variables=["a", "b"]), db.State(variables=["a", "b"]), True), (db.State(variables=np.array(["a", "b"])), - db.State(variables=["a", "b"]), True), + db.State(variables=["a", "b"]), True) ]) def test_equality_and_not_equality(left, right, expect): assert (left == right) == expect assert (left != right) == (not expect) + + +def test_state_equality_valueerror_lengths_must_match(): + """should return False if lengths do not match""" + valid_times = ( + pd.date_range("2020-01-01", periods=2), + pd.date_range("2020-01-01", periods=3), + ) + left = db.State(valid_times=valid_times[0]) + right = db.State(valid_times=valid_times[1]) + assert (left == right) == False + + +def test_time_array_equal(): + left = pd.date_range("2020-01-01", periods=2) + right = pd.date_range("2020-01-01", periods=3) + assert time_array_equal(left, right) == False + + +def test_valueerror_lengths_must_match(): + a = ["2020-01-01T00:00:00Z"] + b = ["2020-02-01T00:00:00Z", "2020-02-02T00:00:00Z", "2020-02-03T00:00:00Z"] + with pytest.raises(ValueError): + pd.to_datetime(a) == pd.to_datetime(b) + + +@pytest.mark.parametrize("left,right,expect", [ + pytest.param([cftime.DatetimeGregorian(2020, 1, 1), + cftime.DatetimeGregorian(2020, 1, 2), + cftime.DatetimeGregorian(2020, 1, 3)], + pd.date_range("2020-01-01", periods=3), + True, + id="gregorian/pandas same values"), + pytest.param([cftime.DatetimeGregorian(2020, 2, 1), + cftime.DatetimeGregorian(2020, 2, 2), + cftime.DatetimeGregorian(2020, 2, 3)], + pd.date_range("2020-01-01", periods=3), + False, + id="gregorian/pandas same length different values"), +]) +def test_time_array_equal_mixed_types(left, right, expect): + assert time_array_equal(left, right) == expect