diff --git a/.pyup.yml b/.pyup.yml new file mode 100644 index 0000000..d7b5957 --- /dev/null +++ b/.pyup.yml @@ -0,0 +1 @@ +schedule: "every two weeks" diff --git a/collate/collate.py b/collate/collate.py index 786babd..0f747c3 100644 --- a/collate/collate.py +++ b/collate/collate.py @@ -1,13 +1,15 @@ # -*- coding: utf-8 -*- +from .sql import execute_insert +from numbers import Number from itertools import product, chain import sqlalchemy.sql.expression as ex +import re from joblib import Parallel, delayed from .sql import make_sql_clause, to_sql_name, CreateTableAs, InsertFromSelect - def make_list(a): return [a] if not isinstance(a, list) else a @@ -16,7 +18,106 @@ def make_tuple(a): return (a,) if not isinstance(a, tuple) else a -class Aggregate(object): +DISTINCT_REGEX = re.compile(r"distinct[ (]") + + +def split_distinct(quantity): + # Only support distinct clauses with one-argument quantities + if len(quantity) != 1: + return ('', quantity) + q = quantity[0] + if DISTINCT_REGEX.match(q): + return "distinct ", (q[8:].lstrip(" "),) + else: + return "", (q,) + + +class AggregateExpression(object): + def __init__(self, aggregate1, aggregate2, operator, + cast=None, operator_str=None, expression_template=None): + """ + Args: + aggregate1: first aggregate + aggregate2: second aggregate + operator: string of SQL operator, e.g. "+" + cast: optional string to put after aggregate1, e.g. "*1.0", "::decimal" + operator_str: optional name of operator to use, defaults to operator + expression_template: optional formatting template with the following keywords: + name1, operator, name2 + """ + self.aggregate1 = aggregate1 + self.aggregate2 = aggregate2 + self.operator = operator + self.cast = cast if cast else "" + self.operator_str = operator if operator_str else operator + self.expression_template = expression_template \ + if expression_template else "{name1}{operator}{name2}" + + def alias(self, expression_template): + """ + Set the expression template used for naming columns of an AggregateExpression + Returns: self, for chaining + """ + self.expression_template = expression_template + return self + + def get_columns(self, when=None, prefix=None, format_kwargs=None): + if prefix is None: + prefix = "" + if format_kwargs is None: + format_kwargs = {} + + columns1 = self.aggregate1.get_columns(when) + columns2 = self.aggregate2.get_columns(when) + + for c1, c2 in product(columns1, columns2): + c = ex.literal_column("({}{} {} {})".format( + c1, self.cast, self.operator, c2)) + yield c.label(prefix + self.expression_template.format( + name1=c1.name, operator=self.operator_str, name2=c2.name, + **format_kwargs)) + + def __add__(self, other): + return AggregateExpression(self, other, "+") + + def __sub__(self, other): + return AggregateExpression(self, other, "-") + + def __mul__(self, other): + return AggregateExpression(self, other, "*") + + def __div__(self, other): + return AggregateExpression(self, other, "/", "*1.0") + + def __truediv__(self, other): + return AggregateExpression(self, other, "/", "*1.0") + + def __lt__(self, other): + return AggregateExpression(self, other, "<") + + def __le__(self, other): + return AggregateExpression(self, other, "<=") + + def __eq__(self, other): + return AggregateExpression(self, other, "=") + + def __ne__(self, other): + return AggregateExpression(self, other, "!=") + + def __gt__(self, other): + return AggregateExpression(self, other, ">") + + def __ge__(self, other): + return AggregateExpression(self, other, ">=") + + def __or__(self, other): + return AggregateExpression(self, other, "or", operator_str="|") + + def __and__(self, other): + return AggregateExpression(self, other, "and", operator_str="&") + + +class Aggregate(AggregateExpression): """ An object representing one or more SQL aggregate columns in a groupby """ @@ -40,7 +141,8 @@ def __init__(self, quantity, function, order=None): for the expressions and values are expressions. """ if isinstance(quantity, dict): - self.quantities = quantity + # make quantity values tuples + self.quantities = {k: make_tuple(q) for k, q in quantity.items()} else: # first convert to list of tuples quantities = [make_tuple(q) for q in make_list(quantity)] @@ -66,25 +168,32 @@ def get_columns(self, when=None, prefix=None, format_kwargs=None): format_kwargs = {} name_template = "{prefix}{quantity_name}_{function}" - column_template = "{function}({args})" + column_template = "{function}({distinct}{args}){order_clause}{filter}" arg_template = "{quantity}" order_template = "" + filter_template = "" if self.orders != [None]: - column_template += " WITHIN GROUP (ORDER BY {order_clause})" - order_template = "CASE WHEN {when} THEN {order} END" if when else "{order}" - elif when: - arg_template = "CASE WHEN {when} THEN {quantity} END" + order_template += " WITHIN GROUP (ORDER BY {order})" + if when: + filter_template = " FILTER (WHERE {when})" for function, (quantity_name, quantity), order in product( - self.functions, self.quantities.items(), self.orders): - args = str.join(", ", (arg_template.format(when=when, quantity=q) - for q in make_tuple(quantity))) - order_clause = order_template.format(when=when, order=order) + self.functions, self.quantities.items(), self.orders): + distinct, quantity = split_distinct(quantity) + args = str.join(", ", (arg_template.format(quantity=q) + for q in quantity)) + order_clause = order_template.format(order=order) + filter = filter_template.format(when=when) + + if order is not None: + if len(quantity_name) > 0: + quantity_name += '_' + quantity_name += to_sql_name(order) kwargs = dict(function=function, args=args, prefix=prefix, - order_clause=order_clause, - quantity_name=quantity_name, **format_kwargs) + distinct=distinct, order_clause=order_clause, + quantity_name=quantity_name, filter=filter, **format_kwargs) column = column_template.format(**kwargs).format(**format_kwargs) name = name_template.format(**kwargs) @@ -92,6 +201,101 @@ def get_columns(self, when=None, prefix=None, format_kwargs=None): yield ex.literal_column(column).label(to_sql_name(name)) +def maybequote(elt, quote_override=None): + "Quote for passing to SQL if necessary, based upon the python type" + def quote_string(string): + return "'{}'".format(string) + + if quote_override is None: + if isinstance(elt, Number): + return elt + else: + return quote_string(elt) + elif quote_override: + return quote_string(elt) + else: + return elt + + +class Compare(Aggregate): + """ + A simple shorthand to automatically create many comparisons against one column + """ + def __init__(self, col, op, choices, function, + order=None, include_null=False, maxlen=None, op_in_name=True, + quote_choices=None): + """ + Args: + col: the column name (or equivalent SQL expression) + op: the SQL operation (e.g., '=' or '~' or 'LIKE') + choices: A list or dictionary of values. When a dictionary is + passed, the keys are a short name for the value. + function: (from Aggregate) + order: (from Aggregate) + include_null: Add an extra `{col} is NULL` if True (default False). + May also be non-boolean, in which case its truthiness determines + the behavior and the value is used as the value short name. + maxlen: The maximum length of aggregate quantity names, if specified. + Names longer than this will be truncated. + op_in_name: Include the operator in aggregate names (default False) + quote_choices: Override smart quoting if present (default None) + + A simple helper method to easily create many comparison columns from + one source column by comparing it against many values. It effectively + creates many quantities of the form "({col} {op} {elt})::INT" for elt + in choices. It automatically quotes strings appropriately and leaves + numbers unquoted. The type of the comparison is converted to an + integer so it can easily be used with 'sum' (for total count) and + 'avg' (for relative fraction) aggregate functions. + + By default, the aggregates are named "{col}_{op}_{elt}", but the + operator may be ommitted if `op_in_name=False`. This name can become + long and exceed the maximum column name length. If ``maxlen`` is + specified then any aggregate name longer than ``maxlen`` gets + truncated with a number appended to ensure that they remain unique and + identifiable (but note that sequntial ordering is not preserved). + """ + if type(choices) is not dict: + choices = {k: k for k in choices} + opname = '_{}_'.format(op) if op_in_name else '_' + d = {'{}{}{}'.format(col, opname, nickname): + "({} {} {})::INT".format(col, op, maybequote(choice, quote_choices)) + for nickname, choice in choices.items()} + if include_null is True: + include_null = '_NULL' + if include_null: + d['{}_{}'.format(col, include_null)] = '({} is NULL)::INT'.format(col) + if maxlen is not None and any(len(k) > maxlen for k in d.keys()): + for i, k in enumerate(list(d.keys())): + d['%s_%02d' % (k[:maxlen-3], i)] = d.pop(k) + + Aggregate.__init__(self, d, function, order) + + +class Categorical(Compare): + """ + A simple shorthand to automatically create many equality comparisons against one column + """ + def __init__(self, col, choices, function, order=None, op_in_name=False, **kwargs): + """ + Create a Compare object with an equality operator, ommitting the `=` + from the generated aggregation names. See Compare for more details. + + As a special extension, Compare's 'include_null' keyword option may be + enabled by including the value `None` in the choices list. Multiple + None values are ignored. + """ + if None in choices: + kwargs['include_null'] = True + choices.remove(None) + elif type(choices) is dict and None in choices.values(): + ks = [k for k, v in choices.items() if v is None] + for k in ks: + choices.pop(k) + kwargs['include_null'] = str(k) + Compare.__init__(self, col, '=', choices, function, order, op_in_name=op_in_name, **kwargs) + + class Aggregation(object): def __init__(self, aggregates, groups, from_obj, prefix=None, suffix=None, schema=None): """ @@ -129,8 +333,8 @@ def _get_aggregates_sql(self, group): prefix = "{prefix}_{group}_".format( prefix=self.prefix, group=group) - return chain(*(a.get_columns(prefix=prefix) - for a in self.aggregates)) + return chain(*[a.get_columns(prefix=prefix) + for a in self.aggregates]) def get_selects(self): """ @@ -258,20 +462,24 @@ def get_create_schema(self): if self.schema is not None: return "CREATE SCHEMA IF NOT EXISTS %s" % self.schema - def execute(self, conn): + def execute(self, conn, join_table=None): """ Execute all SQL statements to create final aggregation table. Args: conn: the SQLAlchemy connection on which to execute """ + self.validate(conn) + create_schema = self.get_create_schema() creates = self.get_creates() drops = self.get_drops() indexes = self.get_indexes() inserts = self.get_inserts() + drop = self.get_drop() + create = self.get_create(join_table=join_table) trans = conn.begin() - if self.schema is not None: - conn.execute(self.get_create_schema()) + if create_schema is not None: + conn.execute(create_schema) for group in self.groups: conn.execute(drops[group]) @@ -280,27 +488,10 @@ def execute(self, conn): conn.execute(insert) conn.execute(indexes[group]) - conn.execute(self.get_drop()) - conn.execute(self.get_create()) + conn.execute(drop) + conn.execute(create) trans.commit() - def execute_insert(get_engine, insert): - try: - engine = get_engine() - except: - print('Could not connect to the database within spawned process') - raise - - print("Starting parallel process") - - # transaction - with engine.begin() as conn: - conn.execute(insert) - - engine.dispose() - - return True - def execute_par(self, conn_func, n_jobs=14): """ Execute all SQL statements to create final aggregation table. @@ -328,7 +519,7 @@ def execute_par(self, conn_func, n_jobs=14): insert_list = [insert for insert in inserts[group]] - out = Parallel(n_jobs=n_jobs, verbose=51)(delayed(Aggregation.execute_insert)(conn_func, insert) + out = Parallel(n_jobs=n_jobs, verbose=51)(delayed(execute_insert)(conn_func, insert) for insert in insert_list) # transaction with engine.begin() as conn: @@ -341,264 +532,12 @@ def execute_par(self, conn_func, n_jobs=14): engine.dispose() - -class SpacetimeAggregation(Aggregation): - def __init__(self, aggregates, groups, intervals, from_obj, dates, - prefix=None, suffix=None, schema=None, date_column=None, output_date_column=None): - """ - Args: - aggregates: collection of Aggregate objects - from_obj: defines the from clause, e.g. the name of the table - groups: a list of expressions to group by in the aggregation or a dictionary - pairs group: expr pairs where group is the alias (used in column names) - intervals: the intervals to aggregate over. either a list of - datetime intervals, e.g. ["1 month", "1 year"], or - a dictionary of group : intervals pairs where - group is a group in groups and intervals is a collection - of datetime intervals, e.g. {"address_id": ["1 month", "1 year]} - dates: list of PostgreSQL date strings, - e.g. ["2012-01-01", "2013-01-01"] - prefix: prefix for column names, defaults to from_obj - suffix: suffix for aggregation table, defaults to "aggregation" - date_column: name of date column in from_obj, defaults to "date" - output_date_column: name of date column in aggregated output, defaults to "date" - - The from_obj and group arguments are passed directly to the - SQLAlchemy Select object so could be anything supported there. - For details see: - http://docs.sqlalchemy.org/en/latest/core/selectable.html - """ - Aggregation.__init__(self, - aggregates=aggregates, - from_obj=from_obj, - groups=groups, - prefix=prefix, - suffix=suffix, - schema=schema) - - if isinstance(intervals, dict): - self.intervals = intervals - else: - self.intervals = {g: intervals for g in self.groups} - self.dates = dates - self.date_column = date_column if date_column else "date" - self.output_date_column = output_date_column if output_date_column else "date" - - def _get_aggregates_sql(self, interval, date, group): - """ - Helper for getting aggregates sql - Args: - interval: SQL time interval string, or "all" - date: SQL date string - group: group clause, for naming columns - Returns: collection of aggregate column SQL strings - """ - if interval != 'all': - when = "{date_column} >= '{date}'::date - interval '{interval}'".format( - interval=interval, date=date, date_column=self.date_column) - else: - when = None - - prefix = "{prefix}_{group}_{interval}_".format( - prefix=self.prefix, interval=interval, - group=group) - - return chain(*(a.get_columns(when, prefix, format_kwargs={"collate_date": date}) - for a in self.aggregates)) - - def get_selects(self): - """ - Constructs select queries for this aggregation - - Returns: a dictionary of group : queries pairs where - group are the same keys as groups - queries is a list of Select queries, one for each date in dates - """ - queries = {} - - for group, groupby in self.groups.items(): - intervals = self.intervals[group] - queries[group] = [] - for date in self.dates: - columns = [groupby, - ex.literal_column("'%s'::date" - % date).label(self.output_date_column)] - columns += list(chain(*(self._get_aggregates_sql( - i, date, group) for i in intervals))) - - # upper bound on date_column by date - where = ex.text("{date_column} < '{date}'".format( - date_column=self.date_column, date=date)) - - gb_clause = make_sql_clause(groupby, ex.literal_column) - query = ex.select(columns=columns, from_obj=self.from_obj) \ - .where(where) \ - .group_by(gb_clause) - - if 'all' not in intervals: - greatest = "greatest(%s)" % str.join( - ",", ["interval '%s'" % i for i in intervals]) - query = query.where(ex.text( - "{date_column} >= '{date}'::date - {greatest}".format( - date_column=self.date_column, date=date, - greatest=greatest))) - - queries[group].append(query) - - return queries - - def get_indexes(self): - """ - Generate create index queries for this aggregation - - Returns: a dictionary of group : index pairs where - group are the same keys as groups - index is a raw create index query for the corresponding table - """ - return {group: "CREATE INDEX ON %s (%s, %s);" % - (self.get_table_name(group), groupby, self.output_date_column) - for group, groupby in self.groups.items()} - - def get_create(self, join_table=None): - """ - Generate a single aggregation table creation query by joining - together the results of get_creates() - Returns: a CREATE TABLE AS query - """ - if not join_table: - join_table = '(%s) t1' % self.get_join_table() - - query = ("SELECT * FROM %s\n" - "CROSS JOIN (select unnest('{%s}'::date[]) as %s) t2\n") % ( - join_table, str.join(',', self.dates), self.output_date_column) - for group, groupby in self.groups.items(): - query += "LEFT JOIN %s USING (%s, %s)" % ( - self.get_table_name(group), groupby, self.output_date_column) - - return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query) - - -class SpacetimeSubQueryAggregation(SpacetimeAggregation): - def __init__(self, aggregates, groups, intervals, from_obj, dates, - prefix=None, suffix=None, schema=None, date_column=None, output_date_column=None, - sub_query=None, join_table=None): - """ - Args: - aggregates: collection of Aggregate objects - from_obj: defines the name of the sub query - groups: a list of expressions to group by in the aggregation or a dictionary - pairs group: expr pairs where group is the alias (used in column names) - intervals: the intervals to aggregate over. either a list of - datetime intervals, e.g. ["1 month", "1 year"], or - a dictionary of group : intervals pairs where - group is a group in groups and intervals is a collection - of datetime intervals, e.g. {"address_id": ["1 month", "1 year]} - dates: list of PostgreSQL date strings, - e.g. ["2012-01-01", "2013-01-01"] - prefix: prefix for column names, defaults to from_obj - suffix: suffix for aggregation table, defaults to "aggregation" - date_column: name of date column in from_obj, defaults to "date" - output_date_column: name of date column in aggregated output, defaults to "date" - join_table: specify a join table, i.e. a table containing unique sets of all possible - valid groups to left join the aggregations onto. - Defaults to None, in which case this table is created by querying the from_obj. - - The group arguments is passed directly to the - SQLAlchemy Select object so could be anything supported there. - For details see: - http://docs.sqlalchemy.org/en/latest/core/selectable.html - """ - Aggregation.__init__(self, - aggregates=aggregates, - from_obj=from_obj, - groups=groups, - prefix=prefix, - suffix=suffix, - schema=schema) - - if isinstance(intervals, dict): - self.intervals = intervals - else: - self.intervals = {g: intervals for g in self.groups} - self.dates = dates - self.date_column = date_column if date_column else "date" - self.output_date_column = output_date_column if output_date_column else "date" - self.sub_query = sub_query - self.join_table = join_table - - def get_selects(self): - """ - Constructs select queries for this aggregation using a sub query - - Returns: a dictionary of group : queries pairs where - group are the same keys as groups - queries is a list of Select queries, one for each date in dates - """ - queries = {} - - for group, groupby in self.groups.items(): - intervals = self.intervals[group] - queries[group] = [] - for date in self.dates: - # sub query - - # upper bound on date_column by date - where = ex.text("{date_column} < '{date}'".format( - date_column=self.date_column, date=date)) - - # the where clause is applied at the the sub_query as this query can make use of indices - sub_query = self.sub_query.where(where) - - if 'all' not in intervals: - greatest = "greatest(%s)" % str.join( - ",", ["interval '%s'" % i for i in intervals]) - sub_query = sub_query.where(ex.text( - "{date_column} >= '{date}'::date - {greatest}".format( - date_column=self.date_column, date=date, - greatest=greatest))) - - # name the sub query - sub_query = sub_query.alias(str(self.from_obj)) - - # main query - columns = [groupby, - ex.literal_column("'%s'::date" - % date).label(self.output_date_column)] - columns += list(chain(*(self._get_aggregates_sql( - i, date, group) for i in intervals))) - - gb_clause = make_sql_clause(groupby, ex.literal_column) - - # note: there is no where clause as the filtering is applied at the sub query level - query = ex.select(columns=columns, from_obj=sub_query) \ - .group_by(gb_clause) - - queries[group].append(query) - - return queries - - def get_join_table(self): + def validate(self, conn): """ - Generate a query for a join table + Validate the Aggregation to ensure that it will perform as expected. + This is done against an active SQL connection in order to enable + validation of the SQL itself. """ - if self.join_table is not None: - return '(%s) t1' % ex.Select(columns=self.groups.values(), from_obj=self.join_table) \ - .group_by(*self.groups.values()) - else: - return '(%s) t1' % ex.Select(columns=self.groups.values(), from_obj=self.from_obj) \ - .group_by(*self.groups.values()) + pass - def get_create(self): - """ - Generate a single aggregation table creation query by joining - together the results of get_creates() - Returns: a CREATE TABLE AS query - """ - query = ("SELECT * FROM %s\n" - "CROSS JOIN (select unnest('{%s}'::date[]) as %s) t2\n") % ( - self.get_join_table(), str.join(',', self.dates), self.output_date_column) - for group, groupby in self.groups.items(): - query += "LEFT JOIN %s USING (%s, %s)" % ( - self.get_table_name(group), groupby, self.output_date_column) - return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query) diff --git a/collate/spacetime.py b/collate/spacetime.py new file mode 100644 index 0000000..e79ac86 --- /dev/null +++ b/collate/spacetime.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +from itertools import chain +import sqlalchemy.sql.expression as ex + +from .sql import make_sql_clause +from .collate import Aggregation + + +class SpacetimeAggregation(Aggregation): + def __init__(self, aggregates, groups, intervals, from_obj, dates, + prefix=None, suffix=None, schema=None, date_column=None, + output_date_column=None, input_min_date=None): + """ + Args: + intervals: the intervals to aggregate over. either a list of + datetime intervals, e.g. ["1 month", "1 year"], or + a dictionary of group : intervals pairs where + group is a group in groups and intervals is a collection + of datetime intervals, e.g. {"address_id": ["1 month", "1 year]} + dates: list of PostgreSQL date strings, + e.g. ["2012-01-01", "2013-01-01"] + date_column: name of date column in from_obj, defaults to "date" + output_date_column: name of date column in aggregated output, defaults to "date" + input_min_date: minimum date for which rows shall be included, defaults + to no absolute time restrictions on the minimum date of included rows + + For all other arguments see collate.Aggregation + """ + Aggregation.__init__(self, + aggregates=aggregates, + from_obj=from_obj, + groups=groups, + prefix=prefix, + suffix=suffix, + schema=schema) + + if isinstance(intervals, dict): + self.intervals = intervals + else: + self.intervals = {g: intervals for g in self.groups} + self.dates = dates + self.date_column = date_column if date_column else "date" + self.output_date_column = output_date_column if output_date_column else "date" + self.input_min_date = input_min_date + + def _get_aggregates_sql(self, interval, date, group): + """ + Helper for getting aggregates sql + Args: + interval: SQL time interval string, or "all" + date: SQL date string + group: group clause, for naming columns + Returns: collection of aggregate column SQL strings + """ + if interval != 'all': + when = "{date_column} >= '{date}'::date - interval '{interval}'".format( + interval=interval, date=date, date_column=self.date_column) + else: + when = None + + prefix = "{prefix}_{group}_{interval}_".format( + prefix=self.prefix, interval=interval, + group=group) + + return chain(*[a.get_columns(when, prefix, format_kwargs={"collate_date": date, + "collate_interval": interval}) + for a in self.aggregates]) + + def get_selects(self): + """ + Constructs select queries for this aggregation + + Returns: a dictionary of group : queries pairs where + group are the same keys as groups + queries is a list of Select queries, one for each date in dates + """ + queries = {} + + for group, groupby in self.groups.items(): + intervals = self.intervals[group] + queries[group] = [] + for date in self.dates: + columns = [groupby, + ex.literal_column("'%s'::date" + % date).label(self.output_date_column)] + columns += list(chain(*[self._get_aggregates_sql( + i, date, group) for i in intervals])) + + gb_clause = make_sql_clause(groupby, ex.literal_column) + query = ex.select(columns=columns, from_obj=self.from_obj)\ + .group_by(gb_clause) + query = query.where(self.where(date, intervals)) + + queries[group].append(query) + + return queries + + def where(self, date, intervals): + """ + Generates a WHERE clause + Args: + date: the end date + intervals: intervals + + Returns: a clause for filtering the from_obj to be between the date and + the greatest interval + """ + # upper bound + w = "{date_column} < '{date}'".format( + date_column=self.date_column, date=date) + + # lower bound (if possible) + if 'all' not in intervals: + greatest = "greatest(%s)" % str.join( + ",", ["interval '%s'" % i for i in intervals]) + min_date = "'{date}'::date - {greatest}".format(date=date, greatest=greatest) + w += "AND {date_column} >= {min_date}".format( + date_column=self.date_column, min_date=min_date) + if self.input_min_date is not None: + w += "AND {date_column} >= '{bot}'::date".format( + date_column=self.date_column, bot=self.input_min_date) + return ex.text(w) + + def get_indexes(self): + """ + Generate create index queries for this aggregation + + Returns: a dictionary of group : index pairs where + group are the same keys as groups + index is a raw create index query for the corresponding table + """ + return {group: "CREATE INDEX ON %s (%s, %s);" % + (self.get_table_name(group), groupby, self.output_date_column) + for group, groupby in self.groups.items()} + + def get_join_table(self): + """ + Generates a join table, consisting of an entry for each combination of + groups and dates in the from_obj + """ + groups = list(self.groups.values()) + intervals = list(set(chain(*self.intervals.values()))) + + queries = [] + for date in self.dates: + columns = groups + [ex.literal_column("'%s'::date" % date).label( + self.output_date_column)] + queries.append(ex.select(columns, from_obj=self.from_obj) + .where(self.where(date, intervals)) + .group_by(*groups)) + + return str.join("\nUNION ALL\n", map(str, queries)) + + def get_create(self, join_table=None): + """ + Generate a single aggregation table creation query by joining + together the results of get_creates() + Returns: a CREATE TABLE AS query + """ + if not join_table: + join_table = '(%s) t1' % self.get_join_table() + query = "SELECT * FROM %s\n" % join_table + for group, groupby in self.groups.items(): + query += " LEFT JOIN %s USING (%s, %s)" % ( + self.get_table_name(group), groupby, self.output_date_column) + + return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query) + + def validate(self, conn): + """ + SpacetimeAggregations ensure that no intervals extend beyond the absolute + minimum time. + """ + if self.input_min_date is not None: + all_intervals = set(*self.intervals.values()) + for date in self.dates: + for interval in all_intervals: + if interval == "all": + continue + # This could be done more efficiently all at once, but doing + # it this way allows for nicer error messages. + r = conn.execute("select ('%s'::date - '%s'::interval) < '%s'::date" % + (date, interval, self.input_min_date)) + if r.fetchone()[0]: + raise ValueError( + "date '%s' - '%s' is before input_min_date ('%s')" % + (date, interval, self.input_min_date)) + + +class SpacetimeSubQueryAggregation(SpacetimeAggregation): + def __init__(self, aggregates, groups, intervals, from_obj, dates, + prefix=None, suffix=None, schema=None, date_column=None, output_date_column=None, + sub_query=None, join_table=None): + """ + Args: + aggregates: collection of Aggregate objects + from_obj: defines the name of the sub query + groups: a list of expressions to group by in the aggregation or a dictionary + pairs group: expr pairs where group is the alias (used in column names) + intervals: the intervals to aggregate over. either a list of + datetime intervals, e.g. ["1 month", "1 year"], or + a dictionary of group : intervals pairs where + group is a group in groups and intervals is a collection + of datetime intervals, e.g. {"address_id": ["1 month", "1 year]} + dates: list of PostgreSQL date strings, + e.g. ["2012-01-01", "2013-01-01"] + prefix: prefix for column names, defaults to from_obj + suffix: suffix for aggregation table, defaults to "aggregation" + date_column: name of date column in from_obj, defaults to "date" + output_date_column: name of date column in aggregated output, defaults to "date" + join_table: specify a join table, i.e. a table containing unique sets of all possible + valid groups to left join the aggregations onto. + Defaults to None, in which case this table is created by querying the from_obj. + + The group arguments is passed directly to the + SQLAlchemy Select object so could be anything supported there. + For details see: + http://docs.sqlalchemy.org/en/latest/core/selectable.html + """ + Aggregation.__init__(self, + aggregates=aggregates, + from_obj=from_obj, + groups=groups, + prefix=prefix, + suffix=suffix, + schema=schema) + + if isinstance(intervals, dict): + self.intervals = intervals + else: + self.intervals = {g: intervals for g in self.groups} + self.dates = dates + self.date_column = date_column if date_column else "date" + self.output_date_column = output_date_column if output_date_column else "date" + self.sub_query = sub_query + self.join_table = join_table + + def get_selects(self): + """ + Constructs select queries for this aggregation using a sub query + + Returns: a dictionary of group : queries pairs where + group are the same keys as groups + queries is a list of Select queries, one for each date in dates + """ + queries = {} + + for group, groupby in self.groups.items(): + intervals = self.intervals[group] + queries[group] = [] + for date in self.dates: + # sub query + + # upper bound on date_column by date + where = ex.text("{date_column} < '{date}'".format( + date_column=self.date_column, date=date)) + + # the where clause is applied at the the sub_query as this query can make use of indices + sub_query = self.sub_query.where(where) + + if 'all' not in intervals: + greatest = "greatest(%s)" % str.join( + ",", ["interval '%s'" % i for i in intervals]) + sub_query = sub_query.where(ex.text( + "{date_column} >= '{date}'::date - {greatest}".format( + date_column=self.date_column, date=date, + greatest=greatest))) + + # name the sub query + sub_query = sub_query.alias(str(self.from_obj)) + + # main query + columns = [groupby, + ex.literal_column("'%s'::date" + % date).label(self.output_date_column)] + columns += list(chain(*(self._get_aggregates_sql( + i, date, group) for i in intervals))) + + gb_clause = make_sql_clause(groupby, ex.literal_column) + + # note: there is no where clause as the filtering is applied at the sub query level + query = ex.select(columns=columns, from_obj=sub_query) \ + .group_by(gb_clause) + + queries[group].append(query) + + return queries + + def get_join_table(self): + """ + Generate a query for a join table + """ + if self.join_table is not None: + return '(%s) t1' % ex.Select(columns=self.groups.values(), from_obj=self.join_table) \ + .group_by(*self.groups.values()) + else: + return '(%s) t1' % ex.Select(columns=self.groups.values(), from_obj=self.from_obj) \ + .group_by(*self.groups.values()) + + def get_create(self): + """ + Generate a single aggregation table creation query by joining + together the results of get_creates() + Returns: a CREATE TABLE AS query + """ + query = ("SELECT * FROM %s\n" + "CROSS JOIN (select unnest('{%s}'::date[]) as %s) t2\n") % ( + self.get_join_table(), str.join(',', self.dates), self.output_date_column) + for group, groupby in self.groups.items(): + query += "LEFT JOIN %s USING (%s, %s)" % ( + self.get_table_name(group), groupby, self.output_date_column) + + return "CREATE TABLE %s AS (%s);" % (self.get_table_name(), query) diff --git a/collate/sql.py b/collate/sql.py index 50d282e..432adac 100644 --- a/collate/sql.py +++ b/collate/sql.py @@ -4,6 +4,7 @@ #from sqlalchemy.sql import compiler #from psycopg2.extensions import adapt as sqlescape + def make_sql_clause(s, constructor): if not isinstance(s, ex.ClauseElement): return constructor(s) @@ -11,6 +12,24 @@ def make_sql_clause(s, constructor): return s +def execute_insert(get_engine, insert): + try: + engine = get_engine() + except: + print('Could not connect to the database within spawned process') + raise + + print("Starting parallel process") + + # transaction + with engine.begin() as conn: + conn.execute(insert) + + engine.dispose() + + return True + + class CreateTableAs(ex.Executable, ex.ClauseElement): def __init__(self, name, query): diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1e5d0ef --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +SQLAlchemy==1.1.9 diff --git a/requirements_dev.txt b/requirements_dev.txt index 0f72e0d..e523e62 100644 --- a/requirements_dev.txt +++ b/requirements_dev.txt @@ -2,15 +2,16 @@ pip==9.0.1 bumpversion==0.5.3 wheel==0.29.0 watchdog==0.8.3 -flake8==3.2.1 -tox==2.5.0 -coverage==4.2 -Sphinx==1.5.1 -cryptography==1.7.1 +flake8==3.3.0 +tox==2.7.0 +coverage==4.3.4 +Sphinx==1.5.4 +cryptography==1.8.1 PyYAML==3.12 -pytest==3.0.4 -SQLAlchemy==1.1.4 -psycopg2==2.6.2 -csvkit==0.9.1 +pytest==3.0.7 +SQLAlchemy==1.1.9 +psycopg2==2.7.1 +csvkit==1.0.1 codecov==2.0.5 pytest-cov==2.4.0 +testing.postgresql==1.3.0 diff --git a/setup.py b/setup.py index 1f0acfe..d27826e 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( name='collate', - version='0.1.0', + version='0.2.0', description="Aggregated feature generation made easy.", long_description=readme + '\n\n' + history, author="DSaPP Researchers", diff --git a/tests/test_collate.py b/tests/test_collate.py index eca2cd4..9fbc8ae 100755 --- a/tests/test_collate.py +++ b/tests/test_collate.py @@ -18,16 +18,17 @@ def test_aggregate(): def test_aggregate_when(): agg = collate.Aggregate("1", "count") assert str(list(agg.get_columns(when="date < '2012-01-01'"))[0]) == ( - "count(CASE WHEN date < '2012-01-01' THEN 1 END)") + "count(1) FILTER (WHERE date < '2012-01-01')") def test_ordered_aggregate(): agg = collate.Aggregate("", "mode", "x") assert str(list(agg.get_columns())[0]) == "mode() WITHIN GROUP (ORDER BY x)" + assert list(agg.get_columns())[0].name == "x_mode" def test_ordered_aggregate_when(): agg = collate.Aggregate("", "mode", "x") assert str(list(agg.get_columns(when="date < '2012-01-01'"))[0]) == ( - "mode() WITHIN GROUP (ORDER BY CASE WHEN date < '2012-01-01' THEN x END)") + "mode() WITHIN GROUP (ORDER BY x) FILTER (WHERE date < '2012-01-01')") def test_aggregate_tuple_quantity(): agg = collate.Aggregate(("x","y"), "corr") @@ -36,8 +37,16 @@ def test_aggregate_tuple_quantity(): def test_aggregate_tuple_quantity_when(): agg = collate.Aggregate(("x","y"), "corr") assert str(list(agg.get_columns(when="date < '2012-01-01'"))[0]) == ( - "corr(CASE WHEN date < '2012-01-01' THEN x END, " - "CASE WHEN date < '2012-01-01' THEN y END)") + "corr(x, y) FILTER (WHERE date < '2012-01-01')") + +def test_aggregate_arithmetic(): + n = collate.Aggregate("x", "sum") + d = collate.Aggregate("1", "count") + m = collate.Aggregate("y", "avg") + + e = list((n/d + m).get_columns(prefix="prefix_"))[0] + assert str(e) == "((sum(x)*1.0 / count(1)) + avg(y))" + assert e.name == "prefix_x_sum/1_count+y_avg" def test_aggregate_format_kwargs(): agg = collate.Aggregate("'{collate_date}' - date", "min") @@ -58,3 +67,12 @@ def test_aggregation_table_name_no_schema(): assert collate.Aggregation([], from_obj='source', schema='schema', groups=[])\ .get_table_name() == '"schema"."source_aggregation"' + +def test_distinct(): + assert str(list(collate.Aggregate("distinct x", "count").get_columns())[0]) == "count(distinct x)" + + assert str(list(collate.Aggregate("distinct x", "count").get_columns(when="date < '2012-01-01'"))[0]) == "count(distinct x) FILTER (WHERE date < '2012-01-01')" + + assert str(list(collate.Aggregate("distinct(x)", "count").get_columns(when="date < '2012-01-01'"))[0]) == "count(distinct (x)) FILTER (WHERE date < '2012-01-01')" + + assert str(list(collate.Aggregate("distinct(x,y)", "count").get_columns(when="date < '2012-01-01'"))[0]) == "count(distinct (x,y)) FILTER (WHERE date < '2012-01-01')" diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 0000000..49fd06e --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,117 @@ +from collate import collate + +def assert_contains(haystack, needle): + for h in haystack: + if type(h) is tuple and needle in h[0]: return + elif needle in h: return + assert False + +def test_compare_lists(): + d = collate.Compare('col','=',['a','b','c'],[],include_null=True).quantities + assert len(d) == 4 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert_contains(d.values(), "col = 'a'") + assert_contains(d.values(), "col = 'b'") + assert_contains(d.values(), "col = 'c'") + assert_contains(map(lambda x: x[0].lower(), d.values()), "col is null") + + d = collate.Compare('col','>',[1,2,3],[]).quantities + assert len(d) == 3 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert_contains(d.values(), "col > 1") + assert_contains(d.values(), "col > 2") + assert_contains(d.values(), "col > 3") + + d = collate.Compare('col','=',['a','b','c'], [], include_null=False).quantities + assert len(d) == 3 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert_contains(d.values(), "col = 'a'") + assert_contains(d.values(), "col = 'b'") + assert_contains(d.values(), "col = 'c'") + + d = collate.Compare('really_long_column_name','=', + ['really long string value that is similar to others', + 'really long string value that is like others', + 'really long string value that is quite alike to others', + 'really long string value that is also like everything else'], [], maxlen=32).quantities + assert len(d) == 4 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert all(len(k) <= 32 for k in d.keys()) + assert_contains(d.values(), "really_long_column_name = 'really long string value that is similar to others'") + assert_contains(d.values(), "really_long_column_name = 'really long string value that is like others'") + assert_contains(d.values(), "really_long_column_name = 'really long string value that is quite alike to others'") + assert_contains(d.values(), "really_long_column_name = 'really long string value that is also like everything else'") + + +def test_compare_override_quoting(): + d = collate.Compare( + 'col', + '@>', + {'one': "array['one'::varchar]", 'two': "array['two'::varchar]"}, + [], + quote_choices=False + ).quantities + assert len(d) == 2 + assert_contains(d.values(), "col @> array['one'::varchar]") + assert_contains(d.values(), "col @> array['two'::varchar]") + + +def test_compare_dicts(): + d = collate.Compare('col','=',{'vala': 'a','valb': 'b','valc': 'c'}, [], include_null=True).quantities + assert len(d) == 4 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert_contains(d.values(), "col = 'a'") + assert_contains(d.values(), "col = 'b'") + assert_contains(d.values(), "col = 'c'") + assert_contains(d.keys(), 'vala') + assert_contains(d.keys(), 'valb') + assert_contains(d.keys(), 'valc') + assert_contains(map(str.lower, d.keys()), 'null') + assert_contains(map(lambda x: x[0].lower(), d.values()), "col is null") + + d = collate.Compare('col','<',{'val1': 1,'val2': 2,'val3': 3}, [], include_null='missing').quantities + assert len(d) == 4 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert_contains(d.values(), "col < 1") + assert_contains(d.values(), "col < 2") + assert_contains(d.values(), "col < 3") + assert_contains(map(lambda x: x[0].lower(), d.values()), "null") + assert_contains(d.keys(), 'val1') + assert_contains(d.keys(), 'val2') + assert_contains(d.keys(), 'val3') + assert_contains(d.keys(), 'missing') + + d = collate.Compare('long_column_name','=', + {'really long string key that is similar to others': 'really long string value that is similar to others', + 'really long string key that is like others': 'really long string value that is like others', + 'different key': 'really long string value that is quite alike to others', + 'ni': 'really long string value that is also like everything else'}, [], maxlen=32).quantities + assert len(d) == 4 + assert len(set(d.values())) == len(d) + assert len(set(d.keys())) == len(d) + assert all(len(k) <= 32 for k in d.keys()) + assert_contains(d.keys(), 'differ') + assert_contains(d.values(), "long_column_name = 'really long string value that is similar to others'") + assert_contains(d.values(), "long_column_name = 'really long string value that is like others'") + assert_contains(d.values(), "long_column_name = 'really long string value that is quite alike to others'") + assert_contains(d.values(), "long_column_name = 'really long string value that is also like everything else'") + +def test_categorical_same_as_compare(): + d1 = collate.Categorical('col',{'vala': 'a','valb': 'b','valc': 'c'}, []).quantities + d2 = collate.Compare('col','=',{'vala': 'a','valb': 'b','valc': 'c'}, []).quantities + assert sorted(d1.values()) == sorted(d2.values()) + d3 = collate.Categorical('col',{'vala': 'a','valb': 'b','valc': 'c'}, [], op_in_name=True).quantities + assert d2 == d3 + +def test_categorical_nones(): + d1 = collate.Categorical('col',{'vala': 'a','valb': 'b','valc': 'c','_NULL': None}, []).quantities + d2 = collate.Compare('col','=',{'vala': 'a','valb': 'b','valc': 'c'}, [], op_in_name=False, include_null=True).quantities + assert d1 == d2 + d3 = collate.Categorical('col',['a','b','c',None],[]).quantities + assert sorted(d1.values()) == sorted(d2.values()) diff --git a/tests/test_integration.py b/tests/test_integration.py index 06a2053..41f8fb4 100755 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -13,7 +13,8 @@ from os import path import sqlalchemy.sql.expression as ex -from collate import collate +from collate.collate import Aggregation, Aggregate +from collate.spacetime import SpacetimeAggregation with open(path.join(path.dirname(__file__), "config/database.yml")) as f: config = yaml.load(f) @@ -23,8 +24,9 @@ def test_engine(): assert len(engine.execute("SELECT * FROM food_inspections").fetchall()) == 966 def test_st_explicit_execute(): - agg = collate.Aggregate("results='Fail'",["count"]) - st = collate.SpacetimeAggregation([agg], + agg = Aggregate("results='Fail'",["count"]) + mode = Aggregate("", "mode", order="zip") + st = SpacetimeAggregation([agg, agg+agg, mode], from_obj = ex.table('food_inspections'), groups = {'license':ex.column('license_no'), 'zip':ex.column('zip')}, @@ -37,8 +39,8 @@ def test_st_explicit_execute(): st.execute(engine.connect()) def test_st_lazy_execute(): - agg = collate.Aggregate("results='Fail'",["count"]) - st = collate.SpacetimeAggregation([agg], + agg = Aggregate("results='Fail'",["count"]) + st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], intervals = {'license_no':["1 year", "2 years", "all"], @@ -49,8 +51,8 @@ def test_st_lazy_execute(): st.execute(engine.connect()) def test_st_execute_broadcast_intervals(): - agg = collate.Aggregate("results='Fail'",["count"]) - st = collate.SpacetimeAggregation([agg], + agg = Aggregate("results='Fail'",["count"]) + st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], intervals = ["1 year", "2 years", "all"], @@ -60,16 +62,16 @@ def test_st_execute_broadcast_intervals(): st.execute(engine.connect()) def test_execute(): - agg = collate.Aggregate("results='Fail'",["count"]) - st = collate.Aggregation([agg], + agg = Aggregate("results='Fail'",["count"]) + st = Aggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip']) st.execute(engine.connect()) def test_execute_schema_output_date_column(): - agg = collate.Aggregate("results='Fail'",["count"]) - st = collate.SpacetimeAggregation([agg], + agg = Aggregate("results='Fail'",["count"]) + st = SpacetimeAggregation([agg], from_obj = 'food_inspections', groups = ['license_no', 'zip'], intervals = {'license_no':["1 year", "2 years", "all"], diff --git a/tests/test_spacetime.py b/tests/test_spacetime.py new file mode 100755 index 0000000..266465e --- /dev/null +++ b/tests/test_spacetime.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +test_spacetime +---------------------------------- + +Unit tests for `collate.spacetime` module. +""" + +import pytest +from collate.collate import Aggregate +from collate.spacetime import SpacetimeAggregation + +import sqlalchemy +import testing.postgresql +from datetime import date + +events_data = [ + # entity id, event_date, outcome + [1, date(2014, 1, 1), True], + [1, date(2014, 11, 10), False], + [1, date(2015, 1, 1), False], + [1, date(2015, 11, 10), True], + [2, date(2013, 6, 8), True], + [2, date(2014, 6, 8), False], + [3, date(2014, 3, 3), False], + [3, date(2014, 7, 24), False], + [3, date(2015, 3, 3), True], + [3, date(2015, 7, 24), False], + [4, date(2015, 12, 13), False], + [4, date(2016, 12, 13), True], +] + + +def test_basic_spacetime(): + with testing.postgresql.Postgresql() as psql: + engine = sqlalchemy.create_engine(psql.url()) + engine.execute( + 'create table events (entity_id int, date date, outcome bool)' + ) + for event in events_data: + engine.execute( + 'insert into events values (%s, %s, %s::bool)', + event + ) + + st = SpacetimeAggregation([Aggregate('outcome::int',['sum','avg'])], + from_obj = 'events', + groups = ['entity_id'], + intervals = ['1y', '2y', 'all'], + dates = ['2016-01-01', '2015-01-01'], + date_column = '"date"') + + st.execute(engine.connect()) + + r = engine.execute('select * from events_entity_id order by entity_id, date') + rows = [x for x in r] + assert rows[0]['entity_id'] == 1 + assert rows[0]['date'] == date(2015, 1, 1) + assert rows[0]['events_entity_id_1y_outcome::int_sum'] == 1 + assert rows[0]['events_entity_id_1y_outcome::int_avg'] == 0.5 + assert rows[0]['events_entity_id_2y_outcome::int_sum'] == 1 + assert rows[0]['events_entity_id_2y_outcome::int_avg'] == 0.5 + assert rows[0]['events_entity_id_all_outcome::int_sum'] == 1 + assert rows[0]['events_entity_id_all_outcome::int_avg'] == 0.5 + assert rows[1]['entity_id'] == 1 + assert rows[1]['date'] == date(2016, 1, 1) + assert rows[1]['events_entity_id_1y_outcome::int_sum'] == 1 + assert rows[1]['events_entity_id_1y_outcome::int_avg'] == 0.5 + assert rows[1]['events_entity_id_2y_outcome::int_sum'] == 2 + assert rows[1]['events_entity_id_2y_outcome::int_avg'] == 0.5 + assert rows[1]['events_entity_id_all_outcome::int_sum'] == 2 + assert rows[1]['events_entity_id_all_outcome::int_avg'] == 0.5 + + assert rows[2]['entity_id'] == 2 + assert rows[2]['date'] == date(2015, 1, 1) + assert rows[2]['events_entity_id_1y_outcome::int_sum'] == 0 + assert rows[2]['events_entity_id_1y_outcome::int_avg'] == 0 + assert rows[2]['events_entity_id_2y_outcome::int_sum'] == 1 + assert rows[2]['events_entity_id_2y_outcome::int_avg'] == 0.5 + assert rows[2]['events_entity_id_all_outcome::int_sum'] == 1 + assert rows[2]['events_entity_id_all_outcome::int_avg'] == 0.5 + assert rows[3]['entity_id'] == 2 + assert rows[3]['date'] == date(2016, 1, 1) + assert rows[3]['events_entity_id_1y_outcome::int_sum'] == None + assert rows[3]['events_entity_id_1y_outcome::int_avg'] == None + assert rows[3]['events_entity_id_2y_outcome::int_sum'] == 0 + assert rows[3]['events_entity_id_2y_outcome::int_avg'] == 0 + assert rows[3]['events_entity_id_all_outcome::int_sum'] == 1 + assert rows[3]['events_entity_id_all_outcome::int_avg'] == 0.5 + + assert rows[4]['entity_id'] == 3 + assert rows[4]['date'] == date(2015, 1, 1) + assert rows[4]['events_entity_id_1y_outcome::int_sum'] == 0 + assert rows[4]['events_entity_id_1y_outcome::int_avg'] == 0 + assert rows[4]['events_entity_id_2y_outcome::int_sum'] == 0 + assert rows[4]['events_entity_id_2y_outcome::int_avg'] == 0 + assert rows[4]['events_entity_id_all_outcome::int_sum'] == 0 + assert rows[4]['events_entity_id_all_outcome::int_avg'] == 0 + assert rows[5]['entity_id'] == 3 + assert rows[5]['date'] == date(2016, 1, 1) + assert rows[5]['events_entity_id_1y_outcome::int_sum'] == 1 + assert rows[5]['events_entity_id_1y_outcome::int_avg'] == 0.5 + assert rows[5]['events_entity_id_2y_outcome::int_sum'] == 1 + assert rows[5]['events_entity_id_2y_outcome::int_avg'] == 0.25 + assert rows[5]['events_entity_id_all_outcome::int_sum'] == 1 + assert rows[5]['events_entity_id_all_outcome::int_avg'] == 0.25 + + assert rows[6]['entity_id'] == 4 + # rows[6]['date'] == date(2015, 1, 1) is skipped due to no data! + assert rows[6]['date'] == date(2016, 1, 1) + assert rows[6]['events_entity_id_1y_outcome::int_sum'] == 0 + assert rows[6]['events_entity_id_1y_outcome::int_avg'] == 0 + assert rows[6]['events_entity_id_2y_outcome::int_sum'] == 0 + assert rows[6]['events_entity_id_2y_outcome::int_avg'] == 0 + assert rows[6]['events_entity_id_all_outcome::int_sum'] == 0 + assert rows[6]['events_entity_id_all_outcome::int_avg'] == 0 + assert len(rows) == 7 + +def test_input_min_date(): + with testing.postgresql.Postgresql() as psql: + engine = sqlalchemy.create_engine(psql.url()) + engine.execute( + 'create table events (entity_id int, date date, outcome bool)' + ) + for event in events_data: + engine.execute( + 'insert into events values (%s, %s, %s::bool)', + event + ) + + st = SpacetimeAggregation([Aggregate('outcome::int',['sum','avg'])], + from_obj = 'events', + groups = ['entity_id'], + intervals = ['all'], + dates = ['2016-01-01'], + date_column = '"date"', + input_min_date = '2015-11-10') + + st.execute(engine.connect()) + + r = engine.execute('select * from events_entity_id order by entity_id') + rows = [x for x in r] + + assert rows[0]['entity_id'] == 1 + assert rows[0]['date'] == date(2016, 1, 1) + assert rows[0]['events_entity_id_all_outcome::int_sum'] == 1 + assert rows[0]['events_entity_id_all_outcome::int_avg'] == 1 + assert rows[1]['entity_id'] == 4 + assert rows[1]['date'] == date(2016, 1, 1) + assert rows[1]['events_entity_id_all_outcome::int_sum'] == 0 + assert rows[1]['events_entity_id_all_outcome::int_avg'] == 0 + + assert len(rows) == 2 + + st = SpacetimeAggregation([Aggregate('outcome::int',['sum','avg'])], + from_obj = 'events', + groups = ['entity_id'], + intervals = ['1y', 'all'], + dates = ['2016-01-01', '2015-01-01'], + date_column = '"date"', + input_min_date = '2014-11-10') + with pytest.raises(ValueError): + st.validate(engine.connect()) + with pytest.raises(ValueError): + st.execute(engine.connect())