From 1d14adfa66e2ae437253eebe223710588648eee7 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 16 Jan 2024 23:54:38 +0000 Subject: [PATCH] [mta] Fused SGD (#116585) depends on #116583 rel: - #94791 Pull Request resolved: https://github.com/pytorch/pytorch/pull/116585 Approved by: https://github.com/janeyx99 --- aten/src/ATen/native/cuda/FusedSgdKernel.cu | 428 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 16 + ...asDecompTest.test_has_decomposition.expect | 6 + test/optim/test_optim.py | 29 +- test/test_cuda.py | 46 +- torch/distributed/optim/functional_sgd.py | 8 + torch/optim/sgd.py | 105 ++++- torch/testing/_internal/common_optimizers.py | 3 +- 8 files changed, 621 insertions(+), 20 deletions(-) create mode 100644 aten/src/ATen/native/cuda/FusedSgdKernel.cu diff --git a/aten/src/ATen/native/cuda/FusedSgdKernel.cu b/aten/src/ATen/native/cuda/FusedSgdKernel.cu new file mode 100644 index 0000000000000..d4e14c8c50946 --- /dev/null +++ b/aten/src/ATen/native/cuda/FusedSgdKernel.cu @@ -0,0 +1,428 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + +namespace { + +template +C10_DEVICE __forceinline__ void sgd_math( + scalar_t r_args[depth][kILP], + const double weight_decay, + const double momentum, + const float* lr_ptr, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const float* grad_scale_ptr) { + using opmath_t = at::opmath_type; + const double double_lr = lr_ptr != nullptr ? *lr_ptr : lr; +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + auto p = static_cast(r_args[0][ii]); + auto g = static_cast(r_args[1][ii]); + if (grad_scale_ptr) { + g /= static_cast(*grad_scale_ptr); + r_args[1][ii] = g; + } + if (maximize) { + g *= -1.0; + } + if (weight_decay != 0) { + g += weight_decay * p; + } + if (depth > 2) { + const auto momentum_buffer = is_first_step + ? g + : (momentum * static_cast(r_args[2][ii]) + + (1 - dampening) * g); + r_args[2][ii] = momentum_buffer; + + if (nesterov) { + g = g + momentum * momentum_buffer; + } else { + g = momentum_buffer; + } + } + p -= double_lr * g; + r_args[0][ii] = p; + } +} + +template +struct FusedSgdMathFunctor { + static_assert( + depth == 2 || depth == 3, + "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); + C10_DEVICE __forceinline__ void operator()( + int chunk_size, + TensorListMetadata& tl, + const double weight_decay, + const double momentum, + const float* lr_ptr, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const float* grad_scale_ptr, + const float* found_inf_ptr) { + if (found_inf_ptr && *found_inf_ptr == 1) { + return; + } + auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + auto n = tl.numel_for_tensor[tensor_loc]; + + scalar_t* args[depth]; + scalar_t r_args[depth][kILP]; + const auto all_aligned{ + init_args(args, tl, chunk_idx, chunk_size, tensor_loc)}; + n -= chunk_idx * chunk_size; + +#ifndef USE_ROCM + const auto use_faster_load_store = + (n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned; +#else + const auto use_faster_load_store{false}; +#endif + if (use_faster_load_store) { + for (auto i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { +#pragma unroll + for (auto i = 0; i < depth; i++) { + load_store(r_args[i], args[i], 0, i_start); + } + sgd_math( + r_args, + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr); + load_store(args[0], r_args[0], i_start, 0); + if (grad_scale_ptr) { + load_store(args[1], r_args[1], i_start, 0); + } + if (depth > 2) { + load_store(args[2], r_args[2], i_start, 0); + } + } + } else { + for (auto i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args(r_args, args, i_start, chunk_size, n); + sgd_math( + r_args, + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr); + store_args(args[0], r_args[0], i_start, chunk_size, n); + if (grad_scale_ptr) { + store_args(args[1], r_args[1], i_start, chunk_size, n); + } + if (depth > 2) { + store_args(args[2], r_args[2], i_start, chunk_size, n); + } + } + } + } +}; + +void _fused_sgd_with_momentum_kernel_cuda_( + at::TensorList params, + at::TensorList grads, + at::TensorList momentum_buffer_list, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + TORCH_CHECK_GT(momentum, 0); + TORCH_CHECK(at::native::check_fast_path_restrictions( + {params, grads, momentum_buffer_list})); + float* grad_scale_ptr = + grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; + float* found_inf_ptr = + found_inf.has_value() ? found_inf->data_ptr() : nullptr; + float* lr_ptr = nullptr; + + std::vector> tensor_lists{ + params.vec(), grads.vec(), momentum_buffer_list.vec()}; + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_sgd_with_momentum_kernel_cuda", + [&]() { + multi_tensor_apply<3>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr, + found_inf_ptr); + }); +} + +void _fused_sgd_with_momentum_kernel_cuda_( + at::TensorList params, + at::TensorList grads, + at::TensorList momentum_buffer_list, + const double weight_decay, + const double momentum, + const at::Tensor& lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (lr.is_cpu()) { + _fused_sgd_with_momentum_kernel_cuda_( + params, + grads, + momentum_buffer_list, + weight_decay, + momentum, + lr.item(), + dampening, + nesterov, + maximize, + is_first_step, + grad_scale, + found_inf); + return; + } + TORCH_CHECK_GT(momentum, 0); + TORCH_CHECK(at::native::check_fast_path_restrictions( + {params, grads, momentum_buffer_list})); + if (grad_scale != c10::nullopt) { + TORCH_CHECK( + grad_scale->device() == params[0].device(), + "grad_scale must be on the same GPU device as the params"); + } + if (found_inf != c10::nullopt) { + TORCH_CHECK( + found_inf->device() == params[0].device(), + "found_inf must be on the same GPU device as the params"); + } + TORCH_CHECK( + lr.device() == params[0].device(), + "found_inf must be on the same GPU device as the params"); + float* grad_scale_ptr = + grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; + float* found_inf_ptr = + found_inf.has_value() ? found_inf->data_ptr() : nullptr; + + std::vector> tensor_lists{ + params.vec(), grads.vec(), momentum_buffer_list.vec()}; + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_sgd_with_momentum_kernel_cuda", + [&]() { + multi_tensor_apply<3>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr.data_ptr(), + 1.0, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale_ptr, + found_inf_ptr); + }); +} + +} // namespace + +void _fused_sgd_kernel_cuda_( + at::TensorList params, + at::TensorList grads, + at::TensorList momentum_buffer_list, + const double weight_decay, + const double momentum, + const double lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (!momentum_buffer_list.empty()) { + _fused_sgd_with_momentum_kernel_cuda_( + params, + grads, + momentum_buffer_list, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale, + found_inf); + return; + } + TORCH_CHECK_EQ(momentum, 0); + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads})); + if (is_first_step) { + TORCH_WARN_ONCE( + "`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); + } + float* grad_scale_ptr = + grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; + float* found_inf_ptr = + found_inf.has_value() ? found_inf->data_ptr() : nullptr; + float* lr_ptr = nullptr; + + std::vector> tensor_lists{params.vec(), grads.vec()}; + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_sgd_kernel_cuda", + [&]() { + multi_tensor_apply<2>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr_ptr, + lr, + dampening, + nesterov, + maximize, + /* is_first_step */ false, + grad_scale_ptr, + found_inf_ptr); + }); +} + +void _fused_sgd_kernel_cuda_( + at::TensorList params, + at::TensorList grads, + at::TensorList momentum_buffer_list, + const double weight_decay, + const double momentum, + const at::Tensor& lr, + const double dampening, + const bool nesterov, + const bool maximize, + const bool is_first_step, + const c10::optional& grad_scale, + const c10::optional& found_inf) { + if (!momentum_buffer_list.empty()) { + _fused_sgd_with_momentum_kernel_cuda_( + params, + grads, + momentum_buffer_list, + weight_decay, + momentum, + lr, + dampening, + nesterov, + maximize, + is_first_step, + grad_scale, + found_inf); + return; + } + if (lr.is_cpu()) { + _fused_sgd_kernel_cuda_( + params, + grads, + momentum_buffer_list, + weight_decay, + momentum, + lr.item(), + dampening, + nesterov, + maximize, + is_first_step, + grad_scale, + found_inf); + return; + } + TORCH_CHECK_EQ(momentum, 0); + TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads})); + if (is_first_step) { + TORCH_WARN_ONCE( + "`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); + } + if (grad_scale.has_value()) { + TORCH_CHECK( + grad_scale->device() == params[0].device(), + "grad_scale must be on the same GPU device as the params"); + } + if (found_inf.has_value()) { + TORCH_CHECK( + found_inf->device() == params[0].device(), + "found_inf must be on the same GPU device as the params"); + } + TORCH_CHECK( + lr.device() == params[0].device(), + "found_inf must be on the same GPU device as the params"); + float* grad_scale_ptr = + grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; + float* found_inf_ptr = + found_inf.has_value() ? found_inf->data_ptr() : nullptr; + + std::vector> tensor_lists{params.vec(), grads.vec()}; + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_sgd_kernel_cuda", + [&]() { + multi_tensor_apply<2>( + tensor_lists, + FusedSgdMathFunctor(), + weight_decay, + momentum, + lr.data_ptr(), + 1.0, + dampening, + nesterov, + maximize, + /* is_first_step */ false, + grad_scale_ptr, + found_inf_ptr); + }); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 3d8f59f61ccd3..a95d61b548864 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -15397,6 +15397,22 @@ CUDA: _fused_adamw_kernel_cuda_ autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out +- func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). + variants: function + dispatch: + CUDA: _fused_sgd_kernel_cuda_ + autogen: _fused_sgd, _fused_sgd.out + +- func: _fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + # Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now). + # but still skip the device check as the Tensor LR can be on CPU + device_check: NoCheck + variants: function + dispatch: + CUDA: _fused_sgd_kernel_cuda_ + autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out + # This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts. - func: _propagate_xla_data(Tensor input, Tensor output) -> () variants: function diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index e2c3c0251702b..3332bbadc3789 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -353,6 +353,12 @@ aten::_fused_moving_avg_obs_fq_helper aten::_fused_moving_avg_obs_fq_helper.out aten::_fused_moving_avg_obs_fq_helper_functional aten::_fused_sdp_choice +aten::_fused_sgd +aten::_fused_sgd.out +aten::_fused_sgd.tensor_lr +aten::_fused_sgd.tensor_lr_out +aten::_fused_sgd_ +aten::_fused_sgd_.tensor_lr aten::_fw_primal aten::_fw_primal_copy aten::_fw_primal_copy.out diff --git a/test/optim/test_optim.py b/test/optim/test_optim.py index d0bcad889b2f7..33ea35e1c9102 100644 --- a/test/optim/test_optim.py +++ b/test/optim/test_optim.py @@ -1290,7 +1290,7 @@ def test_fused_optimizer_does_not_step_if_foundinf(self): if not torch.cuda.is_available(): self.skipTest("CUDA is required.") - from torch.optim import adam, adamw + from torch.optim import adam, adamw, sgd num_tensors = 5 for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)): @@ -1331,6 +1331,31 @@ def test_fused_optimizer_does_not_step_if_foundinf(self): ], ) self.assertEqual(params, prev_params) + else: + for momentum in (0.0, 0.1): + params, d_p_list, momentum_buffer_list = ( + [torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(3)) + if momentum == 0.0: + momentum_buffer_list = [None for _ in range(num_tensors)] + prev_params = [t.clone().detach() for t in params] + grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda") + found_inf = torch.ones((), dtype=torch.float32, device="cuda") + sgd.sgd( + params, + d_p_list, + momentum_buffer_list, + has_sparse_grad=False, + foreach=False, + fused=True, + grad_scale=grad_scale, + found_inf=found_inf, + weight_decay=0.0, + momentum=momentum, + lr=0.01, + dampening=0.0, + nesterov=False, + maximize=False, + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is required.") @@ -1340,7 +1365,7 @@ def test_fused_optimizer_load_state_dict(self): # store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu"). # Since this is a unit test, it is more expedient to simulate what the state_dict # would look like, which is basically CPU tensors with fused/capturable flag = True. - for optimC, kwarg in itertools.product((Adam, AdamW), ("fused", "capturable")): + for optimC, kwarg in list(itertools.product((Adam, AdamW), ("fused", "capturable"))) + [(SGD, "fused")]: input = torch.tensor([0.1, 0.2], dtype=torch.float32, device="cpu") optimizer = optimC([input]) optimizer.zero_grad() diff --git a/test/test_cuda.py b/test/test_cuda.py index 2615308277bbb..f074eb323339c 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1261,17 +1261,27 @@ def test_grad_scaling_autocast_foreach(self): self._grad_scaling_autocast_test(optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True}) def test_grad_scaling_autocast_fused(self): - for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW): + for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW): self._grad_scaling_autocast_test(optimizer_ctor=optimizer_ctor, optimizer_kwargs={"fused": True}) # Compare non-fused optimizer vs fused one as the fused one unscales gradients # inside its cuda kernel unlike the other. def test_grad_scaling_autocast_fused_optimizers(self): - for optimizer_ctor, optimizer_kwargs, separate_unscale in product( + for optimizer_ctor, optimizer_kwargs, separate_unscale in list(product( (torch.optim.Adam, torch.optim.AdamW), ({"fused": True, "amsgrad": False}, {"fused": True, "amsgrad": True}), (False, True), - ): + )) + list(product( + (torch.optim.SGD,), + [ + {"momentum": 0.0, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True} + for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) + ] + [ + {"momentum": 0.5, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True} + for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) + ], + (False, True), + )): with self.subTest(optim=optimizer_ctor, kwargs=optimizer_kwargs, separate_unscale=separate_unscale): self._grad_scaling_autocast_fused_optimizers( optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, separate_unscale=separate_unscale) @@ -2864,6 +2874,10 @@ def test_graph_cudnn_dropout(self): @unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs") def test_graph_grad_scaling(self): + for foreach, fused in ((False, False), (True, False), (False, True)): + self._test_graph_grad_scaling(foreach, fused) + + def _test_graph_grad_scaling(self, foreach, fused): torch.cuda.empty_cache() scaler = torch.cuda.amp.GradScaler(init_scale=4.) @@ -2871,7 +2885,7 @@ def test_graph_grad_scaling(self): s = torch.cuda.Stream() weight = torch.ones((100,), device="cuda", requires_grad=True) - opt = torch.optim.SGD([weight], lr=0.1) + opt = torch.optim.SGD([weight], lr=0.1, foreach=foreach, fused=fused) static_input = torch.ones_like(weight) static_grad = torch.ones_like(weight) @@ -3158,13 +3172,23 @@ def test_graph_scaling_fused_optimizers(self): cases = [ (optimizer_ctor, {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad}) for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True)) - ] + ] + list(product( + (torch.optim.SGD,), + [ + {"lr": 0.1, "momentum": 0.0, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True} + for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,)) + ] + [ + {"lr": 0.1, "momentum": 0.5, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True} + for d, w, n in product((0.0,), (0.0, 0.5), (True, False)) + ], + )) steps_warmup = 3 steps_train = 2 for OptClass, kwargs in cases: - for actually_do_graphs in (True, False): + has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW) + for actually_do_graphs in (True, False) if has_capturable_arg else (True,): params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)] params_control = [p.clone().requires_grad_() for p in params] params_graphed = [p.clone().requires_grad_() for p in params] @@ -3186,8 +3210,9 @@ def test_graph_scaling_fused_optimizers(self): scaler_for_graphed._lazy_init_scale_growth_tracker(torch.device("cuda")) # Control (capturable=False) - - opt = OptClass(params_control, capturable=False, **kwargs) + if has_capturable_arg: + kwargs["capturable"] = False + opt = OptClass(params_control, **kwargs) for i in range(steps_warmup + steps_train): for j, p in enumerate(params_control): @@ -3196,8 +3221,9 @@ def test_graph_scaling_fused_optimizers(self): scaler_for_control.update() # capturable=True - - opt = OptClass(params_graphed, capturable=True, **kwargs) + if has_capturable_arg: + kwargs["capturable"] = True + opt = OptClass(params_graphed, **kwargs) for i in range(steps_warmup): for j, p in enumerate(params_graphed): diff --git a/torch/distributed/optim/functional_sgd.py b/torch/distributed/optim/functional_sgd.py index ff6ce757735ba..4a807a6055719 100644 --- a/torch/distributed/optim/functional_sgd.py +++ b/torch/distributed/optim/functional_sgd.py @@ -28,6 +28,7 @@ def __init__( nesterov: bool = False, maximize: bool = False, foreach: bool = False, + fused: bool = False, _allow_empty_param_list: bool = False, ): self.defaults = { @@ -39,6 +40,7 @@ def __init__( self.nesterov = nesterov self.maximize = maximize self.foreach = foreach + self.fused = fused self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) if len(params) == 0 and not _allow_empty_param_list: @@ -88,6 +90,9 @@ def step_param(self, param: Tensor, grad: Optional[Tensor]): maximize=self.maximize, has_sparse_grad=has_sparse_grad, foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, ) # update momentum_buffer in state state = self.state[param] @@ -142,6 +147,9 @@ def step(self, gradients: List[Optional[Tensor]]): maximize=self.maximize, has_sparse_grad=has_sparse_grad, foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, ) # update momentum_buffers in state diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 8c7d73b83a2b4..a71ab394c027b 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach, - _differentiable_doc, _foreach_doc, _maximize_doc) + _differentiable_doc, _foreach_doc, _maximize_doc, _fused_doc) from typing import List, Optional __all__ = ['SGD', 'sgd'] @@ -10,7 +10,7 @@ class SGD(Optimizer): def __init__(self, params, lr=1e-3, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize: bool = False, foreach: Optional[bool] = None, - differentiable: bool = False): + differentiable: bool = False, fused: Optional[bool] = None): if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0: @@ -21,11 +21,18 @@ def __init__(self, params, lr=1e-3, momentum=0, dampening=0, defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, maximize=maximize, foreach=foreach, - differentiable=differentiable) + differentiable=differentiable, fused=fused) if nesterov and (momentum <= 0 or dampening != 0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) + if fused: + self._step_supports_amp_scaling = True + if differentiable: + raise RuntimeError("`fused` does not support `differentiable`") + if foreach: + raise RuntimeError("`fused` and `foreach` cannot be `True` together.") + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: @@ -33,6 +40,7 @@ def __setstate__(self, state): group.setdefault('maximize', False) group.setdefault('foreach', None) group.setdefault('differentiable', False) + group.setdefault('fused', False) def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): has_sparse_grad = False @@ -82,7 +90,10 @@ def step(self, closure=None): nesterov=group['nesterov'], maximize=group['maximize'], has_sparse_grad=has_sparse_grad, - foreach=group['foreach']) + foreach=group['foreach'], + fused=group['fused'], + grad_scale=getattr(self, "grad_scale", None), + found_inf=getattr(self, "found_inf", None)) # update momentum_buffers in state for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): @@ -138,6 +149,7 @@ def step(self, closure=None): {_maximize_doc} {_foreach_doc} {_differentiable_doc} + {_fused_doc} """ + r""" Example: @@ -189,6 +201,9 @@ def sgd(params: List[Tensor], # setting this as kwarg for now as functional API is compiled by torch/distributed/optim has_sparse_grad: bool = None, foreach: Optional[bool] = None, + fused: Optional[bool] = None, + grad_scale: Optional[Tensor] = None, + found_inf: Optional[Tensor] = None, *, weight_decay: float, momentum: float, @@ -201,19 +216,32 @@ def sgd(params: List[Tensor], See :class:`~torch.optim.SGD` for details. """ - if foreach is None: + # Respect when the user inputs False/True for foreach or fused. We only want to change + # the default when neither have been user-specified. Note that we default to foreach + # and pass False to use_fused. This is not a mistake--we want to give the fused impl + # bake-in time before making it the default, even if it is typically faster. + if foreach is None and fused is None: # why must we be explicit about an if statement for torch.jit.is_scripting here? # because JIT can't handle Optionals nor fancy conditionals when scripting if not torch.jit.is_scripting(): - _, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) + fused, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False) else: foreach = False + fused = False + if foreach is None: + foreach = False + if fused is None: + fused = False if foreach and torch.jit.is_scripting(): raise RuntimeError('torch.jit.script not supported with foreach optimizers') + if fused and torch.jit.is_scripting(): + raise RuntimeError('torch.jit.script not supported with fused optimizers') if foreach and not torch.jit.is_scripting(): func = _multi_tensor_sgd + elif fused and not torch.jit.is_scripting(): + func = _fused_sgd else: func = _single_tensor_sgd @@ -226,11 +254,15 @@ def sgd(params: List[Tensor], dampening=dampening, nesterov=nesterov, has_sparse_grad=has_sparse_grad, - maximize=maximize) + maximize=maximize, + grad_scale=grad_scale, + found_inf=found_inf) def _single_tensor_sgd(params: List[Tensor], d_p_list: List[Tensor], momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], *, weight_decay: float, momentum: float, @@ -239,6 +271,7 @@ def _single_tensor_sgd(params: List[Tensor], nesterov: bool, maximize: bool, has_sparse_grad: bool): + assert grad_scale is None and found_inf is None for i, param in enumerate(params): d_p = d_p_list[i] if not maximize else -d_p_list[i] @@ -266,6 +299,8 @@ def _single_tensor_sgd(params: List[Tensor], def _multi_tensor_sgd(params: List[Tensor], grads: List[Tensor], momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], *, weight_decay: float, momentum: float, @@ -274,6 +309,7 @@ def _multi_tensor_sgd(params: List[Tensor], nesterov: bool, maximize: bool, has_sparse_grad: bool): + assert grad_scale is None and found_inf is None if len(params) == 0: return @@ -329,3 +365,58 @@ def _multi_tensor_sgd(params: List[Tensor], # foreach APIs don't support sparse for i in range(len(device_params)): device_params[i].add_(device_grads[i], alpha=-lr) + + +def _fused_sgd( + params: List[Tensor], + grads: List[Tensor], + momentum_buffer_list: List[Optional[Tensor]], + grad_scale: Optional[Tensor], + found_inf: Optional[Tensor], + *, + weight_decay: float, + momentum: float, + lr: float, + dampening: float, + nesterov: bool, + maximize: bool, + has_sparse_grad: bool, +) -> None: + if not params: + return + if has_sparse_grad: + raise RuntimeError("`_fused_sgd` does not support sparse gradients") + grad_scale_dict = {grad_scale.device: grad_scale} if grad_scale is not None else None + found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None + + no_momentum_buffer = momentum == 0 + is_first_step = all(t is None for t in momentum_buffer_list) and not no_momentum_buffer + if is_first_step: + for i, g in enumerate(grads): + momentum_buffer_list[i] = torch.empty_like(g) + grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( + [params, grads, momentum_buffer_list], with_indices=False) + for (device, dtype), ((device_params, device_grads, device_momentum_buffer_list), _) in grouped_tensors.items(): + device_grad_scale, device_found_inf = None, None + if grad_scale is not None: + if device not in grad_scale_dict: + grad_scale_dict[device] = grad_scale.to(device) + device_grad_scale = grad_scale_dict[device] + if found_inf is not None: + if device not in found_inf_dict: + found_inf_dict[device] = found_inf.to(device) + device_found_inf = found_inf_dict[device] + torch._fused_sgd_( + device_params, + device_grads, + [] if no_momentum_buffer else device_momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=nesterov, + maximize=maximize, + is_first_step=is_first_step, + grad_scale=device_grad_scale, + found_inf=device_found_inf, + ) diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index f687ec610ea1e..4da1970dcb2d4 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -704,6 +704,7 @@ def optim_error_inputs_func_rprop(device, dtype): def optim_inputs_func_sgd(device=None): return [ OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="default"), + OptimizerInput(params=None, kwargs={"lr": 1e-2}, desc="Tensor lr"), OptimizerInput( params=None, kwargs={"lr": 1e-2, "momentum": 0.9}, desc="momentum" ), @@ -1374,7 +1375,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( SGD, optim_inputs_func=optim_inputs_func_sgd, optim_error_inputs_func=optim_error_inputs_func_sgd, - supported_impls=("foreach", "differentiable"), + supported_impls=("foreach", "differentiable", "fused"), supports_sparse_on=("cpu", "cuda"), skips=( DecorateInfo(