Skip to content

Commit

Permalink
Fix pyre
Browse files Browse the repository at this point in the history
Differential Revision: D66787624

Pull Request resolved: pytorch#7185
  • Loading branch information
kirklandsign authored Dec 5, 2024
1 parent 047fd37 commit 8861b9a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
return args.get(key, default_value)
return args.get(key, default_value) # pyre-ignore[16]
elif isclass(key):
for arg in args:
if isinstance(arg, key):
Expand Down
9 changes: 7 additions & 2 deletions backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,17 @@ def call(self, graph_module: torch.fx.GraphModule):
continue

sum_node = cast(torch.fx.Node, node)
keep_dim = get_node_arg(sum_node.args, keep_dim_index, False)
keep_dim = get_node_arg(
# pyre-ignore[6]
sum_node.args,
keep_dim_index,
False,
)

if keep_dim:
continue

dim_list = get_node_arg(sum_node.args, 1, [0])
dim_list = get_node_arg(sum_node.args, 1, [0]) # pyre-ignore[6]

# Add keep_dim = True arg to sum node.
set_node_arg(sum_node, 2, True)
Expand Down

0 comments on commit 8861b9a

Please sign in to comment.