Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add meta learning framework #183

Merged
merged 10 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions examples/2d-meta_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytz
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] Nice tool but pytz is not included in RL4CO's dependence packages. Maybe better to have a package check here:

try:
    import pytz
except ImportError:
    # raise a warning and use python default timeit
    pass

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe it could be removed.

import torch

from datetime import datetime
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
from lightning.pytorch.loggers import WandbLogger

from rl4co.envs import CVRPEnv
from rl4co.models.zoo.am import AttentionModelPolicy
from rl4co.models.zoo.pomo import POMO
from rl4co.utils.meta_trainer import RL4COMetaTrainer, MetaModelCallback

def main():
# Set device
device_id = 0

# RL4CO env based on TorchRL
env = CVRPEnv(generator_params={'num_loc': 50})

# Policy: neural network, in this case with encoder-decoder architecture
# Note that this is adapted the same as POMO did in the original paper
policy = AttentionModelPolicy(env_name=env.name,
embed_dim=128,
num_encoder_layers=6,
num_heads=8,
normalization="instance",
use_graph_context=False
)

# RL Model (POMO)
model = POMO(env,
policy,
batch_size=64, # meta_batch_size
train_data_size=64 * 50, # each epoch
val_data_size=0,
optimizer_kwargs={"lr": 1e-4, "weight_decay": 1e-6},
# for the task scheduler of size setting, where sch_epoch = 0.9 * epochs
)

# Example callbacks
checkpoint_callback = ModelCheckpoint(
dirpath="checkpoints", # save to checkpoints/
filename="epoch_{epoch:03d}", # save as epoch_XXX.ckpt
save_top_k=1, # save only the best model
save_last=True, # save the last model
monitor="val/reward", # monitor validation reward
mode="max", # maximize validation reward
)
rich_model_summary = RichModelSummary(max_depth=3) # model summary callback
# Meta callbacks
meta_callback = MetaModelCallback(
meta_params={
'meta_method': 'reptile', # choose from ['maml', 'fomaml', 'maml_fomaml', 'reptile']
'data_type': 'size', # choose from ["size", "distribution", "size_distribution"]
'sch_bar': 0.9, # for the task scheduler of size setting, where sch_epoch = sch_bar * epochs
'B': 1, # the number of tasks in a mini-batch
'alpha': 0.99, # params for the outer-loop optimization of reptile
'alpha_decay': 0.999, # params for the outer-loop optimization of reptile
'min_size': 20, # minimum of sampled size in meta tasks
'max_size': 150, # maximum of sampled size in meta tasks
},
print_log=True # whether to print the sampled tasks in each meta iteration
)
callbacks = [meta_callback, checkpoint_callback, rich_model_summary]

# Logger
process_start_time = datetime.now(pytz.timezone("Asia/Singapore"))
logger = WandbLogger(project="rl4co", name=f"{env.name}_{process_start_time.strftime('%Y%m%d_%H%M%S')}")
# logger = None # uncomment this line if you don't want logging

# Adjust your trainer to the number of epochs you want to run
trainer = RL4COMetaTrainer(
max_epochs=20000, # (the number of meta-model updates) * (the number of tasks in a mini-batch)
callbacks=callbacks,
accelerator="gpu",
devices=[device_id],
logger=logger,
limit_train_batches=50 # gradient decent steps in the inner-loop optimization of meta-learning method
)

# Fit
trainer.fit(model)


if __name__ == "__main__":
main()

