Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pangu Improvements #656

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Graph Transformer processor for GraphCast/GenCast.
- Utility to generate STL from Signed Distance Field.
- Improved Pangu training code

### Changed

- Refactored CorrDiff training recipe for improved usability
- Refactored Pangu model for better extensibility and gradient checkpointing support.
Some of these changes are not backward compatible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps comment on what specifically is not backward compatible? Is it just the removal of the prepare_input routine from the Pangu model?


### Deprecated

Expand Down
93 changes: 77 additions & 16 deletions examples/weather/pangu_weather/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,91 @@ hydra:
run:
dir: ./outputs/

start_epoch: 0
max_epoch: 100
max_epoch: 101

pangu:
img_size: [721, 1440]
patch_size: [2, 4, 4]
embed_dim: 192
num_heads: [6, 12, 12, 6]
window_size: [2, 6, 12]
number_constant_variables: 3
number_surface_variables: 4
number_atmosphere_variables: 5
number_atmosphere_levels: 13
number_up_sampled_blocks: 2
number_down_sampled_blocks: 6
checkpoint_flag: True

train:
data_dir: "/data/train/"
stats_dir: "/data/stats/"
checkpoint_dir: "/data/checkpoints/"
mask_dir: "/data/constant_mask/"

num_samples_per_year: 1456
batch_size: 1
patch_size: [1, 1]
num_workers: 8
lr: 1e-3
weight_decay: 0.05
use_cosine_zenith: True
mask_dtype: "float32"
enable_amp: False
enable_graphs: False

stages:
- name: "Learning Rate Warmup"
max_iterations: .inf
num_epochs: 1
batch_size: 1
num_rollout_steps: 1
lr_scheduler_name: LinearLR
args:
start_factor: 0.001
end_factor: 1.0
total_iters: 1

- name: "Cosine Annealing LR"
max_iterations: .inf
num_epochs: 100
batch_size: 1
num_rollout_steps: 1
lr_scheduler_name: CosineAnnealingLR
args:
T_max: 100
eta_min: 0.0

- name: "LambdaLR Rollout: 2 Steps"
max_iterations: .inf
num_epochs: 20
batch_size: 2
num_rollout_steps: 2
lr_scheduler_name: LambdaLR
args:
lr_lambda: ${lambda_lr:3e-7,${train.lr}}

- name: "LambdaLR Rollout: 3 Steps"
max_iterations: .inf
num_epochs: 20
batch_size: 2
num_rollout_steps: 3
lr_scheduler_name: LambdaLR
args:
lr_lambda: ${lambda_lr:3e-7,${train.lr}}

- name: "LambdaLR Rollout: 4 Steps"
max_iterations: .inf
num_epochs: 20
batch_size: 1
num_rollout_steps: 4
lr_scheduler_name: LambdaLR
args:
lr_lambda: ${lambda_lr:3e-7,${train.lr}}

val:
data_dir: "/data/test/"
stats_dir: "/data/stats/"
num_samples_per_year: 4
num_samples_per_year: 32
batch_size: 1
patch_size: [1, 1]
num_workers: 8

pangu:
img_size: [721, 1440]
patch_size: [2, 4, 4]
embed_dim: 192
num_heads: [6, 12, 12, 6]
window_size: [2, 6, 12]

mask_dir: "/data/constant_mask"
mask_dtype: "float32"
num_rollout_steps: 6
channels: [0, 1, 2, 3, 4, 7, 43]
70 changes: 52 additions & 18 deletions examples/weather/pangu_weather/conf/config_lite.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,72 @@

experiment_name: "Modulus-Launch-Dev"
experiment_desc: "Modulus launch development"
run_desc: "Pangu lite ERA5 Training"
run_desc: "Pangu ERA5 Lite-Training"

hydra:
job:
chdir: True
run:
dir: ./outputs/

start_epoch: 0
max_epoch: 100
max_epoch: 11

pangu:
img_size: [721, 1440]
patch_size: [2, 4, 4]
embed_dim: 96
num_heads: [6, 12, 12, 6]
window_size: [2, 6, 12]
number_constant_variables: 3
number_surface_variables: 4
number_atmosphere_variables: 5
number_atmosphere_levels: 13
number_up_sampled_blocks: 2
number_down_sampled_blocks: 6
checkpoint_flag: True

train:
data_dir: "/data/train/"
stats_dir: "/data/stats/"
num_samples_per_year: 1456
batch_size: 1
patch_size: [1, 1]
checkpoint_dir: "/data/checkpoints/"
mask_dir: "/data/constant_mask/"

num_samples_per_year: 600
num_workers: 8
lr: 1e-3
weight_decay: 0.05
use_cosine_zenith: True
mask_dtype: "float32"
enable_amp: False
enable_graphs: False

stages:
- name: "Learning Rate Warmup"
max_iterations: .inf
num_epochs: 1
batch_size: 1
num_rollout_steps: 1
lr_scheduler_name: LinearLR
args:
start_factor: 0.001
end_factor: 1.0
total_iters: 1

- name: "Cosine Annealing LR"
max_iterations: .inf
num_epochs: 10
batch_size: 1
num_rollout_steps: 1
lr_scheduler_name: CosineAnnealingLR
args:
T_max: 10
eta_min: 0.0

val:
data_dir: "/data/test/"
stats_dir: "/data/stats/"
num_samples_per_year: 4
num_samples_per_year: 1
batch_size: 1
patch_size: [1, 1]
num_workers: 8

pangu:
img_size: [721, 1440]
patch_size: [2, 8, 8]
embed_dim: 192
num_heads: [6, 12, 12, 6]
window_size: [2, 6, 12]

mask_dir: "/data/constant_mask"
mask_dtype: "float32"
num_rollout_steps: 1
channels: [0, 1, 2, 3, 4, 7, 43]
1 change: 0 additions & 1 deletion examples/weather/pangu_weather/requirements.txt

This file was deleted.

Loading