diff --git a/py/torch_migraphx/fx/converters/acc_ops_converters.py b/py/torch_migraphx/fx/converters/acc_ops_converters.py index dbab3ca..8a0ce57 100644 --- a/py/torch_migraphx/fx/converters/acc_ops_converters.py +++ b/py/torch_migraphx/fx/converters/acc_ops_converters.py @@ -1193,8 +1193,9 @@ def acc_ops_embedding(mgx_module, node, args, kwargs): mgx_module.add_instruction(migraphx.op('gather', axis=0), [weight.instr_ref, inp.instr_ref])) -@migraphx_converter(acc_ops.gather) -def acc_ops_gather(mgx_module, node, args, kwargs): +## MIGraphX cannot optimize gathernd well in some cases +@migraphx_converter(acc_ops.gather, enabled=False) +def acc_ops_gather_legacy(mgx_module, node, args, kwargs): inp = kwargs['input'] dim = kwargs['dim'] index = kwargs['index'] @@ -1219,6 +1220,55 @@ def acc_ops_gather(mgx_module, node, args, kwargs): new_shape = tuple(list(index_lens) + [len(index_lens)]) coords = acc_ops_reshape(mgx_module, node, (), {"input": coords, "shape": new_shape}) return MGXInstruction(mgx_module.add_instruction(migraphx.op('gathernd'), [inp.instr_ref, coords.instr_ref])) + + +@migraphx_converter(acc_ops.gather) +def acc_ops_gather(mgx_module, node, args, kwargs): + inp = kwargs['input'] + dim = kwargs['dim'] + idx = kwargs['index'] + + assert not inp.is_quantized() and not idx.is_quantized() + + inp_ref = mgx_module.add_instruction(migraphx.op("contiguous"), [inp.instr_ref]) + idx_ref = mgx_module.add_instruction(migraphx.op("contiguous"), [idx.instr_ref]) + + inp_lens, inp_strides = inp_ref.shape().lens(), inp_ref.shape().strides() + idx_lens, idx_strides = idx_ref.shape().lens(), idx_ref.shape().strides() + idx_dtype = get_arg_dtype(idx.instr_ref) + + assert len(idx_lens) == len(inp_lens) + if dim < 0: + dim = len(idx_lens) + dim + + base_indices = torch.zeros(idx_lens, dtype=idx_dtype) + for a in range(len(idx_lens)): + if a == dim: + continue + + a_shp = [1] * len(inp_lens) + a_shp[a] = inp_lens[a] + a_inds = torch.arange(inp_lens[a]) * inp_strides[a] + a_inds = a_inds.reshape(a_shp).broadcast_to(idx_lens) + base_indices += a_inds + + base_indices_lit = mgx_module.add_literal(base_indices.numpy()) + dim_stride = mgx_module.add_literal( + torch.tensor(inp_strides[dim], dtype=idx_dtype).numpy()) + dim_stride = mgx_module.add_instruction( + migraphx.op('multibroadcast', out_lens=idx_lens), [dim_stride]) + + dim_indices = mgx_module.add_instruction(migraphx.op("mul"), + [idx_ref, dim_stride]) + data_indices = mgx_module.add_instruction(migraphx.op("add"), + [base_indices_lit, dim_indices]) + + flat_inp = mgx_module.add_instruction( + migraphx.op('reshape', dims=[inp.shape().elements()]), [inp_ref]) + + return MGXInstruction( + mgx_module.add_instruction(migraphx.op('gather', axis=0), + [flat_inp, data_indices])) @migraphx_converter(acc_ops.reshape) @@ -1759,18 +1809,54 @@ def acc_ops_getitem(mgx_module, node, args, kwargs): elif num_tensor_dims > 1: idx_tensors = [idx[ax] for ax in tensor_dims] idx_tensors = broadcast_tensors(mgx_module, *idx_tensors) - unsq_idx_tensors = [] - for t in idx_tensors: - unsq_idx_tensors.append( - mgx_module.add_instruction(migraphx.op('unsqueeze', axes=[-1]), - [t])) - gather_idx = mgx_module.add_instruction(migraphx.op('concat', axis=-1), - unsq_idx_tensors) - - out_mgx = mgx_module.add_instruction(migraphx.op('gathernd'), - [out_mgx, gather_idx]) - - idx_rank = len(gather_idx.shape().lens()) - 1 + idx_rank = len(idx_tensors[0].shape().lens()) + + idx_dtype = get_arg_dtype(idx_tensors[0]) + lens = out_mgx.shape().lens() + out_lens = idx_tensors[0].shape().lens() + lens[num_tensor_dims:] + axial_indices = [] + for ax, dim in enumerate(lens): + post_dims = len(lens) - len(idx_tensors) + unsq_dims = list(range(-1, -post_dims - 1, -1)) + if ax < num_tensor_dims: + ax_idx = idx_tensors[ax] + ax_idx = normalize_neg_indices(mgx_module, ax_idx, dim) + ax_idx = mgx_module.add_instruction( + migraphx.op("unsqueeze", axes=unsq_dims), [ax_idx]) + ax_idx = insert_mbroadcast(mgx_module, ax_idx, out_lens) + else: + shp = [1] * len(out_lens) + shp[ax - len(lens)] = dim + ax_idx = torch.arange(dim).reshape(shp).broadcast_to(out_lens) + ax_idx = mgx_module.add_literal(ax_idx.to(idx_dtype).numpy()) + + axial_indices.append(ax_idx) + + out_mgx = mgx_module.add_instruction( + migraphx.op('reshape', dims=[out_mgx.shape().elements()]), + [out_mgx]) + + ## Compute indices for the new flattened tensor + gather_indices = axial_indices[-1] + multiplier = mgx_module.add_literal(torch.tensor(1, dtype=idx_dtype).numpy()) + multiplier = insert_mbroadcast(mgx_module, multiplier, out_lens) + + for i in range(len(lens)-2, -1, -1): + prev_len = mgx_module.add_literal( + torch.tensor(lens[i+1], dtype=idx_dtype).numpy()) + prev_len = insert_mbroadcast(mgx_module, prev_len, multiplier.shape().lens()) + multiplier = mgx_module.add_instruction(migraphx.op("mul"), + [multiplier, prev_len]) + + offset = mgx_module.add_instruction( + migraphx.op("mul"), [axial_indices[i], multiplier]) + gather_indices = mgx_module.add_instruction( + migraphx.op("add"), [gather_indices, offset]) + + + out_mgx = mgx_module.add_instruction(migraphx.op('gather', axis=0), + [out_mgx, gather_indices]) + offset = num_tensor_dims - idx_rank # Remove squeezed dimensions from original permutation diff --git a/py/torch_migraphx/fx/converters/utils.py b/py/torch_migraphx/fx/converters/utils.py index f3fd9fd..a33f14b 100644 --- a/py/torch_migraphx/fx/converters/utils.py +++ b/py/torch_migraphx/fx/converters/utils.py @@ -51,6 +51,28 @@ def broadcast_tensors(mgx_module, *tensors): return outs +def insert_mbroadcast(mgx_module, ins, dims): + return mgx_module.add_instruction( + migraphx.op("multibroadcast", out_lens=dims), [ins]) + + +def normalize_neg_indices(mgx_module, idx_ins, dim_val): + dtype = get_arg_dtype(idx_ins) + # find locations of negative indices + zeros = mgx_module.add_literal(torch.tensor(0, dtype=dtype).numpy()) + zeros = insert_mbroadcast(mgx_module, zeros, idx_ins.shape().lens()) + neg_idx = mgx_module.add_instruction(migraphx.op('less'), [idx_ins, zeros]) + + dim_size = mgx_module.add_literal( + torch.tensor(dim_val, dtype=dtype).numpy()) + dim_size = insert_mbroadcast(mgx_module, dim_size, idx_ins.shape().lens()) + offset_idx = mgx_module.add_instruction(migraphx.op('add'), + [idx_ins, dim_size]) + + return mgx_module.add_instruction(migraphx.op('where'), + [neg_idx, offset_idx, idx_ins]) + + def get_arg_dtype(arg): if isinstance(arg, migraphx.instruction_ref): dtype = torch_dtype_from_mgx(arg.shape().type_string()) diff --git a/tests/dynamo/converters/test_gather_dynamo.py b/tests/dynamo/converters/test_gather_dynamo.py index 70df445..12e9c25 100644 --- a/tests/dynamo/converters/test_gather_dynamo.py +++ b/tests/dynamo/converters/test_gather_dynamo.py @@ -16,19 +16,19 @@ def forward(self, x, idx): @pytest.mark.parametrize('op_alias', [torch.ops.aten.gather.default]) -@pytest.mark.parametrize("input_dim, dim", [((3, 3), 0), - ((3, 3), 1), - ((3, 3), -1), - ((3, 3), -2), - ((10, 5), -2), - ((2, 3, 4, 5, 6), -3), - ((2, 3, 4, 5, 6), -4)]) +@pytest.mark.parametrize("input_dim, dim", [((3, 2), 0), + ((3, 2, 4), 1), + ((1, 1), -1), + ((2, 4, 6), -2), + ((4, 2), -2), + ((2, 3, 4, 1, 3), -3), + ((3, 3, 2, 4, 5), -4)]) def test_gather(op_alias, input_dim, dim): input = torch.rand(input_dim).cuda() dim_size = input.size(dim) index_shape = list(input.size()) - index_shape[dim] = np.random.randint(1, dim_size) + index_shape[dim] = np.random.randint(1, dim_size*2) index = torch.randint(0, dim_size, index_shape).cuda() mod = GatherModule(op_alias, dim).cuda() diff --git a/tests/fx/converters/test_gather_fx.py b/tests/fx/converters/test_gather_fx.py index e529734..0738552 100644 --- a/tests/fx/converters/test_gather_fx.py +++ b/tests/fx/converters/test_gather_fx.py @@ -24,7 +24,7 @@ def test_gather(input_dim, dim): dim_size = input.size(dim) index_shape = list(input.size()) - index_shape[dim] = np.random.randint(1, dim_size) + index_shape[dim] = np.random.randint(1, dim_size*2) index = torch.randint(0, dim_size, index_shape) mod = GatherModule(dim) diff --git a/tests/fx/converters/test_loss_fx.py b/tests/fx/converters/test_loss_fx.py index e869e83..a7a1360 100644 --- a/tests/fx/converters/test_loss_fx.py +++ b/tests/fx/converters/test_loss_fx.py @@ -26,7 +26,6 @@ def test_nll_loss_fx(inp_size, no_weight, reduction, ignore_index): @pytest.mark.parametrize('reduction', [('mean'), ('sum'), ('none')]) @pytest.mark.parametrize('C, no_weight, target, ignore_index', [ - (3, True, 0, 0), (3, False, 1, -100), (3, True, 2, 1), ]) @@ -40,7 +39,4 @@ def test_nll_loss_1d_fx(C, no_weight, reduction, target, ignore_index): mod = FuncModule(torch.nn.functional.nll_loss, target=target, weight=weight, reduction = reduction, ignore_index = ignore_index) mgx_mod = convert_to_mgx(mod, [inp]) - # Output is nan when ignore_idx == target (div by 0) - # MIGraphX creates a kernel that ends up outputting a tensor of len 1 instead of a scalar - # TODO: fused kernels in migraphx should respect the original output shape - verify_outputs(mod, mgx_mod, [inp], equal_nan=True, scalar=True) \ No newline at end of file + verify_outputs(mod, mgx_mod, [inp]) \ No newline at end of file