Skip to content

Commit

Permalink
[tools/onnx-subgraph] support arbitrary input shape for onnx
Browse files Browse the repository at this point in the history
1. fixed input shape as default value
2. user can set the input shape dynamically with guide

ONE-DCO-1.0-Signed-off-by: Youxin Chen <[email protected]>
  • Loading branch information
chenyx113 committed Feb 27, 2025
1 parent 58324bf commit e8aadae
Showing 1 changed file with 49 additions and 3 deletions.
52 changes: 49 additions & 3 deletions tools/onnx_subgraph/single_vs_multiple_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ModelInference:
Description:
Subgraphsiostxt_path is a txt file that describes the structure of the model graph and
is used to get input/output node names.The model_path contains paths to multiple onnx files.
The load_sessions function will sort the onnx models in the model_path according to the
The load_sessions function will sort the onnx models in the model_path according to the
order specified in subgraphsiostxt_path.
"""
def __init__(self, model_path, subgraphsiostxt_path):
Expand All @@ -45,6 +45,52 @@ def infer_single_onnx_model(model_file, input_data):
return output_dict


def prepare_initial_input_data(onnx_model_path, default_input_data):
"""
Prepares initial input data for inference.
Args:
onnx_model_path (str): Path to the ONNX model file.
default_input_data (dict): Dictionary containing default input data.
Returns:
dict: Dictionary with user-specified or default shaped and typed input data.
"""
session = ort.InferenceSession(onnx_model_path)
input_info = {input.name: input.shape for input in session.get_inputs()}

initial_input_data = {}
dtype_map = {'f': np.float32, 'i': np.int64}

for input_name, shape in input_info.items():
custom_shape_str = input(
f"Enter new shape for input '{input_name}' (comma-separated integers), or press Enter to use default: "
)
custom_dtype_str = input(
f"Enter data type for input '{input_name}' ('f' for float32, 'i' for int64), or press Enter to use default: "
)

if not custom_shape_str:
new_shape = default_input_data[input_name].shape
else:
try:
new_shape = [int(dim) for dim in custom_shape_str.split(',')]
except ValueError:
print("Invalid input, please ensure you enter comma-separated integers.")
continue

if not custom_dtype_str:
dtype = default_input_data[input_name].dtype
else:
dtype = dtype_map.get(custom_dtype_str.strip(), None)
if dtype is None:
print("Invalid data type, please enter 'f' or 'i'.")
continue

input_data = np.random.rand(*new_shape).astype(dtype)
initial_input_data[input_name] = input_data

return initial_input_data


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('-s',
Expand All @@ -68,8 +114,8 @@ def infer_single_onnx_model(model_file, input_data):
default_input_data = {
"x": np.random.rand(1, 3, 256, 256).astype(np.float32),
}

initial_input_data = prepare_initial_input_data(args.single, default_input_data)
# Perform inference using a single ONNX model
output_single = ModelInference.infer_single_onnx_model(args.single,
default_input_data)
initial_input_data)
print("Single model inference completed!")

0 comments on commit e8aadae

Please sign in to comment.