Differentiable Architecture Search Toolkit in PyTorch Lightning
Tip
Boost your research and use solid engineering practices out-of-the-box
Use this toolkit to:
- Research your own DARTS algorithm with pre-built components and create your own components
- Use existing DARTS architectures, like Partially-Connected Differentiable Architecture Search
- Configure hyperparameters with
yaml
files - Scale to multiple GPUs with no effort
- Visualize your neural network architecture
from darts_toolkit.models import LPCDARTSLightningModule
from darts_toolkit.data import CIFAR10DataModule
from darts_toolkit.utils.yaml import load_config
import yaml
# Load configuration
config = load_config(os.path.join("src", "config.yaml"))
# Create data module
data_module = CIFAR10DataModule(config)
# Create model
model = LPCDARTSLightningModule(config)
# Search phase
search_model = LPCDARTSLightningModule(config)
search_trainer = pl.Trainer(
max_epochs=config["training"]["max_epochs"],
accelerator="gpu" if config["training"].get("gpus") else "auto",
devices=config["training"].get("gpus") or "auto",
callbacks=[RichProgressBar()],
logger=TensorBoardLogger(
config["logging"]["log_dir"],
name=f"{config['logging']['experiment_name']}_search",
),
)
# Train the search model
search_trainer.fit(search_model, data_module)
# Test the search model
search_trainer.test(search_model, datamodule=data_module)
# Derive and train the final architecture
derived_architecture = search_model.derive_architecture()
derived_model = DerivedPCDARTSModel(
derived_architecture=derived_architecture, config=config
)
derived_trainer = pl.Trainer(
max_epochs=config["training"]["derived_epochs"],
accelerator="gpu" if config["training"].get("gpus") else "auto",
devices=config["training"].get("gpus") or "auto",
callbacks=[ModelCheckpoint(monitor="val_acc", mode="max"), RichProgressBar()],
logger=TensorBoardLogger(
config["logging"]["log_dir"],
name=f"{config['logging']['experiment_name']}_derived",
),
)
# Train the derived model
derived_trainer.fit(
derived_model,
train_dataloaders=data_module.train_dataloader()["train"],
val_dataloaders=data_module.val_dataloader(),
)
# Test the derived model
derived_trainer.test(derived_model, datamodule=data_module)
Using pip:
pip install git+https://github.com/jmaczan/darts-toolkit.git
Using uv:
uv pip install git+https://github.com/jmaczan/darts-toolkit.git
git clone https://github.com/jmaczan/darts-toolkit.git
cd darts-toolkit
# Install using uv (recommended)
uv pip install -e .
# Or install using pip
pip install -e .
This project uses uv for package management
Also, it uses Ruff for formatting if you run the project in VS Code. You can install Ruff plugin by Astral Software from extensions marketplace and you're good to go
uv sync
uv run python -m src.models.lightning_pc_darts
If you use this software in your research, please use the following citation:
@software{Maczan_PCDARTS_2024,
author = {Maczan, Jędrzej Paweł},
title = {Differentiable Architecture Search Toolkit in PyTorch Lightning},
url = {https://github.com/jmaczan/darts-toolkit},
year = {2024},
publisher = {GitHub}
}
GNU GPLv3
Jędrzej Maczan, 2024