Skip to content

Commit

Permalink
Gather update (#220)
Browse files Browse the repository at this point in the history
* working gather

* update test cases to cover more cases

* update advanced indexing to use gather instead of gathernd

* handle negative indices

* remove comment
  • Loading branch information
shivadbhavsar authored Jan 10, 2025
1 parent af14cd9 commit fd75e20
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 28 deletions.
114 changes: 100 additions & 14 deletions py/torch_migraphx/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,8 +1193,9 @@ def acc_ops_embedding(mgx_module, node, args, kwargs):
mgx_module.add_instruction(migraphx.op('gather', axis=0),
[weight.instr_ref, inp.instr_ref]))

@migraphx_converter(acc_ops.gather)
def acc_ops_gather(mgx_module, node, args, kwargs):
## MIGraphX cannot optimize gathernd well in some cases
@migraphx_converter(acc_ops.gather, enabled=False)
def acc_ops_gather_legacy(mgx_module, node, args, kwargs):
inp = kwargs['input']
dim = kwargs['dim']
index = kwargs['index']
Expand All @@ -1219,6 +1220,55 @@ def acc_ops_gather(mgx_module, node, args, kwargs):
new_shape = tuple(list(index_lens) + [len(index_lens)])
coords = acc_ops_reshape(mgx_module, node, (), {"input": coords, "shape": new_shape})
return MGXInstruction(mgx_module.add_instruction(migraphx.op('gathernd'), [inp.instr_ref, coords.instr_ref]))


@migraphx_converter(acc_ops.gather)
def acc_ops_gather(mgx_module, node, args, kwargs):
inp = kwargs['input']
dim = kwargs['dim']
idx = kwargs['index']

assert not inp.is_quantized() and not idx.is_quantized()

inp_ref = mgx_module.add_instruction(migraphx.op("contiguous"), [inp.instr_ref])
idx_ref = mgx_module.add_instruction(migraphx.op("contiguous"), [idx.instr_ref])

inp_lens, inp_strides = inp_ref.shape().lens(), inp_ref.shape().strides()
idx_lens, idx_strides = idx_ref.shape().lens(), idx_ref.shape().strides()
idx_dtype = get_arg_dtype(idx.instr_ref)

assert len(idx_lens) == len(inp_lens)
if dim < 0:
dim = len(idx_lens) + dim

base_indices = torch.zeros(idx_lens, dtype=idx_dtype)
for a in range(len(idx_lens)):
if a == dim:
continue

a_shp = [1] * len(inp_lens)
a_shp[a] = inp_lens[a]
a_inds = torch.arange(inp_lens[a]) * inp_strides[a]
a_inds = a_inds.reshape(a_shp).broadcast_to(idx_lens)
base_indices += a_inds

base_indices_lit = mgx_module.add_literal(base_indices.numpy())
dim_stride = mgx_module.add_literal(
torch.tensor(inp_strides[dim], dtype=idx_dtype).numpy())
dim_stride = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=idx_lens), [dim_stride])

dim_indices = mgx_module.add_instruction(migraphx.op("mul"),
[idx_ref, dim_stride])
data_indices = mgx_module.add_instruction(migraphx.op("add"),
[base_indices_lit, dim_indices])

flat_inp = mgx_module.add_instruction(
migraphx.op('reshape', dims=[inp.shape().elements()]), [inp_ref])

return MGXInstruction(
mgx_module.add_instruction(migraphx.op('gather', axis=0),
[flat_inp, data_indices]))


@migraphx_converter(acc_ops.reshape)
Expand Down Expand Up @@ -1759,18 +1809,54 @@ def acc_ops_getitem(mgx_module, node, args, kwargs):
elif num_tensor_dims > 1:
idx_tensors = [idx[ax] for ax in tensor_dims]
idx_tensors = broadcast_tensors(mgx_module, *idx_tensors)
unsq_idx_tensors = []
for t in idx_tensors:
unsq_idx_tensors.append(
mgx_module.add_instruction(migraphx.op('unsqueeze', axes=[-1]),
[t]))
gather_idx = mgx_module.add_instruction(migraphx.op('concat', axis=-1),
unsq_idx_tensors)

out_mgx = mgx_module.add_instruction(migraphx.op('gathernd'),
[out_mgx, gather_idx])

idx_rank = len(gather_idx.shape().lens()) - 1
idx_rank = len(idx_tensors[0].shape().lens())

idx_dtype = get_arg_dtype(idx_tensors[0])
lens = out_mgx.shape().lens()
out_lens = idx_tensors[0].shape().lens() + lens[num_tensor_dims:]
axial_indices = []
for ax, dim in enumerate(lens):
post_dims = len(lens) - len(idx_tensors)
unsq_dims = list(range(-1, -post_dims - 1, -1))
if ax < num_tensor_dims:
ax_idx = idx_tensors[ax]
ax_idx = normalize_neg_indices(mgx_module, ax_idx, dim)
ax_idx = mgx_module.add_instruction(
migraphx.op("unsqueeze", axes=unsq_dims), [ax_idx])
ax_idx = insert_mbroadcast(mgx_module, ax_idx, out_lens)
else:
shp = [1] * len(out_lens)
shp[ax - len(lens)] = dim
ax_idx = torch.arange(dim).reshape(shp).broadcast_to(out_lens)
ax_idx = mgx_module.add_literal(ax_idx.to(idx_dtype).numpy())

