diff --git a/forest/__init__.py b/forest/__init__.py index 85188e088..7b73e8720 100644 --- a/forest/__init__.py +++ b/forest/__init__.py @@ -23,7 +23,7 @@ .. automodule:: forest.presets """ -__version__ = '0.16.4' +__version__ = '0.17.0' from .config import * from . import ( diff --git a/forest/db/control.py b/forest/db/control.py index 33b19b4a2..2036be6aa 100644 --- a/forest/db/control.py +++ b/forest/db/control.py @@ -2,6 +2,7 @@ import copy import datetime as dt import numpy as np +import pandas as pd import bokeh.models import bokeh.layouts from collections import namedtuple @@ -117,7 +118,11 @@ def time_array_equal(x, y): return False elif (len(x) == 0) or (len(y) == 0): return x == y - return np.all(_vto_datetime(x) == _vto_datetime(y)) + try: + return np.all(_vto_datetime(x) == _vto_datetime(y)) + except TypeError: + # NOTE: Needed for EarthNetworks DatetimeIndex + return np.all(pd.to_datetime(x) == pd.to_datetime(y)) def equal_value(a, b): if (a is None) and (b is None): diff --git a/forest/drivers/earth_networks.py b/forest/drivers/earth_networks.py index 04ea25bce..e5006c9b7 100644 --- a/forest/drivers/earth_networks.py +++ b/forest/drivers/earth_networks.py @@ -2,12 +2,16 @@ import os import glob import datetime as dt +import datashader import pandas as pd +from functools import lru_cache from forest import geo from forest.util import to_datetime as _to_datetime +import forest.util from forest.old_state import old_state, unique import bokeh.models import bokeh.palettes +import bokeh.colors import numpy as np @@ -15,23 +19,22 @@ class Dataset: """High-level class to relate navigators, loaders and views""" def __init__(self, pattern=None, **kwargs): self.pattern = pattern - if pattern is not None: - self._paths = glob.glob(pattern) - else: - self._paths = [] + self.loader = Loader() + self.locator = TimestampLocator(pattern) def navigator(self): """Construct navigator""" - return Navigator(self._paths) + return Navigator(self.locator) def map_view(self): """Construct view""" - return View(Loader(self._paths)) + return View(self.loader, self.locator) -class View(object): - def __init__(self, loader): +class View: + def __init__(self, loader, locator): self.loader = loader + self.locator = locator palette = bokeh.palettes.all_palettes['Spectral'][11][::-1] self.color_mapper = bokeh.models.LinearColorMapper(low=-1000, high=0, palette=palette) self.empty_image = { @@ -43,16 +46,134 @@ def __init__(self, loader): "flash_type": [], "time_since_flash": [] } - self.source = bokeh.models.ColumnDataSource(self.empty_image) + self.hover_tools = { + "image": [] + } + self.color_mappers = {} + self.color_mappers["image"] = bokeh.models.LinearColorMapper( + low=0, + high=1, + palette="Inferno256", + nan_color=bokeh.colors.RGB(0, 0, 0, a=0) + ) + self.sources = {} + self.sources["scatter"] = bokeh.models.ColumnDataSource(self.empty_image) + self.sources["image"] = bokeh.models.ColumnDataSource({ + "x": [], + "y": [], + "dw": [], + "dh": [], + "image": [], + }) + self.variable_to_method = { + "Lightning": self.scatter, + } @old_state @unique def render(self, state): if state.valid_time is None: return + if state.variable is None: + return + self.variable_to_method.get(state.variable, self.image)(state) + + def image(self, state): + """Image colored by time since flash or flash density""" + valid_time =_to_datetime(state.valid_time) + + # 15 minute/1 hour slice of data? + window = dt.timedelta(minutes=60) # 1 hour window + paths = self.locator.find_period(valid_time, window) + frame = self.loader.load(paths) + frame = self.select_date(frame, valid_time, window) + + # Filter intra-cloud/cloud-ground rows + if "intra-cloud" in state.variable.lower(): + frame = frame[frame["flash_type"] == "IC"] + elif "cloud-ground" in state.variable.lower(): + frame = frame[frame["flash_type"] == "CG"] + + # EarthNetworks validity box (not needed if tiling algorithm) + longitude_range = (26, 40) + latitude_range = (-12, 4) + x_range, y_range = geo.web_mercator(longitude_range, + latitude_range) + + x, y = geo.web_mercator(frame["longitude"], frame["latitude"]) + frame["x"] = x + frame["y"] = y + pixels = 256 + canvas = datashader.Canvas( + plot_width=pixels, + plot_height=pixels, + x_range=x_range, + y_range=y_range + ) + + if "density" in state.variable.lower(): + # N flashes per pixel + agg = canvas.points(frame, "x", "y", datashader.count()) + else: + frame["since_flash"] = self.since_flash(frame["date"], valid_time) + agg = canvas.points(frame, "x", "y", datashader.max("since_flash")) + + # Note: DataArray objects are not JSON serializable, .values is the + # same data cast as a numpy array + x = agg.x.values.min() + y = agg.y.values.min() + dw = agg.x.values.max() - x + dh = agg.y.values.max() - y + image = np.ma.masked_array(agg.values.astype(np.float), + mask=np.isnan(agg.values)) + if "density" in state.variable.lower(): + image[image == 0] = np.ma.masked # Remove pixels with no data + + # Update color_mapper + color_mapper = self.color_mappers["image"] + if "density" in state.variable.lower(): + color_mapper.palette = bokeh.palettes.all_palettes["Spectral"][8] + color_mapper.low = 0 + color_mapper.high = agg.values.max() + else: + color_mapper.palette = bokeh.palettes.all_palettes["RdGy"][8] + color_mapper.low = 0 + color_mapper.high = 60 * 60 # 1 hour + # Update tooltips + for hover_tool in self.hover_tools["image"]: + hover_tool.tooltips = self.tooltips(state.variable) + hover_tool.formatters = self.formatters(state.variable) + + if "density" in state.variable.lower(): + units = "events" + else: + units = "seconds" + + data = { + "x": [x], + "y": [y], + "dw": [dw], + "dh": [dh], + "image": [image], + } + meta_data = { + "variable": [state.variable], + "date": [valid_time], + "units": [units], + "window": [window.total_seconds()] + } + data.update(meta_data) + self.sources["image"].data = data + + def scatter(self, state): + """Scatter plot of flash position colored by time since flash""" valid_time = _to_datetime(state.valid_time) - frame = self.loader.load_date(valid_time) + paths = self.locator.find(valid_time) + frame = self.loader.load(paths) + frame = self.select_date(frame, valid_time) + frame = frame[:400] # Limit points + frame['time_since_flash'] = self.since_flash(frame['date'], valid_time) if len(frame) == 0: return self.empty_image x, y = geo.web_mercator( @@ -60,7 +181,7 @@ def render(self, state): frame.latitude) self.color_mapper.low = np.min(frame.time_since_flash) self.color_mapper.high = np.max(frame.time_since_flash) - self.source.data = { + self.sources["scatter"].data = { "x": x, "y": y, "date": frame.date, @@ -70,6 +191,53 @@ def render(self, state): "time_since_flash": frame.time_since_flash } + def select_date(self, frame, date, window): + if len(frame) == 0: + return frame + frame = frame.set_index('date') + start = date + end = start + window + s = "{:%Y-%m-%dT%H:%M}".format(start) + e = "{:%Y-%m-%dT%H:%M}".format(end) + small_frame = frame[s:e].copy() + return small_frame.reset_index() + + def since_flash(self, date_column, date): + """Pandas helper to calculate seconds since valid date""" + if len(date_column) == 0: + return [] + if isinstance(date, str): + date = pd.Timestamp(date) + if isinstance(date_column, list): + date_column = pd.Series(pd.to_datetime(date_column)) + return (date_column - date).dt.total_seconds() + + @staticmethod + def tooltips(variable): + if "density" in variable.lower(): + return [ + ('Variable', '@variable'), + ('Time window', '@window{00:00:00}'), + ('Period start', '@date{%Y-%m-%d %H:%M:%S}'), + ('Value', '@image @units')] + else: + return [ + ('Variable', '@variable'), + ('Time window', '@window{00:00:00}'), + ('Period start', '@date{%Y-%m-%d %H:%M:%S}'), + ('Since start', '@image{00:00:00}')] + + @staticmethod + def formatters(variable): + defaults = { + "@date": "datetime", + "@window": "numeral" + } + if "density" in variable.lower(): + return defaults + else: + return {**defaults, **{'@image': 'numeral'}} + def add_figure(self, figure): renderer = figure.cross( x="x", @@ -77,30 +245,81 @@ def add_figure(self, figure): size=10, fill_color={'field': 'time_since_flash', 'transform': self.color_mapper}, line_color={'field': 'time_since_flash', 'transform': self.color_mapper}, - source=self.source) + source=self.sources["scatter"]) + + # Add image glyph_renderer + renderer = figure.image(x="x", y="y", dw="dw", dh="dh", image="image", + source=self.sources["image"], + color_mapper=self.color_mappers["image"]) + custom_js = bokeh.models.CustomJS( + args=dict(source=self.sources["image"]), code=""" + let variable = source.data['variable'][0] + let idx = cb_data.index.image_indices[0] + if (typeof idx !== 'undefined') { + let number = source.data['image'][0][idx.flat_index] + if (typeof number === "undefined") { + number = source.data['image'][0][idx.dim1][idx.dim2] + } + if (isNaN(number)) { + if (typeof window._tooltips === 'undefined') { + window._tooltips = {} + } + if (typeof window._tooltips[variable] === 'undefined') { + // TODO: Remove global variable + window._tooltips[variable] = cb_obj.tooltips + } + cb_obj.tooltips = [ + ['Variable', '@variable'], + ]; + } else { + cb_obj.tooltips = window._tooltips[variable] + } + } + """) + variable = "Strike density (cloud-ground)" tool = bokeh.models.HoverTool( - tooltips=[ - ('Time', '@date{%F}'), - ('Since flash', '@time_since_flash'), - ('Lon', '@longitude'), - ('Lat', '@latitude'), - ('Flash type', '@flash_type')], - formatters={ - 'date': 'datetime' - }, - renderers=[renderer]) + tooltips=self.tooltips(variable), + formatters=self.formatters(variable), + renderers=[renderer], + callback=custom_js + ) + self.hover_tools["image"].append(tool) figure.add_tools(tool) return renderer -class Navigator: - def __init__(self, paths): +class TimestampLocator: + """Find files by time stamp""" + def __init__(self, pattern): + if pattern is None: + paths = [] + else: + paths = glob.glob(pattern) + + # TODO: Find better way to reduce data volume + if len(paths) > 1000: + paths = sorted(paths)[-1000:] self.paths = paths + + self.table = {} + for path in self.paths: + self.table[self._parse_date(path)] = path + times = [ - self._parse_date(path) for path in paths + self._parse_date(path) for path in self.paths ] times = [t for t in times if t is not None] - self._valid_times = list(sorted(set(times))) + index = pd.DatetimeIndex(times) + self._valid_times = index.sort_values() + + def find_period(self, date, window): + return self.find(date) # TODO: implement search window + + def find(self, date): + if date in self.table: + return [self.table[date]] + else: + return [] @staticmethod def _parse_date(path): @@ -108,58 +327,65 @@ def _parse_date(path): if groups is not None: return dt.datetime.strptime(groups[0], "%Y%m%dT%H%M") + def valid_times(self): + if len(self._valid_times) == 0: + return [] + return self._valid_times + + +class Navigator: + """Meta-data needed to navigate the dataset""" + def __init__(self, locator): + self.locator = locator + def variables(self, pattern): - return ["Lightning"] + labels = [] + for metric in ("time since flash", "strike density"): + for category in ("cloud-ground", "intra-cloud", "total"): + label = "{} ({})".format(metric.capitalize(), category) + labels.append(label) + return labels def initial_times(self, pattern, variable): return [dt.datetime(1970, 1, 1)] def valid_times(self, pattern, variable, initial_time): - return self._valid_times + # Populates initial_state and used by forest.db.control.Control + return self.locator.valid_times() def pressures(self, pattern, variable, initial_time): return [] -class Loader(object): - def __init__(self, paths): - self.paths = paths - if len(self.paths) > 0: - self.frame = self.read(paths) - - @classmethod - def pattern(cls, text): - return cls(list(sorted(glob.glob(os.path.expanduser(text))))) - - def load_date(self, date): - frame = self.frame.set_index('date') - start = date - end = start + dt.timedelta(minutes=60) # 1 hour window - s = "{:%Y-%m-%dT%H:%M}".format(start) - e = "{:%Y-%m-%dT%H:%M}".format(end) - small_frame = frame[s:e].copy() - small_frame['time_since_flash'] = [t.total_seconds() for t in date - small_frame.index] - return small_frame.reset_index() - - @staticmethod - def read(csv_files): +class Loader: + """Methods to manipulate EarthNetworks data""" + def load(self, csv_files): if isinstance(csv_files, str): csv_files = [csv_files] frames = [] for csv_file in csv_files: - frame = pd.read_csv( - csv_file, - parse_dates=[1], - converters={0: Loader.flash_type}, - usecols=[0, 1, 2, 3], - names=["flash_type", "date", "latitude", "longitude"], - header=None) + frame = self.load_file(csv_file) frames.append(frame) if len(frames) == 0: - return None + return pd.DataFrame({ + "flash_type": [], + "date": [], + "latitude": [], + "longitude": [], + }) else: return pd.concat(frames, ignore_index=True) + @lru_cache(maxsize=32) + def load_file(self, path): + return pd.read_csv( + path, + parse_dates=[1], + converters={0: self.flash_type}, + usecols=[0, 1, 2, 3], + names=["flash_type", "date", "latitude", "longitude"], + header=None) + @staticmethod def flash_type(value): return { diff --git a/forest/util.py b/forest/util.py index 273ae4744..4b1e851f1 100644 --- a/forest/util.py +++ b/forest/util.py @@ -78,10 +78,17 @@ def to_datetime(d): elif isinstance(d, cftime.DatetimeGregorian): return dt.datetime(d.year, d.month, d.day, d.hour, d.minute, d.second) elif isinstance(d, str): - try: - return dt.datetime.strptime(d, "%Y-%m-%d %H:%M:%S") - except ValueError: - return dt.datetime.strptime(d, "%Y-%m-%dT%H:%M:%S") + errors = [] + for fmt in ( + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M:%SZ"): + try: + return dt.datetime.strptime(d, fmt) + except ValueError as e: + errors.append(e) + continue + raise Exception(errors) elif isinstance(d, np.datetime64): return d.astype(dt.datetime) else: diff --git a/test/test_earth_networks.py b/test/test_earth_networks.py index b1525f2ca..438b9ae07 100644 --- a/test/test_earth_networks.py +++ b/test/test_earth_networks.py @@ -1,3 +1,8 @@ +import pytest +from unittest.mock import sentinel, Mock +import bokeh.palettes +import pandas as pd +import pandas.testing as pdt import datetime as dt import numpy as np import glob @@ -16,8 +21,8 @@ def test_earth_networks(tmpdir): with open(path, "w") as stream: stream.writelines(LINES) - loader = earth_networks.Loader([path]) - frame = loader.load_date(dt.datetime(2019, 4, 17)) + loader = earth_networks.Loader() + frame = loader.load([path]) result = frame.iloc[0] atol = 0.000001 if isinstance(result["date"], dt.datetime): @@ -36,11 +41,100 @@ def test_dataset(): assert isinstance(dataset, forest.drivers.earth_networks.Dataset) -def test_dataset_navigator(): - settings = { - "pattern": "*.txt" - } +def get_navigator(settings): dataset = forest.drivers.get_dataset("earth_networks", settings) - navigator = dataset.navigator() - assert isinstance(navigator, - forest.drivers.earth_networks.Navigator) + return dataset.navigator() + + +def test_dataset_navigator(): + navigator = get_navigator({"pattern": "*.txt"}) + assert isinstance(navigator, forest.drivers.earth_networks.Navigator) + + +def test_navigator_variables(): + navigator = earth_networks.Navigator([]) + assert set(navigator.variables(None)) == set([ + "Strike density (cloud-ground)", + "Strike density (intra-cloud)", + "Strike density (total)", + "Time since flash (cloud-ground)", + "Time since flash (intra-cloud)", + "Time since flash (total)" + ]) + + +def test_view_render_density(): + locator = Mock(specs=["find"]) + loader = Mock(specs=["load"]) + loader.load.return_value = pd.DataFrame({ + "flash_type": [], + "longitude": [], + "latitude": [], + }) + view = earth_networks.View(loader, locator) + view.render({ + "variable": "Strike density (cloud-ground)", + "valid_time": "1970-01-01T00:00:00Z" + }) + expect = bokeh.palettes.all_palettes["Spectral"][8] + assert view.color_mappers["image"].palette == expect + + +def test_view_render_time_since_flash(): + locator = Mock(specs=["find"]) + loader = Mock(specs=["load"]) + loader.load.return_value = pd.DataFrame({ + "date": [], + "flash_type": [], + "longitude": [], + "latitude": [], + }) + view = earth_networks.View(loader, locator) + view.render({ + "variable": "Time since flash (cloud-ground)", + "valid_time": "1970-01-01T00:00:00Z" + }) + expect = bokeh.palettes.all_palettes["RdGy"][8] + assert view.color_mappers["image"].palette == expect + + +@pytest.mark.parametrize("variable, expect", [ + pytest.param("Time since flash (intra-cloud)", [ + ('Variable', '@variable'), + ('Time window', '@window{00:00:00}'), + ('Period start', '@date{%Y-%m-%d %H:%M:%S}'), + ("Since start", "@image{00:00:00}") + ], id="time since flash"), + pytest.param("Strike density (cloud-ground)", [ + ('Variable', '@variable'), + ('Time window', '@window{00:00:00}'), + ('Period start', '@date{%Y-%m-%d %H:%M:%S}'), + ('Value', '@image @units'), + ], id="strike density"), +]) +def test_view_tooltips(variable, expect): + assert earth_networks.View.tooltips(variable) == expect + + +@pytest.mark.parametrize("variable, expect", [ + pytest.param("Time since flash (intra-cloud)", { + '@date': 'datetime', + '@window': 'numeral', + '@image': 'numeral' + }, id="time since flash"), + pytest.param("Strike density (cloud-ground)", { + '@date': 'datetime', + '@window': 'numeral' + }, id="strike density"), +]) +def test_view_formatters(variable, expect): + assert earth_networks.View.formatters(variable) == expect + + +def test_view_since_flash(): + view = earth_networks.View(Mock(), Mock()) + strike_times = ["2020-01-01T00:00:00Z", "2020-01-01T01:00:00Z"] + period_start = "2020-01-01T00:00:00Z" + result = view.since_flash(strike_times, period_start) + expect = pd.Series([0., 3600.]) + pdt.assert_series_equal(result, expect)