From d65bb6be59621655561a431bd48bd45ddd644b3a Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Mon, 23 Oct 2023 12:33:24 -0700 Subject: [PATCH] suppress pyre erros in xnnpack/fuse_batch_norm Summary: There were multiple pyre errors in fuse_batch_norm operator before this diff. This diff added several pyre command to suppress them. Before: {F1128589876} After: {F1128590200} Reviewed By: digantdesai Differential Revision: D50563193 fbshipit-source-id: 5843d32e9b8724399c4aaedbacd31b918d8d5748 --- backends/xnnpack/passes/fuse_batch_norm_with_conv.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/backends/xnnpack/passes/fuse_batch_norm_with_conv.py b/backends/xnnpack/passes/fuse_batch_norm_with_conv.py index 9de71b08c8..f87e8a1a80 100644 --- a/backends/xnnpack/passes/fuse_batch_norm_with_conv.py +++ b/backends/xnnpack/passes/fuse_batch_norm_with_conv.py @@ -66,7 +66,10 @@ def call(self, graph_module: torch.fx.GraphModule): # Get the parameters from conv op assert len(conv.args) == 9 + conv_weight = get_param_tensor(self.exported_program, conv.args[1]) + assert conv_weight is not None + conv_bias = get_param_tensor(self.exported_program, conv.args[2]) # Get the parameters from the batchnorm op @@ -80,8 +83,12 @@ def call(self, graph_module: torch.fx.GraphModule): ) bn_weight = get_param_tensor(self.exported_program, bn.args[1]) bn_bias = get_param_tensor(self.exported_program, bn.args[2]) + running_mean = get_param_tensor(self.exported_program, bn.args[3]) + assert running_mean is not None + running_var = get_param_tensor(self.exported_program, bn.args[4]) + assert running_var is not None # args[7] for native_batch_norm, but args[6] for # _native_batch_norm_legit_no_training (which doesn't have training