-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AMP + BF16 failing #95
Comments
I've also seen some issues with AMP. I think theres something missing somewhere... but all the functions seem wrapped to me? |
@mvpatel2000 : this can be worked around for File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
return user_fn(self, *args)
File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
return bwd(*args, **kwargs)
File "/miniconda/lib/python3.10/site-packages/megablocks/layers/mlp.py", line 270, in backward
stk.backend.triton_kernels.sdd(
File "/miniconda/lib/python3.10/site-packages/stk/backend/triton_kernels.py", line 336, in sdd
_sdd_kernel[grid](
File "/miniconda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in run
ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
File "", line 63, in _sdd_kernel
File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/miniconda/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 26:25: ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
# do matrix multiplication
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(A)
b = tl.load(B)
acc += tl.dot(a, b) |
Some more small updates on AMP bugs @mvpatel2000 What works:
What doesn't work:
File "/miniconda/lib/python3.10/site-packages/megablocks/layers/glu.py", line 158, in forward
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
File "/miniconda/lib/python3.10/site-packages/grouped_gemm/ops.py", line 33, in gmm
return GroupedGemm.apply(a, b, batch_sizes, trans_b) File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/miniconda/lib/python3.10/site-packages/grouped_gemm/ops.py", line 11, in forward
return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
File "/miniconda/lib/python3.10/site-packages/grouped_gemm/backend.py", line 27, in gmm
backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
RuntimeError: Expected b.scalar_type() == torch::kBFloat16 to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) |
@jramapuram any chance you can provide a mini repro? happy to look into it |
Hi there,
Great work with dMoE! I'm trying to test dMoE with regular DDP + pytorch AMP(BF16) and I get the following error:
I'm just wrapping your exisiting
dmoe.dMoE(args)
logic.Is this something that is currently unsupported? If I force the entire network to BF16 then everything works fine.
The text was updated successfully, but these errors were encountered: