Skip to content

Commit

Permalink
suppress pyre erros in xnnpack/fuse_batch_norm
Browse files Browse the repository at this point in the history
Summary:
There were multiple pyre errors in fuse_batch_norm operator before this diff.

 This diff added several pyre command to suppress them.

Before:
 {F1128589876}

After:
 {F1128590200}

Reviewed By: digantdesai

Differential Revision: D50563193

fbshipit-source-id: 5843d32e9b8724399c4aaedbacd31b918d8d5748
  • Loading branch information
Gasoonjia authored and facebook-github-bot committed Oct 23, 2023
1 parent f4d9c4a commit d65bb6b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions backends/xnnpack/passes/fuse_batch_norm_with_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def call(self, graph_module: torch.fx.GraphModule):

# Get the parameters from conv op
assert len(conv.args) == 9

conv_weight = get_param_tensor(self.exported_program, conv.args[1])
assert conv_weight is not None

conv_bias = get_param_tensor(self.exported_program, conv.args[2])

# Get the parameters from the batchnorm op
Expand All @@ -80,8 +83,12 @@ def call(self, graph_module: torch.fx.GraphModule):
)
bn_weight = get_param_tensor(self.exported_program, bn.args[1])
bn_bias = get_param_tensor(self.exported_program, bn.args[2])

running_mean = get_param_tensor(self.exported_program, bn.args[3])
assert running_mean is not None

running_var = get_param_tensor(self.exported_program, bn.args[4])
assert running_var is not None

# args[7] for native_batch_norm, but args[6] for
# _native_batch_norm_legit_no_training (which doesn't have training
Expand Down

0 comments on commit d65bb6b

Please sign in to comment.