Skip to content

Commit

Permalink
[mta] Fused SGD (pytorch#116585)
Browse files Browse the repository at this point in the history
depends on pytorch#116583

rel:
- pytorch#94791

Pull Request resolved: pytorch#116585
Approved by: https://github.com/janeyx99
  • Loading branch information
crcrpar authored and pytorchmergebot committed Jan 16, 2024
1 parent 5aac95c commit 1d14adf
Show file tree
Hide file tree
Showing 8 changed files with 621 additions and 20 deletions.
428 changes: 428 additions & 0 deletions aten/src/ATen/native/cuda/FusedSgdKernel.cu

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15397,6 +15397,22 @@
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw.tensor_lr, _fused_adamw.tensor_lr_out

- func: _fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
variants: function
dispatch:
CUDA: _fused_sgd_kernel_cuda_
autogen: _fused_sgd, _fused_sgd.out

- func: _fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()
# Unlike "foreach" functions, lists of tensors should be guaranteed to be on the same device (for now).
# but still skip the device check as the Tensor LR can be on CPU
device_check: NoCheck
variants: function
dispatch:
CUDA: _fused_sgd_kernel_cuda_
autogen: _fused_sgd.tensor_lr, _fused_sgd.tensor_lr_out

# This op is ONLY used by pytorch/XLA in functionalization, and should never show up in vanilla eager mode or in any pytorch tracing contexts.
- func: _propagate_xla_data(Tensor input, Tensor output) -> ()
variants: function
6 changes: 6 additions & 0 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,12 @@ aten::_fused_moving_avg_obs_fq_helper
aten::_fused_moving_avg_obs_fq_helper.out
aten::_fused_moving_avg_obs_fq_helper_functional
aten::_fused_sdp_choice
aten::_fused_sgd
aten::_fused_sgd.out
aten::_fused_sgd.tensor_lr
aten::_fused_sgd.tensor_lr_out
aten::_fused_sgd_
aten::_fused_sgd_.tensor_lr
aten::_fw_primal
aten::_fw_primal_copy
aten::_fw_primal_copy.out
Expand Down
29 changes: 27 additions & 2 deletions test/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,7 +1290,7 @@ def test_fused_optimizer_does_not_step_if_foundinf(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is required.")

from torch.optim import adam, adamw
from torch.optim import adam, adamw, sgd

num_tensors = 5
for functional_optim, amsgrad, no_grad_scale in itertools.product((adam.adam, adamw.adamw), (False, True), (False, True)):
Expand Down Expand Up @@ -1331,6 +1331,31 @@ def test_fused_optimizer_does_not_step_if_foundinf(self):
],
)
self.assertEqual(params, prev_params)
else:
for momentum in (0.0, 0.1):
params, d_p_list, momentum_buffer_list = (
[torch.ones((1,), device="cuda") for _ in range(num_tensors)] for _ in range(3))
if momentum == 0.0:
momentum_buffer_list = [None for _ in range(num_tensors)]
prev_params = [t.clone().detach() for t in params]
grad_scale = None if no_grad_scale else torch.ones((1,), dtype=torch.float32, device="cuda")
found_inf = torch.ones((), dtype=torch.float32, device="cuda")
sgd.sgd(
params,
d_p_list,
momentum_buffer_list,
has_sparse_grad=False,
foreach=False,
fused=True,
grad_scale=grad_scale,
found_inf=found_inf,
weight_decay=0.0,
momentum=momentum,
lr=0.01,
dampening=0.0,
nesterov=False,
maximize=False,
)


@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required.")
Expand All @@ -1340,7 +1365,7 @@ def test_fused_optimizer_load_state_dict(self):
# store checkpoints on CPU as CUDA memory is limited with torch.load(...map_location="cpu").
# Since this is a unit test, it is more expedient to simulate what the state_dict
# would look like, which is basically CPU tensors with fused/capturable flag = True.
for optimC, kwarg in itertools.product((Adam, AdamW), ("fused", "capturable")):
for optimC, kwarg in list(itertools.product((Adam, AdamW), ("fused", "capturable"))) + [(SGD, "fused")]:
input = torch.tensor([0.1, 0.2], dtype=torch.float32, device="cpu")
optimizer = optimC([input])
optimizer.zero_grad()
Expand Down
46 changes: 36 additions & 10 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,17 +1261,27 @@ def test_grad_scaling_autocast_foreach(self):
self._grad_scaling_autocast_test(optimizer_ctor=optimizer_ctor, optimizer_kwargs={"foreach": True})

def test_grad_scaling_autocast_fused(self):
for optimizer_ctor in (torch.optim.Adam, torch.optim.AdamW):
for optimizer_ctor in (torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW):
self._grad_scaling_autocast_test(optimizer_ctor=optimizer_ctor, optimizer_kwargs={"fused": True})

