From ac795da51046ceb5ead683d0745cb151597b1c61 Mon Sep 17 00:00:00 2001 From: Shiv Date: Tue, 12 Nov 2024 11:18:35 -0800 Subject: [PATCH] layernorm fix None inputs --- .../fx/converters/acc_ops_converters.py | 8 ++++++++ tests/dynamo/converters/test_norm_ops_dynamo.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index 8a93263b..75c3af1c 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -2056,6 +2056,14 @@ def acc_ops_layer_norm(mgx_module, node, args, kwargs): normalized_shape = kwargs['normalized_shape'] weight = kwargs['weight'] bias = kwargs['bias'] + + dtype = get_arg_dtype(inp.instr_ref) + if weight is None: + weight = MGXInstruction( + mgx_module.add_literal(torch.tensor(1, dtype=dtype).numpy())) + if bias is None: + bias = MGXInstruction( + mgx_module.add_literal(torch.tensor(0, dtype=dtype).numpy())) assert all(not i.is_quantized() for i in (inp, weight, bias)) inp, weight, bias = inp.instr_ref, weight.instr_ref, bias.instr_ref diff --git a/tests/dynamo/converters/test_norm_ops_dynamo.py b/tests/dynamo/converters/test_norm_ops_dynamo.py index f8f0908f..c517f03f 100644 --- a/tests/dynamo/converters/test_norm_ops_dynamo.py +++ b/tests/dynamo/converters/test_norm_ops_dynamo.py @@ -94,3 +94,17 @@ def test_layernorm(op_alias, ln): mgx_mod = convert_to_mgx(mod, [inp]) verify_outputs(mod, mgx_mod, inp) + + +@pytest.mark.parametrize('op_alias', + [torch.ops.aten.native_layer_norm.default]) +@pytest.mark.parametrize('ln', [ + torch.nn.LayerNorm((25, 25), 1e-2).cuda().eval(), +]) +def test_layernorm_defaults(op_alias, ln): + inp = torch.randn(8, 6, 25, 25).cuda() + norm_shape, eps = ln.normalized_shape, ln.eps + mod = FuncModuleFirstOut(op_alias, norm_shape, None, None, eps) + + mgx_mod = convert_to_mgx(mod, [inp]) + verify_outputs(mod, mgx_mod, inp)