Skip to content

Commit

Permalink
refactor: use moto pagination decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
felixscherz committed Jan 20, 2025
1 parent 04b9fb7 commit fdafeb8
Showing 1 changed file with 35 additions and 96 deletions.
131 changes: 35 additions & 96 deletions moto/s3tables/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import datetime
import re
from base64 import b64decode, b64encode
from hashlib import md5
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Union
from uuid import uuid4

from moto.core.base_backend import BackendDict, BaseBackend
Expand All @@ -22,6 +21,7 @@
TableDoesNotExist,
VersionTokenMismatch,
)
from moto.utilities.paginator import paginate
from moto.utilities.utils import get_partition

# https://docs.aws.amazon.com/AmazonS3/latest/userguide/s3-tables-buckets-naming.html
Expand Down Expand Up @@ -57,6 +57,30 @@ def _validate_table_name(name: str) -> None:
S3TABLES_DEFAULT_MAX_NAMESPACES = 1000
S3TABLES_DEFAULT_MAX_TABLES = 1000

PAGINATION_MODEL = {
"list_table_buckets": {
"input_token": "continuation_token",
"limit_key": "max_buckets",
"limit_default": 1000,
"unique_attribute": ["arn"],
"fail_on_invalid_token": InvalidContinuationToken,
},
"list_namespaces": {
"input_token": "continuation_token",
"limit_key": "max_namespaces",
"limit_default": 1000,
"unique_attribute": ["name"],
"fail_on_invalid_token": InvalidContinuationToken,
},
"list_tables": {
"input_token": "continuation_token",
"limit_key": "max_tables",
"limit_default": 1000,
"unique_attribute": ["arn"],
"fail_on_invalid_token": InvalidContinuationToken,
},
}


class Table:
def __init__(
Expand Down Expand Up @@ -168,44 +192,18 @@ def create_table_bucket(self, name: str) -> FakeTableBucket:

return new_table_bucket

@paginate(pagination_model=PAGINATION_MODEL)
def list_table_buckets(
self,
prefix: Optional[str] = None,
continuation_token: Optional[str] = None,
max_buckets: Optional[int] = None,
) -> Tuple[List[FakeTableBucket], Optional[str]]:
if not max_buckets:
max_buckets = S3TABLES_DEFAULT_MAX_BUCKETS

) -> List[FakeTableBucket]:
all_buckets = list(
bucket
for bucket in self.table_buckets.values()
if (prefix is None or bucket.name.startswith(prefix))
)

# encode bucket arn in the continuation_token together with prefix value
# raise invalidcontinuationtoken if the prefix changed
if continuation_token:
# expect continuation token to be b64encoded
token_arn, token_prefix = (
b64decode(continuation_token.encode()).decode("utf-8").split("|")
)
if token_prefix and token_prefix != prefix:
raise InvalidContinuationToken()
last_bucket_index = list(b.arn for b in all_buckets).index(token_arn)
start = last_bucket_index + 1
else:
start = 0

buckets = all_buckets[start : start + max_buckets]

next_continuation_token = None
if start + max_buckets < len(all_buckets):
next_continuation_token = b64encode(
f"{buckets[-1].arn}|{prefix if prefix else ''}".encode()
).decode()

return buckets, next_continuation_token
return all_buckets

def get_table_bucket(self, table_bucket_arn: str) -> FakeTableBucket:
bucket = self.table_buckets.get(table_bucket_arn)
Expand Down Expand Up @@ -237,49 +235,21 @@ def create_namespace(self, table_bucket_arn: str, namespace: str) -> Namespace:
bucket.namespaces[ns.name] = ns
return ns

@paginate(pagination_model=PAGINATION_MODEL)
def list_namespaces(
self,
table_bucket_arn: str,
prefix: Optional[str] = None,
continuation_token: Optional[str] = None,
max_namespaces: Optional[int] = None,
) -> Tuple[List[Namespace], Optional[str]]:
) -> List[Namespace]:
bucket = self.get_table_bucket(table_bucket_arn)

if not max_namespaces:
max_namespaces = S3TABLES_DEFAULT_MAX_NAMESPACES

all_namespaces = list(
ns
for ns in bucket.namespaces.values()
if (prefix is None or ns.name.startswith(prefix))
)

# encode bucket arn in the continuation_token together with prefix value
# raise invalidcontinuationtoken if the prefix changed
if continuation_token:
# expect continuation token to be b64encoded
ns_name, table_bucket, token_prefix = (
b64decode(continuation_token.encode()).decode("utf-8").split("|")
)
if token_prefix and token_prefix != prefix:
raise InvalidContinuationToken()
if table_bucket != table_bucket_arn:
raise InvalidContinuationToken()
last_namespace_index = list(ns.name for ns in all_namespaces).index(ns_name)
start = last_namespace_index + 1
else:
start = 0

namespaces = all_namespaces[start : start + max_namespaces]

next_continuation_token = None
if start + max_namespaces < len(all_namespaces):
next_continuation_token = b64encode(
f"{namespaces[-1].name}|{table_bucket_arn}|{prefix if prefix else ''}".encode()
).decode()
# implement here
return namespaces, next_continuation_token
return all_namespaces

def get_namespace(self, table_bucket_arn: str, namespace: str) -> Namespace:
bucket = self.table_buckets.get(table_bucket_arn)
Expand Down Expand Up @@ -338,23 +308,19 @@ def get_table(self, table_bucket_arn: str, namespace: str, name: str) -> Table:
return bucket.namespaces[namespace].tables[name]
raise TableDoesNotExist()

@paginate(pagination_model=PAGINATION_MODEL)
def list_tables(
self,
table_bucket_arn: str,
namespace: Optional[str] = None,
prefix: Optional[str] = None,
continuation_token: Optional[str] = None,
max_tables: Optional[int] = None,
) -> Tuple[List[Table], Optional[str]]:
) -> List[Table]:
bucket = self.table_buckets.get(table_bucket_arn)
if not bucket or (namespace and namespace not in bucket.namespaces):
raise NotFoundException(
"The request was rejected because the specified resource could not be found."
)

if not max_tables:
max_tables = S3TABLES_DEFAULT_MAX_TABLES

if namespace:
all_tables = list(
table
Expand All @@ -369,34 +335,7 @@ def list_tables(
if (prefix is None or table.name.startswith(prefix))
)

# encode bucket arn in the continuation_token together with prefix value
# raise invalidcontinuationtoken if the prefix changed
if continuation_token:
# expect continuation token to be b64encoded
table_name, ns_name, table_bucket, token_prefix = (
b64decode(continuation_token.encode()).decode("utf-8").split("|")
)
if token_prefix and token_prefix != prefix:
raise InvalidContinuationToken()
if table_bucket != table_bucket_arn:
raise InvalidContinuationToken()
if ns_name and ns_name != namespace:
raise InvalidContinuationToken()
last_table_index = list(table.name for table in all_tables).index(
table_name
)
start = last_table_index + 1
else:
start = 0

tables = all_tables[start : start + max_tables]

next_continuation_token = None
if start + max_tables < len(all_tables):
next_continuation_token = b64encode(
f"{tables[-1].name}|{namespace if namespace else ''}|{table_bucket_arn}|{prefix if prefix else ''}".encode()
).decode()
return tables, next_continuation_token
return all_tables

def delete_table(
self,
Expand Down

0 comments on commit fdafeb8

Please sign in to comment.