From d2a920a7c86a84aff3cee78611a34847fce2184c Mon Sep 17 00:00:00 2001 From: andrewgryan Date: Tue, 19 May 2020 12:07:32 +0100 Subject: [PATCH 1/2] separate argument sanitization from query generation --- forest/db/database.py | 46 ++++++++++--- test/test_db_database.py | 135 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 165 insertions(+), 16 deletions(-) diff --git a/forest/db/database.py b/forest/db/database.py index 5fb529a65..5a84e2248 100644 --- a/forest/db/database.py +++ b/forest/db/database.py @@ -5,6 +5,8 @@ pass import netCDF4 import jinja2 +import numpy as np +import pandas as pd from .connection import Connection @@ -373,10 +375,22 @@ def insert_pressure(self, path, variable, pressure, i): def valid_times(self, pattern, variable, initial_time): """Valid times associated with search criteria""" + initial_time = self.sanitize_time(initial_time) + query = self.valid_times_query(pattern, variable, initial_time) + self.cursor.execute(query, dict( + variable=variable, + pattern=pattern, + initial_time=initial_time)) + rows = self.cursor.fetchall() + return [time for time, in rows] + + @staticmethod + def valid_times_query(pattern, variable, initial_time): + """Valid times SQL query syntax""" # Note: SQL injection possible if not properly escaped # use ? and :name syntax in template environment = jinja2.Environment(extensions=['jinja2.ext.do']) - query = environment.from_string(""" + return environment.from_string(""" {% set EQNS = [] %} {% if initial_time is not none %} {% do EQNS.append('file.reference = :initial_time') %} @@ -402,6 +416,24 @@ def valid_times(self, pattern, variable, initial_time): initial_time=initial_time, variable=variable, pattern=pattern) + + @staticmethod + def sanitize_time(value): + """Query-compatible equivalent of value""" + fmt = "%Y-%m-%d %H:%M:%S" + if value is None: + return value + elif isinstance(value, str): + return value + elif isinstance(value, np.datetime64): + return pd.to_datetime(str(value)).strftime(fmt) + else: + return value.strftime(fmt) + + def pressures(self, pattern=None, variable=None, initial_time=None): + """Select pressures from database""" + initial_time = self.sanitize_time(initial_time) + query = self.pressures_query(pattern, variable, initial_time) self.cursor.execute(query, dict( variable=variable, pattern=pattern, @@ -409,12 +441,12 @@ def valid_times(self, pattern, variable, initial_time): rows = self.cursor.fetchall() return [time for time, in rows] - def pressures(self, pattern=None, variable=None, initial_time=None): - """Select pressures from database""" + @staticmethod + def pressures_query(pattern, variable, initial_time): # Note: SQL injection possible if not properly escaped # use ? and :name syntax in template environment = jinja2.Environment(extensions=['jinja2.ext.do']) - query = environment.from_string(""" + return environment.from_string(""" {% set EQNS = [] %} {% if variable is not none %} {% do EQNS.append('v.name = :variable') %} @@ -445,12 +477,6 @@ def pressures(self, pattern=None, variable=None, initial_time=None): variable=variable, pattern=pattern, initial_time=initial_time) - self.cursor.execute(query, dict( - variable=variable, - pattern=pattern, - initial_time=initial_time)) - rows = self.cursor.fetchall() - return [time for time, in rows] def fetch_times(self, path, variable): """Helper method to find times related to a variable""" diff --git a/test/test_db_database.py b/test/test_db_database.py index 805150bb1..e02dfc55c 100644 --- a/test/test_db_database.py +++ b/test/test_db_database.py @@ -1,4 +1,8 @@ from unittest.mock import Mock, sentinel +import pytest +import datetime as dt +import cftime +import numpy as np import re import forest.db.database as database @@ -18,12 +22,21 @@ def _assert_query_and_params(db, expected_query, expected_params): db.cursor.execute.assert_called_once() args, kwargs = db.cursor.execute.call_args query, params = args - query = re.sub(r'\s+', ' ', query).strip() - assert query == expected_query + assert_query_equal(query, expected_query) assert params == expected_params assert kwargs == {} +def assert_query_equal(left, right): + left, right = single_spaced(left), single_spaced(right) + assert left == right + + +def single_spaced(query): + query = query.replace("\n", "") + return re.sub(r'\s+', ' ', query).strip() + + def test_Database_valid_times__defaults(): db = _create_db() @@ -39,7 +52,7 @@ def test_Database_valid_times__all_args(): db = _create_db() valid_times = db.valid_times(sentinel.pattern, sentinel.variable, - sentinel.initial_time) + dt.datetime(2020, 1, 1)) _assert_query_and_params( db, 'SELECT time.value FROM time' @@ -49,10 +62,94 @@ def test_Database_valid_times__all_args(): ' WHERE file.reference = :initial_time' ' AND file.name GLOB :pattern AND v.name = :variable', {'pattern': sentinel.pattern, 'variable': sentinel.variable, - 'initial_time':sentinel.initial_time}) + 'initial_time': "2020-01-01 00:00:00"}) assert valid_times == [sentinel.value1, sentinel.value2] +@pytest.mark.parametrize("pattern, variable, initial_time, expect", [ + (None, None, None, """ + SELECT time.value FROM time + """), + (sentinel.pattern, None, None, """ + SELECT time.value FROM time + JOIN variable_to_time AS vt ON vt.time_id = time.id + JOIN variable AS v ON vt.variable_id = v.id + JOIN file ON v.file_id = file.id + WHERE file.name GLOB :pattern + """), + (None, sentinel.variable, None, """ + SELECT time.value FROM time + JOIN variable_to_time AS vt ON vt.time_id = time.id + JOIN variable AS v ON vt.variable_id = v.id + JOIN file ON v.file_id = file.id + WHERE v.name = :variable + """), + (sentinel.pattern, sentinel.variable, None, """ + SELECT time.value FROM time + JOIN variable_to_time AS vt ON vt.time_id = time.id + JOIN variable AS v ON vt.variable_id = v.id + JOIN file ON v.file_id = file.id + WHERE file.name GLOB :pattern AND v.name = :variable + """), + (sentinel.pattern, sentinel.variable, sentinel.initial_time, """ + SELECT time.value FROM time + JOIN variable_to_time AS vt ON vt.time_id = time.id + JOIN variable AS v ON vt.variable_id = v.id + JOIN file ON v.file_id = file.id + WHERE file.reference = :initial_time + AND file.name GLOB :pattern AND v.name = :variable + """), +]) +def test_valid_times_query(pattern, variable, initial_time, expect): + result = database.Database.valid_times_query(pattern, variable, initial_time) + assert_query_equal(expect, result) + + +@pytest.mark.parametrize("pattern, variable, initial_time, expect", [ + (None, None, None, """ + SELECT DISTINCT value FROM pressure + ORDER BY value + """), + (sentinel.pattern, None, None, """ + SELECT DISTINCT pressure.value FROM pressure + JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id + JOIN variable AS v ON v.id = vp.variable_id + JOIN file ON v.file_id = file.id + WHERE file.name GLOB :pattern + ORDER BY value + """), + (None, sentinel.variable, None, """ + SELECT DISTINCT pressure.value FROM pressure + JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id + JOIN variable AS v ON v.id = vp.variable_id + JOIN file ON v.file_id = file.id + WHERE v.name = :variable + ORDER BY value + """), + (sentinel.pattern, sentinel.variable, None, """ + SELECT DISTINCT pressure.value FROM pressure + JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id + JOIN variable AS v ON v.id = vp.variable_id + JOIN file ON v.file_id = file.id + WHERE v.name = :variable AND file.name GLOB :pattern + ORDER BY value + """), + (sentinel.pattern, sentinel.variable, sentinel.initial_time, """ + SELECT DISTINCT pressure.value FROM pressure + JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id + JOIN variable AS v ON v.id = vp.variable_id + JOIN file ON v.file_id = file.id + WHERE v.name = :variable + AND file.name GLOB :pattern + AND file.reference = :initial_time + ORDER BY value + """), +]) +def test_pressures_query(pattern, variable, initial_time, expect): + result = database.Database.pressures_query(pattern, variable, initial_time) + assert_query_equal(expect, result) + + def test_Database_pressures__defaults(): db = _create_db() @@ -69,7 +166,7 @@ def test_Database_pressures__all_args(): db = _create_db() pressures = db.pressures(sentinel.pattern, sentinel.variable, - sentinel.initial_time) + dt.datetime(2020, 1, 1)) _assert_query_and_params( db, 'SELECT DISTINCT pressure.value FROM pressure' @@ -80,5 +177,31 @@ def test_Database_pressures__all_args(): ' AND file.reference = :initial_time' ' ORDER BY value', {'pattern': sentinel.pattern, 'variable': sentinel.variable, - 'initial_time':sentinel.initial_time}) + 'initial_time': "2020-01-01 00:00:00"}) assert pressures == [sentinel.value1, sentinel.value2] + + + +@pytest.mark.parametrize("initial_time", [ + pytest.param(dt.datetime(2020, 1, 1), id="datetime"), + pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"), + pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"), +]) +def test_Database_valid_times_given_datetime_like_objects(initial_time): + initial_datetime = dt.datetime(2020, 1, 1) + valid_times = [dt.datetime(2020, 1, 1, 12)] + db = database.Database.connect(":memory:") + db.insert_file_name("file.nc", initial_datetime) + db.insert_times("file.nc", "air_temperature", valid_times) + result = db.valid_times("file.nc", "air_temperature", initial_time) + expect = ["2020-01-01 12:00:00"] + assert expect == result + + +@pytest.mark.parametrize("time", [ + pytest.param(dt.datetime(2020, 1, 1), id="datetime"), + pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"), + pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"), +]) +def test_Database_sanitize_datetime_like_objects(time): + assert database.Database.sanitize_time(time) == "2020-01-01 00:00:00" From 8a9ec5211bd9e61c52c54714e8b345652f2ad386 Mon Sep 17 00:00:00 2001 From: andrewgryan Date: Tue, 19 May 2020 12:14:07 +0100 Subject: [PATCH 2/2] bump version to 0.17.3 --- forest/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/forest/__init__.py b/forest/__init__.py index b97dbc1e1..d9f846557 100644 --- a/forest/__init__.py +++ b/forest/__init__.py @@ -23,7 +23,7 @@ .. automodule:: forest.presets """ -__version__ = '0.17.2' +__version__ = '0.17.3' from .config import * from . import (