diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 7cf89c4f3..404f4a5d7 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -48,7 +48,7 @@ def get_root_tree(): ) -def SecureEntry(scopes): +def SecureEntry(scopes, structure_families=None): async def inner( path: str, request: Request, @@ -116,7 +116,19 @@ async def inner( ) except NoEntry: raise HTTPException(status_code=404, detail=f"No such entry: {path_parts}") - return entry + # Fast path for the common successful case + if (structure_families is None) or ( + entry.structure_family in structure_families + ): + return entry + raise HTTPException( + status_code=404, + detail=( + f"The node at {path} has structure family {entry.structure_family} " + "and this endpoint is compatible with structure families " + f"{structure_families}" + ), + ) return Security(inner, scopes=scopes) diff --git a/tiled/server/router.py b/tiled/server/router.py index 62612e884..0d4c73a4c 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -347,7 +347,10 @@ async def metadata( ) async def array_block( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), block=Depends(block), slice=Depends(slice_), expected_shape=Depends(expected_shape), @@ -359,15 +362,7 @@ async def array_block( """ Fetch a chunk of array-like data. """ - if entry.structure_family == "array": - shape = entry.structure().shape - elif entry.structure_family == "sparse": - shape = entry.structure().shape - else: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /array/block route.", - ) + shape = entry.structure().shape # Check that block dimensionality matches array dimensionality. ndim = len(shape) if len(block) != ndim: @@ -406,10 +401,14 @@ async def array_block( "Use slicing ('?slice=...') to request smaller chunks." ), ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family try: with record_timing(request.state.metrics, "pack"): return await construct_data_response( - entry.structure_family, + structure_family, serialization_registry, array, entry.metadata(), @@ -429,7 +428,10 @@ async def array_block( ) async def array_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), slice=Depends(slice_), expected_shape=Depends(expected_shape), format: Optional[str] = None, @@ -440,12 +442,10 @@ async def array_full( """ Fetch a slice of array-like data. """ - structure_family = entry.structure_family - if structure_family not in {"array", "sparse"}: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /array/full route.", - ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family # Deferred import because this is not a required dependency of the server # for some use cases. import numpy @@ -453,7 +453,7 @@ async def array_full( try: with record_timing(request.state.metrics, "read"): array = await ensure_awaitable(entry.read, slice) - if structure_family == "array": + if structure_family == StructureFamily.array: array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. except IndexError: raise HTTPException(status_code=400, detail="Block index out of range") @@ -495,7 +495,7 @@ async def array_full( async def get_table_partition( request: Request, partition: int, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Query(None, min_length=1), field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), format: Optional[str] = None, @@ -543,7 +543,7 @@ async def get_table_partition( async def post_table_partition( request: Request, partition: int, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -578,11 +578,6 @@ async def table_partition( """ Fetch a partition (continuous block of rows) from a DataFrame. """ - if entry.structure_family != StructureFamily.table: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /table/partition route.", - ) try: # The singular/plural mismatch here of "fields" and "field" is # due to the ?field=A&field=B&field=C... encodes in a URL. @@ -626,7 +621,7 @@ async def table_partition( ) async def get_table_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -654,7 +649,7 @@ async def get_table_full( ) async def post_table_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -687,11 +682,6 @@ async def table_full( """ Fetch the data for the given table. """ - if entry.structure_family != StructureFamily.table: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /table/full route.", - ) try: with record_timing(request.state.metrics, "read"): data = await ensure_awaitable(entry.read, column) @@ -707,10 +697,14 @@ async def table_full( "request a smaller chunks." ), ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family try: with record_timing(request.state.metrics, "pack"): return await construct_data_response( - entry.structure_family, + structure_family, serialization_registry, data, entry.metadata(), @@ -732,7 +726,9 @@ async def table_full( ) async def get_container_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -760,7 +756,9 @@ async def get_container_full( ) async def post_container_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -793,11 +791,6 @@ async def container_full( """ Fetch the data for the given container. """ - if entry.structure_family != StructureFamily.container: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /container/full route.", - ) try: with record_timing(request.state.metrics, "read"): data = await ensure_awaitable(entry.read, fields=field) @@ -837,7 +830,10 @@ async def container_full( ) async def node_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.table, StructureFamily.container}, + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -900,7 +896,9 @@ async def node_full( ) async def get_awkward_buffers( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), form_key: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -936,7 +934,9 @@ async def get_awkward_buffers( async def post_awkward_buffers( request: Request, body: List[str], - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), @@ -974,11 +974,6 @@ async def _awkward_buffers( ): structure_family = entry.structure_family structure = entry.structure() - if structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /awkward/buffers route.", - ) with record_timing(request.state.metrics, "read"): # The plural vs. singular mismatch is due to the way query parameters # are given as ?form_key=A&form_key=B&form_key=C. @@ -1019,7 +1014,9 @@ async def _awkward_buffers( ) async def awkward_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), # slice=Depends(slice_), format: Optional[str] = None, filename: Optional[str] = None, @@ -1030,11 +1027,6 @@ async def awkward_full( Fetch a slice of AwkwardArray data. """ structure_family = entry.structure_family - if structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /awkward/full route.", - ) # Deferred import because this is not a required dependency of the server # for some use cases. import awkward @@ -1217,7 +1209,10 @@ async def bulk_delete( @router.put("/array/full/{path:path}") async def put_array_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), deserialization_registry=Depends(get_deserialization_registry), ): body = await request.body() @@ -1243,7 +1238,10 @@ async def put_array_full( @router.put("/array/block/{path:path}") async def put_array_block( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), deserialization_registry=Depends(get_deserialization_registry), block=Depends(block), ): @@ -1312,14 +1310,12 @@ async def put_table_partition( @router.put("/awkward/full/{path:path}") async def put_awkward_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.awkward} + ), deserialization_registry=Depends(get_deserialization_registry), ): body = await request.body() - if entry.structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, detail="This route is not applicable to this node." - ) if not hasattr(entry, "write"): raise HTTPException(status_code=405, detail="This node cannot be written to.") media_type = request.headers["content-type"]