diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 4053bb2a1..c9a61475f 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -242,10 +242,12 @@ def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None: const_value = val.const_value if const_value is not None: try: - return const_value.numpy() + array = const_value.numpy() except FileNotFoundError: # External data is not available. return None + assert isinstance(array, np.ndarray) + return array return None @@ -255,14 +257,7 @@ def _get_bool_value(val: ir.Value | None) -> bool | None: value = _get_numpy_value(val) if value is None: return None - # TODO: cleanup following checks, which seem redundant. But need to also ensure - # the invariant when setting the value (and also use clearly defined representation - # types in evaluators, such a reference-evaluator). - if isinstance(value, bool): - return value - if isinstance(value, np.bool_): - return bool(value) - if isinstance(value, np.ndarray) and value.size == 1 and value.dtype == bool: + if value.size == 1 and value.dtype is bool: return value.item(0) return None diff --git a/onnxscript/rewriter/broadcast_to_matmul.py b/onnxscript/rewriter/broadcast_to_matmul.py index df216d977..4ce77c855 100644 --- a/onnxscript/rewriter/broadcast_to_matmul.py +++ b/onnxscript/rewriter/broadcast_to_matmul.py @@ -55,7 +55,7 @@ def check_if_not_need_reshape( return False input_a_shape = input_a_shape.numpy() # type: ignore[assignment] input_b_shape = input_b_shape.numpy() # type: ignore[assignment] - shape_c = shape_c_tensor.numpy().tolist() + shape_c = shape_c_tensor.numpy().tolist() # type: ignore[assignment] a_rank = len(input_a_shape) b_rank = len(input_b_shape) diff --git a/requirements-dev.txt b/requirements-dev.txt index 103fab8ab..466de2c71 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ setuptools>=61.0.0 -numpy +numpy<2.2 onnx-weekly>=1.17.0.dev20240325 onnxruntime>=1.17.0 typing_extensions