Skip to content

Commit

Permalink
Add an upload dataset script. Some other cleanups. (#1059)
Browse files Browse the repository at this point in the history
Python:
- Add limit, offset, filter to labels & deleting. This allows us to
filter by a cluster group, sort by membership prob, limit=N, and add a
label to it
- Export useful methods to lilac public API.

UI:
- Remove filters, searches, group by when clicking the cluster button in
the UI

Example dataset:
https://huggingface.co/datasets/lilacai/hermes-cluster-sample

<img width="1203" alt="image"
src="https://github.com/lilacai/lilac/assets/1100749/162f32e8-bc50-42d2-8dbf-72d8275245d6">
  • Loading branch information
nsthorat authored Jan 12, 2024
1 parent 1784239 commit bb71881
Show file tree
Hide file tree
Showing 11 changed files with 389 additions and 78 deletions.
3 changes: 2 additions & 1 deletion lilac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from .data import * # noqa: F403
from .data.dataset_duckdb import DatasetDuckDB
from .db_manager import get_dataset, set_default_dataset_cls
from .db_manager import get_dataset, list_datasets, set_default_dataset_cls
from .deploy import deploy_config, deploy_project
from .embeddings import * # noqa: F403
from .env import * # noqa: F403
Expand Down Expand Up @@ -49,6 +49,7 @@
'from_dicts',
'from_huggingface',
'get_dataset',
'list_datasets',
'init',
'span',
'load',
Expand Down
56 changes: 54 additions & 2 deletions lilac/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import __version__
from .concepts.db_concept import DISK_CONCEPT_DB
from .data.dataset_storage_utils import download
from .data.dataset_storage_utils import download, upload
from .deploy import deploy_project
from .env import env, get_project_dir
from .hf_docker_start import hf_docker_start
Expand Down Expand Up @@ -236,15 +236,66 @@ def deploy_project_command(
'This can also be set via the `HF_ACCESS_TOKEN` environment flag.',
type=str,
)
@click.option(
'--overwrite',
help='When true, overwrites any existing datasets with the same name.',
is_flag=True,
default=False,
)
def download_command(
url_or_repo: str,
project_dir: Optional[str],
dataset_namespace: Optional[str],
dataset_name: Optional[str],
hf_token: Optional[str],
overwrite: Optional[bool] = False,
) -> None:
"""Download a Lilac dataset from HuggingFace."""
download(url_or_repo, project_dir, dataset_namespace, dataset_name)
download(url_or_repo, project_dir, dataset_namespace, dataset_name, hf_token, overwrite)


@click.command()
@click.argument(
'dataset',
required=True,
)
@click.option(
'--project_dir',
help='The project directory to use for the demo. Defaults to `env.LILAC_PROJECT_DIR`.',
type=str,
)
@click.option(
'--url_or_repo',
help='The repo id, or full dataset URL, to use for uploading. For example: lilacai/my-dataset.',
type=str,
)
@click.option(
'--public',
help='When true, makes the dataset public.',
is_flag=True,
default=False,
)
@click.option(
'--readme_suffix',
help='A suffix string for the readme file.',
type=str,
)
@click.option(
'--hf_token',
help='The HuggingFace access token to use when writing private datasets. '
'This can also be set via the `HF_ACCESS_TOKEN` environment flag.',
type=str,
)
def upload_command(
dataset: str,
project_dir: Optional[str],
url_or_repo: Optional[str] = None,
public: Optional[bool] = False,
readme_suffix: Optional[str] = None,
hf_token: Optional[str] = None,
) -> None:
"""Upload a Lilac dataset to HuggingFace."""
upload(dataset, project_dir, url_or_repo, public, readme_suffix, hf_token)


@click.command()
Expand All @@ -268,6 +319,7 @@ def cli() -> None:
cli.add_command(deploy_project_command, name='deploy-project')
cli.add_command(hf_docker_start_command, name='hf-docker-start')
cli.add_command(download_command, name='download')
cli.add_command(upload_command, name='upload')
cli.add_command(concepts)

if __name__ == '__main__':
Expand Down
4 changes: 4 additions & 0 deletions lilac/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
Filter,
FilterLike,
FilterOp,
GroupsSortBy,
KeywordSearch,
ListOp,
MetadataSearch,
Schema,
SelectGroupsResult,
SelectRowsResult,
SemanticSearch,
SortOrder,
UnaryOp,
)

