Skip to content

Commit

Permalink
🚚 Add ability to have multiple indexes with same PK in Dynamo layer (#…
Browse files Browse the repository at this point in the history
…6034)

In the past, we used to disallow having multiple indexes with the
same partition key; the reason was that we wanted to be able to
determine from any `{ "input": "dict" }` immediately which index
we should use.

And since sort keys are optional, the simplest way we could do that was
to make sure all PKs were unique: that way an index would always be
unambiguously identified.

This now blocks a use case we have, where we want to have 2 indexes with
the same partition key but with 2 different sort keys (for the purposes of
sorting the results differently).

In this PR, make it possible to have multiple indexes with the same
partition key, as long as the sort keys are different. This now
does pose a problem if you query with only a partition key and there
are two indexes that include that partition key:

```py
indexes = [
    Index('epoch', sort_key='created'), # Index 1
    Index('epoch', sort_key='username'), # Index 2
]

results = table.get_many({ "epoch": 1 })
```

Should this use index 1 (return users sorted by timestamp) or should
this use index 2 (return users sorted by username?).

I considered picking the first index in the array that applies because
this would be backwards compatible when adding indexes to the end of
the array... but I'm concerned that this is too implicit and could lead
to unexpected results or action-at-a-distance problems: changing the
order of the `indexes` array could all of a sudden make another part of
the code base incorrect.

So the query above is now an error. When there are multiple indexes
that could apply, you need to explicitly disambiguate between
them by putting a dummy condition on the sort key field you want
to use:

```py
results = table.get_many({
    "epoch": 1,
    "created": UseThisIndex(),
})
```

This makes it clear which index should be used, at the expense of having
to annotate all existing `get_many()` calls in the code base when you
add a conflicting index. The maintainability advantages of the extra
explicitness should hopefully win out in the long run.

**How to test**

No user visible changes, purely a feature change for @jpelay
  • Loading branch information
rix0rrr authored Dec 10, 2024
1 parent ce14c74 commit baf2815
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 28 deletions.
77 changes: 77 additions & 0 deletions tests/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,57 @@ def test_server_side_filter_may_not_filter_nonkey_attrs(self):
'sort': dynamo.BeginsWith('asd'),
})

def test_may_not_use_condition_on_table_partition_key(self):
self.table.create({'id': 'key', 'sort': 'asdf'})

with self.assertRaises(ValueError):
self.table.get_many({
'id': dynamo.BeginsWith('k'),
})

def test_may_not_use_condition_on_index_partition_key(self):
self.table = dynamo.Table(
dynamo.MemoryStorage(),
'table',
partition_key='id',
sort_key='sort',
indexes=[
dynamo.Index('email'),
dynamo.Index('epoch', sort_key='created'),
dynamo.Index('id', sort_key='epoch', keys_only=True),
],
)

with self.assertRaises(ValueError):
self.table.get_many({
'id': dynamo.BeginsWith('k'),
})

def test_can_disambiguate_between_indexes_with_same_pk(self):
self.table = dynamo.Table(
dynamo.MemoryStorage(),
'table',
partition_key='id',
sort_key='sort',
indexes=[
dynamo.Index('sort', sort_key='attr1', keys_only=True),
dynamo.Index('sort', sort_key='attr2', keys_only=True),
],
)

self.table.create({'id': 'key', 'sort': 'asdf', 'attr1': 'asdf'})
self.table.create({'id': 'key', 'sort': 'qwer', 'attr2': 'qwer'})

with self.assertRaises(ValueError):
# Expect to be given a disambiguation error
self.table.get_many(dict(sort='asdf'))

result1 = self.table.get_many(dict(sort='asdf', attr1=dynamo.UseThisIndex()))
self.assertEqual(list(result1), [{'sort': 'asdf', 'attr1': 'asdf', 'id': 'key'}])

result2 = self.table.get_many(dict(sort='qwer', attr2=dynamo.UseThisIndex()))
self.assertEqual(list(result2), [{'sort': 'qwer', 'attr2': 'qwer', 'id': 'key'}])