# Compare non-fused optimizer vs fused one as the fused one unscales gradients
# inside its cuda kernel unlike the other.
def test_grad_scaling_autocast_fused_optimizers(self):
for optimizer_ctor, optimizer_kwargs, separate_unscale in product(
for optimizer_ctor, optimizer_kwargs, separate_unscale in list(product(
(torch.optim.Adam, torch.optim.AdamW),
({"fused": True, "amsgrad": False}, {"fused": True, "amsgrad": True}),
(False, True),
):
)) + list(product(
(torch.optim.SGD,),
[
{"momentum": 0.0, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True}
for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
] + [
{"momentum": 0.5, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True}
for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
],
(False, True),
)):
with self.subTest(optim=optimizer_ctor, kwargs=optimizer_kwargs, separate_unscale=separate_unscale):
self._grad_scaling_autocast_fused_optimizers(
optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs, separate_unscale=separate_unscale)
Expand Down Expand Up @@ -2864,14 +2874,18 @@ def test_graph_cudnn_dropout(self):

@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
def test_graph_grad_scaling(self):
for foreach, fused in ((False, False), (True, False), (False, True)):
self._test_graph_grad_scaling(foreach, fused)

def _test_graph_grad_scaling(self, foreach, fused):
torch.cuda.empty_cache()

scaler = torch.cuda.amp.GradScaler(init_scale=4.)
g = torch.cuda.CUDAGraph()
s = torch.cuda.Stream()

weight = torch.ones((100,), device="cuda", requires_grad=True)
opt = torch.optim.SGD([weight], lr=0.1)
opt = torch.optim.SGD([weight], lr=0.1, foreach=foreach, fused=fused)
static_input = torch.ones_like(weight)
static_grad = torch.ones_like(weight)

Expand Down Expand Up @@ -3158,13 +3172,23 @@ def test_graph_scaling_fused_optimizers(self):
cases = [
(optimizer_ctor, {"lr": 0.1, "betas": (0.8, 0.7), "fused": True, "amsgrad": amsgrad})
for optimizer_ctor, amsgrad in product((torch.optim.Adam, torch.optim.AdamW), (False, True))
]
] + list(product(
(torch.optim.SGD,),
[
{"lr": 0.1, "momentum": 0.0, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True}
for d, w, n in product((0.0, 0.5), (0.0, 0.5), (False,))
] + [
{"lr": 0.1, "momentum": 0.5, "dampening": d, "weight_decay": w, "nesterov": n, "fused": True}
for d, w, n in product((0.0,), (0.0, 0.5), (True, False))
],
))

steps_warmup = 3
steps_train = 2

for OptClass, kwargs in cases:
for actually_do_graphs in (True, False):
has_capturable_arg = OptClass in (torch.optim.Adam, torch.optim.AdamW)
for actually_do_graphs in (True, False) if has_capturable_arg else (True,):
params = [torch.randn((i + 5, i + 5), device="cuda") for i in range(2)]
params_control = [p.clone().requires_grad_() for p in params]
params_graphed = [p.clone().requires_grad_() for p in params]
Expand All @@ -3186,8 +3210,9 @@ def test_graph_scaling_fused_optimizers(self):
scaler_for_graphed._lazy_init_scale_growth_tracker(torch.device("cuda"))

# Control (capturable=False)

opt = OptClass(params_control, capturable=False, **kwargs)
if has_capturable_arg:
kwargs["capturable"] = False
opt = OptClass(params_control, **kwargs)

for i in range(steps_warmup + steps_train):
for j, p in enumerate(params_control):
Expand All @@ -3196,8 +3221,9 @@ def test_graph_scaling_fused_optimizers(self):
scaler_for_control.update()

# capturable=True

opt = OptClass(params_graphed, capturable=True, **kwargs)
if has_capturable_arg:
kwargs["capturable"] = True
opt = OptClass(params_graphed, **kwargs)

for i in range(steps_warmup):
for j, p in enumerate(params_graphed):
Expand Down
8 changes: 8 additions & 0 deletions torch/distributed/optim/functional_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
nesterov: bool = False,
maximize: bool = False,
foreach: bool = False,
fused: bool = False,
_allow_empty_param_list: bool = False,
):
self.defaults = {
Expand All @@ -39,6 +40,7 @@ def __init__(
self.nesterov = nesterov
self.maximize = maximize
self.foreach = foreach
self.fused = fused
self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {})

if len(params) == 0 and not _allow_empty_param_list:
Expand Down Expand Up @@ -88,6 +90,9 @@ def step_param(self, param: Tensor, grad: Optional[Tensor]):
maximize=self.maximize,
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)
# update momentum_buffer in state
state = self.state[param]
Expand Down Expand Up @@ -142,6 +147,9 @@ def step(self, gradients: List[Optional[Tensor]]):
maximize=self.maximize,
has_sparse_grad=has_sparse_grad,
foreach=self.foreach,
fused=self.fused,
grad_scale=None,
found_inf=None,
)

# update momentum_buffers in state
Expand Down
Loading

0 comments on commit 1d14adf

Please sign in to comment.