Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into nik-sig-export
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Feb 2, 2024
2 parents 16e515b + 0c98360 commit 8e9ca2a
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 69 deletions.
33 changes: 27 additions & 6 deletions lilac/concepts/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from joblib import Parallel, delayed
from pydantic import BaseModel, field_validator
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve, roc_auc_score
from sklearn.model_selection import KFold
from sklearn.utils.validation import check_is_fitted

from ..embeddings.embedding import get_embed_fn
from ..signal import TextEmbeddingSignal, get_signal_cls
Expand Down Expand Up @@ -140,6 +142,15 @@ class ConceptMetrics(BaseModel):
overall: OverallScore


def _is_fitted(model: LogisticRegression) -> bool:
"""Check if the model is fitted."""
try:
check_is_fitted(model)
return True
except NotFittedError:
return False


@dataclasses.dataclass
class LogisticEmbeddingModel:
"""A model that uses logistic regression with embeddings."""
Expand All @@ -155,7 +166,10 @@ def __post_init__(self) -> None:

def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""Get the scores for the provided embeddings."""
y_probs = self._model.predict_proba(embeddings)[:, 1]
if _is_fitted(self._model):
y_probs = self._model.predict_proba(embeddings)[:, 1]
else:
y_probs = np.ones(len(embeddings)) * 0.5
# Map [0, threshold, 1] to [0, 0.5, 1].
power = np.log(self._threshold) / np.log(0.5)
return y_probs**power
Expand All @@ -173,7 +187,9 @@ def _setup_training(
def fit(self, embeddings: np.ndarray, labels: list[bool]) -> None:
"""Fit the model to the provided embeddings and labels."""
label_set = set(labels)
if len(label_set) < 2:
if len(label_set) == 0:
return
elif len(label_set) < 2:
dim = embeddings.shape[1]
random_vector = np.random.randn(dim).astype(np.float32)
random_vector /= np.linalg.norm(random_vector)
Expand Down Expand Up @@ -206,7 +222,10 @@ def _fit_and_score(
if len(set(y_train)) < 2:
return np.array([]), np.array([])
model.fit(X_train, y_train)
y_pred = model.predict_proba(X_test)[:, 1]
if _is_fitted(model):
y_pred = model.predict_proba(X_test)[:, 1]
else:
y_pred = np.ones_like(y_test) * 0.5
return y_test, y_pred

# Compute the metrics for each validation fold in parallel.
Expand Down Expand Up @@ -298,7 +317,11 @@ def score_embeddings(self, draft: DraftId, embeddings: np.ndarray) -> np.ndarray

def coef(self, draft: DraftId = DRAFT_MAIN) -> np.ndarray:
"""Get the coefficients of the underlying ML model."""
return self._get_logistic_model(draft)._model.coef_.reshape(-1)
model = self._get_logistic_model(draft)
if _is_fitted(model._model):
return model._model.coef_.reshape(-1)
else:
return np.zeros(0)

def _get_logistic_model(self, draft: DraftId = DRAFT_MAIN) -> LogisticEmbeddingModel:
"""Get the logistic model for the provided draft."""
Expand Down Expand Up @@ -345,8 +368,6 @@ def _compute_embeddings(self, concept: Concept) -> None:
concept_embeddings: dict[str, np.ndarray] = {}

examples = concept.data.items()
if not examples:
raise ValueError(f'Cannot sync concept "{concept.concept_name}". It has no examples.')

# Compute the embeddings for the examples with cache miss.
texts_of_missing_embeddings: dict[str, str] = {}
Expand Down
14 changes: 14 additions & 0 deletions lilac/concepts/db_concept_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,17 @@ def test_embedding_not_found_in_map(

with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
model_db.sync(model.namespace, model.concept_name, model.embedding_name)

def test_empty_concept(
self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB]
) -> None:
concept_db = concept_db_cls()
model_db = model_db_cls(concept_db)

namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT)
model = model_db.create(namespace, concept_name, embedding_name='test_embedding')
model = model_db.sync(model.namespace, model.concept_name, model.embedding_name)
# Make sure the model is in sync.
assert model_db.in_sync(model) is True
13 changes: 6 additions & 7 deletions lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ def setup(self) -> None:
azure_api_endpoint = env('AZURE_OPENAI_ENDPOINT')

if not api_key and not azure_api_key:
raise ValueError('`OPENAI_API_KEY` or `AZURE_OPENAI_KEY` '
'environment variables not set, please set one.')
raise ValueError(
'`OPENAI_API_KEY` or `AZURE_OPENAI_KEY` ' 'environment variables not set, please set one.'
)
if api_key and azure_api_key:
raise ValueError(
'Both `OPENAI_API_KEY` and `AZURE_OPENAI_KEY` '
'environment variables are set, please set only one.')
'environment variables are set, please set only one.'
)

try:
import openai
Expand All @@ -61,16 +63,13 @@ def setup(self) -> None:
)

else:

if api_key:
self._client = openai.OpenAI(api_key=api_key)
self._azure = False

elif azure_api_key:
self._client = openai.AzureOpenAI(
api_key=azure_api_key,
api_version=azure_api_version,
azure_endpoint=azure_api_endpoint
api_key=azure_api_key, api_version=azure_api_version, azure_endpoint=azure_api_endpoint
)
self._azure = True
OpenAIEmbedding.local_batch_size = AZURE_OPENAI_BATCH_SIZE
Expand Down
2 changes: 1 addition & 1 deletion lilac/project_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_create_project_and_set_env(tmp_path_factory: pytest.TempPathFacto


async def test_create_project_and_set_env_from_env(
tmp_path_factory: pytest.TempPathFactory
tmp_path_factory: pytest.TempPathFactory,
) -> None:
tmp_path = str(tmp_path_factory.mktemp('test_project'))

Expand Down
2 changes: 1 addition & 1 deletion web/blueprint/src/lib/components/Page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
on:click={() => ($navStore.open = true)}><SidePanelOpen /></button
>
<div
class="flex flex-grow flex-row items-center justify-between justify-items-center gap-x-6 py-2 pr-12"
class="flex flex-grow flex-row items-center justify-between justify-items-center gap-x-6 py-2 pr-8"
>
<div class="flex flex-row items-center">
<div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
$: filters = $datasetViewStore.query.filters;
</script>

<div class="relative mx-8 my-2 flex items-center justify-between pr-4">
<div class="relative mx-8 my-2 flex items-center justify-between">
<div class="flex w-full items-center justify-between gap-x-6 gap-y-2">
<div class="flex w-full flex-row items-center gap-x-4">
<div class="flex items-center gap-x-2">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,14 @@

<div>
{#if firstRowId != null}
<div class="text-xl text-gray-700">Preview</div>
<div class="mb-2 text-xl text-gray-700">Preview</div>
<RowItem
rowId={firstRowId}
index={0}
totalNumRows={$firstRow?.data?.total_num_rows}
{mediaFields}
{highlightedFields}
datasetViewHeight={320}
/>
{/if}
</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@
<SkeletonText />
{:else}
<div class="flex flex-col gap-y-6">
<section class="flex flex-col gap-y-4">
<section class="flex flex-col gap-y-2">
<div class="text-lg text-gray-700">Media fields</div>
<div class="text-sm text-gray-500">
<div class="mb-2 text-sm text-gray-500">
Media fields are text fields that are rendered large in the dataset viewer. They are the
fields on which you can compute signals, embeddings, search, and label.
</div>
Expand Down Expand Up @@ -255,7 +255,7 @@
</section>

<section class="flex flex-col gap-y-1">
<div class="text-lg text-gray-700">View type</div>
<div class="mb-2 text-lg text-gray-700">View type</div>
<div class="flex gap-x-2">
<Chip
icon={Table}
Expand Down
31 changes: 19 additions & 12 deletions web/blueprint/src/lib/components/datasetView/ItemMedia.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
type LilacValueNode,
type Path
} from '$lilac';
import {SkeletonText} from 'carbon-components-svelte';
import {
CatalogPublish,
ChevronDown,
Expand All @@ -56,6 +55,7 @@
// The root path contains the sub-path up to the point of this leaf.
export let rootPath: Path | undefined = undefined;
export let isFetching: boolean | undefined = undefined;
export let datasetViewHeight: number | undefined = undefined;
let childPathParts: string[];
// Find all the children path parts that match a media field.
Expand All @@ -74,7 +74,7 @@
if (lastMediaPath == null) continue;
const subPath = [...rootPath, field.path[rootPath.length]];
const valueNodes = getValueNodes(row!, subPath);
const valueNodes = row != null ? getValueNodes(row, subPath) : [];
for (const childNode of valueNodes) {
const childPath = L.path(childNode)![rootPath.length];
if (childPath != null) {
Expand Down Expand Up @@ -115,7 +115,7 @@
}
}
$: valueNodes = row != null ? getValueNodes(row, rootPath!) : [];
$: valueNodes = row != null ? getValueNodes(row, rootPath!) : null;
// The child component will communicate this back upwards to this component.
let textIsOverBudget = false;
Expand All @@ -126,7 +126,7 @@
const datasetViewStore = getDatasetViewContext();
const appSettings = getSettingsContext();
$: value = L.value(valueNodes[0]) as string;
$: value = valueNodes != null ? (L.value(valueNodes[0]) as string) : null;
$: settings = querySettings($datasetViewStore.namespace, $datasetViewStore.datasetName);
Expand Down Expand Up @@ -329,23 +329,21 @@
{/if}

<div class="grow pt-1">
{#if isFetching}
<SkeletonText class="!w-80" />
{:else if value == null || row == null}
<span class="ml-12 italic">null</span>
{:else if colCompareState == null && spanValuePaths != null && field != null}
{#if colCompareState == null && field != null}
<ItemMediaTextContent
hidden={markdown}
text={value}
{row}
path={rootPath}
{field}
isExpanded={userExpanded}
spanPaths={spanValuePaths.spanPaths}
spanValueInfos={spanValuePaths.spanValueInfos}
spanPaths={spanValuePaths?.spanPaths || []}
spanValueInfos={spanValuePaths?.spanValueInfos || []}
{datasetViewStore}
embeddings={computedEmbeddings}
{viewType}
{isFetching}
{datasetViewHeight}
bind:textIsOverBudget
/>
<div class="markdown w-full" class:hidden={!markdown}>
Expand All @@ -354,7 +352,14 @@
</div>
</div>
{:else if colCompareState != null}
<ItemMediaDiff {row} {colCompareState} bind:textIsOverBudget isExpanded={userExpanded} />
<ItemMediaDiff
{row}
{colCompareState}
bind:textIsOverBudget
isExpanded={userExpanded}
{datasetViewHeight}
{isFetching}
/>
{/if}
</div>
</div>
Expand Down Expand Up @@ -387,6 +392,8 @@
{mediaFields}
{row}
{highlightedFields}
{datasetViewHeight}
{isFetching}
/>
</div>
{/each}
Expand Down
39 changes: 28 additions & 11 deletions web/blueprint/src/lib/components/datasetView/ItemMediaDiff.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@
import type * as Monaco from 'monaco-editor/esm/vs/editor/editor.api';
import {onDestroy, onMount} from 'svelte';
import {getMonaco, MONACO_OPTIONS} from '$lib/monaco';
import {
DEFAULT_HEIGHT_PEEK_SCROLL_PX,
DEFAULT_HEIGHT_PEEK_SINGLE_ITEM_PX,
getMonaco,
MONACO_OPTIONS
} from '$lib/monaco';
import {getDatasetViewContext, type ColumnComparisonState} from '$lib/stores/datasetViewStore';
import {getDisplayPath} from '$lib/view_utils';
import {getValueNodes, L, type LilacValueNode} from '$lilac';
import {getValueNodes, L, type DatasetUISettings, type LilacValueNode} from '$lilac';
import {PropertyRelationship} from 'carbon-icons-svelte';
import {hoverTooltip} from '../common/HoverTooltip';
const MAX_MONACO_HEIGHT_COLLAPSED = 360;
const MAX_MONACO_HEIGHT_EXPANDED = 720;
const datasetViewStore = getDatasetViewContext();
export let row: LilacValueNode;
export let row: LilacValueNode | undefined | null;
export let colCompareState: ColumnComparisonState;
export let textIsOverBudget: boolean;
export let isExpanded: boolean;
export let viewType: DatasetUISettings['view_type'] | undefined = undefined;
export let datasetViewHeight: number | undefined = undefined;
export let isFetching: boolean | undefined = undefined;
let editorContainer: HTMLElement;
Expand All @@ -27,8 +34,8 @@
$: rightPath = colCompareState.swapDirection
? colCompareState.column
: colCompareState.compareToColumn;
$: leftValue = L.value(getValueNodes(row, leftPath)[0]) as string;
$: rightValue = L.value(getValueNodes(row, rightPath)[0]) as string;
$: leftValue = row != null ? (L.value(getValueNodes(row, leftPath)[0]) as string) : '';
$: rightValue = row != null ? (L.value(getValueNodes(row, rightPath)[0]) as string) : '';
let monaco: typeof Monaco;
let editor: Monaco.editor.IStandaloneDiffEditor;
Expand All @@ -38,7 +45,10 @@
relayout();
}
}
$: maxMonacoHeightCollapsed = datasetViewHeight
? datasetViewHeight -
(viewType === 'scroll' ? DEFAULT_HEIGHT_PEEK_SCROLL_PX : DEFAULT_HEIGHT_PEEK_SINGLE_ITEM_PX)
: MAX_MONACO_HEIGHT_COLLAPSED;
function relayout() {
if (
editor != null &&
Expand All @@ -51,10 +61,12 @@
);
textIsOverBudget = contentHeight > MAX_MONACO_HEIGHT_COLLAPSED;
if (isExpanded || !textIsOverBudget) {
editorContainer.style.height = `${Math.min(contentHeight, MAX_MONACO_HEIGHT_EXPANDED)}px`;
if (isExpanded) {
editorContainer.style.height = contentHeight + 'px';
} else if (!textIsOverBudget) {
editorContainer.style.height = `${Math.min(contentHeight, maxMonacoHeightCollapsed)}px`;
} else {
editorContainer.style.height = MAX_MONACO_HEIGHT_COLLAPSED + 'px';
editorContainer.style.height = maxMonacoHeightCollapsed + 'px';
}
editor.layout();
}
Expand Down Expand Up @@ -93,7 +105,8 @@
});
</script>

<div class="relative -ml-6 flex h-fit w-full flex-col gap-x-4">
<!-- For reasons unknown to me, the -ml-6 is required to make the autolayout of monaco react. -->
<div class="relative left-16 -ml-10 flex h-fit w-full flex-col gap-x-4 pr-6">
<div class="flex flex-row items-center font-mono text-xs font-medium text-neutral-500">
<div class="ml-8 w-1/2">{getDisplayPath(leftPath)}</div>
<div class="ml-8 w-1/2">{getDisplayPath(rightPath)}</div>
Expand All @@ -106,6 +119,10 @@
</div>
</div>
<div class="editor-container" bind:this={editorContainer} />
{#if isFetching}
<!-- Transparent overlay when fetching rows. -->
<div class="absolute inset-0 flex items-center justify-center bg-white bg-opacity-70" />
{/if}
</div>

<style lang="postcss">
Expand Down
Loading

0 comments on commit 8e9ca2a

Please sign in to comment.