diff --git a/tiled/_tests/test_writing.py b/tiled/_tests/test_writing.py index 83813c757..e499a1282 100644 --- a/tiled/_tests/test_writing.py +++ b/tiled/_tests/test_writing.py @@ -19,9 +19,11 @@ from ..client import Context, from_context, record_history from ..queries import Key from ..server.app import build_app -from ..structures.core import Spec +from ..structures.array import ArrayStructure +from ..structures.core import Spec, StructureFamily from ..structures.data_source import DataSource from ..structures.sparse import COOStructure +from ..structures.table import TableStructure from ..validation_registration import ValidationRegistry from .utils import fail_with_status_code @@ -451,3 +453,172 @@ async def test_container_export(tree): a.write_array([1, 2, 3], key="b") buffer = io.BytesIO() client.export(buffer, format="application/json") + + +def test_union_one_table(tree): + with Context.from_app(build_app(tree)) as context: + client = from_context(context) + df = pandas.DataFrame({"A": [], "B": []}) + structure = TableStructure.from_pandas(df) + data_source = DataSource( + structure_family=StructureFamily.table, + structure=structure, + name="table", + ) + client.create_union([data_source], key="x") + + +def test_union_two_tables(tree): + with Context.from_app(build_app(tree)) as context: + client = from_context(context) + df1 = pandas.DataFrame({"A": [], "B": []}) + df2 = pandas.DataFrame({"C": [], "D": [], "E": []}) + structure1 = TableStructure.from_pandas(df1) + structure2 = TableStructure.from_pandas(df2) + x = client.create_union( + [ + DataSource( + structure_family=StructureFamily.table, + structure=structure1, + name="table1", + ), + DataSource( + structure_family=StructureFamily.table, + structure=structure2, + name="table2", + ), + ], + key="x", + ) + x.parts["table1"].write(df1) + x.parts["table2"].write(df2) + x.parts["table1"].read() + x.parts["table2"].read() + + +def test_union_two_tables_colliding_names(tree): + with Context.from_app(build_app(tree)) as context: + client = from_context(context) + df1 = pandas.DataFrame({"A": [], "B": []}) + df2 = pandas.DataFrame({"C": [], "D": [], "E": []}) + structure1 = TableStructure.from_pandas(df1) + structure2 = TableStructure.from_pandas(df2) + with fail_with_status_code(422): + client.create_union( + [ + DataSource( + structure_family=StructureFamily.table, + structure=structure1, + name="table1", + ), + DataSource( + structure_family=StructureFamily.table, + structure=structure2, + name="table1", # collision + ), + ], + key="x", + ) + + +def test_union_two_tables_colliding_keys(tree): + with Context.from_app(build_app(tree)) as context: + client = from_context(context) + df1 = pandas.DataFrame({"A": [], "B": []}) + df2 = pandas.DataFrame({"A": [], "C": [], "D": []}) + structure1 = TableStructure.from_pandas(df1) + structure2 = TableStructure.from_pandas(df2) + with fail_with_status_code(422): + client.create_union( + [ + DataSource( + structure_family=StructureFamily.table, + structure=structure1, + name="table1", + ), + DataSource( + structure_family=StructureFamily.table, + structure=structure2, + name="table2", + ), + ], + key="x", + ) + + +def test_union_two_tables_two_arrays(tree): + with Context.from_app(build_app(tree)) as context: + client = from_context(context) + df1 = pandas.DataFrame({"A": [], "B": []}) + df2 = pandas.DataFrame({"C": [], "D": [], "E": []}) + arr1 = numpy.ones((5, 5), dtype=numpy.float64) + arr2 = 2 * numpy.ones((5, 5), dtype=numpy.int8) + structure1 = TableStructure.from_pandas(df1) + structure2 = TableStructure.from_pandas(df2) + structure3 = ArrayStructure.from_array(arr1) + structure4 = ArrayStructure.from_array(arr2) + x = client.create_union( + [ + DataSource( + structure_family=StructureFamily.table, + structure=structure1, + name="table1", + ), + DataSource( + structure_family=StructureFamily.table, + structure=structure2, + name="table2", + ), + DataSource( + structure_family=StructureFamily.array, + structure=structure3, + name="F", + ), + DataSource( + structure_family=StructureFamily.array, + structure=structure4, + name="G", + ), + ], + key="x", + ) + # Write by data source. + x.parts["table1"].write(df1) + x.parts["table2"].write(df2) + x.parts["F"].write_block(arr1, (0, 0)) + x.parts["G"].write_block(arr2, (0, 0)) + + # Read by data source. + x.parts["table1"].read() + x.parts["table2"].read() + x.parts["F"].read() + x.parts["G"].read() + + # Read by column. + for column in ["A", "B", "C", "D", "E", "F", "G"]: + x[column].read() + + +def test_union_table_column_array_key_collision(tree): + with Context.from_app(build_app(tree)) as context: + client = from_context(context) + df = pandas.DataFrame({"A": [], "B": []}) + arr = numpy.array([], dtype=numpy.float64) + structure1 = TableStructure.from_pandas(df) + structure2 = ArrayStructure.from_array(arr) + with fail_with_status_code(422): + client.create_union( + [ + DataSource( + structure_family=StructureFamily.table, + structure=structure1, + name="table", + ), + DataSource( + structure_family=StructureFamily.array, + structure=structure2, + name="B", + ), + ], + key="x", + ) diff --git a/tiled/adapters/parquet.py b/tiled/adapters/parquet.py index 9c6903bff..1f872e836 100644 --- a/tiled/adapters/parquet.py +++ b/tiled/adapters/parquet.py @@ -74,3 +74,6 @@ def read_partition(self, *args, **kwargs): def structure(self): return self._structure + + def get(self, key): + return self.dataframe_adapter.get(key) diff --git a/tiled/adapters/table.py b/tiled/adapters/table.py index 41b3ce742..6e2937626 100644 --- a/tiled/adapters/table.py +++ b/tiled/adapters/table.py @@ -80,6 +80,11 @@ def __getitem__(self, key): # Must compute to determine shape. return ArrayAdapter.from_array(self.read([key])[key].values) + def get(self, key): + if key not in self.structure().columns: + return None + return ArrayAdapter.from_array(self.read([key])[key].values) + def items(self): yield from ( (key, ArrayAdapter.from_array(self.read([key])[key].values)) diff --git a/tiled/catalog/adapter.py b/tiled/catalog/adapter.py index 9fc01cc28..358913140 100644 --- a/tiled/catalog/adapter.py +++ b/tiled/catalog/adapter.py @@ -42,6 +42,8 @@ ZARR_MIMETYPE, ) from ..query_registration import QueryTranslationRegistry +from ..server.pydantic_container import ContainerStructure +from ..server.pydantic_union import UnionStructure, UnionStructurePart from ..server.schemas import Asset, DataSource, Management, Revision, Spec from ..structures.core import StructureFamily from ..utils import ( @@ -257,6 +259,8 @@ def __init__( context, node, *, + structure_family=None, + data_sources=None, conditions=None, queries=None, sorting=None, @@ -274,13 +278,18 @@ def __init__( self.order_by_clauses = order_by_clauses(self.sorting) self.conditions = conditions or [] self.queries = queries or [] - self.structure_family = node.structure_family self.specs = [Spec.parse_obj(spec) for spec in node.specs] self.ancestors = node.ancestors self.key = node.key self.access_policy = access_policy self.startup_tasks = [self.startup] self.shutdown_tasks = [self.shutdown] + self.structure_family = structure_family or node.structure_family + if data_sources is None: + data_sources = [ + DataSource.from_orm(ds) for ds in (self.node.data_sources or []) + ] + self.data_sources = data_sources def metadata(self): return self.node.metadata_ @@ -319,10 +328,6 @@ async def __aiter__(self): async with self.context.session() as db: return (await db.execute(statement)).scalar().all() - @property - def data_sources(self): - return [DataSource.from_orm(ds) for ds in self.node.data_sources or []] - async def asset_by_id(self, asset_id): statement = ( select(orm.Asset) @@ -344,6 +349,25 @@ async def asset_by_id(self, asset_id): return Asset.from_orm(asset) def structure(self): + if self.structure_family == StructureFamily.container: + # Give no inlined contents. + return ContainerStructure(contents=None, count=None) + if self.structure_family == StructureFamily.union: + parts = [] + all_keys = [] + for data_source in self.data_sources: + parts.append( + UnionStructurePart( + structure=data_source.structure, + structure_family=data_source.structure_family, + name=data_source.name, + ) + ) + if data_source.structure_family == StructureFamily.table: + all_keys.extend(data_source.structure.columns) + else: + all_keys.append(data_source.name) + return UnionStructure(parts=parts, all_keys=all_keys) if self.data_sources: assert len(self.data_sources) == 1 # more not yet implemented return self.data_sources[0].structure @@ -359,7 +383,8 @@ async def async_len(self): return (await db.execute(statement)).scalar_one() async def lookup_adapter( - self, segments + self, + segments, ): # TODO: Accept filter for predicate-pushdown. if not segments: return self @@ -399,6 +424,13 @@ async def lookup_adapter( for i in range(len(segments)): catalog_adapter = await self.lookup_adapter(segments[:i]) + if (catalog_adapter.structure_family == StructureFamily.union) and len( + segments[i:] + ) == 1: + # All the segments but the final segment, segments[-1], resolves + # resolve to a union structure. Dispatch to the union Adapter + # to get the inner Adapter for whatever type of structure it is. + return await ensure_awaitable(catalog_adapter.get, segments[-1]) if catalog_adapter.data_sources: adapter = await catalog_adapter.get_adapter() for segment in segments[i:]: @@ -408,66 +440,60 @@ async def lookup_adapter( return adapter return None return STRUCTURES[node.structure_family]( - self.context, node, access_policy=self.access_policy + self.context, + node, + access_policy=self.access_policy, ) async def get_adapter(self): - num_data_sources = len(self.data_sources) - if num_data_sources > 1: - raise NotImplementedError - if num_data_sources == 1: - (data_source,) = self.data_sources - try: - adapter_factory = self.context.adapters_by_mimetype[ - data_source.mimetype - ] - except KeyError: - raise RuntimeError( - f"Server configuration has no adapter for mimetype {data_source.mimetype!r}" + (data_source,) = self.data_sources + try: + adapter_factory = self.context.adapters_by_mimetype[data_source.mimetype] + except KeyError: + raise RuntimeError( + f"Server configuration has no adapter for mimetype {data_source.mimetype!r}" + ) + parameters = collections.defaultdict(list) + for asset in data_source.assets: + if asset.parameter is None: + continue + scheme = urlparse(asset.data_uri).scheme + if scheme != "file": + raise NotImplementedError( + f"Only 'file://...' scheme URLs are currently supported, not {asset.data_uri}" ) - parameters = collections.defaultdict(list) - for asset in data_source.assets: - if asset.parameter is None: - continue - scheme = urlparse(asset.data_uri).scheme - if scheme != "file": - raise NotImplementedError( - f"Only 'file://...' scheme URLs are currently supported, not {asset.data_uri}" - ) - if scheme == "file": - # Protect against misbehaving clients reading from unintended - # parts of the filesystem. - asset_path = path_from_uri(asset.data_uri) - for readable_storage in self.context.readable_storage: - if Path( - os.path.commonpath( - [path_from_uri(readable_storage), asset_path] - ) - ) == path_from_uri(readable_storage): - break - else: - raise RuntimeError( - f"Refusing to serve {asset.data_uri} because it is outside " - "the readable storage area for this server." + if scheme == "file": + # Protect against misbehaving clients reading from unintended + # parts of the filesystem. + asset_path = path_from_uri(asset.data_uri) + for readable_storage in self.context.readable_storage: + if Path( + os.path.commonpath( + [path_from_uri(readable_storage), asset_path] ) - if asset.num is None: - parameters[asset.parameter] = asset.data_uri + ) == path_from_uri(readable_storage): + break else: - parameters[asset.parameter].append(asset.data_uri) - adapter_kwargs = dict(parameters) - adapter_kwargs.update(data_source.parameters) - adapter_kwargs["specs"] = self.node.specs - adapter_kwargs["metadata"] = self.node.metadata_ - adapter_kwargs["structure"] = data_source.structure - adapter_kwargs["access_policy"] = self.access_policy - adapter = await anyio.to_thread.run_sync( - partial(adapter_factory, **adapter_kwargs) - ) - for query in self.queries: - adapter = adapter.search(query) - return adapter - else: # num_data_sources == 0 - assert False + raise RuntimeError( + f"Refusing to serve {asset.data_uri} because it is outside " + "the readable storage area for this server." + ) + if asset.num is None: + parameters[asset.parameter] = asset.data_uri + else: + parameters[asset.parameter].append(asset.data_uri) + adapter_kwargs = dict(parameters) + adapter_kwargs.update(data_source.parameters) + adapter_kwargs["specs"] = self.node.specs + adapter_kwargs["metadata"] = self.node.metadata_ + adapter_kwargs["structure"] = data_source.structure + adapter_kwargs["access_policy"] = self.access_policy + adapter = await anyio.to_thread.run_sync( + partial(adapter_factory, **adapter_kwargs) + ) + for query in self.queries: + adapter = adapter.search(query) + return adapter def new_variation( self, @@ -597,12 +623,17 @@ async def create_node( if data_source.management != Management.external: if structure_family == StructureFamily.container: raise NotImplementedError(structure_family) - data_source.mimetype = DEFAULT_CREATION_MIMETYPE[structure_family] + data_source.mimetype = DEFAULT_CREATION_MIMETYPE[ + data_source.structure_family + ] data_source.parameters = {} + data_uri_path_parts = self.segments + [key] + if structure_family == StructureFamily.union: + data_uri_path_parts.append(data_source.name) data_uri = str(self.context.writable_storage) + "".join( - f"/{quote_plus(segment)}" for segment in (self.segments + [key]) + f"/{quote_plus(segment)}" for segment in data_uri_path_parts ) - init_storage = DEFAULT_INIT_STORAGE[structure_family] + init_storage = DEFAULT_INIT_STORAGE[data_source.structure_family] assets = await ensure_awaitable( init_storage, data_uri, data_source.structure ) @@ -622,7 +653,7 @@ async def create_node( # Obtain and hash the canonical (RFC 8785) representation of # the JSON structure. structure = _prepare_structure( - structure_family, data_source.structure + data_source.structure_family, data_source.structure ) structure_id = compute_structure_id(structure) # The only way to do "insert if does not exist" i.e. ON CONFLICT @@ -642,6 +673,7 @@ async def create_node( await db.execute(statement) data_source_orm = orm.DataSource( structure_family=data_source.structure_family, + name=data_source.name, mimetype=data_source.mimetype, management=data_source.management, parameters=data_source.parameters, @@ -952,6 +984,9 @@ class CatalogSparseAdapter(CatalogArrayAdapter): class CatalogTableAdapter(CatalogNodeAdapter): + async def get(self, *args, **kwargs): + return (await self.get_adapter()).get(*args, **kwargs) + async def read(self, *args, **kwargs): return await ensure_awaitable((await self.get_adapter()).read, *args, **kwargs) @@ -969,6 +1004,34 @@ async def write_partition(self, *args, **kwargs): ) +class CatalogUnionAdapter(CatalogNodeAdapter): + async def get(self, key): + if key not in self.structure().all_keys: + return None + for data_source in self.data_sources: + if data_source.structure_family == StructureFamily.table: + if key in data_source.structure.columns: + return await ensure_awaitable( + self.for_part(data_source.name).get, key + ) + if key == data_source.name: + return self.for_part(data_source.name) + + def for_part(self, name): + for data_source in self.data_sources: + if name == data_source.name: + break + else: + raise ValueError(f"No DataSource named {name} on this node") + return STRUCTURES[data_source.structure_family]( + self.context, + self.node, + access_policy=self.access_policy, + structure_family=data_source.structure_family, + data_sources=[data_source], + ) + + def delete_asset(data_uri, is_directory): url = urlparse(data_uri) if url.scheme == "file": @@ -1307,9 +1370,10 @@ def specs_array_to_json(specs): STRUCTURES = { - StructureFamily.container: CatalogContainerAdapter, StructureFamily.array: CatalogArrayAdapter, StructureFamily.awkward: CatalogAwkwardAdapter, - StructureFamily.table: CatalogTableAdapter, + StructureFamily.container: CatalogContainerAdapter, StructureFamily.sparse: CatalogSparseAdapter, + StructureFamily.table: CatalogTableAdapter, + StructureFamily.union: CatalogUnionAdapter, } diff --git a/tiled/catalog/core.py b/tiled/catalog/core.py index 6c9869e7d..e085ef4f5 100644 --- a/tiled/catalog/core.py +++ b/tiled/catalog/core.py @@ -5,6 +5,8 @@ # This is list of all valid revisions (from current to oldest). ALL_REVISIONS = [ + "0dc110294112", + "7c8130c40b8f", "e756b9381c14", "2ca16566d692", "1cd99c02d0c7", diff --git a/tiled/catalog/migrations/versions/0dc110294112_add_union_to_structure_family_enum.py b/tiled/catalog/migrations/versions/0dc110294112_add_union_to_structure_family_enum.py new file mode 100644 index 000000000..bf83c8baa --- /dev/null +++ b/tiled/catalog/migrations/versions/0dc110294112_add_union_to_structure_family_enum.py @@ -0,0 +1,33 @@ +"""Add 'union' to structure_family enum. + +Revision ID: 0dc110294112 +Revises: 7c8130c40b8f +Create Date: 2024-02-23 09:13:23.658921 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0dc110294112" +down_revision = "7c8130c40b8f" +branch_labels = None +depends_on = None + + +def upgrade(): + connection = op.get_bind() + + if connection.engine.dialect.name == "postgresql": + with op.get_context().autocommit_block(): + op.execute( + sa.text( + "ALTER TYPE structurefamily ADD VALUE IF NOT EXISTS 'union' AFTER 'table'" + ) + ) + + +def downgrade(): + # This _could_ be implemented but we will wait for a need since we are + # still in alpha releases. + raise NotImplementedError diff --git a/tiled/catalog/migrations/versions/7c8130c40b8f_add_name_column_to_data_sources_table.py b/tiled/catalog/migrations/versions/7c8130c40b8f_add_name_column_to_data_sources_table.py new file mode 100644 index 000000000..6bc1467be --- /dev/null +++ b/tiled/catalog/migrations/versions/7c8130c40b8f_add_name_column_to_data_sources_table.py @@ -0,0 +1,25 @@ +"""Add 'name' column to data_sources table. + +Revision ID: 7c8130c40b8f +Revises: e756b9381c14 +Create Date: 2024-02-23 08:53:24.008576 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "7c8130c40b8f" +down_revision = "e756b9381c14" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column("data_sources", sa.Column("name", sa.Unicode(1023), nullable=True)) + + +def downgrade(): + # This _could_ be implemented but we will wait for a need since we are + # still in alpha releases. + raise NotImplementedError diff --git a/tiled/catalog/orm.py b/tiled/catalog/orm.py index f2ec818c0..c29eaf286 100644 --- a/tiled/catalog/orm.py +++ b/tiled/catalog/orm.py @@ -291,7 +291,6 @@ class DataSource(Timestamped, Base): node_id = Column( Integer, ForeignKey("nodes.id", ondelete="CASCADE"), nullable=False ) - structure_family = Column(Enum(StructureFamily), nullable=False) structure_id = Column( Unicode(32), ForeignKey("structures.id", ondelete="CASCADE"), nullable=True ) @@ -301,6 +300,10 @@ class DataSource(Timestamped, Base): parameters = Column(JSONVariant, nullable=True) # This relates to the mutability of the data. management = Column(Enum(Management), nullable=False) + structure_family = Column(Enum(StructureFamily), nullable=False) + # This is used by `union` structures to address arrays. + # It may have additional uses in the future. + name = Column(Unicode(1023), nullable=True) # many-to-one relationship to Structure structure: Mapped["Structure"] = relationship( diff --git a/tiled/client/base.py b/tiled/client/base.py index 80a1550d7..ffe3daab8 100644 --- a/tiled/client/base.py +++ b/tiled/client/base.py @@ -109,13 +109,10 @@ def __init__( self._metadata_revisions = None self._include_data_sources = include_data_sources attributes = self.item["attributes"] - structure_family = attributes["structure_family"] if structure is not None: # Allow the caller to optionally hand us a structure that is already # parsed from a dict into a structure dataclass. self._structure = structure - elif structure_family == StructureFamily.container: - self._structure = None else: structure_type = STRUCTURE_TYPES[attributes["structure_family"]] self._structure = structure_type.from_json(attributes["structure"]) @@ -435,11 +432,17 @@ def delete_tree(self): StructureFamily.awkward: lambda: importlib.import_module( "...structures.awkward", BaseClient.__module__ ).AwkwardStructure, - StructureFamily.table: lambda: importlib.import_module( - "...structures.table", BaseClient.__module__ - ).TableStructure, + StructureFamily.container: lambda: importlib.import_module( + "...structures.container", BaseClient.__module__ + ).ContainerStructure, StructureFamily.sparse: lambda: importlib.import_module( "...structures.sparse", BaseClient.__module__ ).SparseStructure, + StructureFamily.table: lambda: importlib.import_module( + "...structures.table", BaseClient.__module__ + ).TableStructure, + StructureFamily.union: lambda: importlib.import_module( + "...structures.union", BaseClient.__module__ + ).UnionStructure, } ) diff --git a/tiled/client/constructors.py b/tiled/client/constructors.py index b9b7ee385..8a035c89d 100644 --- a/tiled/client/constructors.py +++ b/tiled/client/constructors.py @@ -150,11 +150,14 @@ def from_context( and (context.http_client.auth is None) ): context.authenticate() + params = {} + if include_data_sources: + params["include_data_sources"] = True content = handle_error( context.http_client.get( item_uri, headers={"Accept": MSGPACK_MIME_TYPE}, - params={"include_data_sources": include_data_sources}, + params=params, ) ).json() else: diff --git a/tiled/client/container.py b/tiled/client/container.py index 2f9d99a0a..87ad53113 100644 --- a/tiled/client/container.py +++ b/tiled/client/container.py @@ -161,9 +161,8 @@ def __len__(self): # If the contents of this node was provided in-line, there is an # implication that the contents are not expected to be dynamic. Used the # count provided in the structure. - structure = self.item["attributes"]["structure"] - if structure["contents"]: - return structure["count"] + if self.structure().count is not None: + return self.structure().count now = time.monotonic() if self._cached_len is not None: length, deadline = self._cached_len @@ -194,14 +193,15 @@ def __iter__(self, _ignore_inlined_contents=False): # If the contents of this node was provided in-line, and we don't need # to apply any filtering or sorting, we can slice the in-lined data # without fetching anything from the server. - contents = self.item["attributes"]["structure"]["contents"] + structure = self.structure() if ( - (contents is not None) + structure + and structure.contents and (not self._queries) and ((not self.sorting) or (self.sorting == [("_", 1)])) and (not _ignore_inlined_contents) ): - return (yield from contents) + return (yield from structure.contents) next_page_url = self.item["links"]["search"] while next_page_url is not None: content = handle_error( @@ -249,16 +249,18 @@ def __getitem__(self, keys, _ignore_inlined_contents=False): # Lookup this key *within the search results* of this Node. key, *tail = keys tail = tuple(tail) # list -> tuple + params = { + **_queries_to_params(KeyLookup(key)), + **self._queries_as_params, + **self._sorting_params, + } + if self._include_data_sources: + params["include_data_sources"] = True content = handle_error( self.context.http_client.get( self.item["links"]["search"], headers={"Accept": MSGPACK_MIME_TYPE}, - params={ - "include_data_sources": self._include_data_sources, - **_queries_to_params(KeyLookup(key)), - **self._queries_as_params, - **self._sorting_params, - }, + params=params, ) ).json() self._cached_len = ( @@ -296,7 +298,8 @@ def __getitem__(self, keys, _ignore_inlined_contents=False): # to the node of interest without downloading information about # intermediate parents. for i, key in enumerate(keys): - item = (self.item["attributes"]["structure"]["contents"] or {}).get(key) + structure = self.structure() + item = (structure.contents or {}).get(key) if (item is None) or _ignore_inlined_contents: # The item was not inlined, either because nothing was inlined # or because it was added after we fetched the inlined contents. @@ -305,13 +308,14 @@ def __getitem__(self, keys, _ignore_inlined_contents=False): self_link = self.item["links"]["self"] if self_link.endswith("/"): self_link = self_link[:-1] + params = {} + if self._include_data_sources: + params["include_data_sources"] = True content = handle_error( self.context.http_client.get( self_link + "".join(f"/{key}" for key in keys[i:]), headers={"Accept": MSGPACK_MIME_TYPE}, - params={ - "include_data_sources": self._include_data_sources - }, + params=params, ) ).json() except ClientError as err: @@ -413,15 +417,17 @@ def _items_slice(self, start, stop, direction, _ignore_inlined_contents=False): next_page_url = f"{self.item['links']['search']}?page[offset]={start}" item_counter = itertools.count(start) while next_page_url is not None: + params = { + **self._queries_as_params, + **sorting_params, + } + if self._include_data_sources: + params["include_data_sources"] = True content = handle_error( self.context.http_client.get( next_page_url, headers={"Accept": MSGPACK_MIME_TYPE}, - params={ - "include_data_sources": self._include_data_sources, - **self._queries_as_params, - **sorting_params, - }, + params=params, ) ).json() self._cached_len = ( @@ -620,8 +626,11 @@ def new( ).json() if structure_family == StructureFamily.container: structure = {"contents": None, "count": None} + elif structure_family == StructureFamily.union: + structure = None + # To be filled in below, by server response. + # We need the server to tell us data_source_ids. else: - # Only containers can have multiple data_sources right now. (data_source,) = data_sources structure = data_source.structure item["attributes"]["structure"] = structure @@ -631,7 +640,7 @@ def new( item["attributes"]["metadata"] = document.pop("metadata") # Ditto for structure if "structure" in document: - item["attributes"]["structure"] = STRUCTURE_TYPES[structure_family]( + structure = STRUCTURE_TYPES[structure_family].from_json( document.pop("structure") ) @@ -676,6 +685,31 @@ def create_container(self, key=None, *, metadata=None, dims=None, specs=None): specs=specs, ) + def create_union(self, data_sources, key=None, *, metadata=None, specs=None): + """ + EXPERIMENTAL: Create a new union backed by data sources. + + Parameters + ---------- + data_sources : List[DataSources] + metadata : dict, optional + User metadata. May be nested. Must contain only basic types + (e.g. numbers, strings, lists, dicts) that are JSON-serializable. + dims : List[str], optional + A label for each dimension of the array. + specs : List[Spec], optional + List of names that are used to label that the data and/or metadata + conform to some named standard specification. + + """ + return self.new( + StructureFamily.union, + data_sources, + key=key, + metadata=metadata, + specs=specs, + ) + def write_array(self, array, *, key=None, metadata=None, dims=None, specs=None): """ EXPERIMENTAL: Write an array. @@ -1018,6 +1052,7 @@ def __call__(self): "table": _LazyLoad( ("..dataframe", Container.__module__), "DataFrameClient" ), + "union": _LazyLoad(("..union", Container.__module__), "UnionClient"), "xarray_dataset": _LazyLoad( ("..xarray", Container.__module__), "DatasetClient" ), @@ -1036,6 +1071,7 @@ def __call__(self): "table": _LazyLoad( ("..dataframe", Container.__module__), "DaskDataFrameClient" ), + "union": _LazyLoad(("..union", Container.__module__), "UnionClient"), "xarray_dataset": _LazyLoad( ("..xarray", Container.__module__), "DaskDatasetClient" ), diff --git a/tiled/client/union.py b/tiled/client/union.py new file mode 100644 index 000000000..6d313a813 --- /dev/null +++ b/tiled/client/union.py @@ -0,0 +1,80 @@ +import copy + +from .base import STRUCTURE_TYPES, BaseClient +from .utils import MSGPACK_MIME_TYPE, ClientError, client_for_item, handle_error + + +class UnionClient(BaseClient): + def __repr__(self): + return ( + f"<{type(self).__name__} {{" + + ", ".join(f"'{key}'" for key in self.structure().all_keys) + + "}>" + ) + + @property + def parts(self): + return UnionContents(self) + + def __getitem__(self, key): + if key not in self.structure().all_keys: + raise KeyError(key) + try: + self_link = self.item["links"]["self"] + if self_link.endswith("/"): + self_link = self_link[:-1] + params = {} + if self._include_data_sources: + params["include_data_sources"] = True + content = handle_error( + self.context.http_client.get( + f"{self_link}/{key}", + headers={"Accept": MSGPACK_MIME_TYPE}, + params=params, + ) + ).json() + except ClientError as err: + if err.response.status_code == 404: + raise KeyError(key) + raise + item = content["data"] + return client_for_item( + self.context, + self.structure_clients, + item, + include_data_sources=self._include_data_sources, + ) + + +class UnionContents: + def __init__(self, node): + self.node = node + + def __repr__(self): + return ( + f"<{type(self).__name__} {{" + + ", ".join(f"'{item.name}'" for item in self.node.structure().parts) + + "}>" + ) + + def __getitem__(self, name): + for index, union_item in enumerate(self.node.structure().parts): + if union_item.name == name: + structure_family = union_item.structure_family + structure_dict = union_item.structure + break + else: + raise KeyError(name) + item = copy.deepcopy(self.node.item) + item["attributes"]["structure_family"] = structure_family + item["attributes"]["structure"] = structure_dict + item["links"] = item["links"]["parts"][index] + structure_type = STRUCTURE_TYPES[structure_family] + structure = structure_type.from_json(structure_dict) + return client_for_item( + self.node.context, + self.node.structure_clients, + item, + structure=structure, + include_data_sources=self.node._include_data_sources, + ) diff --git a/tiled/serialization/table.py b/tiled/serialization/table.py index b205a230c..541375bb2 100644 --- a/tiled/serialization/table.py +++ b/tiled/serialization/table.py @@ -43,6 +43,13 @@ def serialize_csv(df, metadata, preserve_index=False): return file.getvalue().encode() +@deserialization_registry.register(StructureFamily.table, "text/csv") +def deserialize_csv(buffer): + import pandas + + return pandas.read_csv(io.BytesIO(buffer), headers=False) + + serialization_registry.register(StructureFamily.table, "text/csv", serialize_csv) serialization_registry.register( StructureFamily.table, "text/x-comma-separated-values", serialize_csv diff --git a/tiled/server/core.py b/tiled/server/core.py index 857d547a0..10e0d93d6 100644 --- a/tiled/server/core.py +++ b/tiled/server/core.py @@ -34,6 +34,7 @@ ) from . import schemas from .etag import tokenize +from .links import links_for_node from .utils import record_timing del queries @@ -404,6 +405,7 @@ async def construct_resource( depth=0, ): path_str = "/".join(path_parts) + id_ = path_parts[-1] if path_parts else "" attributes = {"ancestors": path_parts[:-1]} if include_data_sources and hasattr(entry, "data_sources"): attributes["data_sources"] = entry.data_sources @@ -488,15 +490,16 @@ async def construct_resource( for key, direction in entry.sorting ] d = { - "id": path_parts[-1] if path_parts else "", + "id": id_, "attributes": schemas.NodeAttributes(**attributes), } if not omit_links: - d["links"] = { - "self": f"{base_url}/metadata/{path_str}", - "search": f"{base_url}/search/{path_str}", - "full": f"{base_url}/container/full/{path_str}", - } + d["links"] = links_for_node( + entry.structure_family, + entry.structure(), + base_url, + path_str, + ) resource = schemas.Resource[ schemas.NodeAttributes, schemas.ContainerLinks, schemas.ContainerMeta @@ -510,34 +513,16 @@ async def construct_resource( entry.structure_family ] links.update( - { - link: template.format(base_url=base_url, path=path_str) - for link, template in FULL_LINKS[entry.structure_family].items() - } + links_for_node( + entry.structure_family, + entry.structure(), + base_url, + path_str, + ) ) structure = asdict(entry.structure()) if schemas.EntryFields.structure_family in fields: attributes["structure_family"] = entry.structure_family - if entry.structure_family == StructureFamily.sparse: - shape = structure.get("shape") - block_template = ",".join(f"{{{index}}}" for index in range(len(shape))) - links[ - "block" - ] = f"{base_url}/array/block/{path_str}?block={block_template}" - elif entry.structure_family == StructureFamily.array: - shape = structure.get("shape") - block_template = ",".join( - f"{{index_{index}}}" for index in range(len(shape)) - ) - links[ - "block" - ] = f"{base_url}/array/block/{path_str}?block={block_template}" - elif entry.structure_family == StructureFamily.table: - links[ - "partition" - ] = f"{base_url}/table/partition/{path_str}?partition={{index}}" - elif entry.structure_family == StructureFamily.awkward: - links["buffers"] = f"{base_url}/awkward/buffers/{path_str}" if schemas.EntryFields.structure in fields: attributes["structure"] = structure else: @@ -719,15 +704,6 @@ class WrongTypeForRoute(Exception): pass -FULL_LINKS = { - StructureFamily.array: {"full": "{base_url}/array/full/{path}"}, - StructureFamily.awkward: {"full": "{base_url}/awkward/full/{path}"}, - StructureFamily.container: {"full": "{base_url}/container/full/{path}"}, - StructureFamily.table: {"full": "{base_url}/table/full/{path}"}, - StructureFamily.sparse: {"full": "{base_url}/array/full/{path}"}, -} - - def asdict(dc): "Compat for converting dataclass or pydantic.BaseModel to dict." if dc is None: diff --git a/tiled/server/dependencies.py b/tiled/server/dependencies.py index 7cf89c4f3..db0a7fb02 100644 --- a/tiled/server/dependencies.py +++ b/tiled/server/dependencies.py @@ -11,6 +11,7 @@ serialization_registry as default_serialization_registry, ) from ..query_registration import query_registry as default_query_registry +from ..structures.core import StructureFamily from ..validation_registration import validation_registry as default_validation_registry from .authentication import get_current_principal, get_session_state from .core import NoEntry @@ -48,10 +49,11 @@ def get_root_tree(): ) -def SecureEntry(scopes): +def SecureEntry(scopes, structure_families=None): async def inner( path: str, request: Request, + part: Optional[str] = None, principal: str = Depends(get_current_principal), root_tree: pydantic.BaseSettings = Depends(get_root_tree), session_state: dict = Depends(get_session_state), @@ -116,7 +118,41 @@ async def inner( ) except NoEntry: raise HTTPException(status_code=404, detail=f"No such entry: {path_parts}") - return entry + # Fast path for the common successful case + if (structure_families is None) or ( + entry.structure_family in structure_families + ): + return entry + # Handle union structure_family + if entry.structure_family == StructureFamily.union: + if not part: + raise HTTPException( + status_code=400, + detail=( + "A part query parameter is required on this endpoint " + "when addressing a 'union' structure." + ), + ) + entry_for_part = entry.for_part(part) + if entry_for_part.structure_family in structure_families: + return entry_for_part + raise HTTPException( + status_code=404, + detail=( + f"The data source named {part} backing the node " + f"at {path} has structure family {entry_for_part.structure_family} " + "and this endpoint is compatible only with structure families " + f"{structure_families}" + ), + ) + raise HTTPException( + status_code=404, + detail=( + f"The node at {path} has structure family {entry.structure_family} " + "and this endpoint is compatible only with structure families " + f"{structure_families}" + ), + ) return Security(inner, scopes=scopes) diff --git a/tiled/server/links.py b/tiled/server/links.py new file mode 100644 index 000000000..ff186ebbc --- /dev/null +++ b/tiled/server/links.py @@ -0,0 +1,81 @@ +""" +Generate the 'links' section of the response JSON. + +The links vary by structure family. +""" +from ..structures.core import StructureFamily + + +def links_for_node(structure_family, structure, base_url, path_str): + links = {} + links = LINKS_BY_STRUCTURE_FAMILY[structure_family]( + structure_family, structure, base_url, path_str + ) + links["self"] = f"{base_url}/metadata/{path_str}" + return links + + +def links_for_array(structure_family, structure, base_url, path_str, part=None): + links = {} + block_template = ",".join(f"{{{index}}}" for index in range(len(structure.shape))) + links["block"] = f"{base_url}/array/block/{path_str}?block={block_template}" + links["full"] = f"{base_url}/array/full/{path_str}" + if part: + links["block"] += f"&part={part}" + links["full"] += f"?part={part}" + return links + + +def links_for_awkward(structure_family, structure, base_url, path_str, part=None): + links = {} + links["buffers"] = f"{base_url}/awkward/buffers/{path_str}" + links["full"] = f"{base_url}/awkward/full/{path_str}" + if part: + links["buffers"] += "?part={part}" + links["full"] += "?part={part}" + return links + + +def links_for_container(structure_family, structure, base_url, path_str): + # Cannot be used inside union, so there is no part parameter. + links = {} + links["full"] = f"{base_url}/container/full/{path_str}" + links["search"] = f"{base_url}/search/{path_str}" + return links + + +def links_for_table(structure_family, structure, base_url, path_str, part=None): + links = {} + links["partition"] = f"{base_url}/table/partition/{path_str}?partition={{index}}" + links["full"] = f"{base_url}/table/full/{path_str}" + if part: + links["partition"] += f"&part={part}" + links["full"] += f"?part={part}" + return links + + +def links_for_union(structure_family, structure, base_url, path_str): + links = {} + # This contains the links for each structure. + links["parts"] = [] + for item in structure.parts: + item_links = LINKS_BY_STRUCTURE_FAMILY[item.structure_family]( + item.structure_family, + item.structure, + base_url, + path_str, + part=item.name, + ) + item_links["self"] = f"{base_url}/metadata/{path_str}" + links["parts"].append(item_links) + return links + + +LINKS_BY_STRUCTURE_FAMILY = { + StructureFamily.array: links_for_array, + StructureFamily.awkward: links_for_awkward, + StructureFamily.container: links_for_container, + StructureFamily.sparse: links_for_array, # sparse and array are the same + StructureFamily.table: links_for_table, + StructureFamily.union: links_for_union, +} diff --git a/tiled/server/pydantic_container.py b/tiled/server/pydantic_container.py new file mode 100644 index 000000000..58fe4b3b9 --- /dev/null +++ b/tiled/server/pydantic_container.py @@ -0,0 +1,8 @@ +from typing import Optional + +import pydantic + + +class ContainerStructure(pydantic.BaseModel): + contents: Optional[dict] + count: Optional[int] diff --git a/tiled/server/pydantic_union.py b/tiled/server/pydantic_union.py new file mode 100644 index 000000000..7d13645df --- /dev/null +++ b/tiled/server/pydantic_union.py @@ -0,0 +1,27 @@ +from typing import Any, List, Optional + +import pydantic + +from ..structures.core import StructureFamily + + +class UnionStructurePart(pydantic.BaseModel): + structure_family: StructureFamily + structure: Any # Union of Structures, but we do not want to import them... + name: str + + @classmethod + def from_json(cls, item): + return cls(**item) + + +class UnionStructure(pydantic.BaseModel): + parts: List[UnionStructurePart] + all_keys: Optional[List[str]] + + @classmethod + def from_json(cls, structure): + return cls( + parts=[UnionStructurePart.from_json(item) for item in structure["parts"]], + all_keys=structure["all_keys"], + ) diff --git a/tiled/server/router.py b/tiled/server/router.py index 9d7db130b..25df1820f 100644 --- a/tiled/server/router.py +++ b/tiled/server/router.py @@ -14,6 +14,7 @@ from starlette.responses import FileResponse from .. import __version__ +from ..server.pydantic_union import UnionStructure, UnionStructurePart from ..structures.core import StructureFamily from ..utils import ensure_awaitable, path_from_uri from ..validation_registration import ValidationError @@ -44,6 +45,7 @@ get_validation_registry, slice_, ) +from .links import links_for_node from .settings import get_settings from .utils import filter_for_access, get_base_url, record_timing @@ -346,27 +348,23 @@ async def metadata( ) async def array_block( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), block=Depends(block), slice=Depends(slice_), expected_shape=Depends(expected_shape), format: Optional[str] = None, filename: Optional[str] = None, + data_source: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), settings: BaseSettings = Depends(get_settings), ): """ Fetch a chunk of array-like data. """ - if entry.structure_family == "array": - shape = entry.structure().shape - elif entry.structure_family == "sparse": - shape = entry.structure().shape - else: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /array/block route.", - ) + shape = entry.structure().shape # Check that block dimensionality matches array dimensionality. ndim = len(shape) if len(block) != ndim: @@ -405,10 +403,14 @@ async def array_block( "Use slicing ('?slice=...') to request smaller chunks." ), ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family try: with record_timing(request.state.metrics, "pack"): return await construct_data_response( - entry.structure_family, + structure_family, serialization_registry, array, entry.metadata(), @@ -428,7 +430,10 @@ async def array_block( ) async def array_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), slice=Depends(slice_), expected_shape=Depends(expected_shape), format: Optional[str] = None, @@ -439,12 +444,10 @@ async def array_full( """ Fetch a slice of array-like data. """ - structure_family = entry.structure_family - if structure_family not in {"array", "sparse"}: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /array/full route.", - ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family # Deferred import because this is not a required dependency of the server # for some use cases. import numpy @@ -452,7 +455,7 @@ async def array_full( try: with record_timing(request.state.metrics, "read"): array = await ensure_awaitable(entry.read, slice) - if structure_family == "array": + if structure_family == StructureFamily.array: array = numpy.asarray(array) # Force dask or PIMS or ... to do I/O. except IndexError: raise HTTPException(status_code=400, detail="Block index out of range") @@ -494,7 +497,7 @@ async def array_full( async def get_table_partition( request: Request, partition: int, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Query(None, min_length=1), field: Optional[List[str]] = Query(None, min_length=1, deprecated=True), format: Optional[str] = None, @@ -542,7 +545,7 @@ async def get_table_partition( async def post_table_partition( request: Request, partition: int, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -577,11 +580,6 @@ async def table_partition( """ Fetch a partition (continuous block of rows) from a DataFrame. """ - if entry.structure_family != StructureFamily.table: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /table/partition route.", - ) try: # The singular/plural mismatch here of "fields" and "field" is # due to the ?field=A&field=B&field=C... encodes in a URL. @@ -625,7 +623,7 @@ async def table_partition( ) async def get_table_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -653,7 +651,7 @@ async def get_table_full( ) async def post_table_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry(scopes=["read:data"], structure_families={StructureFamily.table}), column: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -686,11 +684,6 @@ async def table_full( """ Fetch the data for the given table. """ - if entry.structure_family != StructureFamily.table: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /table/full route.", - ) try: with record_timing(request.state.metrics, "read"): data = await ensure_awaitable(entry.read, column) @@ -706,10 +699,14 @@ async def table_full( "request a smaller chunks." ), ) + if entry.structure_family == StructureFamily.union: + structure_family = entry.data_source.structure_family + else: + structure_family = entry.structure_family try: with record_timing(request.state.metrics, "pack"): return await construct_data_response( - entry.structure_family, + structure_family, serialization_registry, data, entry.metadata(), @@ -731,7 +728,9 @@ async def table_full( ) async def get_container_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -759,7 +758,9 @@ async def get_container_full( ) async def post_container_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.container} + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Body(None, min_length=1), format: Optional[str] = None, @@ -792,11 +793,6 @@ async def container_full( """ Fetch the data for the given container. """ - if entry.structure_family != StructureFamily.container: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /container/full route.", - ) try: with record_timing(request.state.metrics, "read"): data = await ensure_awaitable(entry.read, fields=field) @@ -836,7 +832,10 @@ async def container_full( ) async def node_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], + structure_families={StructureFamily.table, StructureFamily.container}, + ), principal: str = Depends(get_current_principal), field: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, @@ -899,7 +898,9 @@ async def node_full( ) async def get_awkward_buffers( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), form_key: Optional[List[str]] = Query(None, min_length=1), format: Optional[str] = None, filename: Optional[str] = None, @@ -935,7 +936,9 @@ async def get_awkward_buffers( async def post_awkward_buffers( request: Request, body: List[str], - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), format: Optional[str] = None, filename: Optional[str] = None, serialization_registry=Depends(get_serialization_registry), @@ -973,11 +976,6 @@ async def _awkward_buffers( ): structure_family = entry.structure_family structure = entry.structure() - if structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /awkward/buffers route.", - ) with record_timing(request.state.metrics, "read"): # The plural vs. singular mismatch is due to the way query parameters # are given as ?form_key=A&form_key=B&form_key=C. @@ -1018,7 +1016,9 @@ async def _awkward_buffers( ) async def awkward_full( request: Request, - entry=SecureEntry(scopes=["read:data"]), + entry=SecureEntry( + scopes=["read:data"], structure_families={StructureFamily.awkward} + ), # slice=Depends(slice_), format: Optional[str] = None, filename: Optional[str] = None, @@ -1029,11 +1029,6 @@ async def awkward_full( Fetch a slice of AwkwardArray data. """ structure_family = entry.structure_family - if structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, - detail=f"Cannot read {entry.structure_family} structure with /awkward/full route.", - ) # Deferred import because this is not a required dependency of the server # for some use cases. import awkward @@ -1129,14 +1124,24 @@ async def _create_node( body.structure_family, body.specs, ) - if structure_family == StructureFamily.container: - structure = None - else: - if len(body.data_sources) != 1: - raise NotImplementedError - structure = body.data_sources[0].structure - metadata_modified = False + if structure_family == StructureFamily.union: + structure = UnionStructure( + parts=[ + UnionStructurePart( + data_source_id=data_source.id, + structure=data_source.structure, + structure_family=data_source.structure_family, + name=data_source.name, + ) + for data_source in body.data_sources + ] + ) + elif body.data_sources: + assert len(body.data_sources) == 1 # more not yet implemented + structure = body.data_sources[0].structure + else: + structure = None # Specs should be ordered from most specific/constrained to least. # Validate them in reverse order, with the least constrained spec first, @@ -1172,34 +1177,17 @@ async def _create_node( specs=body.specs, data_sources=body.data_sources, ) - links = {} - base_url = get_base_url(request) - path_parts = [segment for segment in path.split("/") if segment] + [key] - path_str = "/".join(path_parts) - links["self"] = f"{base_url}/metadata/{path_str}" - if body.structure_family in {StructureFamily.array, StructureFamily.sparse}: - block_template = ",".join( - f"{{{index}}}" for index in range(len(node.structure().shape)) - ) - links["block"] = f"{base_url}/array/block/{path_str}?block={block_template}" - links["full"] = f"{base_url}/array/full/{path_str}" - elif body.structure_family == StructureFamily.table: - links[ - "partition" - ] = f"{base_url}/table/partition/{path_str}?partition={{index}}" - links["full"] = f"{base_url}/table/full/{path_str}" - elif body.structure_family == StructureFamily.container: - links["full"] = f"{base_url}/container/full/{path_str}" - links["search"] = f"{base_url}/search/{path_str}" - elif body.structure_family == StructureFamily.awkward: - links["buffers"] = f"{base_url}/awkward/buffers/{path_str}" - links["full"] = f"{base_url}/awkward/full/{path_str}" - else: - raise NotImplementedError(body.structure_family) + links = links_for_node( + structure_family, structure, get_base_url(request), path + f"/{key}" + ) + structure = node.structure() + if structure is not None: + structure = structure.dict() response_data = { "id": key, "links": links, "data_sources": [ds.dict() for ds in node.data_sources], + "structure": structure, } if metadata_modified: response_data["metadata"] = metadata @@ -1237,7 +1225,10 @@ async def bulk_delete( @router.put("/array/full/{path:path}") async def put_array_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), deserialization_registry=Depends(get_deserialization_registry), ): body = await request.body() @@ -1263,7 +1254,10 @@ async def put_array_full( @router.put("/array/block/{path:path}") async def put_array_block( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], + structure_families={StructureFamily.array, StructureFamily.sparse}, + ), deserialization_registry=Depends(get_deserialization_registry), block=Depends(block), ): @@ -1295,7 +1289,9 @@ async def put_array_block( @router.put("/node/full/{path:path}", deprecated=True) async def put_node_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.table} + ), deserialization_registry=Depends(get_deserialization_registry), ): if not hasattr(entry, "write"): @@ -1332,14 +1328,12 @@ async def put_table_partition( @router.put("/awkward/full/{path:path}") async def put_awkward_full( request: Request, - entry=SecureEntry(scopes=["write:data"]), + entry=SecureEntry( + scopes=["write:data"], structure_families={StructureFamily.awkward} + ), deserialization_registry=Depends(get_deserialization_registry), ): body = await request.body() - if entry.structure_family != StructureFamily.awkward: - raise HTTPException( - status_code=404, detail="This route is not applicable to this node." - ) if not hasattr(entry, "write"): raise HTTPException(status_code=405, detail="This node cannot be written to.") media_type = request.headers["content-type"] diff --git a/tiled/server/schemas.py b/tiled/server/schemas.py index 94c893aa1..8cd321d76 100644 --- a/tiled/server/schemas.py +++ b/tiled/server/schemas.py @@ -11,11 +11,12 @@ import pydantic.generics from ..structures.core import StructureFamily -from ..structures.data_source import Management +from ..structures.data_source import Management, validate_data_sources from .pydantic_array import ArrayStructure from .pydantic_awkward import AwkwardStructure from .pydantic_sparse import SparseStructure from .pydantic_table import TableStructure +from .pydantic_union import UnionStructure DataT = TypeVar("DataT") LinksT = TypeVar("LinksT") @@ -137,15 +138,17 @@ class DataSource(pydantic.BaseModel): Union[ ArrayStructure, AwkwardStructure, - TableStructure, NodeStructure, SparseStructure, + TableStructure, + UnionStructure, ] ] = None mimetype: Optional[str] = None parameters: dict = {} assets: List[Asset] = [] management: Management = Management.writable + name: Optional[str] = None @classmethod def from_orm(cls, orm): @@ -157,6 +160,7 @@ def from_orm(cls, orm): parameters=orm.parameters, assets=[Asset.from_assoc_orm(assoc) for assoc in orm.asset_associations], management=orm.management, + name=orm.name, ) @@ -169,9 +173,10 @@ class NodeAttributes(pydantic.BaseModel): Union[ ArrayStructure, AwkwardStructure, - TableStructure, NodeStructure, SparseStructure, + TableStructure, + UnionStructure, ] ] sorting: Optional[List[SortingItem]] @@ -217,12 +222,20 @@ class SparseLinks(pydantic.BaseModel): block: str +class UnionLinks(pydantic.BaseModel): + self: str + contents: List[ + Union[ArrayLinks, AwkwardLinks, ContainerLinks, DataFrameLinks, SparseLinks] + ] + + resource_links_type_by_structure_family = { - StructureFamily.container: ContainerLinks, StructureFamily.array: ArrayLinks, StructureFamily.awkward: AwkwardLinks, - StructureFamily.table: DataFrameLinks, + StructureFamily.container: ContainerLinks, StructureFamily.sparse: SparseLinks, + StructureFamily.table: DataFrameLinks, + StructureFamily.union: UnionLinks, } @@ -400,10 +413,22 @@ def specs_uniqueness_validator(cls, v): raise pydantic.errors.ListUniqueItemsError() return v + @pydantic.validator("data_sources", always=True) + def check_consistency(cls, v, values): + return validate_data_sources(values["structure_family"], v) + class PostMetadataResponse(pydantic.BaseModel, Generic[ResourceLinksT]): id: str - links: Union[ArrayLinks, DataFrameLinks, SparseLinks] + links: Union[ArrayLinks, DataFrameLinks, SparseLinks, UnionLinks] + structure: Union[ + ArrayStructure, + AwkwardStructure, + NodeStructure, + SparseStructure, + TableStructure, + UnionStructure, + ] metadata: Dict data_sources: List[DataSource] diff --git a/tiled/structures/container.py b/tiled/structures/container.py new file mode 100644 index 000000000..c451ddb25 --- /dev/null +++ b/tiled/structures/container.py @@ -0,0 +1,12 @@ +import dataclasses +from typing import Optional + + +@dataclasses.dataclass +class ContainerStructure: + contents: Optional[dict] + count: Optional[int] + + @classmethod + def from_json(cls, structure): + return cls(**structure) diff --git a/tiled/structures/core.py b/tiled/structures/core.py index 065e3dd4e..e5e794f47 100644 --- a/tiled/structures/core.py +++ b/tiled/structures/core.py @@ -9,12 +9,22 @@ from typing import Optional -class StructureFamily(str, enum.Enum): +class BaseStructureFamily(str, enum.Enum): + array = "array" awkward = "awkward" container = "container" + sparse = "sparse" + table = "table" + # excludes union, which DataSources cannot have + + +class StructureFamily(str, enum.Enum): array = "array" + awkward = "awkward" + container = "container" sparse = "sparse" table = "table" + union = "union" @dataclass(frozen=True) diff --git a/tiled/structures/data_source.py b/tiled/structures/data_source.py index 97367d097..aedd8f34c 100644 --- a/tiled/structures/data_source.py +++ b/tiled/structures/data_source.py @@ -1,8 +1,9 @@ +import collections import dataclasses import enum from typing import Any, List, Optional -from .core import StructureFamily +from ..structures.core import BaseStructureFamily, StructureFamily class Management(str, enum.Enum): @@ -23,10 +24,69 @@ class Asset: @dataclasses.dataclass class DataSource: - structure_family: StructureFamily + structure_family: BaseStructureFamily structure: Any id: Optional[int] = None mimetype: Optional[str] = None parameters: dict = dataclasses.field(default_factory=dict) assets: List[Asset] = dataclasses.field(default_factory=list) management: Management = Management.writable + name: Optional[str] = None + + +def validate_data_sources(node_structure_family, data_sources): + "Check that data sources are consistent." + return validators[node_structure_family](node_structure_family, data_sources) + + +def validate_container_data_sources(node_structure_family, data_sources): + if len(data_sources) > 1: + raise ValueError( + "A container node can be backed by 0 or 1 data source, " + f"not {len(data_sources)}" + ) + return data_sources + + +def validate_union_data_sources(node_structure_family, data_sources): + "Check that column names and keys of others (e.g. arrays) do not collide." + keys = set() + names = set() + for data_source in data_sources: + if data_source.name is None: + raise ValueError( + "Data sources backing a union structure_family must " + "all have non-NULL names." + ) + if data_source.name in names: + raise ValueError( + "Data sources must have unique names. " + f"This name is used one more than one: {data_source.name}" + ) + names.add(data_source.name) + if data_source.structure_family == StructureFamily.table: + columns = data_source.structure.columns + if keys.intersection(columns): + raise ValueError( + f"Data sources provide colliding keys: {keys.intersection(columns)}" + ) + keys.update(columns) + else: + key = data_source.name + if key in keys: + raise ValueError(f"Data sources provide colliding keys: {key}") + keys.add(key) + return data_sources + + +def validate_other_data_sources(node_structure_family, data_sources): + if len(data_sources) != 1: + raise ValueError( + f"A {node_structure_family} node must be backed by 1 data source." + ) + return data_sources + + +validators = collections.defaultdict(lambda: validate_other_data_sources) +validators[StructureFamily.container] = validate_container_data_sources +validators[StructureFamily.union] = validate_union_data_sources diff --git a/tiled/structures/union.py b/tiled/structures/union.py new file mode 100644 index 000000000..3d4a6cc4b --- /dev/null +++ b/tiled/structures/union.py @@ -0,0 +1,28 @@ +import dataclasses +from typing import Any, List, Optional + +from .core import StructureFamily + + +@dataclasses.dataclass +class UnionStructurePart: + structure_family: StructureFamily + structure: Any # Union of Structures, but we do not want to import them... + name: Optional[str] + + @classmethod + def from_json(cls, item): + return cls(**item) + + +@dataclasses.dataclass +class UnionStructure: + parts: List[UnionStructurePart] + all_keys: List[str] + + @classmethod + def from_json(cls, structure): + return cls( + parts=[UnionStructurePart.from_json(item) for item in structure["parts"]], + all_keys=structure["all_keys"], + )