Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pangu Improvements #656

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

dallasfoster
Copy link
Collaborator

@dallasfoster dallasfoster commented Aug 27, 2024

Modulus Pull Request

Description

This PR adds the following features/changes to the Pangu model and training script:

  1. Configurable number of constant, surface, and atmosphere variables in the model.
  2. Configurable number of upsampled and downsampled transformer blocks.
  3. Gradient checkpointing support in the Pangu processor (encoder/decoder) layers.
  4. Improved training script with improved static capture support, multistep rollout, validation function, and weighted loss function.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

@dallasfoster dallasfoster self-assigned this Aug 27, 2024
@dallasfoster
Copy link
Collaborator Author

/blossom-ci

@dallasfoster
Copy link
Collaborator Author

/blossom-ci

@dallasfoster
Copy link
Collaborator Author

/blossom-ci

@dallasfoster
Copy link
Collaborator Author

/blossom-ci

@dallasfoster
Copy link
Collaborator Author

/blossom-ci

@dallasfoster
Copy link
Collaborator Author

Depends on #660


### Changed

- Refactored CorrDiff training recipe for improved usability
- Refactored Pangu model for better extensibility and gradient checkpointing support.
Some of these changes are not backward compatible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps comment on what specifically is not backward compatible? Is it just the removal of the prepare_input routine from the Pangu model?

outpred = my_model(invar_)
loss += loss_func(outpred, outvar[b : b + 1, t], weights) / batch_size
invar_ = outpred

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this inner multistep loop cannot be run with batch_size > 1? Would be nice to support batched rollout training if it could fit in memory

@@ -203,122 +262,197 @@ def main(cfg: DictConfig) -> None:
)
torch.cuda.current_stream().wait_stream(ddps)

# pangu_model = torch.compile(pangu_model, mode = "max-autotune")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can drop these commented out .compile statements if they're unused

import torch

from ..layers import DownSample3D, FuserLayer, UpSample3D
from ..module import Module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer direct imports here

@pzharrington
Copy link
Collaborator

I added some minor suggestions, overall this looks great though. I found the LambdaLR scheme with the custom hydra resolver a bit convoluted, maybe a ConstantLR would be simpler and more readable to achieve the same effect. However it is nice to have the example in there if someone wants to do more custom scheduling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants