Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for pydantic<3.0 #88

Open
wants to merge 25 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,22 @@ jobs:
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ]
torch-version: [ '1.13', '2.0.1' ]
torch-version: [ '1.13', '2.0.1', '2.5.0']
exclude:
- torch-version: '1.13'
python-version: '3.12'
- python-version: '3.8'
torch-version: '2.5.0'
- python-version: '3.9'
torch-version: '2.5.0'
- python-version: '3.10'
torch-version: '2.5.0'
- python-version: '3.11'
torch-version: '1.13'
- python-version: '3.11'
torch-version: '2.5.0'
- python-version: '3.12'
torch-version: '1.13'
- python-version: '3.12'
torch-version: '2.0.1'

steps:
- uses: actions/checkout@v3
Expand All @@ -23,16 +35,35 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Build the package
- name: Check Python version # https://github.com/python/cpython/issues/95299
id: check-version
run: |
python_version=$(python --version | awk '{print $2}')
major=$(echo $python_version | cut -d'.' -f1)
minor=$(echo $python_version | cut -d'.' -f2)
if ([ "$major" -eq 3 ] && [ "$minor" -ge 12 ]); then
echo "setuptools_present=false" >> $GITHUB_ENV
else
echo "setuptools_present=true" >> $GITHUB_ENV
fi

- name: Build the package (python >= 3.12)
if: env.setuptools_present == 'false'
run: |
python -m pip install build
python -m build

- name: Build the package (python < 3.12)
if: env.setuptools_present == 'true'
run: |
python setup.py sdist

- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .
python -m pip install -r tests/dev-requirements.txt
python -m pip install torch==${{ matrix.torch-version }}
python -m pip install -e .
cd tests
export MODULE_PARENT=$(python -c "import $MODULE_NAME, os; print(os.path.dirname($MODULE_NAME.__path__[0]))")
export MODULE_PARENT=${MODULE_PARENT%"/"}
Expand Down
2 changes: 1 addition & 1 deletion docs/cli/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The command shown above will run SLURM job with 4 CPUs and 100G of RAM.

