Skip to content

Commit

Permalink
numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jan 2, 2025
1 parent ff7a0a2 commit 06d556e
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 11 deletions.
13 changes: 4 additions & 9 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,12 @@ def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
const_value = val.const_value
if const_value is not None:
try:
return const_value.numpy()
array = const_value.numpy()
except FileNotFoundError:
# External data is not available.
return None
assert isinstance(array, np.ndarray)
return array
return None


Expand All @@ -255,14 +257,7 @@ def _get_bool_value(val: ir.Value | None) -> bool | None:
value = _get_numpy_value(val)
if value is None:
return None
# TODO: cleanup following checks, which seem redundant. But need to also ensure
# the invariant when setting the value (and also use clearly defined representation
# types in evaluators, such a reference-evaluator).
if isinstance(value, bool):
return value
if isinstance(value, np.bool_):
return bool(value)
if isinstance(value, np.ndarray) and value.size == 1 and value.dtype == bool:
if value.size == 1 and value.dtype is bool:
return value.item(0)
return None

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/broadcast_to_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def check_if_not_need_reshape(
return False
input_a_shape = input_a_shape.numpy() # type: ignore[assignment]
input_b_shape = input_b_shape.numpy() # type: ignore[assignment]
shape_c = shape_c_tensor.numpy().tolist()
shape_c = shape_c_tensor.numpy().tolist() # type: ignore[assignment]

a_rank = len(input_a_shape)
b_rank = len(input_b_shape)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
setuptools>=61.0.0
numpy
numpy<2.2
onnx-weekly>=1.17.0.dev20240325
onnxruntime>=1.17.0
typing_extensions
Expand Down

0 comments on commit 06d556e

Please sign in to comment.