1 change: 1 addition & 0 deletions rl4co/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from rl4co.utils.pylogger import get_pylogger
from rl4co.utils.rich_utils import enforce_tags, print_config_tree
from rl4co.utils.trainer import RL4COTrainer
from rl4co.utils.meta_trainer import RL4COMetaTrainer
from rl4co.utils.utils import (
extras,
get_metric_value,
Expand Down
245 changes: 245 additions & 0 deletions rl4co/utils/meta_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from typing import Iterable, List, Optional, Union

import lightning.pytorch as pl
import torch
import math
import copy
from torch.optim import Adam

from lightning import Callback, Trainer
from lightning.fabric.accelerators.cuda import num_cuda_devices
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.core.datamodule import LightningDataModule
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies import DDPStrategy, Strategy
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from rl4co import utils
import random
log = utils.get_pylogger(__name__)


class MetaModelCallback(Callback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't thought of Meta-learning as Lightning Callbacks, but it looks neat! :D

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tagging @Junyoungpark for the good ol' Reptile

def __init__(self, meta_params, print_log=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] I think it would be more clear to list the Hyperparameters.

For example

def __init__(self, 
alpha=...,
alpha_decay=...
... 
print_log=True):

which is generally easier to maintain and to document.

super().__init__()
self.meta_params = meta_params
assert meta_params["meta_method"] == 'reptile', NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another general comment: it seems that this callback is only for Reptile, so we should consider calling it ReptileCallback. I think it would be cool to have these as say a meta_learning/ folder

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. ReptileCallback is better. Where should this new meta_learning/ folder be located? In utils/?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I''d say under models/rl (which includes LightningModules), since meta_learning is a way to optimize a policy

assert meta_params["data_type"] == 'size', NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this parameter data_type is not called anywhere? 🤔

self.print_log = print_log

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

# Initialize some hyperparameters
self.alpha = self.meta_params["alpha"]
self.alpha_decay = self.meta_params["alpha_decay"]
self.sch_bar = self.meta_params["sch_bar"]
self.task_set = [(n,) for n in range(self.meta_params["min_size"], self.meta_params["max_size"] + 1)]

# Sample a batch of tasks
self._sample_task()
self.selected_tasks[0] = (pl_module.env.generator.num_loc, )

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] type hints should be classes, not strings. For example Trainer and LightningModule (Maybe even better: RL4COTrainer and RL4COLitModule since everything is inherited from that)

Copy link
Contributor Author

@jieyibi jieyibi May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. I am refactoring the ReptileCallback inherited from the REINFORCE and the RL4COLitModule. But in this case, maybe I need to add new meta_model.py inherited the model.py under the zoo/pomo/ folder or the zoo/am/ folder to call the REPTILE outside. It is a little bit redundant... Maybe a Lightning Callback is more generic. We could apply it to every model and every policy.


# Alpha scheduler (decay for the update of meta model)
self._alpha_scheduler()

# Reinitialize the task model with the parameters of the meta model
if trainer.current_epoch % self.meta_params['B'] == 0: # Save the meta model
self.meta_model_state_dict = copy.deepcopy(pl_module.state_dict())
self.task_models = []
# Print sampled tasks
if self.print_log:
print('\n>> Meta epoch: {} (Exact epoch: {}), Training task: {}'.format(trainer.current_epoch//self.meta_params['B'], trainer.current_epoch, self.selected_tasks))
else:
pl_module.load_state_dict(self.meta_model_state_dict)

# Reinitialize the optimizer every epoch
lr_decay = 0.1 if trainer.current_epoch+1 == int(self.sch_bar * trainer.max_epochs) else 1
old_lr = trainer.optimizers[0].param_groups[0]['lr']
new_optimizer = Adam(pl_module.parameters(), lr=old_lr * lr_decay)
trainer.optimizers = [new_optimizer]

# Print
if self.print_log:
print('\n>> Training task: {}, capacity: {}'.format(pl_module.env.generator.num_loc, pl_module.env.generator.capacity))

def on_train_epoch_end(self, trainer, pl_module):

# Save the task model
self.task_models.append(copy.deepcopy(pl_module.state_dict()))
if (trainer.current_epoch+1) % self.meta_params['B'] == 0:
# Outer-loop optimization (update the meta model with the parameters of the task model)
with torch.no_grad():
state_dict = {params_key: (self.meta_model_state_dict[params_key] +
self.alpha * torch.mean(torch.stack([fast_weight[params_key] - self.meta_model_state_dict[params_key]
for fast_weight in self.task_models], dim=0).float(), dim=0))
for params_key in self.meta_model_state_dict}
pl_module.load_state_dict(state_dict)

# Get ready for the next meta-training iteration
if (trainer.current_epoch + 1) % self.meta_params['B'] == 0:
# Sample a batch of tasks
self._sample_task()

# Load new training task (Update the environment)
self._load_task(pl_module, task_idx = (trainer.current_epoch+1) % self.meta_params['B'])

def _sample_task(self):
# Sample a batch of tasks
w, self.selected_tasks = [1.0] * self.meta_params['B'], []
for b in range(self.meta_params['B']):
task_params = random.sample(self.task_set, 1)[0]
self.selected_tasks.append(task_params)
self.w = torch.softmax(torch.Tensor(w), dim=0)

def _load_task(self, pl_module, task_idx=0):
# Load new training task (Update the environment)
task_params, task_w = self.selected_tasks[task_idx], self.w[task_idx].item()
task_capacity = math.ceil(30 + task_params[0] / 5) if task_params[0] >= 20 else 20
pl_module.env.generator.num_loc = task_params[0]
pl_module.env.generator.capacity = task_capacity

def _alpha_scheduler(self):
self.alpha = max(self.alpha * self.alpha_decay, 0.0001)

class RL4COMetaTrainer(Trainer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any differences compared to the RL4COTrainer? I could not find any at a first glance. I guess the only difference is that we pass the MetaModelCallback?

"""Wrapper around Lightning Trainer, with some RL4CO magic for efficient training.

# Meta training framework for addressing the generalization issue
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587

Note:
The most important hyperparameter to use is `reload_dataloaders_every_n_epochs`.
This allows for datasets to be re-created on the run and distributed by Lightning across
devices on each epoch. Setting to a value different than 1 may lead to overfitting to a
specific (such as the initial) data distribution.

Args:
accelerator: hardware accelerator to use.
callbacks: list of callbacks.
logger: logger (or iterable collection of loggers) for experiment tracking.
min_epochs: minimum number of training epochs.
max_epochs: maximum number of training epochs.
strategy: training strategy to use (if any), such as Distributed Data Parallel (DDP).
devices: number of devices to train on (int) or which GPUs to train on (list or str) applied per node.
gradient_clip_val: 0 means don't clip. Defaults to 1.0 for stability.
precision: allows for mixed precision training. Can be specified as a string (e.g., '16').
This also allows to use `FlashAttention` by default.
disable_profiling_executor: Disable JIT profiling executor. This reduces memory and increases speed.
auto_configure_ddp: Automatically configure DDP strategy if multiple GPUs are available.
reload_dataloaders_every_n_epochs: Set to a value different than 1 to reload dataloaders every n epochs.
matmul_precision: Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
**kwargs: Additional keyword arguments passed to the Lightning Trainer. See :class:`lightning.pytorch.trainer.Trainer` for details.
"""

def __init__(
self,
accelerator: Union[str, Accelerator] = "auto",
callbacks: Optional[List[Callback]] = None,
logger: Optional[Union[Logger, Iterable[Logger]]] = None,
min_epochs: Optional[int] = None,
max_epochs: Optional[int] = None,
strategy: Union[str, Strategy] = "auto",
devices: Union[List[int], str, int] = "auto",
gradient_clip_val: Union[int, float] = 1.0,
precision: Union[str, int] = "16-mixed",
reload_dataloaders_every_n_epochs: int = 1,
disable_profiling_executor: bool = True,
auto_configure_ddp: bool = True,
matmul_precision: Union[str, int] = "medium",
**kwargs,
):
# Disable JIT profiling executor. This reduces memory and increases speed.
# Reference: https://github.com/HazyResearch/safari/blob/111d2726e7e2b8d57726b7a8b932ad8a4b2ad660/train.py#LL124-L129C17
if disable_profiling_executor:
try:
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
except AttributeError:
pass

# Configure DDP automatically if multiple GPUs are available
if auto_configure_ddp and strategy == "auto":
if devices == "auto":
n_devices = num_cuda_devices()
elif isinstance(devices, Iterable):
n_devices = len(devices)
else:
n_devices = devices
if n_devices > 1:
log.info(
"Configuring DDP strategy automatically with {} GPUs".format(
n_devices
)
)
strategy = DDPStrategy(
find_unused_parameters=True, # We set to True due to RL envs
gradient_as_bucket_view=True, # https://pytorch-lightning.readthedocs.io/en/stable/advanced/advanced_gpu.html#ddp-optimizations
)

# Set matmul precision for faster inference https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
if matmul_precision is not None:
torch.set_float32_matmul_precision(matmul_precision)

# Check if gradient_clip_val is set to None
if gradient_clip_val is None:
log.warning(
"gradient_clip_val is set to None. This may lead to unstable training."
)

# We should reload dataloaders every epoch for RL training
if reload_dataloaders_every_n_epochs != 1:
log.warning(
"We reload dataloaders every epoch for RL training. Setting reload_dataloaders_every_n_epochs to a value different than 1 "
+ "may lead to unexpected behavior since the initial conditions will be the same for `n_epochs` epochs."
)

# Main call to `Trainer` superclass
super().__init__(
accelerator=accelerator,
callbacks=callbacks,
logger=logger,
min_epochs=min_epochs,
max_epochs=max_epochs,
strategy=strategy,
gradient_clip_val=gradient_clip_val,
devices=devices,
precision=precision,
reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs,
**kwargs,
)

def fit(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
ckpt_path: Optional[str] = None,
) -> None:
"""
We override the `fit` method to automatically apply and handle RL4CO magic
to 'self.automatic_optimization = False' models, such as PPO

It behaves exactly like the original `fit` method, but with the following changes:
- if the given model is 'self.automatic_optimization = False', we override 'gradient_clip_val' as None
"""

if not model.automatic_optimization:
if self.gradient_clip_val is not None:
log.warning(
"Overriding gradient_clip_val to None for 'automatic_optimization=False' models"
)
self.gradient_clip_val = None

# Fit (Inner-loop Optimization)
super().fit(
model=model,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
datamodule=datamodule,
ckpt_path=ckpt_path,
)



Loading