-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add meta learning framework #183
Changes from 3 commits
475e430
5cbe7b9
05e0870
e8435c6
bff31a9
0f4032c
73d7a65
5a029fa
d11788d
60fa8c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import pytz | ||
import torch | ||
|
||
from datetime import datetime | ||
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary | ||
from lightning.pytorch.loggers import WandbLogger | ||
|
||
from rl4co.envs import CVRPEnv | ||
from rl4co.models.zoo.am import AttentionModelPolicy | ||
from rl4co.models.zoo.pomo import POMO | ||
from rl4co.utils.meta_trainer import RL4COMetaTrainer, MetaModelCallback | ||
|
||
def main(): | ||
# Set device | ||
device_id = 0 | ||
|
||
# RL4CO env based on TorchRL | ||
env = CVRPEnv(generator_params={'num_loc': 50}) | ||
|
||
# Policy: neural network, in this case with encoder-decoder architecture | ||
# Note that this is adapted the same as POMO did in the original paper | ||
policy = AttentionModelPolicy(env_name=env.name, | ||
embed_dim=128, | ||
num_encoder_layers=6, | ||
num_heads=8, | ||
normalization="instance", | ||
use_graph_context=False | ||
) | ||
|
||
# RL Model (POMO) | ||
model = POMO(env, | ||
policy, | ||
batch_size=64, # meta_batch_size | ||
train_data_size=64 * 50, # each epoch | ||
val_data_size=0, | ||
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6}, | ||
# for the task scheduler of size setting, where sch_epoch = 0.9 * epochs | ||
) | ||
|
||
# Example callbacks | ||
checkpoint_callback = ModelCheckpoint( | ||
dirpath="checkpoints", # save to checkpoints/ | ||
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt | ||
save_top_k=1, # save only the best model | ||
save_last=True, # save the last model | ||
monitor="val/reward", # monitor validation reward | ||
mode="max", # maximize validation reward | ||
) | ||
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback | ||
# Meta callbacks | ||
meta_callback = MetaModelCallback( | ||
meta_params={ | ||
'meta_method': 'reptile', # choose from ['maml', 'fomaml', 'maml_fomaml', 'reptile'] | ||
'data_type': 'size', # choose from ["size", "distribution", "size_distribution"] | ||
'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs | ||
'B': 1, # the number of tasks in a mini-batch | ||
'alpha': 0.99, # params for the outer-loop optimization of reptile | ||
'alpha_decay': 0.999, # params for the outer-loop optimization of reptile | ||
'min_size': 20, # minimum of sampled size in meta tasks | ||
'max_size': 150, # maximum of sampled size in meta tasks | ||
}, | ||
print_log=True # whether to print the sampled tasks in each meta iteration | ||
) | ||
callbacks = [meta_callback, checkpoint_callback, rich_model_summary] | ||
|
||
# Logger | ||
process_start_time = datetime.now(pytz.timezone("Asia/Singapore")) | ||
logger = WandbLogger(project="rl4co", name=f"{env.name}_{process_start_time.strftime('%Y%m%d_%H%M%S')}") | ||
# logger = None # uncomment this line if you don't want logging | ||
|
||
# Adjust your trainer to the number of epochs you want to run | ||
trainer = RL4COMetaTrainer( | ||
max_epochs=20000, # (the number of meta-model updates) * (the number of tasks in a mini-batch) | ||
callbacks=callbacks, | ||
accelerator="gpu", | ||
devices=[device_id], | ||
logger=logger, | ||
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method | ||
) | ||
|
||
# Fit | ||
trainer.fit(model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
from typing import Iterable, List, Optional, Union | ||
|
||
import lightning.pytorch as pl | ||
import torch | ||
import math | ||
import copy | ||
from torch.optim import Adam | ||
|
||
from lightning import Callback, Trainer | ||
from lightning.fabric.accelerators.cuda import num_cuda_devices | ||
from lightning.pytorch.accelerators import Accelerator | ||
from lightning.pytorch.core.datamodule import LightningDataModule | ||
from lightning.pytorch.loggers import Logger | ||
from lightning.pytorch.strategies import DDPStrategy, Strategy | ||
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS | ||
from rl4co import utils | ||
import random | ||
log = utils.get_pylogger(__name__) | ||
|
||
|
||
class MetaModelCallback(Callback): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't thought of Meta-learning as Lightning Callbacks, but it looks neat! :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tagging @Junyoungpark for the good ol' Reptile |
||
def __init__(self, meta_params, print_log=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Minor] I think it would be more clear to list the Hyperparameters. For example def __init__(self,
alpha=...,
alpha_decay=...
...
print_log=True): which is generally easier to maintain and to document. |
||
super().__init__() | ||
self.meta_params = meta_params | ||
assert meta_params["meta_method"] == 'reptile', NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another general comment: it seems that this callback is only for Reptile, so we should consider calling it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I''d say under |
||
assert meta_params["data_type"] == 'size', NotImplementedError | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that this parameter |
||
self.print_log = print_log | ||
|
||
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
|
||
# Initialize some hyperparameters | ||
self.alpha = self.meta_params["alpha"] | ||
self.alpha_decay = self.meta_params["alpha_decay"] | ||
self.sch_bar = self.meta_params["sch_bar"] | ||
self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)] | ||
|
||
# Sample a batch of tasks | ||
self._sample_task() | ||
self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) | ||
|
||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Minor] type hints should be classes, not strings. For example There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi. I am refactoring the |
||
|
||
# Alpha scheduler (decay for the update of meta model) | ||
self._alpha_scheduler() | ||
|
||
# Reinitialize the task model with the parameters of the meta model | ||
if trainer.current_epoch % self.meta_params['B'] == 0: # Save the meta model | ||
self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict()) | ||
self.task_models = [] | ||
# Print sampled tasks | ||
if self.print_log: | ||
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.meta_params['B'], trainer.current_epoch, self.selected_tasks)) | ||
else: | ||
pl_module.load_state_dict(self.meta_model_state_dict) | ||
|
||
# Reinitialize the optimizer every epoch | ||
lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1 | ||
old_lr = trainer.optimizers[0].param_groups[0]['lr'] | ||
new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay) | ||
trainer.optimizers = [new_optimizer] | ||
|
||
if self.print_log: | ||
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity)) | ||
|
||
def on_train_epoch_end(self, trainer, pl_module): | ||
|
||
# Save the task model | ||
self.task_models.append(copy.deepcopy(pl_module.state_dict())) | ||
if (trainer.current_epoch+1) % self.meta_params['B'] == 0: | ||
# Outer-loop optimization (update the meta model with the parameters of the task model) | ||
with torch.no_grad(): | ||
state_dict = {params_key: (self.meta_model_state_dict[params_key] + | ||
self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key] | ||
for fast_weight in self.task_models], dim=0).float(), dim=0)) | ||
for params_key in self.meta_model_state_dict} | ||
pl_module.load_state_dict(state_dict) | ||
|
||
# Get ready for the next meta-training iteration | ||
if (trainer.current_epoch + 1) % self.meta_params['B'] == 0: | ||
# Sample a batch of tasks | ||
self._sample_task() | ||
|
||
# Load new training task (Update the environment) | ||
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.meta_params['B']) | ||
|
||
def _sample_task(self): | ||
# Sample a batch of tasks | ||
w, self.selected_tasks = [1.0] * self.meta_params['B'], [] | ||
for b in range(self.meta_params['B']): | ||
task_params = random.sample(self.task_set, 1)[0] | ||
self.selected_tasks.append(task_params) | ||
self.w = torch.softmax(torch.Tensor(w), dim=0) | ||
|
||
def _load_task(self, pl_module, task_idx=0): | ||
# Load new training task (Update the environment) | ||
task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item() | ||
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20 | ||
pl_module.env.generator.num_loc = task_params[0] | ||
pl_module.env.generator.capacity = task_capacity | ||
|
||
def _alpha_scheduler(self): | ||
self.alpha = max(self.alpha * self.alpha_decay, 0.0001) | ||
|
||
class RL4COMetaTrainer(Trainer): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there any differences compared to the |
||
"""Wrapper around Lightning Trainer, with some RL4CO magic for efficient training. | ||
|
||
# Meta training framework for addressing the generalization issue | ||
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 | ||
|
||
Note: | ||
The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`. | ||
This allows for datasets to be re-created on the run and distributed by Lightning across | ||
devices on each epoch. Setting to a value different than 1 may lead to overfitting to a | ||
specific (such as the initial) data distribution. | ||
|
||
Args: | ||
accelerator: hardware accelerator to use. | ||
callbacks: list of callbacks. | ||
logger: logger (or iterable collection of loggers) for experiment tracking. | ||
min_epochs: minimum number of training epochs. | ||
max_epochs: maximum number of training epochs. | ||
strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP). | ||
devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node. | ||
gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability. | ||
precision: allows for mixed precision training. Can be specified as a string (e.g., '16'). | ||
This also allows to use `FlashAttention` by default. | ||
disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed. | ||
auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available. | ||
reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs. | ||
matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision | ||
**kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
accelerator: Union[str, Accelerator] = "auto", | ||
callbacks: Optional[List[Callback]] = None, | ||
logger: Optional[Union[Logger, Iterable[Logger]]] = None, | ||
min_epochs: Optional[int] = None, | ||
max_epochs: Optional[int] = None, | ||
strategy: Union[str, Strategy] = "auto", | ||
devices: Union[List[int], str, int] = "auto", | ||
gradient_clip_val: Union[int, float] = 1.0, | ||
precision: Union[str, int] = "16-mixed", | ||
reload_dataloaders_every_n_epochs: int = 1, | ||
disable_profiling_executor: bool = True, | ||
auto_configure_ddp: bool = True, | ||
matmul_precision: Union[str, int] = "medium", | ||
**kwargs, | ||
): | ||
# Disable JIT profiling executor. This reduces memory and increases speed. | ||
# Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17 | ||
if disable_profiling_executor: | ||
try: | ||
torch._C._jit_set_profiling_executor(False) | ||
torch._C._jit_set_profiling_mode(False) | ||
except AttributeError: | ||
pass | ||
|
||
# Configure DDP automatically if multiple GPUs are available | ||
if auto_configure_ddp and strategy == "auto": | ||
if devices == "auto": | ||
n_devices = num_cuda_devices() | ||
elif isinstance(devices, Iterable): | ||
n_devices = len(devices) | ||
else: | ||
n_devices = devices | ||
if n_devices > 1: | ||
log.info( | ||
"Configuring DDP strategy automatically with {} GPUs".format( | ||
n_devices | ||
) | ||
) | ||
strategy = DDPStrategy( | ||
find_unused_parameters=True, # We set to True due to RL envs | ||
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations | ||
) | ||
|
||
# Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision | ||
if matmul_precision is not None: | ||
torch.set_float32_matmul_precision(matmul_precision) | ||
|
||
# Check if gradient_clip_val is set to None | ||
if gradient_clip_val is None: | ||
log.warning( | ||
"gradient_clip_val is set to None. This may lead to unstable training." | ||
) | ||
|
||
# We should reload dataloaders every epoch for RL training | ||
if reload_dataloaders_every_n_epochs != 1: | ||
log.warning( | ||
"We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 " | ||
+ "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs." | ||
) | ||
|
||
# Main call to `Trainer` superclass | ||
super().__init__( | ||
accelerator=accelerator, | ||
callbacks=callbacks, | ||
logger=logger, | ||
min_epochs=min_epochs, | ||
max_epochs=max_epochs, | ||
strategy=strategy, | ||
gradient_clip_val=gradient_clip_val, | ||
devices=devices, | ||
precision=precision, | ||
reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs, | ||
**kwargs, | ||
) | ||
|
||
def fit( | ||
self, | ||
model: "pl.LightningModule", | ||
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None, | ||
val_dataloaders: Optional[EVAL_DATALOADERS] = None, | ||
datamodule: Optional[LightningDataModule] = None, | ||
ckpt_path: Optional[str] = None, | ||
) -> None: | ||
""" | ||
We override the `fit` method to automatically apply and handle RL4CO magic | ||
to 'self.automatic_optimization = False' models, such as PPO | ||
|
||
It behaves exactly like the original `fit` method, but with the following changes: | ||
- if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None | ||
""" | ||
|
||
if not model.automatic_optimization: | ||
if self.gradient_clip_val is not None: | ||
log.warning( | ||
"Overriding gradient_clip_val to None for 'automatic_optimization=False' models" | ||
) | ||
self.gradient_clip_val = None | ||
|
||
# Fit (Inner-loop Optimization) | ||
super().fit( | ||
model=model, | ||
train_dataloaders=train_dataloaders, | ||
val_dataloaders=val_dataloaders, | ||
datamodule=datamodule, | ||
ckpt_path=ckpt_path, | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] Nice tool but
pytz
is not included in RL4CO's dependence packages. Maybe better to have a package check here:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, maybe it could be removed.