Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMP + BF16 failing #95

Open
jramapuram opened this issue Jan 28, 2024 · 4 comments
Open

AMP + BF16 failing #95

jramapuram opened this issue Jan 28, 2024 · 4 comments

Comments

@jramapuram
Copy link

Hi there,

Great work with dMoE! I'm trying to test dMoE with regular DDP + pytorch AMP(BF16) and I get the following error:

    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/grad_scaler.py", line 248, in _unscale_grads_
    torch._amp_foreach_non_finite_check_and_unscale_(

I'm just wrapping your exisiting dmoe.dMoE(args) logic.

Is this something that is currently unsupported? If I force the entire network to BF16 then everything works fine.

@mvpatel2000
Copy link
Contributor

I've also seen some issues with AMP. I think theres something missing somewhere... but all the functions seem wrapped to me?

@jramapuram
Copy link
Author

@mvpatel2000 : this can be worked around for moe.MoE by force casting moe.to(torch.float32) and AMP works fine. When doing the same with dmoe.dMoE I get a triton error:

  File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 251, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 288, in apply
    return user_fn(self, *args)
  File "/miniconda/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 140, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/miniconda/lib/python3.10/site-packages/megablocks/layers/mlp.py", line 270, in backward
    stk.backend.triton_kernels.sdd(
  File "/miniconda/lib/python3.10/site-packages/stk/backend/triton_kernels.py", line 336, in sdd
    _sdd_kernel[grid](
  File "/miniconda/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in run
    ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
  File "", line 63, in _sdd_kernel
  File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/miniconda/lib/python3.10/site-packages/triton/compiler/compiler.py", line 381, in 
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/miniconda/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 26:25:    ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    # pointers
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
    # do matrix multiplication
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(0, tl.cdiv(K, BLOCK_K)):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)

@jramapuram
Copy link
Author

jramapuram commented Jul 16, 2024

Some more small updates on AMP bugs @mvpatel2000

What works:

  • glu/mlp + sparse + dmoe
  • glu/mlp + sparse + moe
  • glu/mlp + grouped + moe

What doesn't work:

  • glu (and MLP) + grouped + dmoe
  File "/miniconda/lib/python3.10/site-packages/megablocks/layers/glu.py", line 158, in forward
    x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
  File "/miniconda/lib/python3.10/site-packages/grouped_gemm/ops.py", line 33, in gmm
    return GroupedGemm.apply(a, b, batch_sizes, trans_b)                                                                                                        File "/miniconda/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/miniconda/lib/python3.10/site-packages/grouped_gemm/ops.py", line 11, in forward
    return backend.gmm(a, b, batch_sizes, trans_a=False, trans_b=trans_b)
  File "/miniconda/lib/python3.10/site-packages/grouped_gemm/backend.py", line 27, in gmm
    backend.gmm(a, b, c, batch_sizes, trans_a, trans_b)
RuntimeError: Expected b.scalar_type() == torch::kBFloat16 to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

@mvpatel2000
Copy link
Contributor

@jramapuram any chance you can provide a mini repro? happy to look into it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants