Skip to content

Commit

Permalink
Argument to inform if the head will be frozen or not
Browse files Browse the repository at this point in the history
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
  • Loading branch information
Joao-L-S-Almeida committed Jan 21, 2025
1 parent eb20bf8 commit 1fdb5f5
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 3 deletions.
3 changes: 3 additions & 0 deletions terratorch/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def configure_models(self) -> None:
if self.hparams["freeze_decoder"]:
self.model.freeze_decoder()

if self.hparams["freeze_head"]:
self.model.freeze_head()

def configure_optimizers(
self,
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
Expand Down
4 changes: 3 additions & 1 deletion terratorch/tasks/classification_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
#
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
freeze_head: bool = False, # noqa: FBT002, FBT001
class_names: list[str] | None = None,
lr_overrides: dict[str, float] | None = None,
) -> None:
Expand Down Expand Up @@ -96,7 +97,8 @@ def __init__(
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
freeze_head (bool, optional): Whether to freeze the segmentation_head. Defaults to False.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Defaults to numeric ordering.
lr_overrides (dict[str, float] | None, optional): Dictionary to override the default lr in specific
Expand Down
4 changes: 3 additions & 1 deletion terratorch/tasks/regression_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
#
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT001, FBT002
freeze_head: bool = False, # noqa: FBT001, FBT002
plot_on_val: bool | int = 10,
tiled_inference_parameters: TiledInferenceParameters | None = None,
lr_overrides: dict[str, float] | None = None,
Expand Down Expand Up @@ -183,7 +184,8 @@ def __init__(
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
tiled_inference_parameters (TiledInferenceParameters | None, optional): Inference parameters
Expand Down
4 changes: 3 additions & 1 deletion terratorch/tasks/segmentation_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
#
freeze_backbone: bool = False, # noqa: FBT001, FBT002
freeze_decoder: bool = False, # noqa: FBT002, FBT001
freeze_head: bool = False,
plot_on_val: bool | int = 10,
class_names: list[str] | None = None,
tiled_inference_parameters: TiledInferenceParameters = None,
Expand Down Expand Up @@ -97,7 +98,8 @@ def __init__(
scheduler_hparams (dict | None): Parameters to be passed for instantiation of the scheduler.
Overriden by config / cli specification through LightningCLI.
freeze_backbone (bool, optional): Whether to freeze the backbone. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder and segmentation head. Defaults to False.
freeze_decoder (bool, optional): Whether to freeze the decoder. Defaults to False.
freeze_head (bool, optional): Whether to freeze the segmentation head. Defaults to False.
plot_on_val (bool | int, optional): Whether to plot visualizations on validation.
If true, log every epoch. Defaults to 10. If int, will plot every plot_on_val epochs.
class_names (list[str] | None, optional): List of class names passed to metrics for better naming.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ model:
ignore_index: -1
freeze_backbone: true
freeze_decoder: false
freeze_head: false
model_factory: PrithviModelFactory

# uncomment this block for tiled inference
Expand Down

0 comments on commit 1fdb5f5

Please sign in to comment.