Skip to content

Commit

Permalink
add BYOLTrainer for huggingface accelerate distributed training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 16, 2023
1 parent 40565b7 commit 61621ed
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 2 deletions.
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,51 @@ imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)
```

## Distributed Training

The repository now offers distributed training with <a href="https://huggingface.co/docs/accelerate/index">🤗 Huggingface Accelerate</a>. You just have to pass in your own `Dataset` into the imported `BYOLTrainer`

First setup the configuration for distributed training by invoking the accelerate CLI

```bash
$ accelerate config
```

Then craft your training script as shown below, say in `./train.py`

```python
from torchvision import models

from byol_pytorch import (
BYOL,
BYOLTrainer,
MockDataset
)

resnet = models.resnet50(pretrained = True)

dataset = MockDataset(256, 10000)

trainer = BYOLTrainer(
resnet,
dataset = dataset,
image_size = 256,
hidden_layer = 'avgpool',
learning_rate = 3e-4,
num_train_steps = 100_000,
batch_size = 16,
checkpoint_every = 1000 # improved model will be saved periodically to ./checkpoints folder
)

trainer()
```

Then use the accelerate CLI again to launch the script

```bash
$ accelerate launch ./train.py
```

## Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.
Expand Down
1 change: 1 addition & 0 deletions byol_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from byol_pytorch.byol_pytorch import BYOL
from byol_pytorch.trainer import BYOLTrainer, MockDataset
1 change: 0 additions & 1 deletion byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def forward(
online_predictions = self.online_predictor(online_projections)

online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)
online_proj_one, online_proj_two = online_projections.chunk(2, dim = 0)

with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
Expand Down
134 changes: 134 additions & 0 deletions byol_pytorch/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from pathlib import Path

import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn import SyncBatchNorm

from torch.optim import Optimizer, Adam
from torch.utils.data import Dataset, DataLoader

from byol_pytorch.byol_pytorch import BYOL

from beartype import beartype
from beartype.typing import Optional

from accelerate import Accelerator

# functions

def exists(v):
return v is not None

def cycle(dl):
while True:
for batch in dl:
yield batch

# class

class MockDataset(Dataset):
def __init__(self, image_size, length):
self.length = length
self.image_size = image_size

def __len__(self):
return self.length

def __getitem__(self, idx):
return torch.randn(3, self.image_size, self.image_size)

# main trainer

class BYOLTrainer(Module):
@beartype
def __init__(
self,
net: Module,
*,
image_size: int,
hidden_layer: str,
learning_rate: float,
dataset: Dataset,
num_train_steps: int,
batch_size: int = 16,
optimizer_klass = Adam,
checkpoint_every: int = 1000,
checkpoint_folder: str = './checkpoints',
byol_kwargs: dict = dict(),
optimizer_kwargs: dict = dict(),
accelerator_kwargs: dict = dict(),
):
super().__init__()
self.accelerator = Accelerator(**accelerator_kwargs)

if dist.is_initialized() and dist.get_world_size() > 1:
net = SyncBatchNorm.convert_sync_batchnorm(net)

self.net = net

self.byol = BYOL(net, image_size = image_size, hidden_layer = hidden_layer, **byol_kwargs)

self.optimizer = optimizer_klass(self.byol.parameters(), lr = learning_rate, **optimizer_kwargs)

self.dataloader = DataLoader(dataset, shuffle = True, batch_size = batch_size)

self.num_train_steps = num_train_steps

self.checkpoint_every = checkpoint_every
self.checkpoint_folder = Path(checkpoint_folder)
self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
assert self.checkpoint_folder.is_dir()

# prepare with accelerate

(
self.byol,
self.optimizer,
self.dataloader
) = self.accelerator.prepare(
self.byol,
self.optimizer,
self.dataloader
)

self.register_buffer('step', torch.tensor(0))

def wait(self):
return self.accelerator.wait_for_everyone()

def print(self, msg):
return self.accelerator.print(msg)

def forward(self):
step = self.step.item()
data_it = cycle(self.dataloader)

for _ in range(self.num_train_steps):
images = next(data_it)

with self.accelerator.autocast():
loss = self.byol(images)
self.accelerator.backward(loss)

self.print(f'loss {loss.item():.3f}')

self.optimizer.zero_grad()
self.optimizer.step()

self.wait()

self.byol.update_moving_average()

self.wait()

if not (step % self.checkpoint_every) and self.accelerator.is_main_process:
checkpoint_num = step // self.checkpoint_every
checkpoint_path = self.checkpoint_folder / f'checkpoint.{checkpoint_num}.pt'
torch.save(self.net.state_dict(), str(checkpoint_path))

self.wait()

step += 1

self.print('training complete')
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'byol-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.2',
version = '0.8.0',
license='MIT',
description = 'Self-supervised contrastive learning made simple',
author = 'Phil Wang',
Expand All @@ -15,6 +15,8 @@
'artificial intelligence'
],
install_requires=[
'accelerate',
'beartype',
'torch>=1.6',
'torchvision>=0.8'
],
Expand Down

0 comments on commit 61621ed

Please sign in to comment.