-
Notifications
You must be signed in to change notification settings - Fork 647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch ops for d2go models #1509
Open
dncnbuck
wants to merge
59
commits into
apple:main
Choose a base branch
from
dncnbuck:add-torch-ops-for-d2go-models
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
c5fe58a
handle-split-op-when-num-splits-1
dncnbuck 5542de8
handle when unpacked tuple contains only single value
dncnbuck fdd1590
roi_align implimentation
dncnbuck 7b4cbd9
add torch op numel
dncnbuck b002207
add torch op nms
dncnbuck 6389427
add torch op repeat_interleave
dncnbuck cecae9c
add torch op narrow
dncnbuck 15185d8
add torch op logicaland
dncnbuck e1b7d0f
handle broadcasting indicies for torch index op
dncnbuck 70f1954
patch torch clamp op to handle int dtype
dncnbuck 9d2d092
return copy of inpt tensor if no dtype is given
dncnbuck b913630
remove accidential typo
dncnbuck bd08a2b
Merge branch 'main' into add-torch-ops-for-d2go-models
dncnbuck b0074cc
remove logicaland op and alias new logical_and op
dncnbuck a9fb7ed
consistent use of double quotes
dncnbuck 29217d5
remove link to crop and resize layer in NN
dncnbuck fb0cd19
Merge branch 'main' into add-torch-ops-for-d2go-models
dncnbuck 12662dd
Merge branch 'main' into add-torch-ops-for-d2go-models
dncnbuck b268f9b
6.0b1 Release (#1508)
TobyRoseman c58abbd
Add 6.0b1 install instructions to README.md (#1510)
TobyRoseman df41d90
Update README.md (#1511)
ArjunSharda 573f103
remove logicaland op and alias new logical_and op
dncnbuck 8834011
consistent use of double quotes
dncnbuck 12b3cc1
remove link to crop and resize layer in NN
dncnbuck bdcfe40
Docs for v6 with layer_norm fix (#1514)
tonybove-apple 4508f19
Update ---bug-report.md (#1513)
ArjunSharda 7944178
Fix a bug when destructing coreml model (#1515)
jakesabathia2 3356450
Formatting fixes and compression submenu (#1518)
tonybove-apple 203b555
Update CONTRIBUTING.md (#1521)
ArjunSharda 01983e6
Add torch AdaptiveAvgPool2d test. (#1502)
fukatani f181995
Update BUILDING.md (#1523)
ArjunSharda 9715d07
Update ---feature-request.md (change of wording mostly) (#1524)
ArjunSharda d13735d
Torch eq and ne ops supports bool type. (#1501)
fukatani 7ce9f6e
Merge branch 'add-torch-ops-for-d2go-models' of https://github.com/dn…
dncnbuck 5d842ec
accept incoming changes
dncnbuck 4353c4c
Add tests for numel and narrow
dncnbuck ed2f33e
Add tests for torch.op.nms
dncnbuck bf5de6b
tidy up
dncnbuck b2e8153
tidy up
dncnbuck 20da0e2
handle-split-op-when-num-splits-1
dncnbuck ca4cd92
handle when unpacked tuple contains only single value
dncnbuck c80a3a7
handle broadcasting indicies for torch index op
dncnbuck 8631d1b
patch torch clamp op to handle int dtype
dncnbuck 2f05538
return copy of inpt tensor if no dtype is given
dncnbuck ed02c4d
remove accidential typo
dncnbuck ec550ca
Docs for v6 with layer_norm fix (#1514)
tonybove-apple 78ab5fd
Update ---bug-report.md (#1513)
ArjunSharda f8e1776
Fix a bug when destructing coreml model (#1515)
jakesabathia2 d96b7d6
Formatting fixes and compression submenu (#1518)
tonybove-apple c082d4c
Update CONTRIBUTING.md (#1521)
ArjunSharda 108f5da
Add torch AdaptiveAvgPool2d test. (#1502)
fukatani 47debd3
Update BUILDING.md (#1523)
ArjunSharda e1aaf57
Update ---feature-request.md (change of wording mostly) (#1524)
ArjunSharda 1f29b6a
Torch eq and ne ops supports bool type. (#1501)
fukatani f25a684
Add tests for numel and narrow
dncnbuck f2f795b
Add tests for torch.op.nms
dncnbuck 9be029f
tidy up
dncnbuck 37eef0e
resolve conflict
dncnbuck 9e842a2
some code clean up
dncnbuck File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2598,6 +2598,10 @@ def upsample_nearest2d(context, node): | |
def tupleunpack(context, node): | ||
inputs = _get_inputs(context, node, expected=1) | ||
values = inputs[0] | ||
|
||
if len(node.outputs) == 1: | ||
values = [values] | ||
|
||
# Node input could have been turned into constant array in @tupleconstruct | ||
if not isinstance(values, tuple) and not isinstance(values, list): | ||
values = values.val | ||
|
@@ -3097,8 +3101,11 @@ def index(context, node): | |
# For multiple index axes case, we now assume that all the index have equal shape | ||
for index in valid_indices: | ||
if not is_compatible_symbolic_vector(index.shape, valid_indices[0].shape): | ||
raise NotImplementedError("Broadcasable tensor index not supported.") | ||
|
||
broadcast_inputs = _broadcast_tensors([valid_indices[0], index]) | ||
index = broadcast_inputs[1] | ||
valid_indices[0] = broadcast_inputs[0] | ||
valid_indices.append(index) | ||
|
||
# First stack the index together | ||
indices_rank = valid_indices[0].rank | ||
indices = mb.stack(values=valid_indices, axis=indices_rank) | ||
|
@@ -3398,6 +3405,18 @@ def _slice(context, node): | |
context.add(res) | ||
|
||
|
||
def _num_splits_and_sizes(split_sizes): | ||
if split_sizes.sym_val is not None: | ||
return len(split_sizes.sym_val), split_sizes.sym_val | ||
|
||
if any_symbolic(split_sizes.shape): | ||
raise ValueError("Unable to determine number of splits") | ||
|
||
num_splits = len(split_sizes.shape) | ||
sizes = [get_new_symbol() for _ in range(num_splits)] | ||
return num_splits, sizes | ||
|
||
|
||
@register_torch_op(torch_alias=["split_with_sizes"]) | ||
def split(context, node): | ||
inputs = _get_inputs(context, node, expected=3) | ||
|
@@ -3425,6 +3444,14 @@ def split(context, node): | |
else: | ||
partial_size = mb.mul(x=tmp, y=remainder) | ||
split_sizes = mb.concat(values=[whole_sizes, partial_size], axis=0) | ||
|
||
|
||
num_splits, sizes = _num_splits_and_sizes(split_sizes=split_sizes) | ||
if num_splits == 1: | ||
out = mb.identity(x=x, name=node.name) | ||
context.add(out, node.name) | ||
return | ||
|
||
res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name) | ||
context.add(res, torch_name=node.name) | ||
|
||
|
@@ -3482,6 +3509,13 @@ def to(context, node): | |
"Received invalid arguments for PyTorch conversion of op {}".format(node) | ||
) | ||
|
||
# We have to handle the case where the dtype is not set, this should be inferred from the Tensor dtype | ||
# see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to | ||
if dtype is None: | ||
out = mb.identity(x=_input, name=node.name) | ||
context.add(out, node.name) | ||
return | ||
|
||
torch_dtype = NUM_TO_TORCH_DTYPE[dtype] | ||
if isinstance(_input, Var) and _input.val is not None: | ||
_input = _input.val | ||
|
@@ -3924,8 +3958,20 @@ def ceil(context, node): | |
@register_torch_op | ||
def clamp(context, node): | ||
inputs = _get_inputs(context, node, expected=3) | ||
min_val = inputs[1] if inputs[1] else _np.finfo(_np.float32).min | ||
max_val = inputs[2] if inputs[2] else _np.finfo(_np.float32).max | ||
if not inputs[1]: | ||
min_val = _np.finfo(_np.float32).min | ||
else: | ||
min_val = inputs[1] | ||
if types.builtin_to_string(min_val.dtype).startswith('int'): | ||
min_val = mb.cast(x=min_val, dtype='fp32') | ||
|
||
if not inputs[2]: | ||
max_val = _np.finfo(_np.float32).max | ||
else: | ||
max_val = inputs[2] | ||
if types.builtin_to_string(max_val.dtype).startswith('int'): | ||
max_val = mb.cast(x=max_val, dtype='fp32') | ||
|
||
context.add(mb.clip(x=inputs[0], alpha=min_val, beta=max_val, name=node.name)) | ||
|
||
@register_torch_op | ||
|
@@ -4074,7 +4120,7 @@ def is_floating_point(context, node): | |
is_float = types.is_float(inputs[0].dtype) | ||
context.add(mb.const(val=is_float, name=node.name)) | ||
|
||
@register_torch_op() | ||
@register_torch_op(torch_alias=["__and_", "__and__"]) | ||
def logical_and(context, node): | ||
inputs = _get_inputs(context, node, expected=2) | ||
x, y = inputs | ||
|
@@ -4253,6 +4299,11 @@ def _make_tensor(list_of_tensor, name, rank): | |
context.add(mb.identity(x=val, name=node.name)) | ||
return | ||
|
||
if inputs[2] is None: | ||
res = mb.const(val=[val.val], name=node.name) | ||
context.add(res, torch_name=node.name) | ||
return | ||
|
||
# Case 2: Create a tensor filled with a single value | ||
val = val.val # element val to fill | ||
msg_prefix = 'torch::tensor {} '.format(node.name) | ||
|
@@ -4483,7 +4534,6 @@ def _scatter(context, inputs, mode, name): | |
axis=axis, mode=mode, name=name) | ||
context.add(result) | ||
|
||
|
||
@register_torch_op | ||
def scatter(context, node): | ||
inputs = _get_inputs(context, node) | ||
|
@@ -4501,8 +4551,106 @@ def scatter(context, node): | |
|
||
_scatter(context, inputs, mode, node.name) | ||
|
||
|
||
@register_torch_op | ||
def scatter_add(context, node): | ||
inputs = _get_inputs(context, node) | ||
_scatter(context, inputs, 'add', node.name) | ||
|
||
@register_torch_op | ||
def roi_align(context, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there unit tests for this method? |
||
inputs = _get_inputs(context, node) | ||
|
||
x = context[node.inputs[0]] | ||
input_shape = x.shape # (B, h_in, w_in, C) | ||
if len(input_shape) != 4: | ||
raise ValueError( | ||
'"CropResize" op: expected input rank 4, got {}'.format(x.rank) | ||
) | ||
|
||
const_box_info = True | ||
if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None: | ||
const_box_info = False | ||
|
||
extrapolation_value = context[node.inputs[2]].val | ||
|
||
# CoreML index information along with boxes | ||
if const_box_info: | ||
boxes = context[node.inputs[1]].val | ||
# CoreML expects boxes/ROI in | ||
# [N, 1, 5, 1, 1] format | ||
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1) | ||
else: | ||
boxes = inputs[1] | ||
boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1]) | ||
# Get Height and Width of crop | ||
h_out = inputs[3] | ||
w_out = inputs[4] | ||
|
||
# Torch input format: [B, C, h_in, w_in] | ||
# CoreML input format: [B, C, h_in, w_in] | ||
|
||
# Crop Resize | ||
x = mb.crop_resize( | ||
x=x, | ||
roi=boxes, | ||
target_height=h_out.val, | ||
target_width=w_out.val, | ||
normalized_coordinates=True, | ||
spatial_scale=extrapolation_value, | ||
box_coordinate_mode="CORNERS_HEIGHT_FIRST", | ||
sampling_mode='OFFSET_CORNERS', | ||
) | ||
|
||
# CoreML output format: [N, 1, C, h_out, w_out] | ||
# Torch output format: [N, C, h_out, w_out] | ||
x = mb.squeeze(x=x, axes=[1]) | ||
|
||
context.add(x, torch_name=node.outputs[0]) | ||
|
||
@register_torch_op | ||
def numel(context, node): | ||
inputs = _get_inputs(context, node, expected=1) | ||
context.add(mb.reduce_prod(x=inputs[0], name=node.name), torch_name=node.outputs[0]) | ||
|
||
@register_torch_op | ||
def nms(context, node): | ||
inputs = _get_inputs(context, node) | ||
boxes = inputs[0] | ||
|
||
num_boxes = boxes.shape[0] | ||
max_boxes = num_boxes # we set the max_boxes just to be # input boxes | ||
|
||
scores = inputs[1] | ||
iou_threshold = inputs[2] | ||
boxes = mb.expand_dims(x=boxes, axes=[0]) | ||
scores = mb.expand_dims(x=scores, axes=[0, -1]) | ||
|
||
# Follow tensorflow op example: TensorFlow's default value for score_threshold, Core ML does not | ||
# have float('-inf') support, converted to minimum float32 instead | ||
score_threshold = -3.4e38 | ||
|
||
_, _, x, _ = mb.non_maximum_suppression( | ||
boxes=boxes, | ||
scores=scores, | ||
iou_threshold=iou_threshold, | ||
score_threshold=score_threshold, | ||
max_boxes=max_boxes | ||
) | ||
|
||
if not is_symbolic(num_boxes): | ||
x = mb.squeeze(x=x, axes=[0]) | ||
x = mb.slice_by_index(x=x, begin=[0], end=[max_boxes], name=node.name) | ||
else: | ||
x = mb.squeeze(x=x, axes=[0], name=node.name) | ||
context.add(x, torch_name=node.name) | ||
|
||
@register_torch_op | ||
def narrow(context, node): | ||
data, dim, start, length = _get_inputs(context, node, expected=4) | ||
data_shape = mb.shape(x=data).val | ||
begin = [0]*len(data_shape) | ||
end = [x for x in data_shape] | ||
begin[dim.val] = start.val | ||
end[dim.val] = start.val+length.val | ||
out = mb.slice_by_index(x=data, begin=begin, end=end) | ||
context.add(out, torch_name=node.name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this just be an inner method of the
split
method?