Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer extensions #2003

Merged
merged 9 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 151 additions & 18 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ class Replacement:
new_nodes: Sequence[ir.Node]


# The optimizer tracks an optional symbolic value for each value in the model.
# The symbolic value attached to a value X can be:
# - another IR value Y (indicating that X is equal to Y)
# - a list of IR values [Y1, Y2, ...] (indicating that X is a sequence of values Y1, Y2, ...)
# - a Shape object (indicating that X is a shape value)
# A Shape object as a symbolic value indicates that the corresponding value is
# 1-D (or 0-D) tensor of INT64 values. The values in this object may be constants
# or symbolic dimension values (like "batch_size", "sequence_length", etc.).
# Currently, we assume that symbolic dimensions are also guaranteed to be non-negative.
# TODO: Add support for negative symbolic dimensions.


class OptimizerState:
def __init__(self):
self._sym_value_map: dict[ir.Value, Any] = {}
Expand All @@ -159,6 +171,18 @@ def add_initializer_input(self, value: ir.Value) -> None:
def is_initializer_input(self, value: ir.Value) -> bool:
return any(value in inputs for inputs in self._initializer_inputs)

def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
const_value = _get_numpy_value(value, ir.DataType.INT64, size_limit=10)
if const_value is not None:
if const_value.ndim == 1:
return ir.Shape(const_value.tolist())
return None
sym_value = self.get_sym_value(value)
if isinstance(sym_value, ir.Shape):
return sym_value
# TODO use shape of value if available
return None


# The "partial evaluators" below are non-standard evaluators. They are used to perform
# partial evaluation and/or static program analysis (abstract interpretation).
Expand Down Expand Up @@ -235,11 +259,33 @@ def decorator(function: PartialEvaluatorFunction) -> PartialEvaluatorFunction:
register = registry.register


def _get_numpy_value(val: ir.Value | None) -> np.ndarray | None:
def _same_shape(shape1: ir.Shape, shape2: ir.Shape) -> bool:
# Comparison of shapes as tuples works except if any dimension is None
# (which represents an unknown dimension value). Thus, two shapes such
# as (Batch, 1024) and (Batch, 1024) are considered equal, but (None, 1024)
# and (None, 1024) are not considered equal.
if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in shape1):
return False
return shape1.dims == shape2.dims


def _get_numpy_value(
val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
) -> np.ndarray | None:
"""Returns the numpy value of a constant value, if available.

It returns None if the value is not a constant value, or if the value is not of
the specified element dtype, or if the size of the value exceeds the specified
size_limit.
"""
if val is None:
return None
const_value = val.const_value
if const_value is not None:
if dtype is not None and const_value.dtype != dtype:
return None
if size_limit is not None and const_value.size > size_limit:
return None
try:
array = const_value.numpy()
except FileNotFoundError:
Expand All @@ -256,7 +302,7 @@ def _get_bool_value(val: ir.Value | None) -> bool | None:
value = _get_numpy_value(val)
if value is None:
return None
if value.size == 1 and value.dtype == np.bool_:
if value.size == 1 and value.dtype == bool:
return value.item(0)
return None

Expand Down Expand Up @@ -300,6 +346,54 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


