Skip to content

Commit

Permalink
Change custom_skip_targets meaning for constant_prop_pass (pytorch#3491)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3491

Some users of `constant_prop_pass` want to fold across calls to
`full`, because representing a tensor as a program constant is a requirement
for some backends.
This came up when writing some tests using `torch.ones` as a weight tensor,
which is represented as `aten.full` in Edge Dialect.

When the user specifies a custom skip set, do *not* add the default `aten.full`
function, in case the user doesn't want it.

Reviewed By: angelayi

Differential Revision: D56894215

fbshipit-source-id: 24b7f570ce41576650c457586fc5540371889121
  • Loading branch information
dulinriley authored and facebook-github-bot committed May 10, 2024
1 parent 1871ec1 commit b93b7ae
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def get_propagated_const_tensor_dict(
# Initialize dict with all constant placeholders.
const_node_to_tensor = get_constant_placeholder_dict(exported_program)

all_skip_targets: set[EdgeOpOverload] = set()
# Default set of targets to skip.
all_skip_targets.update(_DEFAULT_SKIP_TARGETS)
if custom_skip_targets is not None:
all_skip_targets.update(custom_skip_targets)
all_skip_targets = custom_skip_targets
else:
# Default set of targets to skip.
all_skip_targets = _DEFAULT_SKIP_TARGETS

for node in exported_program.graph.nodes:
if node.op != "call_function" or node.target in all_skip_targets:
Expand Down

0 comments on commit b93b7ae

Please sign in to comment.