Skip to content

Commit

Permalink
Merge branch 'main' into rama/pattern-ext
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Jan 9, 2025
2 parents 267f578 + a942e95 commit 2bc0450
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 18 deletions.
169 changes: 151 additions & 18 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ class Replacement:
new_nodes: Sequence[ir.Node]


# The optimizer tracks an optional symbolic value for each value in the model.
# The symbolic value attached to a value X can be:
# - another IR value Y (indicating that X is equal to Y)
# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...)
# - a Shape object (indicating that X is a shape value)
# A Shape object as a symbolic value indicates that the corresponding value is
# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants
# or symbolic dimension values (like "batch_size", "sequence_length", etc.).
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
# TODO: Add support for negative symbolic dimensions.


class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
Expand All @@ -159,6 +171,18 @@ def add_initializer_input(self, value: ir.Value) -> None:
def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)

def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
if const_value is not None:
if const_value.ndim == 1:
return ir.Shape(const_value.tolist())
return None
sym_value = self.get_sym_value(value)
if isinstance(sym_value, ir.Shape):
return sym_value
# TODO use shape of value if available
return None


# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
Expand Down Expand Up @@ -235,11 +259,33 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction:
register = registry.register


def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool:
# Comparison of shapes as tuples works except if any dimension is None
# (which represents an unknown dimension value). Thus, two shapes such
# as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024)
# and (None, 1024) are not considered equal.
if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1):
return False
return shape1.dims == shape2.dims


def _get_numpy_value(
val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None
) -> np.ndarray | None:
"""Returns the numpy value of a constant value, if available.
It returns None if the value is not a constant value, or if the value is not of
the specified element dtype, or if the size of the value exceeds the specified
size_limit.
"""
if val is None:
return None
const_value = val.const_value
if const_value is not None:
if dtype is not None and const_value.dtype != dtype:
return None
if size_limit is not None and const_value.size > size_limit:
return None
try:
array = const_value.numpy()
except FileNotFoundError:
Expand All @@ -256,7 +302,7 @@ def _get_bool_value(val: ir.Value | None) -> bool | None:
value = _get_numpy_value(val)
if value is None:
return None
if value.size == 1 and value.dtype == np.bool_:
if value.size == 1 and value.dtype == bool:
return value.item(0)
return None

Expand Down Expand Up @@ -300,6 +346,54 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


@register("Abs")
def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Abs node by Identity when applicable.
Currently, addresses Abs applied to symbolic shapes.
"""
input = _get_input(node, 0)
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
if any(isinstance(d, int) and d < 0 for d in input_sym_value):
return None
# Abs applied to a symbolic shape of the form [1, 1, SequenceLength].
# We assume that SequenceLength is a non-negative integer.
# The Abs op is redundant in this case.
return op.Identity(input)


@register("Gather")
def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Gather node by a constant when applicable.
Currently, handles the case of Gathering from a shape tensor.
"""
input = _get_input(node, 0)
indices = _get_input(node, 1)
if input is None or indices is None:
return None
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
indices_numpy_value = _get_numpy_value(indices)
if indices_numpy_value is None:
return None
if indices_numpy_value.ndim != 1:
return None
gathered = [input_sym_value[i] for i in indices_numpy_value]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(gathered))
if all(isinstance(d, int) for d in gathered):
return op.Constant(value_ints=gathered)
return None


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable."""
Expand All @@ -310,15 +404,16 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input_shape = input.shape
if input_shape is None:
return None
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None
shape_value = _get_numpy_value(shape)
# input_shape_dims = list(input_shape.dims)
# if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims):
# return None
shape_value = state.get_shape_value(shape)
if shape_value is None:
return None
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
# target_shape_dims = list(shape_value.dims)
# if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
if _same_shape(input_shape, shape_value):
return op.Identity(input)
return None

Expand Down Expand Up @@ -373,6 +468,9 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
start = _get_int_attribute(node, "start", 0)
end = _get_int_attribute(node, "end", None)
shape_slice = shape[start:end]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(shape_slice))
if all(isinstance(d, int) for d in shape_slice):
return op.Constant(value_ints=list(shape_slice))
return None
Expand Down Expand Up @@ -459,6 +557,19 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
inputs = node.inputs
if len(inputs) == 1:
return op.Identity(inputs[0])
# Track value of tensors that carry a shape value:
output = node.outputs[0]
if output is None:
return None
# Check axis attribute is 0
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
shapes = [state.get_shape_value(input) for input in inputs]
if any(shape is None for shape in shapes):
return None
concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr]
state.set_sym_value(output, concatenated)
return None


Expand Down Expand Up @@ -507,7 +618,10 @@ def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
return None
expanded_sym_shape = state.get_shape_value(node.inputs[1])
if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape):
return None
return op.Identity(input)
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
Expand Down Expand Up @@ -658,6 +772,27 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None


def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
def merge_dims(dim1, dim2):
if dim1 == dim2:
return dim1
if not isinstance(dim1, ir.SymbolicDim):
return dim1 # Prefer int value over symbolic dim
if not isinstance(dim2, ir.SymbolicDim):
return dim2
if dim1.value is None:
return dim2
return dim1

if shape1 is None:
return shape2
if shape2 is None:
return shape1
if len(shape1) != len(shape2):
raise ValueError("Shapes must have the same rank.")
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])


class ConstantFolder:
opset_imports: dict[str, int]

Expand Down Expand Up @@ -723,7 +858,10 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
if output.name in output_types:
inferred_type = output_types[output.name]
# TODO: merge types, check for conflicts
output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type)
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
inferred_type
)
output.shape = _merge_shapes(output.shape, inferred_shape)
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
Expand Down Expand Up @@ -763,13 +901,8 @@ def new_constant(self, irvalue: ir.Value, value):
value.shape,
)

node = ir.Node(
"",
"Constant",
inputs=[],
attributes=ir.convenience.convert_attributes({"value": tensor}),
num_outputs=1,
)
attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node

def process_node(self, node: ir.Node):
Expand Down
71 changes: 71 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,77 @@ def test_expand_identity(self):
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_expand_identity_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (x)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Expand (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_abs_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (x)
const_256 = Constant <value_ints=[256]> ()
b_256 = Concat <axis=0> (b, const_256)
shape = Abs (b_256)
z = Expand (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[128, 256] x) => (float[128, 256] z)
{
shape = Constant <value_ints=[128, 256]> ()
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (y)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_gather_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
{
b_128 = Shape (y)
index_0 = Constant <value_ints=[0]> ()
b = Gather <axis=0> (b_128, index_0)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")


if __name__ == "__main__":
unittest.main()

0 comments on commit 2bc0450

Please sign in to comment.