Skip to content

Commit

Permalink
Merge pull request #380 from informatics-lab/fix-sqlite3-datetime
Browse files Browse the repository at this point in the history
separate argument sanitization from query generation
  • Loading branch information
andrewgryan authored May 19, 2020
2 parents c298063 + 8a9ec52 commit 17f330a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 17 deletions.
2 changes: 1 addition & 1 deletion forest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
.. automodule:: forest.presets
"""
__version__ = '0.17.2'
__version__ = '0.17.3'

from .config import *
from . import (
Expand Down
46 changes: 36 additions & 10 deletions forest/db/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
pass
import netCDF4
import jinja2
import numpy as np
import pandas as pd
from .connection import Connection


Expand Down Expand Up @@ -373,10 +375,22 @@ def insert_pressure(self, path, variable, pressure, i):

def valid_times(self, pattern, variable, initial_time):
"""Valid times associated with search criteria"""
initial_time = self.sanitize_time(initial_time)
query = self.valid_times_query(pattern, variable, initial_time)
self.cursor.execute(query, dict(
variable=variable,
pattern=pattern,
initial_time=initial_time))
rows = self.cursor.fetchall()
return [time for time, in rows]

@staticmethod
def valid_times_query(pattern, variable, initial_time):
"""Valid times SQL query syntax"""
# Note: SQL injection possible if not properly escaped
# use ? and :name syntax in template
environment = jinja2.Environment(extensions=['jinja2.ext.do'])
query = environment.from_string("""
return environment.from_string("""
{% set EQNS = [] %}
{% if initial_time is not none %}
{% do EQNS.append('file.reference = :initial_time') %}
Expand All @@ -402,19 +416,37 @@ def valid_times(self, pattern, variable, initial_time):
initial_time=initial_time,
variable=variable,
pattern=pattern)

@staticmethod
def sanitize_time(value):
"""Query-compatible equivalent of value"""
fmt = "%Y-%m-%d %H:%M:%S"
if value is None:
return value
elif isinstance(value, str):
return value
elif isinstance(value, np.datetime64):
return pd.to_datetime(str(value)).strftime(fmt)
else:
return value.strftime(fmt)

def pressures(self, pattern=None, variable=None, initial_time=None):
"""Select pressures from database"""
initial_time = self.sanitize_time(initial_time)
query = self.pressures_query(pattern, variable, initial_time)
self.cursor.execute(query, dict(
variable=variable,
pattern=pattern,
initial_time=initial_time))
rows = self.cursor.fetchall()
return [time for time, in rows]

def pressures(self, pattern=None, variable=None, initial_time=None):
"""Select pressures from database"""
@staticmethod
def pressures_query(pattern, variable, initial_time):
# Note: SQL injection possible if not properly escaped
# use ? and :name syntax in template
environment = jinja2.Environment(extensions=['jinja2.ext.do'])
query = environment.from_string("""
return environment.from_string("""
{% set EQNS = [] %}
{% if variable is not none %}
{% do EQNS.append('v.name = :variable') %}
Expand Down Expand Up @@ -445,12 +477,6 @@ def pressures(self, pattern=None, variable=None, initial_time=None):
variable=variable,
pattern=pattern,
initial_time=initial_time)
self.cursor.execute(query, dict(
variable=variable,
pattern=pattern,
initial_time=initial_time))
rows = self.cursor.fetchall()
return [time for time, in rows]

def fetch_times(self, path, variable):
"""Helper method to find times related to a variable"""
Expand Down
135 changes: 129 additions & 6 deletions test/test_db_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from unittest.mock import Mock, sentinel
import pytest
import datetime as dt
import cftime
import numpy as np
import re

import forest.db.database as database
Expand All @@ -18,12 +22,21 @@ def _assert_query_and_params(db, expected_query, expected_params):
db.cursor.execute.assert_called_once()
args, kwargs = db.cursor.execute.call_args
query, params = args
query = re.sub(r'\s+', ' ', query).strip()
assert query == expected_query
assert_query_equal(query, expected_query)
assert params == expected_params
assert kwargs == {}


def assert_query_equal(left, right):
left, right = single_spaced(left), single_spaced(right)
assert left == right


def single_spaced(query):
query = query.replace("\n", "")
return re.sub(r'\s+', ' ', query).strip()


def test_Database_valid_times__defaults():
db = _create_db()

Expand All @@ -39,7 +52,7 @@ def test_Database_valid_times__all_args():
db = _create_db()

valid_times = db.valid_times(sentinel.pattern, sentinel.variable,
sentinel.initial_time)
dt.datetime(2020, 1, 1))

_assert_query_and_params(
db, 'SELECT time.value FROM time'
Expand All @@ -49,10 +62,94 @@ def test_Database_valid_times__all_args():
' WHERE file.reference = :initial_time'
' AND file.name GLOB :pattern AND v.name = :variable',
{'pattern': sentinel.pattern, 'variable': sentinel.variable,
'initial_time':sentinel.initial_time})
'initial_time': "2020-01-01 00:00:00"})
assert valid_times == [sentinel.value1, sentinel.value2]


@pytest.mark.parametrize("pattern, variable, initial_time, expect", [
(None, None, None, """
SELECT time.value FROM time
"""),
(sentinel.pattern, None, None, """
SELECT time.value FROM time
JOIN variable_to_time AS vt ON vt.time_id = time.id
JOIN variable AS v ON vt.variable_id = v.id
JOIN file ON v.file_id = file.id
WHERE file.name GLOB :pattern
"""),
(None, sentinel.variable, None, """
SELECT time.value FROM time
JOIN variable_to_time AS vt ON vt.time_id = time.id
JOIN variable AS v ON vt.variable_id = v.id
JOIN file ON v.file_id = file.id
WHERE v.name = :variable
"""),
(sentinel.pattern, sentinel.variable, None, """
SELECT time.value FROM time
JOIN variable_to_time AS vt ON vt.time_id = time.id
JOIN variable AS v ON vt.variable_id = v.id
JOIN file ON v.file_id = file.id
WHERE file.name GLOB :pattern AND v.name = :variable
"""),
(sentinel.pattern, sentinel.variable, sentinel.initial_time, """
SELECT time.value FROM time
JOIN variable_to_time AS vt ON vt.time_id = time.id
JOIN variable AS v ON vt.variable_id = v.id
JOIN file ON v.file_id = file.id
WHERE file.reference = :initial_time
AND file.name GLOB :pattern AND v.name = :variable
"""),
])
def test_valid_times_query(pattern, variable, initial_time, expect):
result = database.Database.valid_times_query(pattern, variable, initial_time)
assert_query_equal(expect, result)


@pytest.mark.parametrize("pattern, variable, initial_time, expect", [
(None, None, None, """
SELECT DISTINCT value FROM pressure
ORDER BY value
"""),
(sentinel.pattern, None, None, """
SELECT DISTINCT pressure.value FROM pressure
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
JOIN variable AS v ON v.id = vp.variable_id
JOIN file ON v.file_id = file.id
WHERE file.name GLOB :pattern
ORDER BY value
"""),
(None, sentinel.variable, None, """
SELECT DISTINCT pressure.value FROM pressure
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
JOIN variable AS v ON v.id = vp.variable_id
JOIN file ON v.file_id = file.id
WHERE v.name = :variable
ORDER BY value
"""),
(sentinel.pattern, sentinel.variable, None, """
SELECT DISTINCT pressure.value FROM pressure
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
JOIN variable AS v ON v.id = vp.variable_id
JOIN file ON v.file_id = file.id
WHERE v.name = :variable AND file.name GLOB :pattern
ORDER BY value
"""),
(sentinel.pattern, sentinel.variable, sentinel.initial_time, """
SELECT DISTINCT pressure.value FROM pressure
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
JOIN variable AS v ON v.id = vp.variable_id
JOIN file ON v.file_id = file.id
WHERE v.name = :variable
AND file.name GLOB :pattern
AND file.reference = :initial_time
ORDER BY value
"""),
])
def test_pressures_query(pattern, variable, initial_time, expect):
result = database.Database.pressures_query(pattern, variable, initial_time)
assert_query_equal(expect, result)


def test_Database_pressures__defaults():
db = _create_db()

Expand All @@ -69,7 +166,7 @@ def test_Database_pressures__all_args():
db = _create_db()

pressures = db.pressures(sentinel.pattern, sentinel.variable,
sentinel.initial_time)
dt.datetime(2020, 1, 1))

_assert_query_and_params(
db, 'SELECT DISTINCT pressure.value FROM pressure'
Expand All @@ -80,5 +177,31 @@ def test_Database_pressures__all_args():
' AND file.reference = :initial_time'
' ORDER BY value',
{'pattern': sentinel.pattern, 'variable': sentinel.variable,
'initial_time':sentinel.initial_time})
'initial_time': "2020-01-01 00:00:00"})
assert pressures == [sentinel.value1, sentinel.value2]



@pytest.mark.parametrize("initial_time", [
pytest.param(dt.datetime(2020, 1, 1), id="datetime"),
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"),
])
def test_Database_valid_times_given_datetime_like_objects(initial_time):
initial_datetime = dt.datetime(2020, 1, 1)
valid_times = [dt.datetime(2020, 1, 1, 12)]
db = database.Database.connect(":memory:")
db.insert_file_name("file.nc", initial_datetime)
db.insert_times("file.nc", "air_temperature", valid_times)
result = db.valid_times("file.nc", "air_temperature", initial_time)
expect = ["2020-01-01 12:00:00"]
assert expect == result


@pytest.mark.parametrize("time", [
pytest.param(dt.datetime(2020, 1, 1), id="datetime"),
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"),
])
def test_Database_sanitize_datetime_like_objects(time):
assert database.Database.sanitize_time(time) == "2020-01-01 00:00:00"

0 comments on commit 17f330a

Please sign in to comment.