class TestSortKeysAgainstAws(unittest.TestCase):
"""Test that the operations send out appropriate Dynamo requests."""
Expand Down Expand Up @@ -727,6 +778,32 @@ def test_key_with_hash_in_it(self):
ScanIndexForward=mock.ANY
)

def test_usethisindex_object(self):
self.table = dynamo.Table(
dynamo.AwsDynamoStorage(self.db, ''),
'table',
partition_key='pk',
sort_key='sk',
indexes=[
dynamo.Index('ik', sort_key='pk'),
],
)
self.table.get_many({
'ik': '3',
'pk': dynamo.UseThisIndex(),
})
self.db.query.assert_called_with(
KeyConditionExpression='#ik = :ik',
ExpressionAttributeValues={
':ik': {'S': '3'},
}, ExpressionAttributeNames={
'#ik': 'ik',
},
TableName=mock.ANY,
IndexName='ik-pk-index',
ScanIndexForward=mock.ANY
)


def try_to_delete(filename):
if os.path.exists(filename):
Expand Down
178 changes: 150 additions & 28 deletions website/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class KeySchema:
def __init__(self, partition_key, sort_key=None):
self.partition_key = partition_key
self.sort_key = sort_key
self.is_compound = sort_key is not None

# Both names in an array
self.key_names = [self.partition_key] + ([self.sort_key] if self.sort_key else [])
Expand Down Expand Up @@ -247,12 +248,16 @@ class Table:
You can use: str, list, bool, bytes, int, float, numbers.Number, dict, list,
string_set, number_set, binary_set (last 3 declared in this module).
"""
key_schema: KeySchema
storage: TableStorage
indexes: List[Index]

def __init__(self, storage: TableStorage, table_name, partition_key, types=None, sort_key=None, indexes=None):
def __init__(self, storage: TableStorage, table_name, partition_key, types=None,
sort_key=None, indexes: Optional[List[Index]] = None):
self.key_schema = KeySchema(partition_key, sort_key)
self.storage = storage
self.table_name = table_name
self.indexes = indexes or []
self.indexes: List[Index] = indexes or []
self.indexed_fields = set()
if types is not None:
self.types = Validator.ensure_all(types)
Expand All @@ -264,11 +269,9 @@ def __init__(self, storage: TableStorage, table_name, partition_key, types=None,
for field in schema.key_names:
self.indexed_fields.add(field)

# Check to make sure the indexes have unique partition keys
part_names = reverse_index((index.index_name, index.key_schema.partition_key) for index in self.indexes)
duped = [names for names in part_names.values() if len(names) > 1]
if duped:
raise ValueError(f'Table {self.table_name}: indexes with the same partition key: {duped}')
# Check to make sure the indexes have unique partition keys. We do this to unambiguously
# check which index to use for a given query.
self._validate_indexes_unambiguous()

# Check to make sure all indexed fields have a declared type
if self.types:
Expand Down Expand Up @@ -659,28 +662,64 @@ def item_count(self):
return self.storage.item_count(self.table_name)

def _determine_lookup(self, key_data, many):
"""Given the key data, determine where we should perform the lookup.
This can be either on the main table, or on one of the indexes.
If the key data contains both a partition key and sort key, the table or
index is identified unambiguously since the combination of (PK, SK) must
be unique.
If the key data contains only a partition key, we do the following:
- If the PK matches the PK of the table, we do a table lookup.
- If the PK matches a single index, we do a lookup in that index.
- If the PK matches multiple indexes, we raise an error. The lookup
needs to be disambiguated by adding a sort key with a `UseThisIndex`
field.
TODO: what this makes impossible is having an index with (for example)
PK and SK reversed; we would short-circuit to using the table
immediately, whereas in that case we'd actually want to query an index.
This is something we can fix later. The more appropriate algorithm
would probably be: filter down all candidates first taking into account
not only the name of the field but also the type of condition, THEN
prefer the table if it is still a viable query candidate.
"""
if any(not v for v in key_data.values()):
raise ValueError(f"Key data cannot have empty values: {key_data}")

# We do a regular table lookup if both the table partition and sort keys occur in the given key.
if self.key_schema.matches(key_data):
# Sanity check that if we expect to query 1 element, we must pass a sort key if defined
# Sanity check that if we expect to query 1 element from the table, we must pass a sort key if defined
if not many and not self.key_schema.contains_both_keys(key_data):
raise RuntimeError(
raise ValueError(
f"Looking up one value, but missing sort key: {self.key_schema.sort_key} in {key_data}")

return TableLookup(self.table_name, key_data)

# We do an index table lookup if the partition (and possibly the sort key) of an index occur in the given key.
for index in self.indexes:
if index.key_schema.matches(key_data):
return IndexLookup(self.table_name, index.index_name, key_data,
index.key_schema.sort_key, keys_only=index.keys_only, key_schema=index.key_schema)
potential_indexes = [index for index in self.indexes if index.key_schema.matches(key_data)]

schemas = [self.key_schema] + [i.key_schema for i in self.indexes]
str_schemas = ', '.join(s.to_string(opt=True) for s in schemas)
data_keys = tuple(key_data.keys())

if not potential_indexes:
schemas = [self.key_schema] + [i.key_schema for i in self.indexes]
str_schemas = ', '.join(s.to_string(opt=True) for s in schemas)
raise ValueError(
f"Table {self.table_name} can be queried using one of {str_schemas}. Got {data_keys}")

if len(potential_indexes) == 1:
index = potential_indexes[0]
return IndexLookup(self.table_name, index.index_name, key_data,
index.key_schema.sort_key, keys_only=index.keys_only, key_schema=index.key_schema)

# More than one index. This can only happen if a user passed a PK that is used
# in multiple indexes. Frame a helpful error message.
sort_keys = [i.key_schema.sort_key for i in potential_indexes]
raise ValueError(
f"Table {self.table_name} can be queried using one of {str_schemas}. Got {tuple(key_data.keys())}")
f'Table {self.table_name} has multiple indexes with partition key \'{data_keys[0]}\'. ' +
f'Include one of these sort keys in your query {sort_keys} ' +
'with a value of UseThisIndex() to indicate which index you want to query')

def _validate_key(self, key):
if not self.key_schema.contains_both_keys(key):
Expand Down Expand Up @@ -729,6 +768,34 @@ def _validate_types(self, data, full):
if not validate_value_against_validator(value, validator):
raise ValueError(f'In {data}, value of {field} should be {validator} (got {value})')

def _validate_indexes_unambiguous(self):
"""From a list of Index objects, make sure there are no duplicate sets of the same PK and SK.
Also, there must not be an index with a PK that is a subset of an existing combination
of PK and SK, because we wouldn't be able to disambiguate between them.
"""
seen = set()
pk_of_compound = set()

# Add the table schema to begin with (we need to disambiguate with the table as well)
seen.add(tuple(self.key_schema.key_names))
if self.key_schema.is_compound:
pk_of_compound.add(self.key_schema.partition_key)

for index in self.indexes:
key_names = tuple(index.key_schema.key_names)
if key_names in seen:
raise ValueError(f'Table {self.table_name}: multiple indexes with the same key: {key_names}')

seen.add(key_names)
if index.key_schema.is_compound:
pk_of_compound.add(index.key_schema.partition_key)

for index in self.indexes:
if not index.key_schema.is_compound and index.key_schema.partition_key in pk_of_compound:
raise ValueError(
f'Table {self.table_name}: PK-only index is a subset of a compound index: {index.key_schema}')


def validate_value_against_validator(value, validator: 'Validator'):
"""Validate a value against a validator.
Expand Down Expand Up @@ -900,16 +967,24 @@ def _prep_query_data(self, key, sort_key=None, is_key_expression=True):
if is_key_expression:
validate_only_sort_key(conditions, sort_key)

