diff --git a/exir/passes/sym_shape_eval_pass.py b/exir/passes/sym_shape_eval_pass.py index c1e3cdfe5d..2b6b0e27d2 100644 --- a/exir/passes/sym_shape_eval_pass.py +++ b/exir/passes/sym_shape_eval_pass.py @@ -259,6 +259,6 @@ def call(self, graph_module: GraphModule): "Please use export's constrain_as_size() or constrain_as_value() apis and set a concrete upper bound to resolve this." ) - spec.shape = concrete_shape # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]` - spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[Optional[int]]` + spec.shape = concrete_shape + spec.stride = concrete_stride # pyre-ignore[8]: Attribute `stride` declared in class `TensorSpec` has type `Tuple[int]` but is used as type `List[int]` return PassResult(graph_module, True) diff --git a/exir/sym_util.py b/exir/sym_util.py index 0fd7177af2..2a55d51a81 100644 --- a/exir/sym_util.py +++ b/exir/sym_util.py @@ -29,7 +29,7 @@ def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]: return int(output) -def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]: +def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int: """ Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's upper bound can not be evaluated to valid integer according to the constraints in shape_env. @@ -41,17 +41,24 @@ def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> Optional[int]: expr = node.expr var_range: ValueRanges = bound_sympy(expr, shape_env.var_to_range) upper_bound = var_range.upper + # This import is needed temporarily until we update the pinned torch version. + + try: + from torch.utils._sympy.numbers import int_oo # @manual # pyre-ignore + except ImportError: + int_oo = None + if isinstance(upper_bound, sympy.Integer): concrete_upper = int(var_range.upper) assert isinstance( concrete_upper, int ), f"Expect upper bound to be a concrete int but got {concrete_upper}" return concrete_upper - elif isinstance(upper_bound, sympy.oo): - return None + elif int_oo is not None and upper_bound is int_oo: # pyre-ignore + return int_oo # pyre-ignore else: raise RuntimeError( - f"Expect upper bound to be sympy.Integer or sympy.oo. but got {upper_bound}" + f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}" )