Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is RoCm apex.amp deprecated & behavior mismatch vs NVIDIA APEX #118

Open
fxmarty opened this issue Sep 19, 2023 · 1 comment
Open

Is RoCm apex.amp deprecated & behavior mismatch vs NVIDIA APEX #118

fxmarty opened this issue Sep 19, 2023 · 1 comment
Labels
bug Something isn't working

Comments

@fxmarty
Copy link

fxmarty commented Sep 19, 2023

Hi, I am wondering if RoCm apex.amp is deprecated? NVIDIA APEX has some deprecation warnings that are not present in this repo: https://github.com/NVIDIA/apex/pull/1506/files

Moreover, I realize that this code

import torch
import torch.nn as nn
from apex import amp

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(3, 4)

    def forward(self, attn_probs, value_states):
        attn_output = torch.bmm(attn_probs, value_states)
        return attn_output

from torch.optim import AdamW

model = MyModule().to("cuda")
optimizer = AdamW(model.parameters())

model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

attn_probs = torch.rand(4, 16, 16).to("cuda")
value_states = torch.rand(4, 16, 2).to(torch.float16).to("cuda")

attn_output = model(attn_probs, value_states)

runs fine with NVIDA APEX but fails on RoCm APEX with the following log:

Traceback (most recent call last):
  File "run_bmm.py", line 26, in <module>
    attn_output = model(attn_probs, value_states)
  File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "run_bmm.py", line 11, in forward
    attn_output = torch.bmm(attn_probs, value_states)
RuntimeError: expected scalar type Half but found Float

However, using torch.cuda.amp.autocast instead works fine for both RoCm and CUDA-powered devices (with torch 2.0.1).

Thank you!

@fxmarty fxmarty added the bug Something isn't working label Sep 19, 2023
@pruthvistony
Copy link

@fxmarty,
I believe the problem could be happening due to some missing fix in Adam optimizer handling in ROCm apex. Checking on it will get back.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants