Skip to content

Commit

Permalink
Add implicit optional unwrapping (pytorch#15587)
Browse files Browse the repository at this point in the history
Summary:
Add support for type inference for optional type refinement.

If a conditional is of the form "x is None" or "x is not None", or is a boolean expression containing multiple none checks, the proper type refinements are inserted in each branch.

For example:
if optional_tensor is not None and len(optional_tensor) < 2:
	# optional_tensor is a Tensor

if optional_tensor1 is not None and optional_tensor2 is not None:
	# both optional_tensor1 and optional_tensor2 are Tensors

TODO:

- not run an op for unchecked unwrap optional in the interpreter

- potentially refine types to prim::None (omitted for now to simply things & because it's not an actual use cause).
Pull Request resolved: pytorch#15587

Differential Revision: D13733810

Pulled By: eellison

fbshipit-source-id: 57c32be9f5a09ab5542ba0144a6059b96de23d7a
  • Loading branch information
Elias Ellison authored and facebook-github-bot committed Jan 18, 2019
1 parent da578b7 commit d4f6bef
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 16 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ namespace c10 {
_(aten, index_put_) \
_(aten, device) \
_(aten, len) \
_(prim, unchecked_unwrap_optional)\
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
Expand Down
13 changes: 7 additions & 6 deletions test/expect/TestScript.test_if_is_none_dispatch.expect
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ graph(%input : Tensor
%5 : int = prim::Constant[value=4]()
%x.1 : Tensor = aten::add(%input, %4, %3)
%7 : bool = aten::__isnot__(%opt.1, %2)
%opt : Tensor?, %x.3 : Tensor = prim::If(%7)
%opt.4 : Tensor?, %x.3 : Tensor = prim::If(%7)
block0() {
%opt.2 : Tensor = aten::_unwrap_optional(%opt.1)
%x.2 : Tensor = aten::add(%opt.2, %x.1, %3)
-> (%opt.2, %x.2)
%opt.2 : Tensor = prim::unchecked_unwrap_optional(%opt.1)
%opt.3 : Tensor = aten::_unwrap_optional(%opt.2)
%x.2 : Tensor = aten::add(%opt.3, %x.1, %3)
-> (%opt.3, %x.2)
}
block1() {
-> (%opt.1, %x.1)
}
%12 : bool = aten::__is__(%opt, %2)
%x : Tensor = prim::If(%12)
%13 : bool = aten::__is__(%opt.4, %2)
%x : Tensor = prim::If(%13)
block0() {
%x.4 : Tensor = aten::add(%x.3, %5, %3)
-> (%x.4)
Expand Down
101 changes: 101 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4123,6 +4123,107 @@ def test_while(a, b):
return a + b
''')

def test_optional_refinement(self):
@torch.jit.script
def test_if_none_assignment(x):
# type: (Optional[int]) -> int
if x is None:
x = 1
return x + 1

self.assertEqual(test_if_none_assignment(1), 2)

@torch.jit.script
def test_ternary(x):
# type: (Optional[int]) -> int
x = x if x is not None else 2
return x

@torch.jit.script
def test_not_none(x):
# type: (Optional[int]) -> None
if x is not None:
print(x + 1)

@torch.jit.script
def test_and(x, y):
# type: (Optional[int], Optional[int]) -> None
if x is not None and y is not None:
print(x + y)

@torch.jit.script
def test_not(x, y):
# type: (Optional[int], Optional[int]) -> None
if not (x is not None and y is not None):
pass
else:
print(x + y)

@torch.jit.script
def test_bool_expression(x):
# type: (Optional[int]) -> None
if x is not None and x < 2:
print(x + 1)

@torch.jit.script
def test_nested_bool_expression(x, y):
# type: (Optional[int], Optional[int]) -> int
if x is not None and x < 2 and y is not None:
x = x + y
else:
x = 5
return x + 2

@torch.jit.script
def test_or(x, y):
# type: (Optional[int], Optional[int]) -> None
if y is None or x is None:
pass
else:
print(x + y)

# backwards compatibility
@torch.jit.script
def test_manual_unwrap_opt(x):
# type: (Optional[int]) -> int
if x is None:
x = 1
else:
x = torch.jit._unwrap_optional(x)
return x

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def or_error(x, y):
# type: (Optional[int], Optional[int]) -> int
if x is None or y is None:
print(x + y)

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def and_error(x, y):
# type: (Optional[int], Optional[int]) -> int
if x is None and y is None:
pass
else:
print(x + y)

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def named_var(x):
# type: (Optional[int]) -> None
x_none = x is not None
if x_none:
print(x + 1)

with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
@torch.jit.script
def named_var_and(x, y):
# type: (Optional[int], Optional[int]) -> None
x_none = x is not None
if y is not None and x_none:
print(x + y)

def test_while_write_outer_then_read(self):
def func(a, b):
while bool(a < 10):
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/passes/constant_propagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ std::unordered_set<Symbol> skip_list = {
prim::Loop, // TODO: handle Loop
prim::Constant,
prim::Undefined,
prim::unchecked_unwrap_optional, //TODO remove
prim::None, // it is already a constant and propagating it will lose
// important type information about which Optional type it is
// TODO (zach): we should consider skipping tensor factories in the cases
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,10 @@ RegisterOperators reg({
return 0;
};
}),
// This op can be removed in preprocessing before being run in the interpreter
// (but is currently not removed), even when it is removed it needs to remain
// a registered op so that constant prop can run.
Operator("prim::unchecked_unwrap_optional(t(a)? optional) -> t(a)", noop),
Operator(
prim::fork,
[](const Node* node) {
Expand Down
Loading

0 comments on commit d4f6bef

Please sign in to comment.