diff --git a/tests/test_dynamo.py b/tests/test_dynamo.py index 93ae8b35c31..2539b3c190b 100644 --- a/tests/test_dynamo.py +++ b/tests/test_dynamo.py @@ -664,6 +664,57 @@ def test_server_side_filter_may_not_filter_nonkey_attrs(self): 'sort': dynamo.BeginsWith('asd'), }) + def test_may_not_use_condition_on_table_partition_key(self): + self.table.create({'id': 'key', 'sort': 'asdf'}) + + with self.assertRaises(ValueError): + self.table.get_many({ + 'id': dynamo.BeginsWith('k'), + }) + + def test_may_not_use_condition_on_index_partition_key(self): + self.table = dynamo.Table( + dynamo.MemoryStorage(), + 'table', + partition_key='id', + sort_key='sort', + indexes=[ + dynamo.Index('email'), + dynamo.Index('epoch', sort_key='created'), + dynamo.Index('id', sort_key='epoch', keys_only=True), + ], + ) + + with self.assertRaises(ValueError): + self.table.get_many({ + 'id': dynamo.BeginsWith('k'), + }) + + def test_can_disambiguate_between_indexes_with_same_pk(self): + self.table = dynamo.Table( + dynamo.MemoryStorage(), + 'table', + partition_key='id', + sort_key='sort', + indexes=[ + dynamo.Index('sort', sort_key='attr1', keys_only=True), + dynamo.Index('sort', sort_key='attr2', keys_only=True), + ], + ) + + self.table.create({'id': 'key', 'sort': 'asdf', 'attr1': 'asdf'}) + self.table.create({'id': 'key', 'sort': 'qwer', 'attr2': 'qwer'}) + + with self.assertRaises(ValueError): + # Expect to be given a disambiguation error + self.table.get_many(dict(sort='asdf')) + + result1 = self.table.get_many(dict(sort='asdf', attr1=dynamo.UseThisIndex())) + self.assertEqual(list(result1), [{'sort': 'asdf', 'attr1': 'asdf', 'id': 'key'}]) + + result2 = self.table.get_many(dict(sort='qwer', attr2=dynamo.UseThisIndex())) + self.assertEqual(list(result2), [{'sort': 'qwer', 'attr2': 'qwer', 'id': 'key'}]) + class TestSortKeysAgainstAws(unittest.TestCase): """Test that the operations send out appropriate Dynamo requests.""" @@ -727,6 +778,32 @@ def test_key_with_hash_in_it(self): ScanIndexForward=mock.ANY ) + def test_usethisindex_object(self): + self.table = dynamo.Table( + dynamo.AwsDynamoStorage(self.db, ''), + 'table', + partition_key='pk', + sort_key='sk', + indexes=[ + dynamo.Index('ik', sort_key='pk'), + ], + ) + self.table.get_many({ + 'ik': '3', + 'pk': dynamo.UseThisIndex(), + }) + self.db.query.assert_called_with( + KeyConditionExpression='#ik = :ik', + ExpressionAttributeValues={ + ':ik': {'S': '3'}, + }, ExpressionAttributeNames={ + '#ik': 'ik', + }, + TableName=mock.ANY, + IndexName='ik-pk-index', + ScanIndexForward=mock.ANY + ) + def try_to_delete(filename): if os.path.exists(filename): diff --git a/website/dynamo.py b/website/dynamo.py index 49281525925..9a1447b9c2b 100644 --- a/website/dynamo.py +++ b/website/dynamo.py @@ -76,6 +76,7 @@ class KeySchema: def __init__(self, partition_key, sort_key=None): self.partition_key = partition_key self.sort_key = sort_key + self.is_compound = sort_key is not None # Both names in an array self.key_names = [self.partition_key] + ([self.sort_key] if self.sort_key else []) @@ -247,12 +248,16 @@ class Table: You can use: str, list, bool, bytes, int, float, numbers.Number, dict, list, string_set, number_set, binary_set (last 3 declared in this module). """ + key_schema: KeySchema + storage: TableStorage + indexes: List[Index] - def __init__(self, storage: TableStorage, table_name, partition_key, types=None, sort_key=None, indexes=None): + def __init__(self, storage: TableStorage, table_name, partition_key, types=None, + sort_key=None, indexes: Optional[List[Index]] = None): self.key_schema = KeySchema(partition_key, sort_key) self.storage = storage self.table_name = table_name - self.indexes = indexes or [] + self.indexes: List[Index] = indexes or [] self.indexed_fields = set() if types is not None: self.types = Validator.ensure_all(types) @@ -264,11 +269,9 @@ def __init__(self, storage: TableStorage, table_name, partition_key, types=None, for field in schema.key_names: self.indexed_fields.add(field) - # Check to make sure the indexes have unique partition keys - part_names = reverse_index((index.index_name, index.key_schema.partition_key) for index in self.indexes) - duped = [names for names in part_names.values() if len(names) > 1] - if duped: - raise ValueError(f'Table {self.table_name}: indexes with the same partition key: {duped}') + # Check to make sure the indexes have unique partition keys. We do this to unambiguously + # check which index to use for a given query. + self._validate_indexes_unambiguous() # Check to make sure all indexed fields have a declared type if self.types: @@ -659,28 +662,64 @@ def item_count(self): return self.storage.item_count(self.table_name) def _determine_lookup(self, key_data, many): + """Given the key data, determine where we should perform the lookup. + + This can be either on the main table, or on one of the indexes. + + If the key data contains both a partition key and sort key, the table or + index is identified unambiguously since the combination of (PK, SK) must + be unique. + + If the key data contains only a partition key, we do the following: + + - If the PK matches the PK of the table, we do a table lookup. + - If the PK matches a single index, we do a lookup in that index. + - If the PK matches multiple indexes, we raise an error. The lookup + needs to be disambiguated by adding a sort key with a `UseThisIndex` + field. + + TODO: what this makes impossible is having an index with (for example) + PK and SK reversed; we would short-circuit to using the table + immediately, whereas in that case we'd actually want to query an index. + This is something we can fix later. The more appropriate algorithm + would probably be: filter down all candidates first taking into account + not only the name of the field but also the type of condition, THEN + prefer the table if it is still a viable query candidate. + """ if any(not v for v in key_data.values()): raise ValueError(f"Key data cannot have empty values: {key_data}") # We do a regular table lookup if both the table partition and sort keys occur in the given key. if self.key_schema.matches(key_data): - # Sanity check that if we expect to query 1 element, we must pass a sort key if defined + # Sanity check that if we expect to query 1 element from the table, we must pass a sort key if defined if not many and not self.key_schema.contains_both_keys(key_data): - raise RuntimeError( + raise ValueError( f"Looking up one value, but missing sort key: {self.key_schema.sort_key} in {key_data}") return TableLookup(self.table_name, key_data) - # We do an index table lookup if the partition (and possibly the sort key) of an index occur in the given key. - for index in self.indexes: - if index.key_schema.matches(key_data): - return IndexLookup(self.table_name, index.index_name, key_data, - index.key_schema.sort_key, keys_only=index.keys_only, key_schema=index.key_schema) + potential_indexes = [index for index in self.indexes if index.key_schema.matches(key_data)] - schemas = [self.key_schema] + [i.key_schema for i in self.indexes] - str_schemas = ', '.join(s.to_string(opt=True) for s in schemas) + data_keys = tuple(key_data.keys()) + + if not potential_indexes: + schemas = [self.key_schema] + [i.key_schema for i in self.indexes] + str_schemas = ', '.join(s.to_string(opt=True) for s in schemas) + raise ValueError( + f"Table {self.table_name} can be queried using one of {str_schemas}. Got {data_keys}") + + if len(potential_indexes) == 1: + index = potential_indexes[0] + return IndexLookup(self.table_name, index.index_name, key_data, + index.key_schema.sort_key, keys_only=index.keys_only, key_schema=index.key_schema) + + # More than one index. This can only happen if a user passed a PK that is used + # in multiple indexes. Frame a helpful error message. + sort_keys = [i.key_schema.sort_key for i in potential_indexes] raise ValueError( - f"Table {self.table_name} can be queried using one of {str_schemas}. Got {tuple(key_data.keys())}") + f'Table {self.table_name} has multiple indexes with partition key \'{data_keys[0]}\'. ' + + f'Include one of these sort keys in your query {sort_keys} ' + + 'with a value of UseThisIndex() to indicate which index you want to query') def _validate_key(self, key): if not self.key_schema.contains_both_keys(key): @@ -729,6 +768,34 @@ def _validate_types(self, data, full): if not validate_value_against_validator(value, validator): raise ValueError(f'In {data}, value of {field} should be {validator} (got {value})') + def _validate_indexes_unambiguous(self): + """From a list of Index objects, make sure there are no duplicate sets of the same PK and SK. + + Also, there must not be an index with a PK that is a subset of an existing combination + of PK and SK, because we wouldn't be able to disambiguate between them. + """ + seen = set() + pk_of_compound = set() + + # Add the table schema to begin with (we need to disambiguate with the table as well) + seen.add(tuple(self.key_schema.key_names)) + if self.key_schema.is_compound: + pk_of_compound.add(self.key_schema.partition_key) + + for index in self.indexes: + key_names = tuple(index.key_schema.key_names) + if key_names in seen: + raise ValueError(f'Table {self.table_name}: multiple indexes with the same key: {key_names}') + + seen.add(key_names) + if index.key_schema.is_compound: + pk_of_compound.add(index.key_schema.partition_key) + + for index in self.indexes: + if not index.key_schema.is_compound and index.key_schema.partition_key in pk_of_compound: + raise ValueError( + f'Table {self.table_name}: PK-only index is a subset of a compound index: {index.key_schema}') + def validate_value_against_validator(value, validator: 'Validator'): """Validate a value against a validator. @@ -900,16 +967,24 @@ def _prep_query_data(self, key, sort_key=None, is_key_expression=True): if is_key_expression: validate_only_sort_key(conditions, sort_key) - escaped_names = {k: slugify(k) for k in conditions.keys()} - - key_expression = " AND ".join(cond.to_dynamo_expression(escaped_names[field]) - for field, cond in conditions.items()) - + escaped_names = {} + key_conditions = [] attr_values = {} + attr_names = {} + for field, cond in conditions.items(): - attr_values.update(cond.to_dynamo_values(escaped_names[field])) + escaped_name = slugify(field) + expr = cond.to_dynamo_expression(escaped_name) + # This may return 'None' to avoid emitting this condition to DDB altogether + if expr is None: + continue + + escaped_names[field] = escaped_name + key_conditions.append(expr) + attr_values.update(cond.to_dynamo_values(escaped_name)) + attr_names[f'#{escaped_name}'] = field - attr_names = {f'#{e}': k for k, e in escaped_names.items()} + key_expression = " AND ".join(key_conditions) return key_expression, attr_values, attr_names def put(self, table_name, _key, data): @@ -1410,7 +1485,13 @@ def make_conditions(key): class Equals(DynamoCondition): - """Assert that a value is equal to another value.""" + """Assert that a value is equal to another value. + + Conditions can be applied to sort keys for efficient lookup, or as a + `server_side_filter` as a post-retrieval, pre-download filter. Queries will + never fetch more than 1MB from disk, so your server-side filter should + not filter out more than ~50% of the rows. + """ def __init__(self, value): self.value = value @@ -1428,7 +1509,13 @@ def matches(self, value): class Between(DynamoCondition): - """Assert that a value is between two other values.""" + """Assert that a value is between two other values. + + Conditions can be applied to sort keys for efficient lookup, or as a + `server_side_filter` as a post-retrieval, pre-download filter. Queries will + never fetch more than 1MB from disk, so your server-side filter should + not filter out more than ~50% of the rows. + """ def __init__(self, minval, maxval): self.minval = minval @@ -1448,7 +1535,13 @@ def matches(self, value): class BeginsWith(DynamoCondition): - """Assert that a string begins with another string.""" + """Assert that a string begins with another string. + + Conditions can be applied to sort keys for efficient lookup, or as a + `server_side_filter` as a post-retrieval, pre-download filter. Queries will + never fetch more than 1MB from disk, so your server-side filter should + not filter out more than ~50% of the rows. + """ def __init__(self, prefix): self.prefix = prefix @@ -1465,6 +1558,35 @@ def matches(self, value): return isinstance(value, str) and value.startswith(self.prefix) +class UseThisIndex(DynamoCondition): + """A dummy condition that always matches, and allows picking a specific index. + + If you have multiple indexes on the same primary key but with a different + sort key, you need a way to disambiguate between those indexes. I.e. you add + a field with a `UseThisIndex` to indicate which sort key you want to use. + + In practice, it looks like this: + + table.get_many({ + "pk": "some_value", + "preferred_sortkey": UseThisIndex(), + }) + """ + + def __init__(self): + pass + + def to_dynamo_expression(self, field_name): + # Dynamo does not support the expression "true", so we'll have to make exceptions + return None + + def to_dynamo_values(self, field_name): + return {} + + def matches(self, value): + return True + + def replace_decimals(obj): """ Replace Decimals with native Python values. @@ -1511,7 +1633,7 @@ def validate_only_sort_key(conds, sort_key): """Check that non-Equals conditions are only used on the sort key.""" non_equals_fields = [k for k, v in conds.items() if not isinstance(v, Equals)] if sort_key and set(non_equals_fields) - {sort_key}: - raise RuntimeError(f"Non-Equals conditions only allowed on sort key {sort_key}, got: {list(conds)}") + raise ValueError(f"Non-Equals conditions only allowed on sort key {sort_key}, got: {list(conds)}") def encode_page_token(x, inverted):