Skip to content

Commit

Permalink
layernorm fix None inputs (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Nov 13, 2024
1 parent 203e437 commit 8d95e63
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
8 changes: 8 additions & 0 deletions py/torch_migraphx/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,14 @@ def acc_ops_layer_norm(mgx_module, node, args, kwargs):
normalized_shape = kwargs['normalized_shape']
weight = kwargs['weight']
bias = kwargs['bias']

dtype = get_arg_dtype(inp.instr_ref)
if weight is None:
weight = MGXInstruction(
mgx_module.add_literal(torch.tensor(1, dtype=dtype).numpy()))
if bias is None:
bias = MGXInstruction(
mgx_module.add_literal(torch.tensor(0, dtype=dtype).numpy()))

assert all(not i.is_quantized() for i in (inp, weight, bias))
inp, weight, bias = inp.instr_ref, weight.instr_ref, bias.instr_ref
Expand Down
14 changes: 14 additions & 0 deletions tests/dynamo/converters/test_norm_ops_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@ def test_layernorm(op_alias, ln):

mgx_mod = convert_to_mgx(mod, [inp])
verify_outputs(mod, mgx_mod, inp)


@pytest.mark.parametrize('op_alias',
[torch.ops.aten.native_layer_norm.default])
@pytest.mark.parametrize('ln', [
torch.nn.LayerNorm((25, 25), 1e-2).cuda().eval(),
])
def test_layernorm_defaults(op_alias, ln):
inp = torch.randn(8, 6, 25, 25).cuda()
norm_shape, eps = ln.normalized_shape, ln.eps
mod = FuncModuleFirstOut(op_alias, norm_shape, None, None, eps)

mgx_mod = convert_to_mgx(mod, [inp])
verify_outputs(mod, mgx_mod, inp)

0 comments on commit 8d95e63

Please sign in to comment.