Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FroSSL Implementation #390

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion main_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def main(cfg: DictConfig):
assert cfg.method in METHODS, f"Choose from {METHODS.keys()}"

if cfg.data.num_large_crops != 2:
assert cfg.method in ["wmse", "mae"]
assert cfg.method in ["wmse", "mae", "frossl"]

model = METHODS[cfg.method](cfg)
make_contiguous(model)
Expand Down
54 changes: 54 additions & 0 deletions scripts/pretrain/cifar/frossl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
defaults:
- _self_
- augmentations: symmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "frossl-cifar10" # change here for cifar100
method: "frossl"
backbone:
name: "resnet18"
method_kwargs:
proj_hidden_dim: 2048
proj_output_dim: 1024
invariance_weight: 1.4

data:
dataset: cifar10 # change here for cifar100
train_path: "./datasets"
val_path: "./datasets"
format: "image_folder"
num_workers: 8
optimizer:
name: "lars"
batch_size: 256
lr: 0.3
classifier_lr: 0.1
weight_decay: 1e-4
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 1000
devices: [0]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
53 changes: 53 additions & 0 deletions scripts/pretrain/imagenet-100/frossl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
defaults:
- _self_
- augmentations: symmetric.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "frossl-imagenet100"
method: "frossl"
backbone:
name: "resnet18"
method_kwargs:
proj_hidden_dim: 2048
proj_output_dim: 1024
invariance_weight: 2.0
data:
dataset: imagenet100
train_path: "./datasets/imagenet100/train"
val_path: "./datasets/imagenet100/val"
format: "dali"
num_workers: 16
optimizer:
name: "lars"
batch_size: 256
lr: 0.3
classifier_lr: 0.1
weight_decay: 1e-4
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: False

# overwrite PL stuff
max_epochs: 400
devices: [0, 1]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
54 changes: 54 additions & 0 deletions scripts/pretrain/imagenet/frossl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
defaults:
- _self_
- augmentations: vicreg.yaml
- wandb: private.yaml
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

# disable hydra outputs
hydra:
output_subdir: null
run:
dir: .

name: "frossl-imagenet"
method: "frossl"
backbone:
name: "resnet18"
method_kwargs:
proj_hidden_dim: 2048
proj_output_dim: 1024
invariance_weight: 2.0

data:
dataset: imagenet
train_path: "./datasets/imagenet/train"
val_path: "./datasets/imagenet/val"
format: "dali"
num_workers: 8
optimizer:
name: "lars"
batch_size: 256
lr: 0.3
classifier_lr: 0.1
weight_decay: 1e-4
kwargs:
clip_lr: True
eta: 0.02
exclude_bias_n_norm: True
scheduler:
name: "warmup_cosine"
checkpoint:
enabled: True
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True

# overwrite PL stuff
max_epochs: 100
devices: [0, 1]
sync_batchnorm: True
accelerator: "gpu"
strategy: "ddp"
precision: 16-mixed
2 changes: 2 additions & 0 deletions solo/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from solo.losses.byol import byol_loss_func
from solo.losses.deepclusterv2 import deepclusterv2_loss_func
from solo.losses.dino import DINOLoss
from solo.losses.frossl import frossl_loss_func
from solo.losses.mae import mae_loss_func
from solo.losses.mocov2plus import mocov2plus_loss_func
from solo.losses.mocov3 import mocov3_loss_func
Expand All @@ -38,6 +39,7 @@
"byol_loss_func",
"deepclusterv2_loss_func",
"DINOLoss",
"frossl_loss_func",
"mae_loss_func",
"mocov2plus_loss_func",
"mocov3_loss_func",
Expand Down
89 changes: 89 additions & 0 deletions solo/losses/frossl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2024 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, List, Sequence, Dict
import torch
import torch.distributed as dist
import torch.nn.functional as F

def calculate_frobenius_regularization_term(z: torch.Tensor) -> torch.Tensor:
V, N, D = z.shape

if N > D:
cov = torch.matmul(z.transpose(1, 2), z) # V x D x D
else:
cov = torch.matmul(z, z.transpose(1, 2)) # V x N x N

# divide each view covariance by its trace
trace = torch.diagonal(cov, dim1=1, dim2=2) # V x D
trace = torch.sum(trace, dim=1) # V x 1
cov = cov / trace.unsqueeze(-1).unsqueeze(-1)

# REGULARIZATION TERM - sum the log-frobenius norm of each view covariance matrix
fro_norm_per_view = torch.linalg.norm(cov, dim=(1,2), ord='fro') # V x 1
regularization_term = -torch.sum( 2*torch.log(fro_norm_per_view) ) # we bring frobenius square outside log

return regularization_term

def calculate_invariance_term(z: torch.Tensor) -> torch.Tensor:
V, N, D = z.shape

# INVARIANCE - align each view to the average view
average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views
average_z = average_z.repeat(V, 1, 1) # V x N x D
invariance_loss_term = F.mse_loss(z, average_z)

return invariance_loss_term

def frossl_loss_func(
z: torch.Tensor, invariance_weight=1, logger=None
) -> torch.Tensor:
"""
Implements FroSSL (https://arxiv.org/pdf/2310.02903)
Heavily adapted from https://github.com/OFSkean/FroSSL. The main difference is that this
implementation stacks the views and operates on all of them at once, rather than one at a time.
This saves ~2 seconds (about 5% improvement) per batch with N=2,D=1024 on a A5000 GPU. For a simpler,
ableit slower, implementation of loss that operates on one view at a time, please see
the original implementation.

Args:
z (torch.Tensor): V x N x D Tensor containing projected features from the views.
Every N-th sample is a different view of the same image.
invariance_weight (float): weight for the invariance loss term. default is 1.

Return:
torch.Tensor: FroSSL loss.
"""
V, N, D = z.shape

z = F.normalize(z, dim=1) # V x N x D

regularization_term = calculate_frobenius_regularization_term(z)
regularization_term = -1 * regularization_term # make sure its maximized

invariance_tradeoff = V * D * invariance_weight
invariance_term = calculate_invariance_term(z)
invariance_term = invariance_tradeoff * invariance_term

if logger is not None:
logger("frossl_regularization_loss", -regularization_term, sync_dist=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be logged as - regularization_term? I would also say that for consistency with the other loss functions, it's better if you return the individual elements and log them on the method's side. You could return the two terms and sum them in the method.

logger("frossl_invariance_loss", invariance_term, sync_dist=True)

total_loss = regularization_term + invariance_term
return total_loss
2 changes: 2 additions & 0 deletions solo/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from solo.methods.byol import BYOL
from solo.methods.deepclusterv2 import DeepClusterV2
from solo.methods.dino import DINO
from solo.methods.frossl import FroSSL
from solo.methods.linear import LinearModel
from solo.methods.mae import MAE
from solo.methods.mocov2plus import MoCoV2Plus
Expand Down Expand Up @@ -49,6 +50,7 @@
"byol": BYOL,
"deepclusterv2": DeepClusterV2,
"dino": DINO,
"frossl": FroSSL,
"mae": MAE,
"mocov2plus": MoCoV2Plus,
"mocov3": MoCoV3,
Expand Down
Loading
Loading