### Predefined run configs
You can predefine run configs to avoid reentering the same flags.
Create `~/.config/thunder/backends.yml` (you can run `thunder show` in your terminal,
Create `~/.config/thunder/backends.yml` (you can run `thunder backend list` in your terminal,
required path will be at the title of the table) in you home directory.
Now you can specify config name and its parameters:
```yaml
Expand Down
9 changes: 8 additions & 1 deletion docs/examples/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ module = ThunderModule(
architecture, nn.CrossEntropyLoss(), optimizer=torch.optim.Adam(architecture.parameters())
)

# Preparing metrics
# 'y' and 'x' are single label and
# model prediction for a single image,
# hence the 'np.argmax(x)' for extracting
# the predicted label.
group_accuracy = {lambda y, x: (y, np.argmax(x)): accuracy_score}

# Initialize a trainer
trainer = Trainer(
callbacks=[ModelCheckpoint(save_last=True),
MetricMonitor(group_metrics={lambda y, x: (np.argmax(y), x): accuracy_score})],
MetricMonitor(group_metrics=group_accuracy)],
accelerator="auto",
devices=1,
max_epochs=100,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
lightning>=2.0.0,<3.0.0
lazycon>=0.6.3,<1.0.0
typer>=0.9.0,<1.0.0
pydantic<2.0.0
pydantic<3.0.0
click
torch
toolz
Expand Down
16 changes: 11 additions & 5 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_build(temp_dir, mock_backend):

with cleanup(experiment):
result = invoke('build', config, experiment, '-u', 'c=3')
assert result.exit_code != 0
assert result.exit_code != 0, result.output
assert 'are missing from the config' in str(result.exception)

result = invoke('build', config, experiment, '-u', 'a=10')
assert result.exit_code == 0
assert result.exit_code == 0, result.output
assert Config.load(experiment / 'experiment.config').a == 10

with cleanup(experiment):
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_build_overwrite(temp_dir):
config.write_text('b = 2')

result = invoke('build', config, experiment, "--overwrite")
assert result.exit_code == 0
assert result.exit_code == 0, result.output
assert not hasattr(read_config(experiment / "experiment.config"), "a")
assert read_config(experiment / "experiment.config").b == 2

Expand Down Expand Up @@ -187,6 +187,11 @@ def test_backend_add(temp_dir, mock_backend):
local = load_backend_configs()
assert "new_config" in local and "new_config_2" in local

invoke("backend", "add", "new_config_3", "backend=cli", "n_workers=8")
local = load_backend_configs()
assert "new_config" in local and "new_config_2" in local
assert "new_config_3" in local


def test_backend_list(temp_dir, mock_backend):
# language=yaml
Expand All @@ -208,10 +213,11 @@ def test_backend_list(temp_dir, mock_backend):


def test_backend_set(temp_dir, mock_backend):
assert invoke("backend", "add", "config", "backend=slurm", "ram=100G", "--force").exit_code == 0
result = invoke("backend", "add", "config", "backend=slurm", "ram=100G", "--force")
assert result.exit_code == 0, result.output
result = invoke("backend", "set", "config")

assert result.exit_code == 0
assert result.exit_code == 0, result.output
local = load_backend_configs()
assert local[local["meta"].default].config.ram == "100G"

Expand Down
29 changes: 19 additions & 10 deletions thunder/backend/interface.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pathlib import Path
from typing import Dict, Optional, Sequence, Type

from pydantic import BaseModel, Extra, validator
from pydantic import BaseModel

from ..layout import Node
from ..pydantic_compat import PYDANTIC_MAJOR, NoExtra, field_validator, model_validate


class BackendConfig(BaseModel):
class Config:
extra = Extra.ignore
class BackendConfig(NoExtra):
"""Backend Parameters"""


class Backend:
Expand All @@ -19,24 +19,33 @@ def run(config: BackendConfig, experiment: Path, nodes: Optional[Sequence[Node]]
"""Start running the given `nodes` of an experiment located at the given path"""


class BackendEntryConfig(BaseModel):
class BackendEntryConfig(NoExtra):
backend: str
config: BackendConfig

@validator('config', pre=True)
@field_validator("config", mode="before")
def _val_config(cls, v, values):
val = backends[values['backend']]
return val.Config.parse_obj(v)
return parse_backend_config(v, values)

@property
def backend_cls(self):
return backends[self.backend]

class Config:
extra = Extra.ignore

if PYDANTIC_MAJOR == 2:
def parse_backend_config(v, values):
val = backends[values.data["backend"]]
return model_validate(val.Config, v)
else:
def parse_backend_config(v, values):
val = backends[values["backend"]]
return model_validate(val.Config, v)


class MetaEntry(BaseModel):
"""
Default backend set by `thunder backend set`
"""
default: str


Expand Down
36 changes: 18 additions & 18 deletions thunder/backend/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from typing import Optional, Sequence

from deli import save
from pydantic import validator
from pytimeparse.timeparse import timeparse
from typer import Option
from typing_extensions import Annotated

from ..layout import Node
from ..pydantic_compat import field_validator
from .interface import Backend, BackendConfig, backends


Expand All @@ -26,45 +26,45 @@

class Slurm(Backend):
class Config(BackendConfig):
ram: Annotated[str, Option(
..., '-r', '--ram', '--mem',
ram: Annotated[Optional[str], Option(
None, '-r', '--ram', '--mem',
help='The amount of RAM required per node. Default units are megabytes. '
'Different units can be specified using the suffix [K|M|G|T].'
)] = None
cpu: Annotated[int, Option(
..., '-c', '--cpu', '--cpus-per-task', show_default=False,
cpu: Annotated[Optional[int], Option(
None, ..., '-c', '--cpu', '--cpus-per-task', show_default=False,
help='Number of CPU cores to allocate. Default to 1'
)] = None
gpu: Annotated[int, Option(
..., '-g', '--gpu', '--gpus-per-node',
gpu: Annotated[Optional[int], Option(
None, '-g', '--gpu', '--gpus-per-node',
help='Number of GPUs to allocate'
)] = None
partition: Annotated[str, Option(
..., '-p', '--partition',
partition: Annotated[Optional[str], Option(
None, '-p', '--partition',
help='Request a specific partition for the resource allocation'
)] = None
nodelist: Annotated[str, Option(
...,
nodelist: Annotated[Optional[str], Option(
None,
help='Request a specific list of hosts. The list may be specified as a comma-separated '
'list of hosts, a range of hosts (host[1-5,7,...] for example).'
'list of hosts, a range of hosts (host[1-5,7,None] for example).'
)] = None
time: Annotated[str, Option(
..., '-t', '--time',
time: Annotated[Optional[str], Option(
None, '-t', '--time',
help='Set a limit on the total run time of the job allocation. When the time limit is reached, '
'each task in each job step is sent SIGTERM followed by SIGKILL.'
)] = None
limit: Annotated[int, Option(
...,
limit: Annotated[Optional[int], Option(
None,
help='Limit the number of jobs that are simultaneously running during the experiment',
)] = None

@validator('time')
@field_validator("time")
def val_time(cls, v):
if v is None:
return
return parse_duration(v)

@validator('limit')
@field_validator("limit")
def val_limit(cls, v):
assert v is None or v > 0, 'The jobs limit, if specified, must be positive'
return v
Expand Down
51 changes: 34 additions & 17 deletions thunder/cli/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typer.models import ParamMeta

from ..backend import BackendEntryConfig, MetaEntry, backends
from ..pydantic_compat import model_validate, resolve_pydantic_major
from .app import app


Expand Down Expand Up @@ -92,23 +93,39 @@ def populate(backend_name):
show_default=False,
),
)]
for field in entry.backend_cls.Config.__fields__.values():
annotation = field.outer_type_
# TODO: https://stackoverflow.com/a/68337036
if not hasattr(annotation, '__metadata__') or not hasattr(annotation, '__origin__'):
raise ValueError('Please use the `Annotated` syntax to annotate you backend config')

# TODO
default, = annotation.__metadata__
default = copy.deepcopy(default)
default.default = getattr(entry.config, field.name)
default.help = f'[{backend_name} backend] {default.help}'
backend_params.append(ParamMeta(
name=field.name, default=default, annotation=annotation.__origin__,
))
backend_params.extend(_collect_backend_params(entry, backend_name))
return backend_params


if resolve_pydantic_major() >= 2:
def _collect_backend_params(entry, backend_name):
"""
Config Annotation depends on pydantic version.
"""
for field_name, field in entry.backend_cls.Config.model_fields.items():
field_clone = copy.deepcopy(field)
field_clone.default = getattr(entry.config, field_name)
yield ParamMeta(
name=field_name, default=field_clone.default, annotation=field.annotation,
)
else:
def _collect_backend_params(entry, backend_name):
for field in entry.backend_cls.Config.__fields__.values():
annotation = field.outer_type_
# TODO: https://stackoverflow.com/a/68337036
if not hasattr(annotation, '__metadata__') or not hasattr(annotation, '__origin__'):
raise ValueError('Please use the `Annotated` syntax to annotate you backend config')

# TODO
default, = annotation.__metadata__
default = copy.deepcopy(default)
default.default = getattr(entry.config, field.name)
default.help = f'[{backend_name} backend] {default.help}'
yield ParamMeta(
name=field.name, default=default, annotation=annotation.__origin__,
)


def collect_backends() -> ChainMap:
"""
Collects backend for each config.
Expand Down Expand Up @@ -144,7 +161,7 @@ def collect_configs() -> Tuple[ChainMap, Union[MetaEntry, None]]:
def load_backend_configs() -> Dict[str, Union[BackendEntryConfig, MetaEntry]]:
path = BACKENDS_CONFIG_PATH
if not path.exists():
# print(path, flush=True)
# TODO: return Option[Dict]
return {}

with path.open('r') as file:
Expand All @@ -153,5 +170,5 @@ def load_backend_configs() -> Dict[str, Union[BackendEntryConfig, MetaEntry]]:
return {}
# FIXME
assert isinstance(local, dict), type(local)
return {k: BackendEntryConfig.parse_obj(v)
if k != "meta" else MetaEntry.parse_obj(v) for k, v in local.items()}
return {k: model_validate(BackendEntryConfig, v)
if k != "meta" else model_validate(MetaEntry, v) for k, v in local.items()}
Loading
Loading