Skip to content

Commit

Permalink
Merge pull request #276 from twsearle/datetime64_argmax
Browse files Browse the repository at this point in the history
Argmax of a datetime64
  • Loading branch information
andrewgryan authored Feb 10, 2020
2 parents f215c70 + 9a283ad commit 5a4cd2b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion forest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
.. automodule:: forest.presets
"""
__version__ = '0.7.1'
__version__ = '0.7.2'

from .config import *
from . import (
Expand Down
13 changes: 11 additions & 2 deletions forest/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
from forest.exceptions import FileNotFound, IndexNotFound


MIN_DATETIME64 = np.datetime64('0001-01-01T00:00:00.000000')


def _natargmax(arr):
""" Find the arg max when an array contains NaT's"""
no_nats = np.where(np.isnat(arr), MIN_DATETIME64, arr)
return np.argmax(no_nats)


class EIDA50(object):
def __init__(self, pattern):
self.locator = Locator(pattern)
Expand Down Expand Up @@ -86,7 +95,7 @@ def find_file_index(self, paths, user_date):
raise FileNotFound(msg)
before_dates = np.ma.array(
dates, mask=mask, dtype='datetime64[s]')
return np.ma.argmax(before_dates)
return _natargmax(before_dates.filled())

@staticmethod
def find_index(times, time, length):
Expand All @@ -99,7 +108,7 @@ def find_index(times, time, length):
if valid_times.mask.all():
msg = "{}: not found".format(time)
raise IndexNotFound(msg)
return np.ma.argmax(valid_times)
return _natargmax(valid_times.filled())

@staticmethod
def parse_date(path):
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ intake
intake-esm
pygrib
libwebp=1.0.2 # Pin to prevent Travis issue
numpy=1.17.*
xarray

0 comments on commit 5a4cd2b

Please sign in to comment.