diff --git a/src/python/txtai/pipeline/train/mlonnx.py b/src/python/txtai/pipeline/train/mlonnx.py index 1e3bac653..126e025c7 100644 --- a/src/python/txtai/pipeline/train/mlonnx.py +++ b/src/python/txtai/pipeline/train/mlonnx.py @@ -54,8 +54,9 @@ def __call__(self, model, task="default", output=None, opset=12): # Find probabilities output node and rename to logits for node in model.graph.node: - if node.output[0] == "probabilities": - node.output[0] = "logits" + for x, _ in enumerate(node.output): + if node.output[x] == "probabilities": + node.output[x] = "logits" # Save model to specified output path or return bytes model = save_onnx_model(model, output)