Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 22, 2024
1 parent 21f89a6 commit 9b6c932
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 17 deletions.
1 change: 0 additions & 1 deletion docs/source/solo/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,3 @@ forward
~~~~~~~
.. automethod:: solo.utils.positional_encoding.Summer.forward
:noindex:

6 changes: 3 additions & 3 deletions main_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def main(cfg: DictConfig):
"logger": wandb_logger if cfg.wandb.enabled else None,
"callbacks": callbacks,
"enable_checkpointing": False,
"strategy": DDPStrategy(find_unused_parameters=False)
if cfg.strategy == "ddp"
else cfg.strategy,
"strategy": (
DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy
),
}
)
trainer = Trainer(**trainer_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,9 +221,9 @@ def main(cfg: DictConfig):
"logger": wandb_logger if cfg.wandb.enabled else None,
"callbacks": callbacks,
"enable_checkpointing": False,
"strategy": DDPStrategy(find_unused_parameters=False)
if cfg.strategy == "ddp"
else cfg.strategy,
"strategy": (
DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy
),
}
)
trainer = Trainer(**trainer_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion scripts/pretrain/cifar/all4one.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data:
dataset: cifar100 # change here for cifar10
train_path: "./datasets/"
val_path: "./datasets/"
format: "image_folder"
format: "image_folder"
num_workers: 4
optimizer:
name: "lars"
Expand Down
8 changes: 5 additions & 3 deletions solo/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,11 @@ def configure_optimizers(self) -> Tuple[List, List]:
if idxs_no_scheduler:
partial_fn = partial(
static_lr,
get_lr=scheduler["scheduler"].get_lr
if isinstance(scheduler, dict)
else scheduler.get_lr,
get_lr=(
scheduler["scheduler"].get_lr
if isinstance(scheduler, dict)
else scheduler.get_lr
),
param_group_indexes=idxs_no_scheduler,
lrs_to_replace=[self.lr] * len(idxs_no_scheduler),
)
Expand Down
8 changes: 2 additions & 6 deletions solo/utils/positional_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,7 @@ def forward(self, tensor):
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1)
emb_y = get_emb(sin_inp_y)
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
tensor.type()
)
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(tensor.type())

Check warning on line 103 in solo/utils/positional_encodings.py

View check run for this annotation

Codecov / codecov/patch

solo/utils/positional_encodings.py#L103

Added line #L103 was not covered by tests
emb[:, :, : self.channels] = emb_x
emb[:, :, self.channels : 2 * self.channels] = emb_y

Expand Down Expand Up @@ -165,9 +163,7 @@ def forward(self, tensor):
emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1)
emb_y = get_emb(sin_inp_y).unsqueeze(1)
emb_z = get_emb(sin_inp_z)
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(
tensor.type()
)
emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(tensor.type())

Check warning on line 166 in solo/utils/positional_encodings.py

View check run for this annotation

Codecov / codecov/patch

solo/utils/positional_encodings.py#L166

Added line #L166 was not covered by tests
emb[:, :, :, : self.channels] = emb_x
emb[:, :, :, self.channels : 2 * self.channels] = emb_y
emb[:, :, :, 2 * self.channels :] = emb_z
Expand Down

0 comments on commit 9b6c932

Please sign in to comment.