Skip to content

Commit

Permalink
Fix issues with exporting data.
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 2, 2024
1 parent 9185105 commit 16e515b
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 38 deletions.
2 changes: 1 addition & 1 deletion lilac/data/dataset_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3381,7 +3381,7 @@ def to_json(
file.write(orjson.dumps(row))
file.write('\n'.encode('utf-8'))
else:
file.write(orjson.dumps(rows))
file.write(orjson.dumps(list(rows)))
log(f'Dataset exported to {filepath}')

@override
Expand Down
81 changes: 68 additions & 13 deletions lilac/data/dataset_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': 'hello'}, {'text': 'everybody'}]

# Include signals.
dataset.to_json(filepath, include_signals=True)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
Expand All @@ -126,7 +126,7 @@ def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -
include_signals=True,
)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
Expand All @@ -138,7 +138,62 @@ def test_export_to_json(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -
filepath, filters=[('text.test_signal.flen', 'less_equal', '5')], include_signals=True
)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': {VALUE_KEY: 'hello', 'test_signal': {'flen': 5.0, 'len': 5}}}]


def test_export_to_jsonl(make_test_data: TestDataMaker, tmp_path: pathlib.Path) -> None:
dataset = make_test_data([{'text': 'hello'}, {'text': 'everybody'}])
dataset.compute_signal(TestSignal(), 'text')

# Download all columns.
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, jsonl=True)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': 'hello'}, {'text': 'everybody'}]

# Include signals.
dataset.to_json(filepath, jsonl=True, include_signals=True)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
{'text': {VALUE_KEY: 'hello', 'test_signal': {'flen': 5.0, 'len': 5}}},
{'text': {VALUE_KEY: 'everybody', 'test_signal': {'flen': 9.0, 'len': 9}}},
]

# Download a subset of columns with filter.
filepath = tmp_path / 'dataset2.json'
dataset.to_json(
filepath,
jsonl=True,
columns=['text', 'text.test_signal'],
filters=[('text.test_signal.len', 'greater', '6')],
include_signals=True,
)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
{'text': {VALUE_KEY: 'everybody', 'test_signal': {'flen': 9.0, 'len': 9}}}
]

filepath = tmp_path / 'dataset3.json'
dataset.to_json(
filepath,
jsonl=True,
filters=[('text.test_signal.flen', 'less_equal', '5')],
include_signals=True,
)

