From 5e1ad19281613dd2b631e2feeceac271f4dc581f Mon Sep 17 00:00:00 2001 From: oliver Date: Wed, 1 Apr 2020 10:54:56 +0100 Subject: [PATCH] Implement Array operators --- docs/querysets.md | 8 +++++- src/infi/clickhouse_orm/query.py | 47 +++++++++++++++++++++++++++++++- tests/base_test_with_data.py | 10 +++---- tests/test_database.py | 4 ++- tests/test_querysets.py | 25 +++++++++++++++++ 5 files changed, 86 insertions(+), 8 deletions(-) diff --git a/docs/querysets.md b/docs/querysets.md index 056e794..d6dfe7d 100644 --- a/docs/querysets.md +++ b/docs/querysets.md @@ -60,6 +60,12 @@ There are different operators that can be used, by passing `__>](field_types.md) \ No newline at end of file +[<< Models and Databases](models_and_databases.md) | [Table of Contents](toc.md) | [Field Types >>](field_types.md) diff --git a/src/infi/clickhouse_orm/query.py b/src/infi/clickhouse_orm/query.py index f1fb119..2189c02 100644 --- a/src/infi/clickhouse_orm/query.py +++ b/src/infi/clickhouse_orm/query.py @@ -7,11 +7,11 @@ from .engines import CollapsingMergeTree from .utils import comma_join +from . import fields # TODO # - check that field names are valid -# - operators for arrays: length, has, empty class Operator(object): """ @@ -131,6 +131,45 @@ def to_sql(self, model_cls, field_name, value): if value1 and not value0: return ' '.join([field_name, '<=', value1]) + +class FuncOperator(Operator): + """ + An operator that implements func(field, value). Use this to write + selects involving functions: + - 'SELECT * FROM table WHERE func(field, value)' + 'value' must have either the same type as 'field', or be the 'inner' type + of the field (e.g. for arrays). + """ + def __init__(self, funcname, inner=False): + self._funcname = funcname + self._inner = inner + + def to_sql(self, model_cls, field_name, value): + field = getattr(model_cls, field_name) + if self._inner: + field = field.inner_field + value = field.to_db_string(field.to_python(value, pytz.utc)) + return '%s(%s, %s)' % (self._funcname, field_name, value) + + +class FuncEqOperator(Operator): + """ + An operator that implements func(field) = value. Use this to write selects + of the form + - 'SELECT col_x FROM table WHERE func(col_y) = value' + """ + + def __init__(self, funcname, return_type): + self._funcname = funcname + self._return_type = return_type + assert isinstance(return_type, fields.Field) + + def to_sql(self, model_cls, field_name, value): + type_ = self._return_type + value = type_.to_db_string(type_.to_python(value, pytz.utc)) + return '%s(%s) = %s' % (self._funcname, field_name, value) + + # Define the set of builtin operators _operators = {} @@ -154,6 +193,12 @@ def register_operator(name, sql): register_operator('istartswith', LikeOperator('{}%', False)) register_operator('iendswith', LikeOperator('%{}', False)) register_operator('iexact', IExactOperator()) +register_operator('has', FuncOperator('has', inner=True)) +register_operator('has_any', FuncOperator('hasAny')) +register_operator('has_all', FuncOperator('hasAll')) +register_operator('length', FuncEqOperator('length', return_type=fields.UInt64Field())) +register_operator('empty', FuncEqOperator('empty', return_type=fields.UInt8Field())) +register_operator('not_empty', FuncEqOperator('notEmpty', return_type=fields.UInt8Field())) class FOV(object): diff --git a/tests/base_test_with_data.py b/tests/base_test_with_data.py index 8cbea48..a135800 100644 --- a/tests/base_test_with_data.py +++ b/tests/base_test_with_data.py @@ -39,18 +39,18 @@ class Person(Model): birthday = DateField() height = Float32Field() passport = NullableField(UInt32Field()) + addresses = ArrayField(StringField()) engine = MergeTree('birthday', ('first_name', 'last_name', 'birthday')) data = [ {"first_name": "Abdul", "last_name": "Hester", "birthday": "1970-12-02", "height": "1.63", - "passport": 35052255}, - + "passport": 35052255, "addresses": ["Elm Street", "Accacia Avenue"]}, {"first_name": "Adam", "last_name": "Goodman", "birthday": "1986-01-07", "height": "1.74", - "passport": 36052255}, - - {"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66"}, + "passport": 36052255, "addresses": ["Elm Street"]}, + {"first_name": "Adena", "last_name": "Norman", "birthday": "1979-05-14", "height": "1.66", + "addresses": ["My House"]}, {"first_name": "Aline", "last_name": "Crane", "birthday": "1988-05-01", "height": "1.62"}, {"first_name": "Althea", "last_name": "Barrett", "birthday": "2004-07-28", "height": "1.71"}, {"first_name": "Amanda", "last_name": "Vang", "birthday": "1973-02-23", "height": "1.68"}, diff --git a/tests/test_database.py b/tests/test_database.py index 4911a49..e005c69 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -142,7 +142,9 @@ def test_raw(self): self._insert_and_check(self._sample_data(), len(data)) query = "SELECT * FROM `test-db`.person WHERE first_name = 'Whitney' ORDER BY last_name" results = self.database.raw(query) - self.assertEqual(results, "Whitney\tDurham\t1977-09-15\t1.72\t\\N\nWhitney\tScott\t1971-07-04\t1.7\t\\N\n") + self.assertEqual( + results, "Whitney\tDurham\t1977-09-15\t1.72\t\\N\t[]\nWhitney\tScott\t1971-07-04\t1.7\t\\N\t[]\n" + ) def test_invalid_user(self): with self.assertRaises(ServerError) as cm: diff --git a/tests/test_querysets.py b/tests/test_querysets.py index 4a6a17b..18764ab 100644 --- a/tests/test_querysets.py +++ b/tests/test_querysets.py @@ -70,6 +70,31 @@ def test_filter_string_field(self): self._test_qs(qs.filter(first_name__iendswith='ia'), 3) # case insensitive self._test_qs(qs.filter(first_name__iendswith=''), 100) # empty suffix + def test_filter_array(self): + qs = Person.objects_in(self.database) + self._test_qs(qs.filter(addresses__has="Elm Street"), 2) + self._test_qs(qs.filter(addresses__has="Neverland"), 0) + self._test_qs(qs.filter(addresses__has="My House"), 1) + + self._test_qs(qs.filter(addresses__has_all=["Elm Street", "Accacia Avenue"]), 1) + self._test_qs(qs.filter(addresses__has_all=["Elm Street", "Neverland"]), 0) + self._test_qs(qs.filter(addresses__has_any=["Elm Street", "Neverland"]), 2) + + total = qs.count() + self._test_qs(qs.filter(addresses__length=2), 1) + self._test_qs(qs.filter(addresses__length=1), 2) + self._test_qs(qs.filter(addresses__length=0), total - 3) + + self._test_qs(qs.filter(addresses__empty=False), 3) + self._test_qs(qs.filter(addresses__empty=True), total - 3) + self._test_qs(qs.filter(addresses__empty=0), 3) + self._test_qs(qs.filter(addresses__empty=1), total - 3) + + self._test_qs(qs.filter(addresses__not_empty=True), 3) + self._test_qs(qs.filter(addresses__not_empty=1), 3) + self._test_qs(qs.filter(addresses__not_empty=False), total - 3) + self._test_qs(qs.filter(addresses__not_empty=0), total - 3) + def test_filter_with_q_objects(self): qs = Person.objects_in(self.database) self._test_qs(qs.filter(Q(first_name='Ciaran')), 2)