Skip to content

Commit

Permalink
Merge pull request #378 from informatics-lab/fix-timestamp-old-state
Browse files Browse the repository at this point in the history
Fix timestamp old state
  • Loading branch information
andrewgryan authored May 18, 2020
2 parents 9b9c281 + 0e53df8 commit b8cce79
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
2 changes: 1 addition & 1 deletion forest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. automodule:: forest.presets
"""
__version__ = '0.17.1'
__version__ = '0.17.2'

from .config import *
from . import (
Expand Down
13 changes: 11 additions & 2 deletions forest/db/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,20 @@ def time_array_equal(x, y):
return False
elif (len(x) == 0) or (len(y) == 0):
return x == y
else:
if len(x) != len(y):
return False
left = _as_datetime_array(x)
right = _as_datetime_array(y)
return np.all(left == right)

def _as_datetime_array(x):
"""Either vectorized _to_datetime or pd.to_datetime"""
try:
return np.all(_vto_datetime(x) == _vto_datetime(y))
return _vto_datetime(x)
except TypeError:
# NOTE: Needed for EarthNetworks DatetimeIndex
return np.all(pd.to_datetime(x) == pd.to_datetime(y))
return pd.to_datetime(x)

def equal_value(a, b):
if (a is None) and (b is None):
Expand Down
47 changes: 46 additions & 1 deletion test/test_state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import datetime as dt
import numpy as np
import pandas as pd
import cftime
from forest import db
from forest.db.control import time_array_equal


@pytest.mark.parametrize("left,right,expect", [
Expand Down Expand Up @@ -35,8 +38,50 @@
(db.State(), db.State(variables=["a", "b"]), False),
(db.State(variables=["a", "b"]), db.State(variables=["a", "b"]), True),
(db.State(variables=np.array(["a", "b"])),
db.State(variables=["a", "b"]), True),
db.State(variables=["a", "b"]), True)
])
def test_equality_and_not_equality(left, right, expect):
assert (left == right) == expect
assert (left != right) == (not expect)


def test_state_equality_valueerror_lengths_must_match():
"""should return False if lengths do not match"""
valid_times = (
pd.date_range("2020-01-01", periods=2),
pd.date_range("2020-01-01", periods=3),
)
left = db.State(valid_times=valid_times[0])
right = db.State(valid_times=valid_times[1])
assert (left == right) == False


def test_time_array_equal():
left = pd.date_range("2020-01-01", periods=2)
right = pd.date_range("2020-01-01", periods=3)
assert time_array_equal(left, right) == False


def test_valueerror_lengths_must_match():
a = ["2020-01-01T00:00:00Z"]
b = ["2020-02-01T00:00:00Z", "2020-02-02T00:00:00Z", "2020-02-03T00:00:00Z"]
with pytest.raises(ValueError):
pd.to_datetime(a) == pd.to_datetime(b)


@pytest.mark.parametrize("left,right,expect", [
pytest.param([cftime.DatetimeGregorian(2020, 1, 1),
cftime.DatetimeGregorian(2020, 1, 2),
cftime.DatetimeGregorian(2020, 1, 3)],
pd.date_range("2020-01-01", periods=3),
True,
id="gregorian/pandas same values"),
pytest.param([cftime.DatetimeGregorian(2020, 2, 1),
cftime.DatetimeGregorian(2020, 2, 2),
cftime.DatetimeGregorian(2020, 2, 3)],
pd.date_range("2020-01-01", periods=3),
False,
id="gregorian/pandas same length different values"),
])
def test_time_array_equal_mixed_types(left, right, expect):
assert time_array_equal(left, right) == expect

0 comments on commit b8cce79

Please sign in to comment.