-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
239 additions
and
161 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -159,3 +159,6 @@ coverage.xml | |
# ignore testmon and coverage files | ||
.coverage | ||
.testmondata* | ||
|
||
# ignore data files | ||
datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import logging | ||
|
||
import torch.distributed as dist | ||
|
||
|
||
def create_logger(logging_dir): | ||
""" | ||
Create a logger that writes to a log file and stdout. | ||
""" | ||
if dist.get_rank() == 0: # real logger | ||
logging.basicConfig( | ||
level=logging.INFO, | ||
format="[\033[34m%(asctime)s\033[0m] %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")], | ||
) | ||
logger = logging.getLogger(__name__) | ||
else: # dummy logger (does nothing) | ||
logger = logging.getLogger(__name__) | ||
logger.addHandler(logging.NullHandler()) | ||
return logger |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import random | ||
from typing import Iterator, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
from torch.distributed import ProcessGroup | ||
from torch.distributed.distributed_c10d import _get_default_group | ||
from torch.utils.data import DataLoader, Dataset, DistributedSampler | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
|
||
class StatefulDistributedSampler(DistributedSampler): | ||
def __init__( | ||
self, | ||
dataset: Dataset, | ||
num_replicas: Optional[int] = None, | ||
rank: Optional[int] = None, | ||
shuffle: bool = True, | ||
seed: int = 0, | ||
drop_last: bool = False, | ||
) -> None: | ||
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last) | ||
self.start_index: int = 0 | ||
|
||
def __iter__(self) -> Iterator: | ||
iterator = super().__iter__() | ||
indices = list(iterator) | ||
indices = indices[self.start_index :] | ||
return iter(indices) | ||
|
||
def __len__(self) -> int: | ||
return self.num_samples - self.start_index | ||
|
||
def set_start_index(self, start_index: int) -> None: | ||
self.start_index = start_index | ||
|
||
|
||
def prepare_dataloader( | ||
dataset, | ||
batch_size, | ||
shuffle=False, | ||
seed=1024, | ||
drop_last=False, | ||
pin_memory=False, | ||
num_workers=0, | ||
process_group: Optional[ProcessGroup] = None, | ||
**kwargs, | ||
): | ||
r""" | ||
Prepare a dataloader for distributed training. The dataloader will be wrapped by | ||
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`. | ||
Args: | ||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded. | ||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. | ||
seed (int, optional): Random worker seed for sampling, defaults to 1024. | ||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. | ||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size | ||
is not divisible by the batch size. If False and the size of dataset is not divisible by | ||
the batch size, then the last batch will be smaller, defaults to False. | ||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. | ||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. | ||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in | ||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_. | ||
Returns: | ||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. | ||
""" | ||
_kwargs = kwargs.copy() | ||
process_group = process_group or _get_default_group() | ||
sampler = StatefulDistributedSampler( | ||
dataset, num_replicas=process_group.size(), rank=process_group.rank(), shuffle=shuffle | ||
) | ||
|
||
# Deterministic dataloader | ||
def seed_worker(worker_id): | ||
worker_seed = seed | ||
np.random.seed(worker_seed) | ||
torch.manual_seed(worker_seed) | ||
random.seed(worker_seed) | ||
|
||
return DataLoader( | ||
dataset, | ||
batch_size=batch_size, | ||
sampler=sampler, | ||
worker_init_fn=seed_worker, | ||
drop_last=drop_last, | ||
pin_memory=pin_memory, | ||
num_workers=num_workers, | ||
**_kwargs, | ||
) | ||
|
||
|
||
def center_crop_arr(pil_image, image_size): | ||
""" | ||
Center cropping implementation from ADM. | ||
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 | ||
""" | ||
while min(*pil_image.size) >= 2 * image_size: | ||
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) | ||
|
||
scale = image_size / min(*pil_image.size) | ||
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) | ||
|
||
arr = np.array(pil_image) | ||
crop_y = (arr.shape[0] - image_size) // 2 | ||
crop_x = (arr.shape[1] - image_size) // 2 | ||
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) |
Oops, something went wrong.