Expand All @@ -34,4 +36,6 @@
'SelectRowsResult',
'SelectGroupsResult',
'FilterLike',
'SortOrder',
'GroupsSortBy',
]
39 changes: 37 additions & 2 deletions lilac/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,10 @@ def add_labels(
row_ids: Optional[Sequence[str]] = None,
searches: Optional[Sequence[Search]] = None,
filters: Optional[Sequence[FilterLike]] = None,
sort_by: Optional[Sequence[Path]] = None,
sort_order: Optional[SortOrder] = SortOrder.DESC,
limit: Optional[int] = None,
offset: Optional[int] = 0,
include_deleted: bool = False,
value: Optional[str] = 'true',
) -> int:
Expand Down Expand Up @@ -668,6 +672,10 @@ def remove_labels(
row_ids: Optional[Sequence[str]] = None,
searches: Optional[Sequence[Search]] = None,
filters: Optional[Sequence[FilterLike]] = None,
sort_by: Optional[Sequence[Path]] = None,
sort_order: Optional[SortOrder] = SortOrder.DESC,
limit: Optional[int] = None,
offset: Optional[int] = 0,
include_deleted: bool = False,
) -> int:
"""Removes labels from a row, or a set of rows defined by searches and filters.
Expand All @@ -681,24 +689,51 @@ def delete_rows(
row_ids: Optional[Sequence[str]] = None,
searches: Optional[Sequence[Search]] = None,
filters: Optional[Sequence[FilterLike]] = None,
sort_by: Optional[Sequence[Path]] = None,
sort_order: Optional[SortOrder] = SortOrder.DESC,
limit: Optional[int] = None,
offset: Optional[int] = 0,
) -> int:
"""Deletes rows from the dataset.
Returns the number of deleted rows.
"""
return self.add_labels(DELETED_LABEL_NAME, row_ids, searches, filters)
return self.add_labels(
name=DELETED_LABEL_NAME,
row_ids=row_ids,
searches=searches,
filters=filters,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset,
)

def restore_rows(
self,
row_ids: Optional[Sequence[str]] = None,
searches: Optional[Sequence[Search]] = None,
filters: Optional[Sequence[FilterLike]] = None,
sort_by: Optional[Sequence[Path]] = None,
sort_order: Optional[SortOrder] = SortOrder.DESC,
limit: Optional[int] = None,
offset: Optional[int] = 0,
) -> int:
"""Undeletes rows from the dataset.
Returns the number of restored rows.
"""
return self.remove_labels(DELETED_LABEL_NAME, row_ids, searches, filters, include_deleted=True)
return self.remove_labels(
DELETED_LABEL_NAME,
row_ids=row_ids,
searches=searches,
filters=filters,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset,
include_deleted=True,
)

@abc.abstractmethod
def stats(self, leaf_path: Path, include_deleted: bool = False) -> StatsResult:
Expand Down
26 changes: 24 additions & 2 deletions lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,10 @@ def add_labels(
row_ids: Optional[Sequence[str]] = None,
searches: Optional[Sequence[Search]] = None,
filters: Optional[Sequence[FilterLike]] = None,
sort_by: Optional[Sequence[Path]] = None,
sort_order: Optional[SortOrder] = SortOrder.DESC,
limit: Optional[int] = None,
offset: Optional[int] = 0,
include_deleted: bool = False,
value: Optional[str] = 'true',
) -> int:
Expand All @@ -2152,7 +2156,14 @@ def add_labels(
insert_row_ids = (
row[ROWID]
for row in self.select_rows(
columns=[ROWID], searches=searches, filters=filters, include_deleted=include_deleted
columns=[ROWID],
searches=searches,
filters=filters,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset,
include_deleted=include_deleted,
)
)

Expand Down Expand Up @@ -2206,6 +2217,10 @@ def remove_labels(
row_ids: Optional[Sequence[str]] = None,
searches: Optional[Sequence[Search]] = None,
filters: Optional[Sequence[FilterLike]] = None,
sort_by: Optional[Sequence[Path]] = None,
sort_order: Optional[SortOrder] = SortOrder.DESC,
limit: Optional[int] = None,
offset: Optional[int] = 0,
include_deleted: bool = False,
) -> int:
# Check if the label file exists.
Expand All @@ -2222,7 +2237,14 @@ def remove_labels(
remove_row_ids = [
row[ROWID]
for row in self.select_rows(
columns=[ROWID], searches=searches, filters=filters, include_deleted=include_deleted
columns=[ROWID],
searches=searches,
filters=filters,
sort_by=sort_by,
sort_order=sort_order,
limit=limit,
offset=offset,
include_deleted=include_deleted,
)
]

Expand Down
95 changes: 95 additions & 0 deletions lilac/data/dataset_labels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ def test_add_single_label(make_test_data: TestDataMaker, mocker: MockerFixture)
assert dataset.get_label_names() == ['test_label']


@freeze_time(TEST_TIME)
def test_add_labels_sort_limit(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data(TEST_ITEMS)

num_labels = dataset.add_labels('test_label', sort_by=['int'], sort_order=SortOrder.ASC, limit=2)
assert num_labels == 2
assert dataset.manifest() == DatasetManifest(
source=TestSource(),
namespace='test_namespace',
dataset_name='test_dataset',
data_schema=schema(
{
'str': 'string',
'int': 'int32',
'test_label': field(fields={'label': 'string', 'created': 'timestamp'}, label='test_label'),
}
),
num_items=3,
)

assert list(dataset.select_rows([PATH_WILDCARD])) == [
{
'str': 'a',
'int': 1,
'test_label.label': 'true',
'test_label.created': Timestamp(TEST_TIME),
},
{
'str': 'b',
'int': 2,
'test_label.label': 'true',
'test_label.created': Timestamp(TEST_TIME),
},
{
'str': 'c',
'int': 3,
'test_label.label': None,
'test_label.created': None,
},
]

assert dataset.get_label_names() == ['test_label']


@freeze_time(TEST_TIME)
def test_add_row_labels(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data(TEST_ITEMS)
Expand Down Expand Up @@ -301,6 +345,57 @@ def test_remove_labels_no_filters(make_test_data: TestDataMaker, mocker: MockerF
assert dataset.get_label_names() == []


@freeze_time(TEST_TIME)
def test_remove_labels_sort_limit(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data(TEST_ITEMS)

# Add labels to every row.
num_labels = dataset.add_labels('test_label')
assert num_labels == 3

# Remove 2 labels.
num_labels = dataset.remove_labels(
'test_label', sort_by=['int'], sort_order=SortOrder.ASC, limit=2
)
assert num_labels == 2
assert dataset.manifest() == DatasetManifest(
source=TestSource(),
namespace='test_namespace',
dataset_name='test_dataset',
data_schema=schema(
{
'str': 'string',
'int': 'int32',
'test_label': field(fields={'label': 'string', 'created': 'timestamp'}, label='test_label'),
}
),
num_items=3,
)

assert list(dataset.select_rows([PATH_WILDCARD], sort_by=('str',), sort_order=SortOrder.ASC)) == [
{
'str': 'a',
'int': 1,
'test_label.label': None,
'test_label.created': None,
},
{
'str': 'b',
'int': 2,
'test_label.label': None,
'test_label.created': None,
},
{
'str': 'c',
'int': 3,
'test_label.label': 'true',
'test_label.created': Timestamp(TEST_TIME),
},
]

assert dataset.get_label_names() == ['test_label']


@freeze_time(TEST_TIME)
def test_label_overwrites(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data(TEST_ITEMS)
Expand Down
Loading

0 comments on commit bb71881

Please sign in to comment.