From 7525c610e6b77247222099b59f21a227f1c3545c Mon Sep 17 00:00:00 2001 From: Brendan Roof Date: Fri, 18 Jan 2019 15:18:29 -0800 Subject: [PATCH] Remove scattering for multi-GPU training. (#2200) - Instead just pull off a batch for each GPU. - Enables increasing the effective batch size for `bidirectional_language_model.jsonnet` by 2x giving a 1.5x speedup. --- allennlp/commands/find_learning_rate.py | 13 ++- allennlp/common/util.py | 95 ------------------- allennlp/data/fields/metadata_field.py | 5 +- allennlp/data/fields/production_rule_field.py | 5 +- allennlp/data/iterators/bucket_iterator.py | 3 + allennlp/data/iterators/data_iterator.py | 2 + allennlp/tests/models/simple_tagger_test.py | 8 +- allennlp/tests/training/trainer_test.py | 19 +++- allennlp/training/trainer.py | 47 +++++---- allennlp/training/util.py | 19 ++-- .../bidirectional_language_model.jsonnet | 9 +- 11 files changed, 86 insertions(+), 139 deletions(-) diff --git a/allennlp/commands/find_learning_rate.py b/allennlp/commands/find_learning_rate.py index fb5d3ce5f13..ba23edde7e7 100644 --- a/allennlp/commands/find_learning_rate.py +++ b/allennlp/commands/find_learning_rate.py @@ -58,7 +58,7 @@ from allennlp.commands.subcommand import Subcommand from allennlp.common.checks import ConfigurationError, check_for_gpu from allennlp.common import Params, Tqdm -from allennlp.common.util import prepare_environment +from allennlp.common.util import prepare_environment, lazy_groups_of from allennlp.data import Vocabulary, DataIterator from allennlp.models import Model from allennlp.training import Trainer @@ -263,8 +263,11 @@ def search_learning_rate(trainer: Trainer, trainer.model.train() - train_generator = trainer.iterator(trainer.train_data, - shuffle=trainer.shuffle) + num_gpus = len(trainer._cuda_devices) # pylint: disable=protected-access + + raw_train_generator = trainer.iterator(trainer.train_data, + shuffle=trainer.shuffle) + train_generator = lazy_groups_of(raw_train_generator, num_gpus) train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_batches) @@ -276,7 +279,7 @@ def search_learning_rate(trainer: Trainer, else: lr_update_factor = (end_lr / start_lr) ** (1.0 / num_batches) - for i, batch in enumerate(train_generator_tqdm): + for i, batch_group in enumerate(train_generator_tqdm): if linear_steps: current_lr = start_lr + (lr_update_factor * i) @@ -287,7 +290,7 @@ def search_learning_rate(trainer: Trainer, param_group['lr'] = current_lr trainer.optimizer.zero_grad() - loss = trainer.batch_loss(batch, for_training=True) + loss = trainer.batch_loss(batch_group, for_training=True) loss.backward() loss = loss.detach().cpu().item() diff --git a/allennlp/common/util.py b/allennlp/common/util.py index 267a8109e69..0bf9ecce65f 100644 --- a/allennlp/common/util.py +++ b/allennlp/common/util.py @@ -1,7 +1,6 @@ """ Various utilities that don't fit anwhere else. """ -from ctypes import sizeof, c_void_p, c_int64, cast, py_object, c_uint64 from itertools import zip_longest, islice from typing import Any, Callable, Dict, List, Tuple, TypeVar, Iterable, Iterator, Union import importlib @@ -14,8 +13,6 @@ import os import re -from torch.nn.parallel._functions import Scatter - try: import resource except ImportError: @@ -392,98 +389,6 @@ def from_list(strings): # TODO(brendanr): Determine why mypy can't tell that this matches the Union. return int(cuda_device) # type: ignore -class ScatterableList(list): - """ - A normal list, but one that should be scattered like a tensor. - """ - - # Ensure pointers will fit in a torch.LongTensor. "64 bits ought to be enough for anybody." - assert sizeof(c_void_p) <= sizeof(c_int64) - - def to_pointer_tensor(self) -> torch.LongTensor: - """ - Converts the elements to pointers, casts them to ``int64`` and then returns them in a tensor. This cast is - important as ``id`` gives back unsigned integers while ``torch.LongTensor`` is signed. - - See: - https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Python/bltinmodule.c#L1118 - https://github.com/python/cpython/blob/6ec5cf24b7f38ea72bb42d5cd60dca0d3ee332f9/Objects/longobject.c#L990 - """ - pointers = [c_int64(id(element)).value for element in self] - return torch.LongTensor(pointers) - - @classmethod - def from_pointer_tensor(cls, pointers: torch.LongTensor) -> list: - """ - The inverse of ``to_pointer_tensor`` except that a plain ``list`` is returned. Typically this will be - called on a single chunk of the scattered tensor. - - Parameters - ---------- - pointers : ``torch.LongTensor``, required. - A tensor of shape (list_length,). - """ - return [cast(c_uint64(pointer.item()).value, py_object).value for pointer in pointers] - -def scatter(inputs, target_gpus, dim=0): - """ - Slices tensors and ScatterableLists into approximately equal chunks and distributes them across given GPUs. - Duplicates references to objects that are not tensors or ScatterableLists. - - Adapted from `scatter` at: - https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L5-L30. - - Please see the LICENSE and NOTICE files as well: - https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE - https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE - """ - def scatter_map(obj): - if isinstance(obj, torch.Tensor): - return Scatter.apply(target_gpus, None, dim, obj) - if isinstance(obj, ScatterableList): - # In order to have precisely the same method of scattering as PyTorch we scatter - # a tensor of pointers. - pointers = scatter_map(obj.to_pointer_tensor()) - # Then we reconstruct the lists from the pointer tensors. - return [obj.from_pointer_tensor(chunk) for chunk in pointers] - if isinstance(obj, tuple) and obj: - return list(zip(*map(scatter_map, obj))) - if isinstance(obj, list) and obj: - return list(map(list, zip(*map(scatter_map, obj)))) - if isinstance(obj, dict) and obj: - return list(map(type(obj), zip(*map(scatter_map, obj.items())))) - return [obj for _ in target_gpus] - - # After scatter_map is called, a scatter_map cell will exist. This cell - # has a reference to the actual function scatter_map, which has references - # to a closure that has a reference to the scatter_map cell (because the - # fn is recursive). To avoid this reference cycle, we set the function to - # None, clearing the cell - try: - return scatter_map(inputs) - finally: - scatter_map = None - -def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): - """Scatter with support for kwargs dictionary. - - Adapted from `scatter_kwargs` at: - https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/torch/nn/parallel/scatter_gather.py#L33-L43 - - Please see the LICENSE and NOTICE files as well: - https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/LICENSE - https://github.com/pytorch/pytorch/blob/1d406c04ae56255e58dcec85e3479bb2b3dbd75e/NOTICE - """ - inputs = scatter(inputs, target_gpus, dim) if inputs else [] - kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] - if len(inputs) < len(kwargs): - inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) - elif len(kwargs) < len(inputs): - kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) - inputs = tuple(inputs) - kwargs = tuple(kwargs) - return inputs, kwargs - def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List: frozen_parameter_names = [] tunable_parameter_names = [] diff --git a/allennlp/data/fields/metadata_field.py b/allennlp/data/fields/metadata_field.py index 2c301ff82d1..11a3d102e63 100644 --- a/allennlp/data/fields/metadata_field.py +++ b/allennlp/data/fields/metadata_field.py @@ -3,7 +3,6 @@ from overrides import overrides -from allennlp.common.util import ScatterableList from allennlp.data.fields.field import DataArray, Field @@ -61,8 +60,8 @@ def empty_field(self) -> 'MetadataField': @classmethod @overrides - def batch_tensors(cls, tensor_list: List[DataArray]) -> ScatterableList: # type: ignore - return ScatterableList(tensor_list) + def batch_tensors(cls, tensor_list: List[DataArray]) -> List[DataArray]: # type: ignore + return tensor_list def __str__(self) -> str: diff --git a/allennlp/data/fields/production_rule_field.py b/allennlp/data/fields/production_rule_field.py index 0fc3337f544..0de654f6207 100644 --- a/allennlp/data/fields/production_rule_field.py +++ b/allennlp/data/fields/production_rule_field.py @@ -3,7 +3,6 @@ import torch from overrides import overrides -from allennlp.common.util import ScatterableList from allennlp.data.fields.field import Field from allennlp.data.vocabulary import Vocabulary @@ -114,9 +113,9 @@ def empty_field(self): # pylint: disable=no-self-use return ProductionRuleField(rule='', is_global_rule=False) @overrides - def batch_tensors(self, tensor_list: List[ProductionRule]) -> ScatterableList: # type: ignore + def batch_tensors(self, tensor_list: List[ProductionRule]) -> List[ProductionRule]: # type: ignore # pylint: disable=no-self-use - return ScatterableList(tensor_list) + return tensor_list def __str__(self) -> str: return f"ProductionRuleField with rule: {self.rule} (is_global_rule: " \ diff --git a/allennlp/data/iterators/bucket_iterator.py b/allennlp/data/iterators/bucket_iterator.py index cb630fc0439..ef80a642d46 100644 --- a/allennlp/data/iterators/bucket_iterator.py +++ b/allennlp/data/iterators/bucket_iterator.py @@ -124,6 +124,9 @@ def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Itera if excess: batches.append(Batch(excess)) + # TODO(brendanr): Add multi-GPU friendly grouping, i.e. group + # num_gpu batches together, shuffle and then expand the groups. + # This guards against imbalanced batches across GPUs. move_to_front = self._biggest_batch_first and len(batches) > 1 if move_to_front: # We'll actually pop the last _two_ batches, because the last one might not be full. diff --git a/allennlp/data/iterators/data_iterator.py b/allennlp/data/iterators/data_iterator.py index c54ece2d4ef..a96a6a1dfca 100644 --- a/allennlp/data/iterators/data_iterator.py +++ b/allennlp/data/iterators/data_iterator.py @@ -125,6 +125,8 @@ def __call__(self, tensor_dicts = self._cache[key] if shuffle: + # TODO(brendanr): How can we handle this shuffle in a way + # that respects multi-GPU friendly grouping? random.shuffle(tensor_dicts) for tensor_dict in tensor_dicts: if self._track_epoch: diff --git a/allennlp/tests/models/simple_tagger_test.py b/allennlp/tests/models/simple_tagger_test.py index 15bd2846ad4..dbacabb82c0 100644 --- a/allennlp/tests/models/simple_tagger_test.py +++ b/allennlp/tests/models/simple_tagger_test.py @@ -64,8 +64,8 @@ def test_regularization(self): training_batch = next(iterator(self.instances, num_epochs=1)) validation_batch = next(iterator(self.instances, num_epochs=1)) - training_loss = trainer.batch_loss(training_batch, for_training=True).item() - validation_loss = trainer.batch_loss(validation_batch, for_training=False).item() + training_loss = trainer.batch_loss([training_batch], for_training=True).item() + validation_loss = trainer.batch_loss([validation_batch], for_training=False).item() # Training loss should have the regularization penalty, but validation loss should not. numpy.testing.assert_almost_equal(training_loss, validation_loss) @@ -116,8 +116,8 @@ def test_regularization(self): training_batch = next(self.iterator(self.instances, num_epochs=1)) validation_batch = next(self.iterator(self.instances, num_epochs=1)) - training_loss = self.trainer.batch_loss(training_batch, for_training=True).data - validation_loss = self.trainer.batch_loss(validation_batch, for_training=False).data + training_loss = self.trainer.batch_loss([training_batch], for_training=True).data + validation_loss = self.trainer.batch_loss([validation_batch], for_training=False).data # Training loss should have the regularization penalty, but validation loss should not. assert (training_loss != validation_loss).all() diff --git a/allennlp/tests/training/trainer_test.py b/allennlp/tests/training/trainer_test.py index 9a478977903..18b166e2ce3 100644 --- a/allennlp/tests/training/trainer_test.py +++ b/allennlp/tests/training/trainer_test.py @@ -20,7 +20,8 @@ from allennlp.common.params import Params from allennlp.models.simple_tagger import SimpleTagger from allennlp.data.iterators import BasicIterator -from allennlp.data.dataset_readers import SequenceTaggingDatasetReader +from allennlp.data.dataset_readers import SequenceTaggingDatasetReader, WikiTablesDatasetReader +from allennlp.models.archival import load_archive from allennlp.models.model import Model @@ -133,6 +134,22 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]: # type: ignore # pylint assert 'peak_gpu_1_memory_MB' in metrics assert isinstance(metrics['peak_gpu_1_memory_MB'], int) + @pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need multiple GPUs.") + def test_production_rule_field_with_multiple_gpus(self): + wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/' + wikitables_reader = WikiTablesDatasetReader(tables_directory=wikitables_dir, + dpd_output_directory=wikitables_dir + 'dpd_output/') + instances = wikitables_reader.read(wikitables_dir + 'sample_data.examples') + archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz' + model = load_archive(archive_path).model + model.cuda() + + multigpu_iterator = BasicIterator(batch_size=4) + multigpu_iterator.index_with(model.vocab) + trainer = Trainer(model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1]) + trainer.train() + def test_trainer_can_resume_training(self): trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances, diff --git a/allennlp/training/trainer.py b/allennlp/training/trainer.py index 2dca4865a55..b7350aba906 100644 --- a/allennlp/training/trainer.py +++ b/allennlp/training/trainer.py @@ -1,5 +1,6 @@ import logging +import math import os import time import re @@ -13,10 +14,10 @@ from allennlp.common import Params from allennlp.common.checks import ConfigurationError from allennlp.common.util import (dump_metrics, gpu_memory_mb, parse_cuda_device, peak_memory_mb, - get_frozen_and_tunable_parameter_names) + get_frozen_and_tunable_parameter_names, lazy_groups_of) from allennlp.common.tqdm import Tqdm from allennlp.data.instance import Instance -from allennlp.data.iterators.data_iterator import DataIterator +from allennlp.data.iterators.data_iterator import DataIterator, TensorDict from allennlp.data.vocabulary import Vocabulary from allennlp.models.model import Model from allennlp.nn import util as nn_util @@ -216,14 +217,16 @@ def __init__(self, def rescale_gradients(self) -> Optional[float]: return training_util.rescale_gradients(self.model, self._grad_norm) - def batch_loss(self, batch: torch.Tensor, for_training: bool) -> torch.Tensor: + def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor: """ - Does a forward pass on the given batch and returns the ``loss`` value in the result. + Does a forward pass on the given batches and returns the ``loss`` value in the result. If ``for_training`` is `True` also applies regularization penalty. """ if self._multiple_gpu: - output_dict = training_util.data_parallel(batch, self.model, self._cuda_devices) + output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices) else: + assert len(batch_group) == 1 + batch = batch_group[0] batch = nn_util.move_to_device(batch, self._cuda_devices[0]) output_dict = self.model(**batch) @@ -255,11 +258,14 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: # Set the model to "train" mode. self.model.train() + num_gpus = len(self._cuda_devices) + # Get tqdm for the training batches - train_generator = self.iterator(self.train_data, - num_epochs=1, - shuffle=self.shuffle) - num_training_batches = self.iterator.get_num_batches(self.train_data) + raw_train_generator = self.iterator(self.train_data, + num_epochs=1, + shuffle=self.shuffle) + train_generator = lazy_groups_of(raw_train_generator, num_gpus) + num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data)/num_gpus) self._last_log = time.time() last_save_time = time.time() @@ -269,18 +275,20 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging()) + logger.info("Training") train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches) cumulative_batch_size = 0 - for batch in train_generator_tqdm: + for batch_group in train_generator_tqdm: batches_this_epoch += 1 self._batch_num_total += 1 batch_num_total = self._batch_num_total self.optimizer.zero_grad() - loss = self.batch_loss(batch, for_training=True) + loss = self.batch_loss(batch_group, for_training=True) + if torch.isnan(loss): raise ValueError("nan loss encountered") @@ -329,7 +337,7 @@ def _train_epoch(self, epoch: int) -> Dict[str, float]: self._tensorboard.log_histograms(self.model, histogram_parameters) if self._log_batch_size_period: - cur_batch = training_util.get_batch_size(batch) + cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group]) cumulative_batch_size += cur_batch if (batches_this_epoch - 1) % self._log_batch_size_period == 0: average = cumulative_batch_size/batches_this_epoch @@ -365,17 +373,20 @@ def _validation_loss(self) -> Tuple[float, int]: else: val_iterator = self.iterator - val_generator = val_iterator(self._validation_data, - num_epochs=1, - shuffle=False) - num_validation_batches = val_iterator.get_num_batches(self._validation_data) + num_gpus = len(self._cuda_devices) + + raw_val_generator = val_iterator(self._validation_data, + num_epochs=1, + shuffle=False) + val_generator = lazy_groups_of(raw_val_generator, num_gpus) + num_validation_batches = math.ceil(val_iterator.get_num_batches(self._validation_data)/num_gpus) val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches) batches_this_epoch = 0 val_loss = 0 - for batch in val_generator_tqdm: + for batch_group in val_generator_tqdm: - loss = self.batch_loss(batch, for_training=False) + loss = self.batch_loss(batch_group, for_training=False) if loss is not None: # You shouldn't necessarily have to compute a loss for validation, so we allow for # `loss` to be None. We need to be careful, though - `batches_this_epoch` is diff --git a/allennlp/training/util.py b/allennlp/training/util.py index 81deea63800..e0075d37de8 100644 --- a/allennlp/training/util.py +++ b/allennlp/training/util.py @@ -14,10 +14,10 @@ from allennlp.common.checks import ConfigurationError, check_for_gpu from allennlp.common.params import Params from allennlp.common.tqdm import Tqdm -from allennlp.common.util import scatter_kwargs from allennlp.data.dataset_readers import DatasetReader from allennlp.data import Instance from allennlp.data.iterators import DataIterator +from allennlp.data.iterators.data_iterator import TensorDict from allennlp.models.model import Model from allennlp.models.archival import CONFIG_NAME from allennlp.nn import util as nn_util @@ -223,24 +223,31 @@ def create_serialization_dir( "does not exist. There is nothing to recover from.") os.makedirs(serialization_dir, exist_ok=True) -def data_parallel(batch, model: Model, cuda_devices: List) -> Dict[str, torch.Tensor]: +def data_parallel(batch_group: List[TensorDict], + model: Model, + cuda_devices: List) -> Dict[str, torch.Tensor]: """ Performs a forward pass using multiple GPUs. This is a simplification of torch.nn.parallel.data_parallel to support the allennlp model interface. """ - inputs, module_kwargs = scatter_kwargs((), batch, cuda_devices, 0) + assert len(batch_group) <= len(cuda_devices) - used_device_ids = cuda_devices[:len(inputs)] + moved = [nn_util.move_to_device(batch, device) + for batch, device in zip(batch_group, cuda_devices)] + + used_device_ids = cuda_devices[:len(moved)] replicas = replicate(model, used_device_ids) - outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) + # We pass all our arguments as kwargs. Create a list of empty tuples of the + # correct shape to serve as (non-existent) positional arguments. + inputs = [()] * len(batch_group) + outputs = parallel_apply(replicas, inputs, moved, used_device_ids) # Only the 'loss' is needed. # a (num_gpu, ) tensor with loss on each GPU losses = gather([output['loss'].unsqueeze(0) for output in outputs], used_device_ids[0], 0) return {'loss': losses.mean()} - def enable_gradient_clipping(model: Model, grad_clipping: Optional[float]) -> None: if grad_clipping is not None: for parameter in model.parameters(): diff --git a/training_config/bidirectional_language_model.jsonnet b/training_config/bidirectional_language_model.jsonnet index 5c9a10a4da9..6272abe370d 100644 --- a/training_config/bidirectional_language_model.jsonnet +++ b/training_config/bidirectional_language_model.jsonnet @@ -21,7 +21,7 @@ local BASE_READER = { "type": "elmo_characters" } }, - "max_sequence_length": 500, + "max_sequence_length": 400, "start_tokens": [""], "end_tokens": [""] }; @@ -34,7 +34,7 @@ local BASE_ITERATOR = { // samples in every batch. "batch_size": 512 * NUM_GPUS, "sorting_keys": [["source", "num_tokens"]], - "maximum_samples_per_batch": ["num_tokens", NUM_GPUS * 1000] + "maximum_samples_per_batch": ["num_tokens", 2000] }; { @@ -117,7 +117,7 @@ local BASE_ITERATOR = { // The multiprocess dataset reader and iterator use many file descriptors, // so we need to increase the ulimit depending on the size of this queue. // See https://pytorch.org/docs/stable/multiprocessing.html#file-descriptor-file-descriptor - // for a description of the underlying issue. `ulimit -n 4096` has sufficed, + // for a description of the underlying issue. `ulimit -n 8192` has sufficed, // but that number could use tuning. "output_queue_size": 500 }, @@ -139,6 +139,7 @@ local BASE_ITERATOR = { // See https://github.com/allenai/calypso/blob/master/bin/train_transformer_lm1b.py#L51. // Adjusted based on our sample size relative to Calypso's. "warmup_steps": 6000 - } + }, + "should_log_learning_rate": true } }