axial_indices.append(ax_idx)

out_mgx = mgx_module.add_instruction(
migraphx.op('reshape', dims=[out_mgx.shape().elements()]),
[out_mgx])

## Compute indices for the new flattened tensor
gather_indices = axial_indices[-1]
multiplier = mgx_module.add_literal(torch.tensor(1, dtype=idx_dtype).numpy())
multiplier = insert_mbroadcast(mgx_module, multiplier, out_lens)

for i in range(len(lens)-2, -1, -1):
prev_len = mgx_module.add_literal(
torch.tensor(lens[i+1], dtype=idx_dtype).numpy())
prev_len = insert_mbroadcast(mgx_module, prev_len, multiplier.shape().lens())
multiplier = mgx_module.add_instruction(migraphx.op("mul"),
[multiplier, prev_len])

offset = mgx_module.add_instruction(
migraphx.op("mul"), [axial_indices[i], multiplier])
gather_indices = mgx_module.add_instruction(
migraphx.op("add"), [gather_indices, offset])


out_mgx = mgx_module.add_instruction(migraphx.op('gather', axis=0),
[out_mgx, gather_indices])

offset = num_tensor_dims - idx_rank

# Remove squeezed dimensions from original permutation
Expand Down
22 changes: 22 additions & 0 deletions py/torch_migraphx/fx/converters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ def broadcast_tensors(mgx_module, *tensors):
return outs


def insert_mbroadcast(mgx_module, ins, dims):
return mgx_module.add_instruction(
migraphx.op("multibroadcast", out_lens=dims), [ins])


def normalize_neg_indices(mgx_module, idx_ins, dim_val):
dtype = get_arg_dtype(idx_ins)
# find locations of negative indices
zeros = mgx_module.add_literal(torch.tensor(0, dtype=dtype).numpy())
zeros = insert_mbroadcast(mgx_module, zeros, idx_ins.shape().lens())
neg_idx = mgx_module.add_instruction(migraphx.op('less'), [idx_ins, zeros])

dim_size = mgx_module.add_literal(
torch.tensor(dim_val, dtype=dtype).numpy())
dim_size = insert_mbroadcast(mgx_module, dim_size, idx_ins.shape().lens())
offset_idx = mgx_module.add_instruction(migraphx.op('add'),
[idx_ins, dim_size])

return mgx_module.add_instruction(migraphx.op('where'),
[neg_idx, offset_idx, idx_ins])


def get_arg_dtype(arg):
if isinstance(arg, migraphx.instruction_ref):
dtype = torch_dtype_from_mgx(arg.shape().type_string())
Expand Down
16 changes: 8 additions & 8 deletions tests/dynamo/converters/test_gather_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ def forward(self, x, idx):


@pytest.mark.parametrize('op_alias', [torch.ops.aten.gather.default])
@pytest.mark.parametrize("input_dim, dim", [((3, 3), 0),
((3, 3), 1),
((3, 3), -1),
((3, 3), -2),
((10, 5), -2),
((2, 3, 4, 5, 6), -3),
((2, 3, 4, 5, 6), -4)])
@pytest.mark.parametrize("input_dim, dim", [((3, 2), 0),
((3, 2, 4), 1),
((1, 1), -1),
((2, 4, 6), -2),
((4, 2), -2),
((2, 3, 4, 1, 3), -3),
((3, 3, 2, 4, 5), -4)])
def test_gather(op_alias, input_dim, dim):
input = torch.rand(input_dim).cuda()

dim_size = input.size(dim)
index_shape = list(input.size())
index_shape[dim] = np.random.randint(1, dim_size)
index_shape[dim] = np.random.randint(1, dim_size*2)
index = torch.randint(0, dim_size, index_shape).cuda()

mod = GatherModule(op_alias, dim).cuda()
Expand Down
2 changes: 1 addition & 1 deletion tests/fx/converters/test_gather_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_gather(input_dim, dim):

dim_size = input.size(dim)
index_shape = list(input.size())
index_shape[dim] = np.random.randint(1, dim_size)
index_shape[dim] = np.random.randint(1, dim_size*2)
index = torch.randint(0, dim_size, index_shape)

mod = GatherModule(dim)
Expand Down
6 changes: 1 addition & 5 deletions tests/fx/converters/test_loss_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def test_nll_loss_fx(inp_size, no_weight, reduction, ignore_index):

@pytest.mark.parametrize('reduction', [('mean'), ('sum'), ('none')])
@pytest.mark.parametrize('C, no_weight, target, ignore_index', [
(3, True, 0, 0),
(3, False, 1, -100),
(3, True, 2, 1),
])
Expand All @@ -40,7 +39,4 @@ def test_nll_loss_1d_fx(C, no_weight, reduction, target, ignore_index):
mod = FuncModule(torch.nn.functional.nll_loss, target=target, weight=weight,
reduction = reduction, ignore_index = ignore_index)
mgx_mod = convert_to_mgx(mod, [inp])
# Output is nan when ignore_idx == target (div by 0)
# MIGraphX creates a kernel that ends up outputting a tensor of len 1 instead of a scalar
# TODO: fused kernels in migraphx should respect the original output shape
verify_outputs(mod, mgx_mod, [inp], equal_nan=True, scalar=True)
verify_outputs(mod, mgx_mod, [inp])

0 comments on commit fd75e20

Please sign in to comment.