Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch ops for d2go models #1509

Open
wants to merge 59 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
c5fe58a
handle-split-op-when-num-splits-1
dncnbuck Jun 7, 2022
5542de8
handle when unpacked tuple contains only single value
dncnbuck Jun 7, 2022
fdd1590
roi_align implimentation
dncnbuck Jun 7, 2022
7b4cbd9
add torch op numel
dncnbuck Jun 7, 2022
b002207
add torch op nms
dncnbuck Jun 7, 2022
6389427
add torch op repeat_interleave
dncnbuck Jun 7, 2022
cecae9c
add torch op narrow
dncnbuck Jun 7, 2022
15185d8
add torch op logicaland
dncnbuck Jun 7, 2022
e1b7d0f
handle broadcasting indicies for torch index op
dncnbuck Jun 7, 2022
70f1954
patch torch clamp op to handle int dtype
dncnbuck Jun 7, 2022
9d2d092
return copy of inpt tensor if no dtype is given
dncnbuck Jun 7, 2022
b913630
remove accidential typo
dncnbuck Jun 8, 2022
bd08a2b
Merge branch 'main' into add-torch-ops-for-d2go-models
dncnbuck Jun 8, 2022
b0074cc
remove logicaland op and alias new logical_and op
dncnbuck Jun 8, 2022
a9fb7ed
consistent use of double quotes
dncnbuck Jun 8, 2022
29217d5
remove link to crop and resize layer in NN
dncnbuck Jun 8, 2022
fb0cd19
Merge branch 'main' into add-torch-ops-for-d2go-models
dncnbuck Jun 16, 2022
12662dd
Merge branch 'main' into add-torch-ops-for-d2go-models
dncnbuck Jun 20, 2022
b268f9b
6.0b1 Release (#1508)
TobyRoseman Jun 7, 2022
c58abbd
Add 6.0b1 install instructions to README.md (#1510)
TobyRoseman Jun 7, 2022
df41d90
Update README.md (#1511)
ArjunSharda Jun 7, 2022
573f103
remove logicaland op and alias new logical_and op
dncnbuck Jun 8, 2022
8834011
consistent use of double quotes
dncnbuck Jun 8, 2022
12b3cc1
remove link to crop and resize layer in NN
dncnbuck Jun 8, 2022
bdcfe40
Docs for v6 with layer_norm fix (#1514)
tonybove-apple Jun 8, 2022
4508f19
Update ---bug-report.md (#1513)
ArjunSharda Jun 8, 2022
7944178
Fix a bug when destructing coreml model (#1515)
jakesabathia2 Jun 9, 2022
3356450
Formatting fixes and compression submenu (#1518)
tonybove-apple Jun 9, 2022
203b555
Update CONTRIBUTING.md (#1521)
ArjunSharda Jun 13, 2022
01983e6
Add torch AdaptiveAvgPool2d test. (#1502)
fukatani Jun 13, 2022
f181995
Update BUILDING.md (#1523)
ArjunSharda Jun 13, 2022
9715d07
Update ---feature-request.md (change of wording mostly) (#1524)
ArjunSharda Jun 13, 2022
d13735d
Torch eq and ne ops supports bool type. (#1501)
fukatani Jun 16, 2022
7ce9f6e
Merge branch 'add-torch-ops-for-d2go-models' of https://github.com/dn…
dncnbuck Jul 5, 2022
5d842ec
accept incoming changes
dncnbuck Jul 5, 2022
4353c4c
Add tests for numel and narrow
dncnbuck Jul 6, 2022
ed2f33e
Add tests for torch.op.nms
dncnbuck Jul 6, 2022
bf5de6b
tidy up
dncnbuck Jul 6, 2022
b2e8153
tidy up
dncnbuck Sep 3, 2022
20da0e2
handle-split-op-when-num-splits-1
dncnbuck Jun 7, 2022
ca4cd92
handle when unpacked tuple contains only single value
dncnbuck Jun 7, 2022
c80a3a7
handle broadcasting indicies for torch index op
dncnbuck Jun 7, 2022
8631d1b
patch torch clamp op to handle int dtype
dncnbuck Jun 7, 2022
2f05538
return copy of inpt tensor if no dtype is given
dncnbuck Jun 7, 2022
ed02c4d
remove accidential typo
dncnbuck Jun 8, 2022
ec550ca
Docs for v6 with layer_norm fix (#1514)
tonybove-apple Jun 8, 2022
78ab5fd
Update ---bug-report.md (#1513)
ArjunSharda Jun 8, 2022
f8e1776
Fix a bug when destructing coreml model (#1515)
jakesabathia2 Jun 9, 2022
d96b7d6
Formatting fixes and compression submenu (#1518)
tonybove-apple Jun 9, 2022
c082d4c
Update CONTRIBUTING.md (#1521)
ArjunSharda Jun 13, 2022
108f5da
Add torch AdaptiveAvgPool2d test. (#1502)
fukatani Jun 13, 2022
47debd3
Update BUILDING.md (#1523)
ArjunSharda Jun 13, 2022
e1aaf57
Update ---feature-request.md (change of wording mostly) (#1524)
ArjunSharda Jun 13, 2022
1f29b6a
Torch eq and ne ops supports bool type. (#1501)
fukatani Jun 16, 2022
f25a684
Add tests for numel and narrow
dncnbuck Jul 6, 2022
f2f795b
Add tests for torch.op.nms
dncnbuck Jul 6, 2022
9be029f
tidy up
dncnbuck Jul 6, 2022
37eef0e
resolve conflict
dncnbuck Sep 3, 2022
9e842a2
some code clean up
dncnbuck Sep 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 155 additions & 7 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2598,6 +2598,10 @@ def upsample_nearest2d(context, node):
def tupleunpack(context, node):
inputs = _get_inputs(context, node, expected=1)
values = inputs[0]

if len(node.outputs) == 1:
values = [values]

# Node input could have been turned into constant array in @tupleconstruct
if not isinstance(values, tuple) and not isinstance(values, list):
values = values.val
Expand Down Expand Up @@ -3097,8 +3101,11 @@ def index(context, node):
# For multiple index axes case, we now assume that all the index have equal shape
for index in valid_indices:
if not is_compatible_symbolic_vector(index.shape, valid_indices[0].shape):
raise NotImplementedError("Broadcasable tensor index not supported.")

broadcast_inputs = _broadcast_tensors([valid_indices[0], index])
index = broadcast_inputs[1]
valid_indices[0] = broadcast_inputs[0]
valid_indices.append(index)

# First stack the index together
indices_rank = valid_indices[0].rank
indices = mb.stack(values=valid_indices, axis=indices_rank)
Expand Down Expand Up @@ -3398,6 +3405,18 @@ def _slice(context, node):
context.add(res)


def _num_splits_and_sizes(split_sizes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this just be an inner method of the split method?

if split_sizes.sym_val is not None:
return len(split_sizes.sym_val), split_sizes.sym_val

if any_symbolic(split_sizes.shape):
raise ValueError("Unable to determine number of splits")

num_splits = len(split_sizes.shape)
sizes = [get_new_symbol() for _ in range(num_splits)]
return num_splits, sizes


@register_torch_op(torch_alias=["split_with_sizes"])
def split(context, node):
inputs = _get_inputs(context, node, expected=3)
Expand Down Expand Up @@ -3425,6 +3444,14 @@ def split(context, node):
else:
partial_size = mb.mul(x=tmp, y=remainder)
split_sizes = mb.concat(values=[whole_sizes, partial_size], axis=0)


num_splits, sizes = _num_splits_and_sizes(split_sizes=split_sizes)
if num_splits == 1:
out = mb.identity(x=x, name=node.name)
context.add(out, node.name)
return

res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name)
context.add(res, torch_name=node.name)

Expand Down Expand Up @@ -3482,6 +3509,13 @@ def to(context, node):
"Received invalid arguments for PyTorch conversion of op {}".format(node)
)

# We have to handle the case where the dtype is not set, this should be inferred from the Tensor dtype
# see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to
if dtype is None:
out = mb.identity(x=_input, name=node.name)
context.add(out, node.name)
return

torch_dtype = NUM_TO_TORCH_DTYPE[dtype]
if isinstance(_input, Var) and _input.val is not None:
_input = _input.val
Expand Down Expand Up @@ -3924,8 +3958,20 @@ def ceil(context, node):
@register_torch_op
def clamp(context, node):
inputs = _get_inputs(context, node, expected=3)
min_val = inputs[1] if inputs[1] else _np.finfo(_np.float32).min
max_val = inputs[2] if inputs[2] else _np.finfo(_np.float32).max
if not inputs[1]:
min_val = _np.finfo(_np.float32).min
else:
min_val = inputs[1]
if types.builtin_to_string(min_val.dtype).startswith('int'):
min_val = mb.cast(x=min_val, dtype='fp32')

if not inputs[2]:
max_val = _np.finfo(_np.float32).max
else:
max_val = inputs[2]
if types.builtin_to_string(max_val.dtype).startswith('int'):
max_val = mb.cast(x=max_val, dtype='fp32')

context.add(mb.clip(x=inputs[0], alpha=min_val, beta=max_val, name=node.name))

@register_torch_op
Expand Down Expand Up @@ -4074,7 +4120,7 @@ def is_floating_point(context, node):
is_float = types.is_float(inputs[0].dtype)
context.add(mb.const(val=is_float, name=node.name))

@register_torch_op()
@register_torch_op(torch_alias=["__and_", "__and__"])
def logical_and(context, node):
inputs = _get_inputs(context, node, expected=2)
x, y = inputs
Expand Down Expand Up @@ -4253,6 +4299,11 @@ def _make_tensor(list_of_tensor, name, rank):
context.add(mb.identity(x=val, name=node.name))
return

if inputs[2] is None:
res = mb.const(val=[val.val], name=node.name)
context.add(res, torch_name=node.name)
return

# Case 2: Create a tensor filled with a single value
val = val.val # element val to fill
msg_prefix = 'torch::tensor {} '.format(node.name)
Expand Down Expand Up @@ -4483,7 +4534,6 @@ def _scatter(context, inputs, mode, name):
axis=axis, mode=mode, name=name)
context.add(result)


@register_torch_op
def scatter(context, node):
inputs = _get_inputs(context, node)
Expand All @@ -4501,8 +4551,106 @@ def scatter(context, node):

_scatter(context, inputs, mode, node.name)


@register_torch_op
def scatter_add(context, node):
inputs = _get_inputs(context, node)
_scatter(context, inputs, 'add', node.name)

@register_torch_op
def roi_align(context, node):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there unit tests for this method?

inputs = _get_inputs(context, node)

x = context[node.inputs[0]]
input_shape = x.shape # (B, h_in, w_in, C)
if len(input_shape) != 4:
raise ValueError(
'"CropResize" op: expected input rank 4, got {}'.format(x.rank)
)

const_box_info = True
if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None:
const_box_info = False

extrapolation_value = context[node.inputs[2]].val

# CoreML index information along with boxes
if const_box_info:
boxes = context[node.inputs[1]].val
# CoreML expects boxes/ROI in
# [N, 1, 5, 1, 1] format
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
else:
boxes = inputs[1]
boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
# Get Height and Width of crop
h_out = inputs[3]
w_out = inputs[4]

# Torch input format: [B, C, h_in, w_in]
# CoreML input format: [B, C, h_in, w_in]

# Crop Resize
x = mb.crop_resize(
x=x,
roi=boxes,
target_height=h_out.val,
target_width=w_out.val,
normalized_coordinates=True,
spatial_scale=extrapolation_value,
box_coordinate_mode="CORNERS_HEIGHT_FIRST",
sampling_mode='OFFSET_CORNERS',
)

# CoreML output format: [N, 1, C, h_out, w_out]
# Torch output format: [N, C, h_out, w_out]
x = mb.squeeze(x=x, axes=[1])

context.add(x, torch_name=node.outputs[0])

@register_torch_op
def numel(context, node):
inputs = _get_inputs(context, node, expected=1)
context.add(mb.reduce_prod(x=inputs[0], name=node.name), torch_name=node.outputs[0])

@register_torch_op
def nms(context, node):
inputs = _get_inputs(context, node)
boxes = inputs[0]

num_boxes = boxes.shape[0]
max_boxes = num_boxes # we set the max_boxes just to be # input boxes

scores = inputs[1]
iou_threshold = inputs[2]
boxes = mb.expand_dims(x=boxes, axes=[0])
scores = mb.expand_dims(x=scores, axes=[0, -1])

# Follow tensorflow op example: TensorFlow's default value for score_threshold, Core ML does not
# have float('-inf') support, converted to minimum float32 instead
score_threshold = -3.4e38

_, _, x, _ = mb.non_maximum_suppression(
boxes=boxes,
scores=scores,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
max_boxes=max_boxes
)

if not is_symbolic(num_boxes):
x = mb.squeeze(x=x, axes=[0])
x = mb.slice_by_index(x=x, begin=[0], end=[max_boxes], name=node.name)
else:
x = mb.squeeze(x=x, axes=[0], name=node.name)
context.add(x, torch_name=node.name)

@register_torch_op
def narrow(context, node):
data, dim, start, length = _get_inputs(context, node, expected=4)
data_shape = mb.shape(x=data).val
begin = [0]*len(data_shape)
end = [x for x in data_shape]
begin[dim.val] = start.val
end[dim.val] = start.val+length.val
out = mb.slice_by_index(x=data, begin=begin, end=end)
context.add(out, torch_name=node.name)
75 changes: 75 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pytest
import torch.nn as nn

import torchvision

from .testing_utils import (
contains_op,
generate_input_data,
Expand Down Expand Up @@ -4564,3 +4566,76 @@ def forward(self, x):
backend=backend,
converter_input_type=converter_input_type,
)

class TestNumel(TorchBaseTest):
@pytest.mark.parametrize(
"shapes, backend",
itertools.product(
[
[(2, 1)],
[(5, 1, 4, 1)],
[(1,)],
],
backends
),
)
def test_numel(self, shapes, backend):
class Model(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
v = torch.numel(x)
return torch.tensor(v)

model = Model()
self.run_compare_torch(shapes, model, backend=backend)


class TestNarrow(TorchBaseTest):
@pytest.mark.parametrize(
"shapes, dim_start_length, backend",
itertools.product(
[
[(3, 3)],
],
[
(0, 0, 2)
]
,
backends
),
)
def test_narrow(self, shapes, dim_start_length, backend):
dim, start, length = dim_start_length
class Model(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.narrow(x, dim, start, length)

model = Model()
self.run_compare_torch(shapes, model, backend=backend)


class TestNonMaximalSuppression(TorchBaseTest):
@pytest.mark.parametrize(
"shapes, scores, backend",
itertools.product(
[[(2, 4)]],
[(2,)],
backends
),
)
def test_non_maximal_supression(self, shapes, scores, backend):
scores = torch.rand(scores)
class Model(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torchvision.ops.nms(x, scores, iou_threshold=0.7)

model = Model()
self.run_compare_torch(shapes, model, backend=backend)