Skip to content

Commit

Permalink
Merge pull request #381 from informatics-lab/fix-datetime-support
Browse files Browse the repository at this point in the history
Fix datetime support
  • Loading branch information
andrewgryan authored May 20, 2020
2 parents 17f330a + 9b9cee6 commit 3c7489c
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 18 deletions.
2 changes: 1 addition & 1 deletion forest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. automodule:: forest.presets
"""
__version__ = '0.17.3'
__version__ = '0.17.4'

from .config import *
from . import (
Expand Down
18 changes: 3 additions & 15 deletions forest/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
from .connection import Connection
from forest import mark


__all__ = [
Expand Down Expand Up @@ -373,9 +374,9 @@ def insert_pressure(self, path, variable, pressure, i):
(SELECT id FROM pressure WHERE value=:pressure AND i=:i))
""", dict(path=path, variable=variable, pressure=pressure, i=i))

@mark.sql_sanitize_time("initial_time")
def valid_times(self, pattern, variable, initial_time):
"""Valid times associated with search criteria"""
initial_time = self.sanitize_time(initial_time)
query = self.valid_times_query(pattern, variable, initial_time)
self.cursor.execute(query, dict(
variable=variable,
Expand Down Expand Up @@ -417,22 +418,9 @@ def valid_times_query(pattern, variable, initial_time):
variable=variable,
pattern=pattern)

@staticmethod
def sanitize_time(value):
"""Query-compatible equivalent of value"""
fmt = "%Y-%m-%d %H:%M:%S"
if value is None:
return value
elif isinstance(value, str):
return value
elif isinstance(value, np.datetime64):
return pd.to_datetime(str(value)).strftime(fmt)
else:
return value.strftime(fmt)

@mark.sql_sanitize_time("initial_time")
def pressures(self, pattern=None, variable=None, initial_time=None):
"""Select pressures from database"""
initial_time = self.sanitize_time(initial_time)
query = self.pressures_query(pattern, variable, initial_time)
self.cursor.execute(query, dict(
variable=variable,
Expand Down
3 changes: 3 additions & 0 deletions forest/db/locate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from .connection import Connection
from forest.exceptions import SearchFail
from forest import mark


__all__ = [
Expand All @@ -17,6 +18,7 @@ def __init__(self, connection, directory=None):
self.connection = connection
self.cursor = self.connection.cursor()

@mark.sql_sanitize_time("initial_time", "valid_time")
def locate(
self,
pattern,
Expand Down Expand Up @@ -66,6 +68,7 @@ def locate(
return path, (ti, pi)
raise SearchFail("Could not locate: {}".format(pattern))

@mark.sql_sanitize_time("initial_time", "valid_time")
@lru_cache()
def file_names(self, pattern, variable, initial_time, valid_time):
self.cursor.execute("""
Expand Down
48 changes: 48 additions & 0 deletions forest/mark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Decorators to mark classes and functions"""
import inspect
from unittest.mock import Mock
from contextlib import contextmanager
from functools import wraps
from forest.observe import Observable
import pandas as pd
import numpy as np


def component(cls):
Expand All @@ -29,3 +32,48 @@ def disable(obj, method_name):
setattr(obj, method_name, Mock())
yield
setattr(obj, method_name, method)


def sql_sanitize_time(*labels):
"""Decorator to protect SQL statements from unsupported datetime types
>>> @sql_sanitize_time("b", "c")
... def method(self, a, b, c=None, d=False):
... # b and c will be converted to a str compatible with SQL queries
... pass
"""
def outer(f):
parameters = inspect.signature(f).parameters

# Get positional index
index = {}
for i, name in enumerate(parameters):
if name in labels:
index[name] = i

def inner(*args, **kwargs):
args = list(args)
for label in labels:
if label in kwargs:
kwargs[label] = sanitize_time(kwargs[label])
else:
i = index[label]
if i < len(args):
args[i] = sanitize_time(args[i])
return f(*args, **kwargs)
return inner
return outer


def sanitize_time(value):
"""Query-compatible equivalent of value"""
fmt = "%Y-%m-%d %H:%M:%S"
if value is None:
return value
elif isinstance(value, str):
return value
elif isinstance(value, np.datetime64):
return pd.to_datetime(str(value)).strftime(fmt)
else:
return value.strftime(fmt)
5 changes: 3 additions & 2 deletions test/test_db_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import re

import forest.db.database as database
import forest.mark


def _create_db():
Expand Down Expand Up @@ -203,5 +204,5 @@ def test_Database_valid_times_given_datetime_like_objects(initial_time):
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"),
])
def test_Database_sanitize_datetime_like_objects(time):
assert database.Database.sanitize_time(time) == "2020-01-01 00:00:00"
def test_sanitize_datetime_like_objects(time):
assert forest.mark.sanitize_time(time) == "2020-01-01 00:00:00"
24 changes: 24 additions & 0 deletions test/test_db_locate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
import pytest
import unittest
from unittest.mock import Mock, sentinel
import sqlite3
import datetime as dt
import cftime
import numpy as np
import pandas as pd
from forest import db
from forest.exceptions import SearchFail


@pytest.mark.parametrize("time", [
pytest.param("2020-01-01 00:00:00", id="str"),
pytest.param(dt.datetime(2020, 1, 1), id="datetime"),
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
pytest.param(np.datetime64("2020-01-01 00:00:00", "ns"), id="numpy"),
pytest.param(pd.Timestamp("2020-01-01 00:00:00"), id="pandas"),
])
def test_locator_file_names_supports_datetime_types(time):
path = "file.nc"
variable = "variable"
database = db.Database.connect(":memory:")
database.insert_file_name(path, "2020-01-01 00:00:00")
database.insert_time(path, variable, "2020-01-01 00:00:00", 0)
locator = db.Locator(database.connection)
result = locator.file_names(path, variable, time, time)
expect = [path]
assert expect == result


class TestLocate(unittest.TestCase):
def setUp(self):
self.database = db.Database.connect(":memory:")
Expand Down

0 comments on commit 3c7489c

Please sign in to comment.