diff --git a/forest/__init__.py b/forest/__init__.py index d9f846557..988785ac8 100644 --- a/forest/__init__.py +++ b/forest/__init__.py @@ -23,7 +23,7 @@ .. automodule:: forest.presets """ -__version__ = '0.17.3' +__version__ = '0.17.4' from .config import * from . import ( diff --git a/forest/db/database.py b/forest/db/database.py index 5a84e2248..2d9f220f5 100644 --- a/forest/db/database.py +++ b/forest/db/database.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd from .connection import Connection +from forest import mark __all__ = [ @@ -373,9 +374,9 @@ def insert_pressure(self, path, variable, pressure, i): (SELECT id FROM pressure WHERE value=:pressure AND i=:i)) """, dict(path=path, variable=variable, pressure=pressure, i=i)) + @mark.sql_sanitize_time("initial_time") def valid_times(self, pattern, variable, initial_time): """Valid times associated with search criteria""" - initial_time = self.sanitize_time(initial_time) query = self.valid_times_query(pattern, variable, initial_time) self.cursor.execute(query, dict( variable=variable, @@ -417,22 +418,9 @@ def valid_times_query(pattern, variable, initial_time): variable=variable, pattern=pattern) - @staticmethod - def sanitize_time(value): - """Query-compatible equivalent of value""" - fmt = "%Y-%m-%d %H:%M:%S" - if value is None: - return value - elif isinstance(value, str): - return value - elif isinstance(value, np.datetime64): - return pd.to_datetime(str(value)).strftime(fmt) - else: - return value.strftime(fmt) - + @mark.sql_sanitize_time("initial_time") def pressures(self, pattern=None, variable=None, initial_time=None): """Select pressures from database""" - initial_time = self.sanitize_time(initial_time) query = self.pressures_query(pattern, variable, initial_time) self.cursor.execute(query, dict( variable=variable, diff --git a/forest/db/locate.py b/forest/db/locate.py index 04e0e4473..32874d44d 100644 --- a/forest/db/locate.py +++ b/forest/db/locate.py @@ -3,6 +3,7 @@ import numpy as np from .connection import Connection from forest.exceptions import SearchFail +from forest import mark __all__ = [ @@ -17,6 +18,7 @@ def __init__(self, connection, directory=None): self.connection = connection self.cursor = self.connection.cursor() + @mark.sql_sanitize_time("initial_time", "valid_time") def locate( self, pattern, @@ -66,6 +68,7 @@ def locate( return path, (ti, pi) raise SearchFail("Could not locate: {}".format(pattern)) + @mark.sql_sanitize_time("initial_time", "valid_time") @lru_cache() def file_names(self, pattern, variable, initial_time, valid_time): self.cursor.execute(""" diff --git a/forest/mark.py b/forest/mark.py index 384eee858..494a4c700 100644 --- a/forest/mark.py +++ b/forest/mark.py @@ -1,8 +1,11 @@ """Decorators to mark classes and functions""" +import inspect from unittest.mock import Mock from contextlib import contextmanager from functools import wraps from forest.observe import Observable +import pandas as pd +import numpy as np def component(cls): @@ -29,3 +32,48 @@ def disable(obj, method_name): setattr(obj, method_name, Mock()) yield setattr(obj, method_name, method) + + +def sql_sanitize_time(*labels): + """Decorator to protect SQL statements from unsupported datetime types + + >>> @sql_sanitize_time("b", "c") + ... def method(self, a, b, c=None, d=False): + ... # b and c will be converted to a str compatible with SQL queries + ... pass + + """ + def outer(f): + parameters = inspect.signature(f).parameters + + # Get positional index + index = {} + for i, name in enumerate(parameters): + if name in labels: + index[name] = i + + def inner(*args, **kwargs): + args = list(args) + for label in labels: + if label in kwargs: + kwargs[label] = sanitize_time(kwargs[label]) + else: + i = index[label] + if i < len(args): + args[i] = sanitize_time(args[i]) + return f(*args, **kwargs) + return inner + return outer + + +def sanitize_time(value): + """Query-compatible equivalent of value""" + fmt = "%Y-%m-%d %H:%M:%S" + if value is None: + return value + elif isinstance(value, str): + return value + elif isinstance(value, np.datetime64): + return pd.to_datetime(str(value)).strftime(fmt) + else: + return value.strftime(fmt) diff --git a/test/test_db_database.py b/test/test_db_database.py index e02dfc55c..4dbfb7ed7 100644 --- a/test/test_db_database.py +++ b/test/test_db_database.py @@ -6,6 +6,7 @@ import re import forest.db.database as database +import forest.mark def _create_db(): @@ -203,5 +204,5 @@ def test_Database_valid_times_given_datetime_like_objects(initial_time): pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"), pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"), ]) -def test_Database_sanitize_datetime_like_objects(time): - assert database.Database.sanitize_time(time) == "2020-01-01 00:00:00" +def test_sanitize_datetime_like_objects(time): + assert forest.mark.sanitize_time(time) == "2020-01-01 00:00:00" diff --git a/test/test_db_locate.py b/test/test_db_locate.py index 09fde6f4c..c205ef554 100644 --- a/test/test_db_locate.py +++ b/test/test_db_locate.py @@ -1,10 +1,34 @@ +import pytest import unittest +from unittest.mock import Mock, sentinel import sqlite3 import datetime as dt +import cftime +import numpy as np +import pandas as pd from forest import db from forest.exceptions import SearchFail +@pytest.mark.parametrize("time", [ + pytest.param("2020-01-01 00:00:00", id="str"), + pytest.param(dt.datetime(2020, 1, 1), id="datetime"), + pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"), + pytest.param(np.datetime64("2020-01-01 00:00:00", "ns"), id="numpy"), + pytest.param(pd.Timestamp("2020-01-01 00:00:00"), id="pandas"), +]) +def test_locator_file_names_supports_datetime_types(time): + path = "file.nc" + variable = "variable" + database = db.Database.connect(":memory:") + database.insert_file_name(path, "2020-01-01 00:00:00") + database.insert_time(path, variable, "2020-01-01 00:00:00", 0) + locator = db.Locator(database.connection) + result = locator.file_names(path, variable, time, time) + expect = [path] + assert expect == result + + class TestLocate(unittest.TestCase): def setUp(self): self.database = db.Database.connect(":memory:")