-
Notifications
You must be signed in to change notification settings - Fork 1
/
split-cityscapes.py
103 lines (84 loc) · 3.36 KB
/
split-cityscapes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import logging
import hydra
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from src.datasets.cityscapes2 import Cityscapes2SplitDataModule
from src.models.factory.cosmos.upsampler import Upsampler
from src.models.factory.phn.phn_wrappers import HyperModel
from src.models.factory.rotograd import RotogradWrapper
from src.models.factory.segnet_cityscapes import (
SegNet,
SegNetDepthDecoder,
SegNetMtan,
SegNetSegmentationDecoder,
SegNetSplitEncoder,
)
from src.utils import set_seed
from src.utils._selectors import get_ensemble_model, get_trainer
from src.utils.callbacks.auto_lambda_callback import AutoLambdaCallback
from src.utils.callbacks.cityscapes_metric_cb import CityscapesMetricCallback
from src.utils.callbacks.save_model import SaveModelCallback
from src.utils.logging_utils import initialize_wandb, install_logging
from src.utils.losses import CityscapesTwoTaskLoss
@hydra.main(config_path="configs/experiment/cityscapes", config_name="cityscapes")
def my_app(config: DictConfig) -> None:
import warnings
warnings.filterwarnings(
"ignore", message="Note that order of the arguments: ceil_mode and return_indices will change"
)
install_logging()
logging.info(OmegaConf.to_yaml(config))
set_seed(config.seed)
initialize_wandb(config)
wandb.run.tags = ["313"]
dm = Cityscapes2SplitDataModule(
batch_size=config.data.batch_size,
num_workers=config.data.num_workers,
apply_augmentation=config.data.apply_augmentation,
)
logging.info(f"I am using the following benchmark {dm.name}")
in_channels = 5 if config.method.name == "cosmos" else 3
model = dict(segnetmtan=SegNetMtan(in_channels), segnet=SegNet(in_channels))[config.model.type]
if config.method.name == "pamal":
model = get_ensemble_model(model, dm.num_tasks, config)
elif config.method.name == "cosmos":
model = Upsampler(dm.num_tasks, model, input_dim=dm.input_dims)
elif config.method.name == "rotograd":
backbone = SegNetSplitEncoder(in_channels=in_channels, rotograd=True)
head1, head2 = SegNetSegmentationDecoder(rotograd=True), SegNetDepthDecoder(rotograd=True)
model = RotogradWrapper(backbone=backbone, heads=[head1, head2], latent_size=50)
param_groups = model.parameters()
param_groups = model.parameters()
# logging.info(model)
optimizer = torch.optim.Adam(param_groups, lr=config.optimizer.lr)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=config.scheduler.step,
gamma=config.scheduler.gamma,
)
logging.info(scheduler)
callbacks = [CityscapesMetricCallback(), SaveModelCallback()]
if config.method.name == "autol":
callbacks.append(AutoLambdaCallback(config.method.meta_lr))
trainer_kwargs = dict(
model=model,
benchmark=dm,
optimizer=optimizer,
loss_fn=CityscapesTwoTaskLoss(),
gpu=0,
scheduler=scheduler,
scheduler_step_on_epoch=True,
callbacks=callbacks,
)
trainer = get_trainer(config, trainer_kwargs, dm.num_tasks)
trainer.fit(epochs=config.training.epochs)
if config.method.name == "pamal":
trainer.predict_interpolations(dm.test_dataloader())
else:
trainer.predict(test_loader=dm.test_dataloader())
wandb.finish()
if __name__ == "__main__":
my_app()
"""
"""