From 61621ed3d6816f5fec62d934c7853d3293d812ed Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 16 Nov 2023 11:20:12 -0800 Subject: [PATCH] add BYOLTrainer for huggingface accelerate distributed training --- README.md | 45 ++++++++++++ byol_pytorch/__init__.py | 1 + byol_pytorch/byol_pytorch.py | 1 - byol_pytorch/trainer.py | 134 +++++++++++++++++++++++++++++++++++ setup.py | 4 +- 5 files changed, 183 insertions(+), 2 deletions(-) create mode 100644 byol_pytorch/trainer.py diff --git a/README.md b/README.md index 5447bda64..770b45099 100644 --- a/README.md +++ b/README.md @@ -164,6 +164,51 @@ imgs = torch.randn(2, 3, 256, 256) projection, embedding = learner(imgs, return_embedding = True) ``` +## Distributed Training + +The repository now offers distributed training with 🤗 Huggingface Accelerate. You just have to pass in your own `Dataset` into the imported `BYOLTrainer` + +First setup the configuration for distributed training by invoking the accelerate CLI + +```bash +$ accelerate config +``` + +Then craft your training script as shown below, say in `./train.py` + +```python +from torchvision import models + +from byol_pytorch import ( + BYOL, + BYOLTrainer, + MockDataset +) + +resnet = models.resnet50(pretrained = True) + +dataset = MockDataset(256, 10000) + +trainer = BYOLTrainer( + resnet, + dataset = dataset, + image_size = 256, + hidden_layer = 'avgpool', + learning_rate = 3e-4, + num_train_steps = 100_000, + batch_size = 16, + checkpoint_every = 1000 # improved model will be saved periodically to ./checkpoints folder +) + +trainer() +``` + +Then use the accelerate CLI again to launch the script + +```bash +$ accelerate launch ./train.py +``` + ## Alternatives If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning. diff --git a/byol_pytorch/__init__.py b/byol_pytorch/__init__.py index 78f0c5ac8..afbaef66b 100644 --- a/byol_pytorch/__init__.py +++ b/byol_pytorch/__init__.py @@ -1 +1,2 @@ from byol_pytorch.byol_pytorch import BYOL +from byol_pytorch.trainer import BYOLTrainer, MockDataset diff --git a/byol_pytorch/byol_pytorch.py b/byol_pytorch/byol_pytorch.py index e3150b6b6..4b843725e 100644 --- a/byol_pytorch/byol_pytorch.py +++ b/byol_pytorch/byol_pytorch.py @@ -268,7 +268,6 @@ def forward( online_predictions = self.online_predictor(online_projections) online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0) - online_proj_one, online_proj_two = online_projections.chunk(2, dim = 0) with torch.no_grad(): target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder diff --git a/byol_pytorch/trainer.py b/byol_pytorch/trainer.py new file mode 100644 index 000000000..eb1ffec0a --- /dev/null +++ b/byol_pytorch/trainer.py @@ -0,0 +1,134 @@ +from pathlib import Path + +import torch +import torch.distributed as dist +from torch.nn import Module +from torch.nn import SyncBatchNorm + +from torch.optim import Optimizer, Adam +from torch.utils.data import Dataset, DataLoader + +from byol_pytorch.byol_pytorch import BYOL + +from beartype import beartype +from beartype.typing import Optional + +from accelerate import Accelerator + +# functions + +def exists(v): + return v is not None + +def cycle(dl): + while True: + for batch in dl: + yield batch + +# class + +class MockDataset(Dataset): + def __init__(self, image_size, length): + self.length = length + self.image_size = image_size + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return torch.randn(3, self.image_size, self.image_size) + +# main trainer + +class BYOLTrainer(Module): + @beartype + def __init__( + self, + net: Module, + *, + image_size: int, + hidden_layer: str, + learning_rate: float, + dataset: Dataset, + num_train_steps: int, + batch_size: int = 16, + optimizer_klass = Adam, + checkpoint_every: int = 1000, + checkpoint_folder: str = './checkpoints', + byol_kwargs: dict = dict(), + optimizer_kwargs: dict = dict(), + accelerator_kwargs: dict = dict(), + ): + super().__init__() + self.accelerator = Accelerator(**accelerator_kwargs) + + if dist.is_initialized() and dist.get_world_size() > 1: + net = SyncBatchNorm.convert_sync_batchnorm(net) + + self.net = net + + self.byol = BYOL(net, image_size = image_size, hidden_layer = hidden_layer, **byol_kwargs) + + self.optimizer = optimizer_klass(self.byol.parameters(), lr = learning_rate, **optimizer_kwargs) + + self.dataloader = DataLoader(dataset, shuffle = True, batch_size = batch_size) + + self.num_train_steps = num_train_steps + + self.checkpoint_every = checkpoint_every + self.checkpoint_folder = Path(checkpoint_folder) + self.checkpoint_folder.mkdir(exist_ok = True, parents = True) + assert self.checkpoint_folder.is_dir() + + # prepare with accelerate + + ( + self.byol, + self.optimizer, + self.dataloader + ) = self.accelerator.prepare( + self.byol, + self.optimizer, + self.dataloader + ) + + self.register_buffer('step', torch.tensor(0)) + + def wait(self): + return self.accelerator.wait_for_everyone() + + def print(self, msg): + return self.accelerator.print(msg) + + def forward(self): + step = self.step.item() + data_it = cycle(self.dataloader) + + for _ in range(self.num_train_steps): + images = next(data_it) + + with self.accelerator.autocast(): + loss = self.byol(images) + self.accelerator.backward(loss) + + self.print(f'loss {loss.item():.3f}') + + self.optimizer.zero_grad() + self.optimizer.step() + + self.wait() + + self.byol.update_moving_average() + + self.wait() + + if not (step % self.checkpoint_every) and self.accelerator.is_main_process: + checkpoint_num = step // self.checkpoint_every + checkpoint_path = self.checkpoint_folder / f'checkpoint.{checkpoint_num}.pt' + torch.save(self.net.state_dict(), str(checkpoint_path)) + + self.wait() + + step += 1 + + self.print('training complete') diff --git a/setup.py b/setup.py index 2e9173eea..96cc77df2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'byol-pytorch', packages = find_packages(exclude=['examples']), - version = '0.7.2', + version = '0.8.0', license='MIT', description = 'Self-supervised contrastive learning made simple', author = 'Phil Wang', @@ -15,6 +15,8 @@ 'artificial intelligence' ], install_requires=[ + 'accelerate', + 'beartype', 'torch>=1.6', 'torchvision>=0.8' ],