Skip to content

Commit

Permalink
formats
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 15, 2024
1 parent 78f0c55 commit b91e496
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 44 deletions.
2 changes: 1 addition & 1 deletion lilac/formats/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ class ShareGPT(DatasetFormat):

input_selectors: ClassVar[dict[str, DatasetFormatInputSelector]] = {
selector.name: selector
for selector in [_SYSTEM_SELECTOR, _HUMAN_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR]
for selector in [_HUMAN_SELECTOR, _SYSTEM_SELECTOR, _GPT_SELECTOR, _TOOL_SELECTOR]
}
10 changes: 10 additions & 0 deletions lilac/router_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,13 @@ def restore_rows(
searches=options.searches,
filters=sanitized_filters,
)


@router.get('/{namespace}/{dataset_name}/format_selectors')
def get_format_selectors(namespace: str, dataset_name: str) -> list[str]:
"""Get format selectors for the dataset if a format has been inferred."""
dataset = get_dataset(namespace, dataset_name)
manifest = dataset.manifest()
if manifest.dataset_format:
return list(manifest.dataset_format.input_selectors.keys())
return []
49 changes: 42 additions & 7 deletions lilac/router_dataset_signals.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Routing endpoints for running signals on datasets."""
from typing import Annotated, Optional
from typing import Annotated, Optional, Union

from fastapi import APIRouter, HTTPException
from fastapi.params import Depends
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic import Field as PydanticField

from .auth import UserInfo, get_session_user, get_user_access
from .config import ClusterInputSelectorConfig
from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls
from .db_manager import get_dataset
from .router_utils import RouteErrorHandler
from .schema import Path
from .schema import Path, PathTuple, normalize_path
from .signal import Signal, resolve_signal
from .tasks import TaskId, get_task_manager, launch_task

Expand Down Expand Up @@ -82,7 +84,9 @@ def run() -> None:
class ClusterOptions(BaseModel):
"""The request for the cluster endpoint."""

input: Path
input: Optional[Path] = None
input_selector: Optional[ClusterInputSelectorConfig] = None

output_path: Optional[Path] = None
use_garden: bool = PydanticField(
default=False, description='Accelerate computation by running remotely on Lilac Garden.'
Expand All @@ -107,14 +111,45 @@ def cluster(
if not get_user_access(user).dataset.compute_signals:
raise HTTPException(401, 'User does not have access to compute clusters over this dataset.')

path_str = '.'.join(map(str, options.input))
task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"'
task_id = get_task_manager().task_id(name=task_name)
if options.input is None and options.input_selector is None:
raise HTTPException(400, 'Either input or input_selector must be provided.')

dataset = get_dataset(namespace, dataset_name)
manifest = dataset.manifest()

cluster_input: Optional[Union[DatasetFormatInputSelector, PathTuple]] = None
if options.input:
path_str = '.'.join(map(str, options.input))
task_name = f'[{namespace}/{dataset_name}] Clustering "{path_str}"'
cluster_input = normalize_path(options.input)
elif options.input_selector:
dataset_format = manifest.dataset_format
if dataset_format is None:
raise ValueError('Dataset format is not defined.')

format_cls = get_dataset_format_cls(dataset_format.name)
if format_cls is None:
raise ValueError(f'Unknown format: {c.input_selector.format}')

format = format_cls()
if format != manifest.dataset_format:
raise ValueError(
f'Cluster input format {c.input_selector.format} does not match '
f'dataset format {manifest.dataset_format}'
)

cluster_input = format_cls.input_selectors[c.input_selector.selector]

task_name = (
f'[{namespace}/{dataset_name}] Clustering using input selector '
f'"{options.input_selector.selector}"'
)

task_id = get_task_manager().task_id(name=task_name)

def run() -> None:
dataset.cluster(
options.input,
cluster_input,
options.output_path,
use_garden=options.use_garden,
overwrite=options.overwrite,
Expand Down
120 changes: 120 additions & 0 deletions notebooks/Clustering copy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Clustering\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook accompanies the [Cluster a dataset](https://docs.lilacml.com/datasets/dataset_cluster.html) guide.\n",
"Let's start by loading a small dataset of multi-turn conversations between a human and a chatbot:\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset \"capybara\" written to ./datasets/local/capybara\n"
]
}
],
"source": [
"import lilac as ll\n",
"\n",
"ds = ll.get_dataset('local', 'OpenHermes-2.5-100k')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can cluster the `input` field under the `conversation` array by calling:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[local/capybara][1 shards] map \"extract_text\" to \"('conversation_input__cluster',)\": 100%|██████████| 16006/16006 [00:00<00:00, 30424.61it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wrote map output to conversation_input__cluster-00000-of-00001.parquet\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[local/capybara][1 shards] map \"compute_clusters\" to \"('conversation_input__cluster',)\": 0%| | 0/16006 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"jinaai/jina-embeddings-v2-small-en using device: mps:0\n"
]
}
],
"source": [
"ds.cluster('conversation.*.input')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start a web server to visualize the data\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ll.start_server()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
119 changes: 84 additions & 35 deletions web/blueprint/src/lib/components/ComputeClusterModal.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
</script>

<script lang="ts">
import {clusterMutation} from '$lib/queries/datasetQueries';
import {clusterMutation, queryFormatSelectors} from '$lib/queries/datasetQueries';
import {queryAuthInfo} from '$lib/queries/serverQueries';
import type {Path} from '$lilac';
import {serializePath, type Path} from '$lilac';
import {
ComposedModal,
ModalBody,
ModalFooter,
ModalHeader,
Select,
SelectItem,
Toggle
} from 'carbon-components-svelte';
import FieldSelect from './commands/selectors/FieldSelect.svelte';
Expand All @@ -36,6 +38,29 @@
$: canComputeRemotely = $authInfo.data?.access.dataset.execute_remotely;
$: formatSelectorsQuery =
options != null ? queryFormatSelectors(options?.namespace, options?.datasetName) : null;
let selectedFormatSelector: string | undefined = undefined;
let formatSelectors: string[] | undefined = undefined;
let outputColumn: string | undefined = undefined;
$: outputColumnRequired = formatSelectors != null;
$: {
if (options?.output_path != null) {
outputColumn = serializePath(options.output_path);
}
}
$: {
if (
formatSelectorsQuery != null &&
$formatSelectorsQuery != null &&
$formatSelectorsQuery.data != null
) {
selectedFormatSelector = $formatSelectorsQuery.data[0];
formatSelectors = $formatSelectorsQuery.data;
}
}
function close() {
store.set(null);
}
Expand All @@ -47,7 +72,7 @@
{
input: options.input,
use_garden: options.use_garden,
output_path: options.output_path,
output_path: outputColumn,
overwrite: options.overwrite
}
]);
Expand All @@ -59,47 +84,71 @@
<ComposedModal open on:submit={submit} on:close={close}>
<ModalHeader title="Compute clusters" />
<ModalBody hasForm>
<div class="max-w-2xl">
<FieldSelect
filter={f => f.dtype?.type === 'string'}
defaultPath={options.input}
bind:path={options.input}
labelText="Field"
/>
</div>
<div class="mt-8">
<div class="label mb-2 font-medium text-gray-700">Use Garden</div>
<div class="label mb-2 text-sm text-gray-700">
Accelerate computation by running remotely on <a
href="https://lilacml.com/#garden"
target="_blank">Lilac Garden</a
>
<div class="flex max-w-2xl flex-col gap-y-8">
<div>
<FieldSelect
filter={f => f.dtype?.type === 'string'}
defaultPath={options.input}
bind:path={options.input}
labelText="Field"
/>
</div>
<Toggle
disabled={!canComputeRemotely}
labelA={'False'}
labelB={'True'}
bind:toggled={options.use_garden}
hideLabel
/>
{#if !canComputeRemotely}
<div class="mt-2">
<a href="https://forms.gle/Gz9cpeKJccNar5Lq8" target="_blank">
Sign up for Lilac Garden
</a>
to enable this feature.
{#if formatSelectors != null}
<div>
<div class="label text-s mb-2 font-medium text-gray-700">Selector</div>
<Select hideLabel={true} bind:selected={selectedFormatSelector} required>
{#each formatSelectors as formatSelector}
<SelectItem value={formatSelector} text={formatSelector} />
{/each}
</Select>
</div>
{/if}
</div>
<div class="mt-8">
<div class="label text-s mb-2 font-medium text-gray-700">Overwrite</div>
<Toggle labelA={'False'} labelB={'True'} bind:toggled={options.overwrite} hideLabel />
<div>
<div class="label text-s mb-2 font-medium text-gray-700">
{outputColumnRequired ? '*' : ''} Output column
</div>
<input
required={outputColumnRequired}
class="h-full w-full rounded border border-neutral-300 p-2"
placeholder="Choose a new column name to write clusters"
bind:value={outputColumn}
/>
</div>
<div>
<div class="label mb-2 font-medium text-gray-700">Use Garden</div>
<div class="label text-sm text-gray-700">
Accelerate computation by running remotely on <a
href="https://lilacml.com/#garden"
target="_blank">Lilac Garden</a
>
</div>
<Toggle
disabled={!canComputeRemotely}
labelA={'False'}
labelB={'True'}
bind:toggled={options.use_garden}
hideLabel
/>
{#if !canComputeRemotely}
<div>
<a href="https://forms.gle/Gz9cpeKJccNar5Lq8" target="_blank">
Sign up for Lilac Garden
</a>
to enable this feature.
</div>
{/if}
</div>
<div>
<div class="label text-s mb-2 font-medium text-gray-700">Overwrite</div>
<Toggle labelA={'False'} labelB={'True'} bind:toggled={options.overwrite} hideLabel />
</div>
</div>
</ModalBody>
<ModalFooter
primaryButtonText="Cluster"
secondaryButtonText="Cancel"
on:click:button--secondary={close}
primaryButtonDisabled={outputColumnRequired && !outputColumn}
/>
</ComposedModal>
{/if}
4 changes: 4 additions & 0 deletions web/blueprint/src/lib/queries/datasetQueries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,7 @@ function invalidateQueriesLabelEdit(
]);
}
}
export const queryFormatSelectors = createApiQuery(
DatasetsService.getFormatSelectors,
DATASETS_TAG
);
Loading

0 comments on commit b91e496

Please sign in to comment.