escaped_names = {k: slugify(k) for k in conditions.keys()}

key_expression = " AND ".join(cond.to_dynamo_expression(escaped_names[field])
for field, cond in conditions.items())

escaped_names = {}
key_conditions = []
attr_values = {}
attr_names = {}

for field, cond in conditions.items():
attr_values.update(cond.to_dynamo_values(escaped_names[field]))
escaped_name = slugify(field)
expr = cond.to_dynamo_expression(escaped_name)
# This may return 'None' to avoid emitting this condition to DDB altogether
if expr is None:
continue

escaped_names[field] = escaped_name
key_conditions.append(expr)
attr_values.update(cond.to_dynamo_values(escaped_name))
attr_names[f'#{escaped_name}'] = field

attr_names = {f'#{e}': k for k, e in escaped_names.items()}
key_expression = " AND ".join(key_conditions)
return key_expression, attr_values, attr_names

def put(self, table_name, _key, data):
Expand Down Expand Up @@ -1410,7 +1485,13 @@ def make_conditions(key):


class Equals(DynamoCondition):
"""Assert that a value is equal to another value."""
"""Assert that a value is equal to another value.
Conditions can be applied to sort keys for efficient lookup, or as a
`server_side_filter` as a post-retrieval, pre-download filter. Queries will
never fetch more than 1MB from disk, so your server-side filter should
not filter out more than ~50% of the rows.
"""

