Skip to content

Commit

Permalink
add examples smoke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed Jul 30, 2024
1 parent 22175cb commit 582c5b3
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ env:
FORCE_COLOR: "1"

jobs:
build:
run:
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-benchmarks') }}
runs-on: ubuntu-latest

Expand Down
37 changes: 37 additions & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Examples

on:
workflow_dispatch:
schedule:
- cron: '0 3 * * *'
push: # to remove

env:
FORCE_COLOR: "1"

jobs:
run:
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest-8-cores, macos-latest, windows-latest-8-cores]
pyv: ['3.9', '3.12']
steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.pyv }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.pyv }}
cache: 'pip'

- name: Upgrade nox and uv
run: |
python -m pip install --upgrade 'nox[uv]'
nox --version
uv --version
- name: Run examples
run: nox -s examples -p ${{ matrix.pyv }}
19 changes: 9 additions & 10 deletions examples/get_started/json-csv-reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main():
print("========================================================================")
uri = "gs://datachain-demo/jsonl/object.jsonl"
jsonl_ds = DataChain.from_json(uri, meta_type="jsonl", show_schema=True)
print(jsonl_ds.to_pandas())
jsonl_ds.show()

print()
print("========================================================================")
Expand All @@ -49,8 +49,7 @@ def main():
json_pairs_ds = DataChain.from_json(
uri, schema_from=schema_uri, jmespath="@", model_name="OpenImage"
)
print(json_pairs_ds.to_pandas())
# print(list(json_pairs_ds.collect())[0])
json_pairs_ds.show()

uri = "gs://datachain-demo/coco2017/annotations_captions/"

Expand All @@ -72,14 +71,14 @@ def main():
static_json_ds = DataChain.from_json(
uri, jmespath="licenses", spec=LicenseFeature, nrows=3
)
print(static_json_ds.to_pandas())
static_json_ds.show()

print()
print("========================================================================")
print("dynamic JSON schema test parsing 5K objects")
print("========================================================================")
dynamic_json_ds = DataChain.from_json(uri, jmespath="images", show_schema=True)
print(dynamic_json_ds.to_pandas())
dynamic_json_ds.show()

uri = "gs://datachain-demo/chatbot-csv/"
print()
Expand All @@ -88,16 +87,16 @@ def main():
print("========================================================================")
static_csv_ds = DataChain.from_csv(uri, output=ChatDialog, object_name="chat")
static_csv_ds.print_schema()
print(static_csv_ds.to_pandas())
static_csv_ds.show()

uri = "gs://datachain-demo/laion-aesthetics-csv"
uri = "gs://datachain-demo/laion-aesthetics-csv/"
print()
print("========================================================================")
print("dynamic CSV with header schema test parsing 3/3M objects")
print("dynamic CSV with header schema test parsing 3M objects")
print("========================================================================")
dynamic_csv_ds = DataChain.from_csv(uri, object_name="laion", nrows=3)
dynamic_csv_ds = DataChain.from_csv(uri, object_name="laion")
dynamic_csv_ds.print_schema()
print(dynamic_csv_ds.to_pandas())
dynamic_csv_ds.show()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/get_started/torch-loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def forward(self, x):
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 10
num_epochs = 3
for epoch in range(num_epochs):
for i, data in enumerate(train_loader):
inputs, labels = data
Expand Down
31 changes: 20 additions & 11 deletions examples/multimodal/wds.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pandas as pd

from datachain import C, DataChain
from datachain.lib.webdataset import process_webdataset
from datachain.lib.webdataset_laion import WDSLaion, process_laion_meta
Expand All @@ -9,25 +7,36 @@
.filter(C("file.name").glob("00000000.tar"))
.settings(cache=True)
.gen(laion=process_webdataset(spec=WDSLaion), params="file")
.save() # materialize chain to avoid downloading data multiple times
)

meta_pq = (
DataChain.from_parquet("gs://datachain-demo/datacomp-small/metadata/0020f*.parquet")
.filter(
C("uid").in_(values[0] for values in wds.select("laion.json.uid").collect())
)
.map(stem=lambda file: file.get_file_stem(), params=["source.file"], output=str)
.save()
)

meta_emd = (
DataChain.from_storage("gs://datachain-demo/datacomp-small/metadata")
.filter(C("file.name").glob("0020f*.npz"))
DataChain.from_storage("gs://datachain-demo/datacomp-small/metadata/0020f*.npz")
.gen(emd=process_laion_meta)
.filter(
C("emd.index").in_(
values[0] for values in meta_pq.select("source.index").collect()
)
)
.map(stem=lambda file: file.get_file_stem(), params=["emd.file"], output=str)
)

meta_pq = DataChain.from_parquet(
"gs://datachain-demo/datacomp-small/metadata/0020f*.parquet"
).map(stem=lambda file: file.get_file_stem(), params=["source.file"], output=str)

meta = meta_emd.merge(
meta_pq, on=["stem", "emd.index"], right_on=["stem", "source.index"]
meta_pq,
on=["stem", "emd.index"],
right_on=["stem", "source.index"],
)

res = wds.merge(meta, on="laion.json.uid", right_on="uid")

df = res.limit(10).to_pandas()
with pd.option_context("display.max_columns", None):
print(df)
res.show()
11 changes: 11 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,14 @@ def dev(session: nox.Session) -> None:

python = os.path.join(venv_dir, "bin/python")
session.run(python, "-m", "pip", "install", "-e", ".[dev]", external=True)


@nox.session(python=["3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"])
def examples(session: nox.Session) -> None:
session.install(".[tests]")
session.run(
"pytest",
"-m",
"examples",
*session.posargs,
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ namespaces = false
[tool.setuptools_scm]

[tool.pytest.ini_options]
addopts = "-rfEs -m 'not benchmark'"
addopts = "-rfEs -m 'not benchmark and not examples'"
markers = [
"benchmark: benchmarks.",
"e2e: End-to-end tests"
Expand Down
48 changes: 48 additions & 0 deletions tests/examples/test_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import glob
import os
import subprocess
import sys

import pytest

get_started_examples = glob.glob("examples/get_started/**/*.py", recursive=True)
llm_and_nlp_examples = filter(
# no anthropic token
lambda filename: "claude" not in filename,
glob.glob("examples/llm_and_nlp/**/*.py", recursive=True),
)
multimodal_examples = filter(
# no OpenAI token and hf download painfully slow
lambda filename: "openai" not in filename and "hf" not in filename,
glob.glob("examples/multimodal/**/*.py", recursive=True),
)


def smoke_test(example: str):
completed_process = subprocess.run( # noqa: S603
[sys.executable, example],
capture_output=True,
cwd=os.path.abspath(os.path.join(__file__, "..", "..", "..")),
check=True,
)

assert completed_process.stdout
assert completed_process.stderr


@pytest.mark.examples
@pytest.mark.parametrize("example", get_started_examples)
def test_get_started_examples(example):
smoke_test(example)


@pytest.mark.examples
@pytest.mark.parametrize("example", llm_and_nlp_examples)
def test_llm_and_nlp_examples(example):
smoke_test(example)


@pytest.mark.examples
@pytest.mark.parametrize("example", multimodal_examples)
def test_multimodal(example):
smoke_test(example)

0 comments on commit 582c5b3

Please sign in to comment.