diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cf2fae3db..1620efe0d 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3046,6 +3046,104 @@ def _aten_embedding_bag_1d_padding_idx_onnx( return result, offset2bag, bag_size, max_indices + +def test_embedding_bag_onnx(): + import numpy as np + # https://github.com/microsoft/onnxscript/issues/1056 + weight = np.array( + [[-2.7199, -1.7691, -8.5981, -5.9605, -3.7100], + [ 0.3334, 3.5580, 5.4002, -6.1015, -3.9192], + [ 3.2690, 7.4735, -1.8522, 6.7348, -1.4507], + [ 0.9523, 8.1493, -8.3490, -5.6658, -2.2785], + [-3.5082, 7.7760, -5.8336, -4.1430, -6.2878], + [-8.4290, -5.2537, 7.7364, 4.0160, 4.3621], + [ 0.4733, -4.6142, 1.5227, -8.4033, -6.5031], + [-4.6398, 5.6784, 5.2769, -3.9915, -0.3247], + [ 5.7560, 8.9472, 3.5719, 1.2158, 6.0344], + [-5.2992, 1.6771, -6.9777, -6.2378, -4.6493]], + dtype=np.float16) + indices = np.array([4, 9, 3, 0, 3], dtype=np.int64) + offsets = np.array([0, 3], dtype=np.int64) + # sample=7 + # weight = np.array( + # [[ 1.9951, -1.1777, -3.7695, -3.3125, 8.5078], + # [-3.9648, -3.2617, 4.5430, -6.7500, 1.1953], + # [ 1.8193, -4.9297, 8.3438, 1.2217, 0.0352], + # [-5.2812, -5.9414, -0.7295, 2.4785, -3.8496], + # [ 7.2070, -0.1582, 3.8047, 1.9248, -1.8018]], + # dtype=np.float16) + # indices = np.array([2, 3, 1, 4, 3, 0], dtype=np.int64) + # offsets = np.array([0, 3, 6], dtype=np.int64) + mode = 0 # sum + # include_last_offset = True + per_sample_weights = np.array([2.4134, -0.1783, 7.1360, -0.7987, 2.3815], dtype=np.float16) + #per_sample_weights = np.array([-2.2930, 6.2148, 3.1562, 0.0791, 6.3555], dtype=np.float16) + result1, offset2bag, bag_size, max_indices = aten_embedding_bag(weight, indices, offsets, mode=mode, per_sample_weights=per_sample_weights) + result2, offset2bag, bag_size, max_indices = aten_embedding_bag_padding_idx(weight, indices, offsets, mode=mode, per_sample_weights=per_sample_weights) + print("result from onnx-script:") + print(result1) + print(result2) + # print(offset2bag) + # print(bag_size) + # print(max_indices) + +def test_embedding_bag_aten(): + import torch as t + + weight = t.tensor( + [[ 1.9951, -1.1777, -3.7695, -3.3125, 8.5078], + [-3.9648, -3.2617, 4.5430, -6.7500, 1.1953], + [ 1.8193, -4.9297, 8.3438, 1.2217, 0.0352], + [-5.2812, -5.9414, -0.7295, 2.4785, -3.8496], + [ 7.2070, -0.1582, 3.8047, 1.9248, -1.8018]], + dtype=t.float16) + + indices = t.tensor([2, 3, 1, 4, 3, 0], + dtype=t.int64) + + mode = 0 + + offsets = t.tensor([0, 3, 6], dtype=t.int64) + mode = 0 # sum + include_last_offset = True + #per_sample_weights = t.tensor([-2.2930, 6.2148, 3.1562, 0.0791, 6.3555], dtype=t.float16) + + result, offset2bag, bag_size, max_indices = t.ops.aten.embedding_bag(weight, indices, offsets, mode=mode, include_last_offset=include_last_offset) + print("result from aten:") + print(result) + print(offset2bag) + print(bag_size) + print(max_indices) + +def test_embedding_bag_nn_function(): + import torch as t + weight = t.tensor( + [[-6.5664, 6.6250, 7.0664, -3.7344, 0.6152], + [ 4.1484, -3.7266, 3.4805, -6.2422, -2.8047], + [ 4.2734, -4.1562, -8.2344, -7.4688, 5.2734], + [-1.5381, 5.9492, -4.2812, -1.5732, -8.3672], + [-2.1719, 8.0469, -7.9883, -0.4219, -2.3633], + [ 6.2305, 8.9844, 7.4453, 3.7891, -5.0625], + [-1.5293, -8.1328, 8.6484, 1.5557, -2.3633], + [-1.9951, -3.2070, 1.2920, -1.0020, -5.2812], + [ 2.5312, 8.4453, 2.3281, -2.8750, -3.3828], + [-4.2188, -4.2266, -2.7246, -6.8555, -7.6719]], dtype=t.float16) + indices = t.tensor([4, 9, 3, 0, 3], dtype=t.int64) + offsets = t.tensor([0, 3], dtype=t.int64) + mode = 0 # sum + per_sample_weights = t.tensor([-2.2930, 6.2148, 3.1562, 0.0791, 6.3555], dtype=t.float16) + result = t.nn.functional.embedding_bag(indices, weight, offsets=offsets, mode="sum", per_sample_weights=per_sample_weights) + print("result from nn.functional:") + print(result) + + +test_embedding_bag_onnx() +#test_embedding_bag_aten() +#test_embedding_bag_nn_function() + +exit(0) + + def aten_embedding_dense_backward( grad_output: TensorType, indices: TensorType,