Skip to content

Commit

Permalink
refactor(yaml_model_schema): added Pydantic classes for parsing model…
Browse files Browse the repository at this point in the history
… yaml
  • Loading branch information
mathysgrapotte committed Jan 28, 2025
1 parent 1ee47aa commit 6ee91ea
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/stimulus/utils/yaml_model_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,66 @@
from copy import deepcopy
from typing import Any

import pydantic
import yaml
from ray import tune


class Loss(pydantic.BaseModel):
"""Loss parameters."""

loss_fn: dict[str, Any]


class Data(pydantic.BaseModel):
"""Data parameters."""

batch_size: dict[str, Any]


class TuneParams(pydantic.BaseModel):
"""Tune parameters."""

metric: str
mode: str
num_samples: int


class Scheduler(pydantic.BaseModel):
"""Scheduler parameters."""

name: str
params: dict[str, Any]


class RunParams(pydantic.BaseModel):
"""Run parameters."""

stop: dict[str, Any]


class Tune(pydantic.BaseModel):
"""Tune parameters."""

config_name: str
tune_params: TuneParams
scheduler: Scheduler
run_params: RunParams
step_size: int
gpu_per_trial: int
cpu_per_trial: int


class Model(pydantic.BaseModel):
"""Model configuration."""

model_params: dict[str, Any]
optimizer_params: dict[str, Any]
loss_params: Loss
data_params: Data
tune: Tune


class YamlRayConfigLoader:
"""Load and convert YAML configurations to Ray Tune format.
Expand Down

0 comments on commit 6ee91ea

Please sign in to comment.