-
Notifications
You must be signed in to change notification settings - Fork 90
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add meta learning framework #183
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! 🚀
Left some comments here and there~
rl4co/utils/meta_trainer.py
Outdated
def _alpha_scheduler(self): | ||
self.alpha = max(self.alpha * self.alpha_decay, 0.0001) | ||
|
||
class RL4COMetaTrainer(Trainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there any differences compared to the RL4COTrainer
? I could not find any at a first glance. I guess the only difference is that we pass the MetaModelCallback
?
rl4co/utils/meta_trainer.py
Outdated
log = utils.get_pylogger(__name__) | ||
|
||
|
||
class MetaModelCallback(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't thought of Meta-learning as Lightning Callbacks, but it looks neat! :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tagging @Junyoungpark for the good ol' Reptile
rl4co/utils/meta_trainer.py
Outdated
def __init__(self, meta_params, print_log=True): | ||
super().__init__() | ||
self.meta_params = meta_params | ||
assert meta_params["meta_method"] == 'reptile', NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another general comment: it seems that this callback is only for Reptile, so we should consider calling it ReptileCallback
. I think it would be cool to have these as say a meta_learning/
folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. ReptileCallback
is better. Where should this new meta_learning/
folder be located? In utils/
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I''d say under models/rl
(which includes LightningModules
), since meta_learning
is a way to optimize a policy
rl4co/utils/meta_trainer.py
Outdated
self._sample_task() | ||
self.selected_tasks[0] = (pl_module.env.generator.num_loc, ) | ||
|
||
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] type hints should be classes, not strings. For example Trainer
and LightningModule
(Maybe even better: RL4COTrainer
and RL4COLitModule
since everything is inherited from that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi. I am refactoring the ReptileCallback
inherited from the REINFORCE
and the RL4COLitModule
. But in this case, maybe I need to add new meta_model.py
inherited the model.py
under the zoo/pomo/
folder or the zoo/am/
folder to call the REPTILE outside. It is a little bit redundant... Maybe a Lightning Callback is more generic. We could apply it to every model and every policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job! 🚀 Really nice to have meta learning supporting 😁
examples/2d-meta_train.py
Outdated
@@ -0,0 +1,87 @@ | |||
import pytz |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] Nice tool but pytz
is not included in RL4CO's dependence packages. Maybe better to have a package check here:
try:
import pytz
except ImportError:
# raise a warning and use python default timeit
pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, maybe it could be removed.
rl4co/utils/meta_trainer.py
Outdated
super().__init__() | ||
self.meta_params = meta_params | ||
assert meta_params["meta_method"] == 'reptile', NotImplementedError | ||
assert meta_params["data_type"] == 'size', NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this parameter data_type
is not called anywhere? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
Also:
- Could you reproduce the learning curves of the original model?
- Could you make a simple test as this for your model so that it can be automatically checked?
rl4co/utils/meta_trainer.py
Outdated
|
||
# Meta training framework for addressing the generalization issue | ||
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 | ||
def __init__(self, meta_params, print_log=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Minor] I think it would be more clear to list the Hyperparameters.
For example
def __init__(self,
alpha=...,
alpha_decay=...
...
print_log=True):
which is generally easier to maintain and to document.
It seems there was a mistake here - I noticed that the "tasks" are a bit hardcoded, i.e. the "size" and "capacity" (I guess for TSP and CVRP, right?) Speaking of testing: you can make sure things work on your device before committing by using |
rl4co/utils/meta_trainer.py
Outdated
class ReptileCallback(Callback): | ||
|
||
# Meta training framework for addressing the generalization issue | ||
# Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587 | ||
def __init__(self, | ||
num_tasks, | ||
alpha, | ||
alpha_decay, | ||
min_size, | ||
max_size, | ||
sch_bar = 0.9, | ||
data_type = "size", | ||
print_log=True): | ||
super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Documentation] It's recommended to have a doc with parameters, possibly including the data type, constraints, hints, etc. Better for us "non-experts" to understand 😆
class ReptileCallback(Callback):
"""Meta training framework for addressing the generalization issue
Based on Zhou et al. (2023): https://arxiv.org/abs/2305.19587
Args:
- num_tasks: number of task types, i.e. `B` in the original paper
- alpha: ...
- ...
"""
def __init__(
self,
num_tasks: int,
alpha: float,
alpha_decay: float,
min_size: int,
max_size: int,
sch_bar: float = 0.9,
data_type: str = "size",
print_log: bool = True,
):
super().__init__()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Chuanbo. I have added the documentation as you recommended, along with the generation code for some distributions defined in the generalization-related works. Now the meta learning framework is supported for cross-distribution generalization.
@@ -0,0 +1,184 @@ | |||
import torch | |||
|
|||
class Cluster(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this! @cbhua I think we should train some model with say TSP50 / CVRP50 with a mixed distribution and test its generalization performance
Minor comment: Shouldn't this be a subclass of torch.distributions.distribution.Distribution
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes actually this class is what we are missing in the distribution! This could be used by various environments' generator.
About the experiment, we want to test the distribution generalization ability right?
Routine check, how is progress going? I think the multi-distribution generators in particular should be included, since they are part of |
The updated commit supports training on multiple mixed distributions by changing the argument
The rest of the arguments remain the same. Note that this mixed distribution setting follows the # setting in Bi et al., 2022. |
Great! Should the code be merged? |
Change some parameters for performance
Yes. I think it is ready to be merged. The newly reproduced performance (by changing some key parameters) seems similar to that in the literature. |
Great! How about the generalization experiments, i.e., MDPOMO? Do you have |
Yeah, already added to the main branch:) |
Awesome! Then we can go ahead and merge :) |
Description
Add a new training framework based on meta learning. Details refer to Zhou et al. 2023.
Motivation and Context
To address the generalization issue.
Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
x
in all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!