with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': {VALUE_KEY: 'hello', 'test_signal': {'flen': 5.0, 'len': 5}}}]
Expand All @@ -152,7 +207,7 @@ def test_export_to_csv(make_test_data: TestDataMaker, tmp_path: pathlib.Path) ->
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath)

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand All @@ -172,7 +227,7 @@ def test_export_to_csv_include_signals(
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath, include_signals=True)

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand All @@ -196,7 +251,7 @@ def test_export_to_csv_subset_source_columns(
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath, columns=['age', 'metric'])

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand Down Expand Up @@ -232,7 +287,7 @@ def test_export_to_csv_subset_of_nested_data(
filepath = tmp_path / 'dataset.csv'
dataset.to_csv(filepath, columns=['doc.content', 'doc.paragraphs.*.text'])

with open(filepath) as f:
with open(filepath, 'r') as f:
rows = list(csv.reader(f))

assert rows == [
Expand Down Expand Up @@ -323,7 +378,7 @@ def test_label_and_export_by_excluding(
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{f'{DELETED_LABEL_NAME}': None, 'text': 'a'}]
Expand All @@ -332,7 +387,7 @@ def test_label_and_export_by_excluding(
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, include_deleted=True)

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [
Expand All @@ -357,7 +412,7 @@ def test_include_multiple_labels(make_test_data: TestDataMaker, tmp_path: pathli
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, columns=['text'], include_labels=['good', 'very_good'])

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

parsed_items = sorted(parsed_items, key=lambda x: x['text'])
Expand All @@ -373,7 +428,7 @@ def test_exclude_multiple_labels(make_test_data: TestDataMaker, tmp_path: pathli
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, columns=['text'], exclude_labels=['bad', 'very_bad'])

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

parsed_items = sorted(parsed_items, key=lambda x: x['text'])
Expand All @@ -389,7 +444,7 @@ def test_exclude_trumps_include(make_test_data: TestDataMaker, tmp_path: pathlib
filepath = tmp_path / 'dataset.json'
dataset.to_json(filepath, columns=['text'], include_labels=['good'], exclude_labels=['bad'])

with open(filepath) as f:
with open(filepath, 'r') as f:
parsed_items = [json.loads(line) for line in f.readlines()]

assert parsed_items == [{'text': 'b'}]
30 changes: 22 additions & 8 deletions lilac/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class SelectRowsOptions(BaseModel):
offset: Optional[int] = None
combine_columns: Optional[bool] = None
include_deleted: bool = False
exclude_signals: bool = False


class SelectRowsSchemaOptions(BaseModel):
Expand Down Expand Up @@ -206,6 +207,7 @@ def select_rows(
offset=options.offset,
combine_columns=options.combine_columns or False,
include_deleted=options.include_deleted,
exclude_signals=options.exclude_signals,
user=user,
)

Expand Down Expand Up @@ -303,6 +305,7 @@ class ExportOptions(BaseModel):
columns: Sequence[Path] = []
include_labels: Sequence[str] = []
exclude_labels: Sequence[str] = []
include_signals: bool = False
# Note: "__deleted__" is "just" another label, and the UI
# will default to adding the "__deleted__" label to the exclude_labels list. If the user wants
# to include deleted items, they can remove the "__deleted__" label from the exclude_labels list.
Expand All @@ -328,20 +331,31 @@ def export_dataset(namespace: str, dataset_name: str, options: ExportOptions) ->

if options.format == 'csv':
dataset.to_csv(
options.filepath, options.columns, [], options.include_labels, options.exclude_labels
filepath=options.filepath,
columns=options.columns,
filters=[],
include_labels=options.include_labels,
exclude_labels=options.exclude_labels,
include_signals=options.include_signals,
)
elif options.format == 'json':
dataset.to_json(
options.filepath,
options.jsonl or False,
options.columns,
[],
options.include_labels,
options.exclude_labels,
filepath=options.filepath,
jsonl=options.jsonl or False,
columns=options.columns,
filters=[],
include_labels=options.include_labels,
exclude_labels=options.exclude_labels,
include_signals=options.include_signals,
)
elif options.format == 'parquet':
dataset.to_parquet(
options.filepath, options.columns, [], options.include_labels, options.exclude_labels
filepath=options.filepath,
columns=options.columns,
filters=[],
include_labels=options.include_labels,
exclude_labels=options.exclude_labels,
include_signals=options.include_signals,
)
else:
raise ValueError(f'Unknown format: {options.format}')
Expand Down
63 changes: 47 additions & 16 deletions web/blueprint/src/lib/components/datasetView/ExportModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
NotificationActionButton,
RadioButton,
RadioButtonGroup,
SkeletonPlaceholder,
SkeletonText,
TextArea,
TextInput,
Expand All @@ -48,14 +47,29 @@
const datasetViewStore = getDatasetViewContext();
$: ({sourceFields, enrichedFields, labelFields, mapFields} = getFields(schema));
$: ({sourceFields, signalFields: signalFields, labelFields, mapFields} = getFields(schema));
let checkedSourceFields: LilacField[] | undefined = undefined;
let checkedLabeledFields: LilacField[] = [];
let checkedEnrichedFields: LilacField[] = [];
let checkedSignalFields: LilacField[] = [];
let checkedMapFields: LilacField[] = [];
let includeOnlyLabels: boolean[] = [];
let excludeLabels: boolean[] = [];
let includeSignals = false;
function includeSignalsChecked(e: Event) {
includeSignals = (e.target as HTMLInputElement).checked;
if (includeSignals) {
checkedSignalFields = signalFields;
} else {
checkedSignalFields = [];
}
}
function signalCheckboxClicked() {
if (checkedSignalFields.length > 0) {
includeSignals = true;
}
}
// Default the checked source fields to all of them.
$: {
Expand All @@ -67,16 +81,19 @@
$: exportFields = [
...(checkedSourceFields || []),
...checkedLabeledFields,
...checkedEnrichedFields,
...checkedSignalFields,
...checkedMapFields
];
$: console.log('export fields:', exportFields);
$: previewRows =
exportFields.length > 0 && open
? querySelectRows($datasetViewStore.namespace, $datasetViewStore.datasetName, {
columns: exportFields.map(x => x.path),
limit: 3,
combine_columns: true
combine_columns: true,
exclude_signals: !includeSignals
})
: null;
$: exportDisabled =
Expand All @@ -87,21 +104,21 @@
const petalFields = petals(schema).filter(f => !isEmbeddingField(f));
const labelFields = allFields.filter(f => isLabelRootField(f));
const enrichedFields = allFields
.filter(f => isSignalField(f) || isClusterField(f))
const signalFields = allFields
.filter(f => isSignalField(f))
.filter(f => !childFields(f).some(f => f.dtype?.type === 'embedding'));
const mapFields = allFields.filter(f => isMapField(f));
const mapFields = allFields.filter(f => isMapField(f) || isClusterField(f));
const sourceFields = petalFields.filter(
f =>
!labelFields.includes(f) &&
!enrichedFields.includes(f) &&
!signalFields.includes(f) &&
!mapFields.includes(f) &&
// Labels are special in that we only show the root of the label field so the children do
// not show up in the labelFields.
!isLabelField(f)
);
return {sourceFields, enrichedFields, labelFields, mapFields};
return {sourceFields, signalFields, labelFields, mapFields};
}
async function submit() {
Expand All @@ -113,7 +130,8 @@
jsonl,
columns: exportFields.map(x => x.path),
include_labels: labelFields.filter((_, i) => includeOnlyLabels[i]).map(x => x.path[0]),
exclude_labels: labelFields.filter((_, i) => excludeLabels[i]).map(x => x.path[0])
exclude_labels: labelFields.filter((_, i) => excludeLabels[i]).map(x => x.path[0]),
include_signals: includeSignals
};
$exportDataset.mutate([namespace, datasetName, options]);
}
Expand Down Expand Up @@ -150,7 +168,7 @@
<p class="text-red-600" class:invisible={exportFields.length > 0}>
No fields selected. Please select at least one field to export.
</p>
<div class="flex flex-wrap gap-x-8">
<div class="flex flex-row gap-x-8">
<section>
<h4>Source fields</h4>
{#if checkedSourceFields != null}
Expand All @@ -163,10 +181,23 @@
<FieldList fields={labelFields} bind:checkedFields={checkedLabeledFields} />
</section>
{/if}
{#if enrichedFields.length > 0}
{#if signalFields.length > 0}
<section>
<h4>Enriched fields</h4>
<FieldList fields={enrichedFields} bind:checkedFields={checkedEnrichedFields} />
<div class="flex flex-row items-center">
<div class="w-8">
<Checkbox
hideLabel
checked={includeSignals}
on:change={e => includeSignalsChecked(e)}
/>
</div>
<h4>Signal fields</h4>
</div>
<FieldList
fields={signalFields}
bind:checkedFields={checkedSignalFields}
on:change={() => signalCheckboxClicked()}
/>
</section>
{/if}
{#if mapFields.length > 0}
Expand Down Expand Up @@ -241,7 +272,7 @@
hideCloseButton
/>
{:else if $exportDataset.isLoading}
<SkeletonPlaceholder />
<SkeletonText lines={5} class="mt-4" />
{:else if $exportDataset.data}
<div class="export-success">
<InlineNotification kind="success" lowContrast hideCloseButton>
Expand Down
Loading

0 comments on commit 16e515b

Please sign in to comment.