def __init__(self, value):
self.value = value
Expand All @@ -1428,7 +1509,13 @@ def matches(self, value):


class Between(DynamoCondition):
"""Assert that a value is between two other values."""
"""Assert that a value is between two other values.
Conditions can be applied to sort keys for efficient lookup, or as a
`server_side_filter` as a post-retrieval, pre-download filter. Queries will
never fetch more than 1MB from disk, so your server-side filter should
not filter out more than ~50% of the rows.
"""

def __init__(self, minval, maxval):
self.minval = minval
Expand All @@ -1448,7 +1535,13 @@ def matches(self, value):


class BeginsWith(DynamoCondition):
"""Assert that a string begins with another string."""
"""Assert that a string begins with another string.
Conditions can be applied to sort keys for efficient lookup, or as a
`server_side_filter` as a post-retrieval, pre-download filter. Queries will
never fetch more than 1MB from disk, so your server-side filter should
not filter out more than ~50% of the rows.
"""

def __init__(self, prefix):
self.prefix = prefix
Expand All @@ -1465,6 +1558,35 @@ def matches(self, value):
return isinstance(value, str) and value.startswith(self.prefix)


class UseThisIndex(DynamoCondition):
"""A dummy condition that always matches, and allows picking a specific index.
If you have multiple indexes on the same primary key but with a different
sort key, you need a way to disambiguate between those indexes. I.e. you add
a field with a `UseThisIndex` to indicate which sort key you want to use.
In practice, it looks like this:
table.get_many({
"pk": "some_value",
"preferred_sortkey": UseThisIndex(),
})
"""

def __init__(self):
pass

def to_dynamo_expression(self, field_name):
# Dynamo does not support the expression "true", so we'll have to make exceptions
return None

def to_dynamo_values(self, field_name):
return {}

def matches(self, value):
return True


def replace_decimals(obj):
"""
Replace Decimals with native Python values.
Expand Down Expand Up @@ -1511,7 +1633,7 @@ def validate_only_sort_key(conds, sort_key):
"""Check that non-Equals conditions are only used on the sort key."""
non_equals_fields = [k for k, v in conds.items() if not isinstance(v, Equals)]
if sort_key and set(non_equals_fields) - {sort_key}:
raise RuntimeError(f"Non-Equals conditions only allowed on sort key {sort_key}, got: {list(conds)}")
raise ValueError(f"Non-Equals conditions only allowed on sort key {sort_key}, got: {list(conds)}")


def encode_page_token(x, inverted):
Expand Down

0 comments on commit baf2815

Please sign in to comment.