You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
runs fine with NVIDA APEX but fails on RoCm APEX with the following log:
Traceback (most recent call last):
File "run_bmm.py", line 26, in <module>
attn_output = model(attn_probs, value_states)
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "run_bmm.py", line 11, in forward
attn_output = torch.bmm(attn_probs, value_states)
RuntimeError: expected scalar type Half but found Float
However, using torch.cuda.amp.autocast instead works fine for both RoCm and CUDA-powered devices (with torch 2.0.1).
Thank you!
The text was updated successfully, but these errors were encountered:
Hi, I am wondering if RoCm apex.amp is deprecated? NVIDIA APEX has some deprecation warnings that are not present in this repo: https://github.com/NVIDIA/apex/pull/1506/files
Moreover, I realize that this code
runs fine with NVIDA APEX but fails on RoCm APEX with the following log:
However, using
torch.cuda.amp.autocast
instead works fine for both RoCm and CUDA-powered devices (with torch 2.0.1).Thank you!
The text was updated successfully, but these errors were encountered: