Skip to content

Commit

Permalink
Remove scattering for multi-GPU training. (allenai#2200)
Browse files Browse the repository at this point in the history
- Instead just pull off a batch for each GPU.
- Enables increasing the effective batch size for `bidirectional_language_model.jsonnet` by 2x giving a 1.5x speedup.
  • Loading branch information
brendan-ai2 authored Jan 18, 2019
1 parent d0a5a40 commit 7525c61
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 139 deletions.
13 changes: 8 additions & 5 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.checks import ConfigurationError, check_for_gpu
from allennlp.common import Params, Tqdm
from allennlp.common.util import prepare_environment
from allennlp.common.util import prepare_environment, lazy_groups_of
from allennlp.data import Vocabulary, DataIterator
from allennlp.models import Model
from allennlp.training import Trainer
Expand Down Expand Up @@ -263,8 +263,11 @@ def search_learning_rate(trainer: Trainer,

trainer.model.train()

train_generator = trainer.iterator(trainer.train_data,
shuffle=trainer.shuffle)
num_gpus = len(trainer._cuda_devices) # pylint: disable=protected-access

raw_train_generator = trainer.iterator(trainer.train_data,
shuffle=trainer.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
train_generator_tqdm = Tqdm.tqdm(train_generator,
total=num_batches)

Expand All @@ -276,7 +279,7 @@ def search_learning_rate(trainer: Trainer,
else:
lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches)

for i, batch in enumerate(train_generator_tqdm):
for i, batch_group in enumerate(train_generator_tqdm):

if linear_steps:
current_lr = start_lr + (lr_update_factor * i)
Expand All @@ -287,7 +290,7 @@ def search_learning_rate(trainer: Trainer,
param_group['lr'] = current_lr

trainer.optimizer.zero_grad()
loss = trainer.batch_loss(batch, for_training=True)
loss = trainer.batch_loss(batch_group, for_training=True)
loss.backward()
loss = loss.detach().cpu().item()

Expand Down
95 changes: 0 additions & 95 deletions allennlp/common/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Various utilities that don't fit anwhere else.
"""
from ctypes import sizeof, c_void_p, c_int64, cast, py_object, c_uint64
from itertools import zip_longest, islice
from typing import Any, Callable, Dict, List, Tuple, TypeVar, Iterable, Iterator, Union
import importlib
Expand All @@ -14,8 +13,6 @@
import os
import re

from torch.nn.parallel._functions import Scatter

try:
import resource
except ImportError:
Expand Down Expand Up @@ -392,98 +389,6 @@ def from_list(strings):
# TODO(brendanr): Determine why mypy can't tell that this matches the Union.
return int(cuda_device) # type: ignore

class ScatterableList(list):
"""
A normal list, but one that should be scattered like a tensor.
"""

# Ensure pointers will fit in a torch.LongTensor. "64 bits ought to be enough for anybody."
assert sizeof(c_void_p) <= sizeof(c_int64)

def to_pointer_tensor(self) -> torch.LongTensor:
"""
Converts the elements to pointers, casts them to ``int64`` and then returns them in a tensor. This cast is
important as ``id`` gives back unsigned integers while ``torch.LongTensor`` is signed.
See:
https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Python/bltinmodule.c#L1118
https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Objects/longobject.c#L990
"""
pointers = [c_int64(id(element)).value for element in self]
return torch.LongTensor(pointers)

@classmethod
def from_pointer_tensor(cls, pointers: torch.LongTensor) -> list:
"""
The inverse of ``to_pointer_tensor`` except that a plain ``list`` is returned. Typically this will be
called on a single chunk of the scattered tensor.
Parameters
----------
pointers : ``torch.LongTensor``, required.
A tensor of shape (list_length,).
"""
return [cast(c_uint64(pointer.item()).value, py_object).value for pointer in pointers]

def scatter(inputs, target_gpus, dim=0):
"""
Slices tensors and ScatterableLists into approximately equal chunks and distributes them across given GPUs.
Duplicates references to objects that are not tensors or ScatterableLists.
Adapted from `scatter` at:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L5-L30.
Please see the LICENSE and NOTICE files as well:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE
"""
def scatter_map(obj):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if isinstance(obj, ScatterableList):
# In order to have precisely the same method of scattering as PyTorch we scatter
# a tensor of pointers.
pointers = scatter_map(obj.to_pointer_tensor())
# Then we reconstruct the lists from the pointer tensors.
return [obj.from_pointer_tensor(chunk) for chunk in pointers]
if isinstance(obj, tuple) and obj:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and obj:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict) and obj:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for _ in target_gpus]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
"""Scatter with support for kwargs dictionary.
Adapted from `scatter_kwargs` at:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L33-L43
Please see the LICENSE and NOTICE files as well:
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE
https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE
"""
inputs = scatter(inputs, target_gpus, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List:
frozen_parameter_names = []
tunable_parameter_names = []
Expand Down
5 changes: 2 additions & 3 deletions allennlp/data/fields/metadata_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from overrides import overrides

from allennlp.common.util import ScatterableList
from allennlp.data.fields.field import DataArray, Field


Expand Down Expand Up @@ -61,8 +60,8 @@ def empty_field(self) -> 'MetadataField':

@classmethod
@overrides
def batch_tensors(cls, tensor_list: List[DataArray]) -> ScatterableList: # type: ignore
return ScatterableList(tensor_list)
def batch_tensors(cls, tensor_list: List[DataArray]) -> List[DataArray]: # type: ignore
return tensor_list


def __str__(self) -> str:
Expand Down
5 changes: 2 additions & 3 deletions allennlp/data/fields/production_rule_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from overrides import overrides

from allennlp.common.util import ScatterableList
from allennlp.data.fields.field import Field
from allennlp.data.vocabulary import Vocabulary

Expand Down Expand Up @@ -114,9 +113,9 @@ def empty_field(self): # pylint: disable=no-self-use
return ProductionRuleField(rule='', is_global_rule=False)

@overrides
def batch_tensors(self, tensor_list: List[ProductionRule]) -> ScatterableList: # type: ignore
def batch_tensors(self, tensor_list: List[ProductionRule]) -> List[ProductionRule]: # type: ignore
# pylint: disable=no-self-use
return ScatterableList(tensor_list)
return tensor_list

def __str__(self) -> str:
return f"ProductionRuleField with rule: {self.rule} (is_global_rule: " \
Expand Down
3 changes: 3 additions & 0 deletions allennlp/data/iterators/bucket_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Itera
if excess:
batches.append(Batch(excess))

# TODO(brendanr): Add multi-GPU friendly grouping, i.e. group
# num_gpu batches together, shuffle and then expand the groups.
# This guards against imbalanced batches across GPUs.
move_to_front = self._biggest_batch_first and len(batches) > 1
if move_to_front:
# We'll actually pop the last _two_ batches, because the last one might not be full.
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/iterators/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __call__(self,
tensor_dicts = self._cache[key]

if shuffle:
# TODO(brendanr): How can we handle this shuffle in a way
# that respects multi-GPU friendly grouping?
random.shuffle(tensor_dicts)
for tensor_dict in tensor_dicts:
if self._track_epoch:
Expand Down
8 changes: 4 additions & 4 deletions allennlp/tests/models/simple_tagger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_regularization(self):
training_batch = next(iterator(self.instances, num_epochs=1))
validation_batch = next(iterator(self.instances, num_epochs=1))

training_loss = trainer.batch_loss(training_batch, for_training=True).item()
validation_loss = trainer.batch_loss(validation_batch, for_training=False).item()
training_loss = trainer.batch_loss([training_batch], for_training=True).item()
validation_loss = trainer.batch_loss([validation_batch], for_training=False).item()

# Training loss should have the regularization penalty, but validation loss should not.
numpy.testing.assert_almost_equal(training_loss, validation_loss)
Expand Down Expand Up @@ -116,8 +116,8 @@ def test_regularization(self):
training_batch = next(self.iterator(self.instances, num_epochs=1))
validation_batch = next(self.iterator(self.instances, num_epochs=1))

training_loss = self.trainer.batch_loss(training_batch, for_training=True).data
validation_loss = self.trainer.batch_loss(validation_batch, for_training=False).data
training_loss = self.trainer.batch_loss([training_batch], for_training=True).data
validation_loss = self.trainer.batch_loss([validation_batch], for_training=False).data

# Training loss should have the regularization penalty, but validation loss should not.
assert (training_loss != validation_loss).all()
Expand Down
19 changes: 18 additions & 1 deletion allennlp/tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from allennlp.common.params import Params
from allennlp.models.simple_tagger import SimpleTagger
from allennlp.data.iterators import BasicIterator
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader
from allennlp.data.dataset_readers import SequenceTaggingDatasetReader, WikiTablesDatasetReader
from allennlp.models.archival import load_archive
from allennlp.models.model import Model


Expand Down Expand Up @@ -133,6 +134,22 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint
assert 'peak_gpu_1_memory_MB' in metrics
assert isinstance(metrics['peak_gpu_1_memory_MB'], int)

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need multiple GPUs.")
def test_production_rule_field_with_multiple_gpus(self):
wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir,
dpd_output_directory=wikitables_dir + 'dpd_output/')
instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples')
archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
model = load_archive(archive_path).model
model.cuda()

multigpu_iterator = BasicIterator(batch_size=4)
multigpu_iterator.index_with(model.vocab)
trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1])
trainer.train()

def test_trainer_can_resume_training(self):
trainer = Trainer(self.model, self.optimizer,
self.iterator, self.instances,
Expand Down
47 changes: 29 additions & 18 deletions allennlp/training/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import logging
import math
import os
import time
import re
Expand All @@ -13,10 +14,10 @@
from allennlp.common import Params
from allennlp.common.checks import ConfigurationError
from allennlp.common.util import (dump_metrics, gpu_memory_mb, parse_cuda_device, peak_memory_mb,
get_frozen_and_tunable_parameter_names)
get_frozen_and_tunable_parameter_names, lazy_groups_of)
from allennlp.common.tqdm import Tqdm
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
from allennlp.data.vocabulary import Vocabulary
from allennlp.models.model import Model
from allennlp.nn import util as nn_util
Expand Down Expand Up @@ -216,14 +217,16 @@ def __init__(self,
def rescale_gradients(self) -> Optional[float]:
return training_util.rescale_gradients(self.model, self._grad_norm)

def batch_loss(self, batch: torch.Tensor, for_training: bool) -> torch.Tensor:
def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
"""
Does a forward pass on the given batch and returns the ``loss`` value in the result.
Does a forward pass on the given batches and returns the ``loss`` value in the result.
If ``for_training`` is `True` also applies regularization penalty.
"""
if self._multiple_gpu:
output_dict = training_util.data_parallel(batch, self.model, self._cuda_devices)
output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
else:
assert len(batch_group) == 1
batch = batch_group[0]
batch = nn_util.move_to_device(batch, self._cuda_devices[0])
output_dict = self.model(**batch)

Expand Down Expand Up @@ -255,11 +258,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
# Set the model to "train" mode.
self.model.train()

num_gpus = len(self._cuda_devices)

# Get tqdm for the training batches
train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
num_training_batches = self.iterator.get_num_batches(self.train_data)
raw_train_generator = self.iterator(self.train_data,
num_epochs=1,
shuffle=self.shuffle)
train_generator = lazy_groups_of(raw_train_generator, num_gpus)
num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/num_gpus)
self._last_log = time.time()
last_save_time = time.time()

Expand All @@ -269,18 +275,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:

histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())


logger.info("Training")
train_generator_tqdm = Tqdm.tqdm(train_generator,
total=num_training_batches)
cumulative_batch_size = 0
for batch in train_generator_tqdm:
for batch_group in train_generator_tqdm:
batches_this_epoch += 1
self._batch_num_total += 1
batch_num_total = self._batch_num_total

self.optimizer.zero_grad()

loss = self.batch_loss(batch, for_training=True)
loss = self.batch_loss(batch_group, for_training=True)

if torch.isnan(loss):
raise ValueError("nan loss encountered")

Expand Down Expand Up @@ -329,7 +337,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]:
self._tensorboard.log_histograms(self.model, histogram_parameters)

if self._log_batch_size_period:
cur_batch = training_util.get_batch_size(batch)
cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
cumulative_batch_size += cur_batch
if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
average = cumulative_batch_size/batches_this_epoch
Expand Down Expand Up @@ -365,17 +373,20 @@ def _validation_loss(self) -> Tuple[float, int]:
else:
val_iterator = self.iterator

val_generator = val_iterator(self._validation_data,
num_epochs=1,
shuffle=False)
num_validation_batches = val_iterator.get_num_batches(self._validation_data)
num_gpus = len(self._cuda_devices)

raw_val_generator = val_iterator(self._validation_data,
num_epochs=1,
shuffle=False)
val_generator = lazy_groups_of(raw_val_generator, num_gpus)
num_validation_batches = math.ceil(val_iterator.get_num_batches(self._validation_data)/num_gpus)
val_generator_tqdm = Tqdm.tqdm(val_generator,
total=num_validation_batches)
batches_this_epoch = 0
val_loss = 0
for batch in val_generator_tqdm:
for batch_group in val_generator_tqdm:

loss = self.batch_loss(batch, for_training=False)
loss = self.batch_loss(batch_group, for_training=False)
if loss is not None:
# You shouldn't necessarily have to compute a loss for validation, so we allow for
# `loss` to be None. We need to be careful, though - `batches_this_epoch` is
Expand Down
Loading

0 comments on commit 7525c61

Please sign in to comment.