diff --git a/src/datachain/lib/dc.py b/src/datachain/lib/dc.py index beb3b9849..8b2cca25d 100644 --- a/src/datachain/lib/dc.py +++ b/src/datachain/lib/dc.py @@ -1524,6 +1524,7 @@ def from_records( to_insert: Optional[Union[dict, list[dict]]], session: Optional[Session] = None, in_memory: bool = False, + schema: Optional[dict[str, DataType]] = None, ) -> "DataChain": """Create a DataChain from the provided records. This method can be used for programmatically generating a chain in contrast of reading data from storages @@ -1532,10 +1533,10 @@ def from_records( Parameters: to_insert : records (or a single record) to insert. Each record is a dictionary of signals and theirs values. + schema : describes chain signals and their corresponding types Example: ```py - empty = DataChain.from_records() single_record = DataChain.from_records(DataChain.DEFAULT_FILE_RECORD) ``` """ @@ -1543,11 +1544,27 @@ def from_records( catalog = session.catalog name = session.generate_temp_dataset_name() - columns: tuple[sqlalchemy.Column[Any], ...] = tuple( - sqlalchemy.Column(name, typ) - for name, typ in File._datachain_column_types.items() + signal_schema = None + columns: list[sqlalchemy.Column] = [] + + if schema: + signal_schema = SignalSchema(schema) + columns = signal_schema.db_signals(as_columns=True) # type: ignore[assignment] + else: + columns = [ + sqlalchemy.Column(name, typ) + for name, typ in File._datachain_column_types.items() + ] + + dsr = catalog.create_dataset( + name, + columns=columns, + feature_schema=( + signal_schema.clone_without_sys_signals().serialize() + if signal_schema + else None + ), ) - dsr = catalog.create_dataset(name, columns=columns) if isinstance(to_insert, dict): to_insert = [to_insert] diff --git a/src/datachain/lib/signal_schema.py b/src/datachain/lib/signal_schema.py index 338afd06d..43232b7a8 100644 --- a/src/datachain/lib/signal_schema.py +++ b/src/datachain/lib/signal_schema.py @@ -14,6 +14,7 @@ get_origin, ) +import sqlalchemy as sa from pydantic import BaseModel, create_model from typing_extensions import Literal as LiteralEx @@ -232,7 +233,7 @@ def db_signals( signals = [ DEFAULT_DELIMITER.join(path) if not as_columns - else Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type)) + else sa.Column(DEFAULT_DELIMITER.join(path), python_to_sql(_type)) for path, _type, has_subtree, _ in self.get_flat_tree() if not has_subtree ] diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 2994d525a..53ebcb0d0 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -159,6 +159,52 @@ def test_from_features(test_session): assert t1 == features[i] +def test_from_records_empty_chain_with_schema(test_session): + schema = {"my_file": File, "my_col": int} + ds = DataChain.from_records([], schema=schema, session=test_session) + ds_sys = ds.settings(sys=True) + + ds_name = "my_ds" + ds.save(ds_name) + ds = DataChain(name=ds_name) + + assert isinstance(ds.feature_schema, dict) + assert isinstance(ds.signals_schema, SignalSchema) + assert ds.schema.keys() == {"my_file", "my_col"} + assert set(ds.schema.values()) == {File, int} + assert ds.count() == 0 + + # check that columns have actually been created from schema + dr = ds_sys.catalog.warehouse.dataset_rows(ds_sys.catalog.get_dataset(ds_name)) + assert sorted([c.name for c in dr.c]) == sorted(ds.signals_schema.db_signals()) + + +def test_from_records_empty_chain_without_schema(test_session): + ds = DataChain.from_records([], schema=None, session=test_session) + ds_sys = ds.settings(sys=True) + + ds_name = "my_ds" + ds.save(ds_name) + ds = DataChain(name=ds_name) + + assert ds.schema.keys() == { + "source", + "path", + "size", + "version", + "etag", + "is_latest", + "last_modified", + "location", + "vtype", + } + assert ds.count() == 0 + + # check that columns have actually been created from schema + dr = ds_sys.catalog.warehouse.dataset_rows(ds_sys.catalog.get_dataset(ds_name)) + assert sorted([c.name for c in dr.c]) == sorted(ds.signals_schema.db_signals()) + + def test_datasets(test_session): ds = DataChain.datasets(session=test_session) datasets = [d for d in ds.collect("dataset") if d.name == "fibonacci"] diff --git a/tests/unit/lib/test_signal_schema.py b/tests/unit/lib/test_signal_schema.py index 13f2b8da6..56431a23b 100644 --- a/tests/unit/lib/test_signal_schema.py +++ b/tests/unit/lib/test_signal_schema.py @@ -2,8 +2,9 @@ from typing import Optional, Union import pytest +from sqlalchemy import Column -from datachain import Column, DataModel +from datachain import DataModel from datachain.lib.convert.flatten import flatten from datachain.lib.file import File from datachain.lib.signal_schema import (