@register("Abs")
def abs(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace an Abs node by Identity when applicable.

Currently, addresses Abs applied to symbolic shapes.
"""
input = _get_input(node, 0)
input_sym_value = state.get_shape_value(input)
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
if input_sym_value is None:
return None
if any(isinstance(d, int) and d < 0 for d in input_sym_value):
return None
# Abs applied to a symbolic shape of the form [1, 1, SequenceLength].
# We assume that SequenceLength is a non-negative integer.
# The Abs op is redundant in this case.
return op.Identity(input)


@register("Gather")
def gather(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Gather node by a constant when applicable.

Currently, handles the case of Gathering from a shape tensor.
"""
input = _get_input(node, 0)
indices = _get_input(node, 1)
if input is None or indices is None:
return None
input_sym_value = state.get_shape_value(input)
if input_sym_value is None:
return None
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
indices_numpy_value = _get_numpy_value(indices)
if indices_numpy_value is None:
return None
if indices_numpy_value.ndim != 1:
return None
gathered = [input_sym_value[i] for i in indices_numpy_value]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(gathered))
if all(isinstance(d, int) for d in gathered):
return op.Constant(value_ints=gathered)
return None


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable."""
Expand All @@ -310,15 +404,16 @@ def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input_shape = input.shape
if input_shape is None:
return None
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None
shape_value = _get_numpy_value(shape)
# input_shape_dims = list(input_shape.dims)
# if any(isinstance(dim, ir.SymbolicDim) and dim.value is None for dim in input_shape_dims):
# return None
shape_value = state.get_shape_value(shape)
if shape_value is None:
return None
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
# target_shape_dims = list(shape_value.dims)
# if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
if _same_shape(input_shape, shape_value):
return op.Identity(input)
return None

Expand Down Expand Up @@ -373,6 +468,9 @@ def shape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
start = _get_int_attribute(node, "start", 0)
end = _get_int_attribute(node, "end", None)
shape_slice = shape[start:end]
output = _get_output(node, 0)
if output is not None:
state.set_sym_value(output, ir.Shape(shape_slice))
if all(isinstance(d, int) for d in shape_slice):
return op.Constant(value_ints=list(shape_slice))
return None
Expand Down Expand Up @@ -459,6 +557,19 @@ def concat(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
inputs = node.inputs
if len(inputs) == 1:
return op.Identity(inputs[0])
# Track value of tensors that carry a shape value:
output = node.outputs[0]
if output is None:
return None
gramalingam marked this conversation as resolved.
Show resolved Hide resolved
# Check axis attribute is 0
axis = _get_int_attribute(node, "axis", None)
if axis != 0:
return None
shapes = [state.get_shape_value(input) for input in inputs]
if any(shape is None for shape in shapes):
return None
concatenated = ir.Shape(dim for shape in shapes for dim in shape.dims) # type: ignore[union-attr]
state.set_sym_value(output, concatenated)
return None


Expand Down Expand Up @@ -507,7 +618,10 @@ def expand(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None
if (expanded_shape := _get_numpy_value(node.inputs[1])) is None:
# Target shape is not known.
return None
expanded_sym_shape = state.get_shape_value(node.inputs[1])
if expanded_sym_shape is None or not _same_shape(input_shape, expanded_sym_shape):
return None
return op.Identity(input)
if expanded_shape.ndim != 1:
# Target shape must be a 1D tensor. Erroneous model.
return None
Expand Down Expand Up @@ -658,6 +772,27 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
return None


def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
def merge_dims(dim1, dim2):
if dim1 == dim2:
return dim1
if not isinstance(dim1, ir.SymbolicDim):
return dim1 # Prefer int value over symbolic dim
if not isinstance(dim2, ir.SymbolicDim):
return dim2
if dim1.value is None:
return dim2
return dim1

if shape1 is None:
return shape2
if shape2 is None:
return shape1
if len(shape1) != len(shape2):
raise ValueError("Shapes must have the same rank.")
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])


class ConstantFolder:
opset_imports: dict[str, int]

Expand Down Expand Up @@ -723,7 +858,10 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
if output.name in output_types:
inferred_type = output_types[output.name]
# TODO: merge types, check for conflicts
output.shape = ir.serde.deserialize_type_proto_for_shape(inferred_type)
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
inferred_type
)
output.shape = _merge_shapes(output.shape, inferred_shape)
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
except Exception as e:
logger.debug(
Expand Down Expand Up @@ -763,13 +901,8 @@ def new_constant(self, irvalue: ir.Value, value):
value.shape,
)

node = ir.Node(
"",
"Constant",
inputs=[],
attributes=ir.convenience.convert_attributes({"value": tensor}),
num_outputs=1,
)
attributes = ir.convenience.convert_attributes({"value": tensor})
node = ir.Node("", "Constant", inputs=[], attributes=attributes, num_outputs=1)
return node

def process_node(self, node: ir.Node):
Expand Down
71 changes: 71 additions & 0 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,77 @@ def test_expand_identity(self):
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_expand_identity_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (x)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Expand (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_abs_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (x)
const_256 = Constant <value_ints=[256]> ()
b_256 = Concat <axis=0> (b, const_256)
shape = Abs (b_256)
z = Expand (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[128, 256] x) => (float[128, 256] z)
{
shape = Constant <value_ints=[128, 256]> ()
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_reshape_identity_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
{
b = Shape <start=0, end=1> (y)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")

def test_gather_symdim(self):
model = """
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[B, 256] x, float[B, 128] y) => (float[B, 256] z)
{
b_128 = Shape (y)
index_0 = Constant <value_ints=[0]> ()
b = Gather <axis=0> (b_128, index_0)
const_256 = Constant <value_ints=[256]> ()
shape = Concat <axis=0> (b, const_256)
z = Reshape (x, shape)
}
"""
optimized = self._fold(model)
self.assertEqual(optimized.graph.node(-1).op_type, "Identity")


if __name__ == "__main__":
unittest.main()
Loading