From 9b6c932e2f8a11fb14c2ee52ddeb20b77e96806b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jul 2024 19:25:30 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/solo/utils.rst | 1 - main_linear.py | 6 +++--- main_pretrain.py | 6 +++--- scripts/pretrain/cifar/all4one.yaml | 2 +- solo/methods/base.py | 8 +++++--- solo/utils/positional_encodings.py | 8 ++------ 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/docs/source/solo/utils.rst b/docs/source/solo/utils.rst index 83eb3636..e623bdc4 100644 --- a/docs/source/solo/utils.rst +++ b/docs/source/solo/utils.rst @@ -257,4 +257,3 @@ forward ~~~~~~~ .. automethod:: solo.utils.positional_encoding.Summer.forward :noindex: - diff --git a/main_linear.py b/main_linear.py index bb0b1454..03820b9a 100644 --- a/main_linear.py +++ b/main_linear.py @@ -203,9 +203,9 @@ def main(cfg: DictConfig): "logger": wandb_logger if cfg.wandb.enabled else None, "callbacks": callbacks, "enable_checkpointing": False, - "strategy": DDPStrategy(find_unused_parameters=False) - if cfg.strategy == "ddp" - else cfg.strategy, + "strategy": ( + DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy + ), } ) trainer = Trainer(**trainer_kwargs) diff --git a/main_pretrain.py b/main_pretrain.py index e83d5591..3e06ae40 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -221,9 +221,9 @@ def main(cfg: DictConfig): "logger": wandb_logger if cfg.wandb.enabled else None, "callbacks": callbacks, "enable_checkpointing": False, - "strategy": DDPStrategy(find_unused_parameters=False) - if cfg.strategy == "ddp" - else cfg.strategy, + "strategy": ( + DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy + ), } ) trainer = Trainer(**trainer_kwargs) diff --git a/scripts/pretrain/cifar/all4one.yaml b/scripts/pretrain/cifar/all4one.yaml index 7db7ea76..ccd7e7cd 100644 --- a/scripts/pretrain/cifar/all4one.yaml +++ b/scripts/pretrain/cifar/all4one.yaml @@ -28,7 +28,7 @@ data: dataset: cifar100 # change here for cifar10 train_path: "./datasets/" val_path: "./datasets/" - format: "image_folder" + format: "image_folder" num_workers: 4 optimizer: name: "lars" diff --git a/solo/methods/base.py b/solo/methods/base.py index d771b4f3..4020c7b7 100644 --- a/solo/methods/base.py +++ b/solo/methods/base.py @@ -389,9 +389,11 @@ def configure_optimizers(self) -> Tuple[List, List]: if idxs_no_scheduler: partial_fn = partial( static_lr, - get_lr=scheduler["scheduler"].get_lr - if isinstance(scheduler, dict) - else scheduler.get_lr, + get_lr=( + scheduler["scheduler"].get_lr + if isinstance(scheduler, dict) + else scheduler.get_lr + ), param_group_indexes=idxs_no_scheduler, lrs_to_replace=[self.lr] * len(idxs_no_scheduler), ) diff --git a/solo/utils/positional_encodings.py b/solo/utils/positional_encodings.py index e65be6a4..c72483bf 100644 --- a/solo/utils/positional_encodings.py +++ b/solo/utils/positional_encodings.py @@ -100,9 +100,7 @@ def forward(self, tensor): sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) emb_x = get_emb(sin_inp_x).unsqueeze(1) emb_y = get_emb(sin_inp_y) - emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( - tensor.type() - ) + emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(tensor.type()) emb[:, :, : self.channels] = emb_x emb[:, :, self.channels : 2 * self.channels] = emb_y @@ -165,9 +163,7 @@ def forward(self, tensor): emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1) emb_y = get_emb(sin_inp_y).unsqueeze(1) emb_z = get_emb(sin_inp_z) - emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( - tensor.type() - ) + emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(tensor.type()) emb[:, :, :, : self.channels] = emb_x emb[:, :, :, self.channels : 2 * self.channels] = emb_y emb[:, :, :, 2 * self.channels :] = emb_z