-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aten roi align #209
base: master
Are you sure you want to change the base?
Aten roi align #209
Conversation
…must also equal K. Converter still does not work for aligned=True
…ap which demonstrates corrupted input.
…t output height and width (original tests were symmetrical). Still failing some tests
# batch_indices = batch_indices2 | ||
elif boxes_ref.shape().lens()[1] == 4: | ||
# batch_indices3=range(boxes_ref.shape().lens()[0]) | ||
# boxes2 = boxes_ref |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up comments please, this just looks like unused debug code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below; some of this code is necessary for this case if we decide to support it. If you pass a list of tensors, you must supply default indices [0, 1, 2...]. If we put it off for the future, I can make it into a tidier looking TODO comment.
py/torch_migraphx/fx/mgx_module.py
Outdated
@@ -100,7 +100,6 @@ def _initialize(self): | |||
if not self.program.is_compiled(): | |||
if self.quantize_fp16: | |||
migraphx.quantize_fp16(self.program) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
revert
# batch_indices3=range(boxes_ref.shape().lens()[0]) | ||
# boxes2 = boxes_ref | ||
# This isn't supported at this time because torchvision roi_align.default() doesn't support it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean up debug comments. Also can we change the isnt supported comment to a TODO
comment. Its not supported at this time because we are having trouble consuming a list input in the test cases atm. "torchvision roi_align.default() doesn't support it" is not a valid reason not to support it in the acc converter.
@@ -33,6 +33,7 @@ | |||
|
|||
import torch | |||
from typing import cast, Iterable, List, Sequence | |||
import torchvision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import torchvision | |
try: | |
import torchvision | |
except ImportError: | |
pass |
This is not a mandatory prerequisite for torch_migraphx so wrap it in a try/except
@@ -373,6 +374,27 @@ def clamp(*, input, min=None, max=None): | |||
return torch.clamp(input=input, min=min, max=max) | |||
|
|||
|
|||
@register_acc_op_mapping( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@register_acc_op_mapping( | |
if 'torchvision' in sys.modules: | |
@register_acc_op_mapping( |
wrap the roi_align function definition in this if so that it doesnt blow up if torchvision isnt installed
]), | ||
[7, 6])] | ||
) | ||
def test_roialign(input, boxes, output_size, spatial_scale, sampling_ratio, aligned): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you move this to test_torchvision_models_fx.py
and follow the example there to skip if torchvision is not installed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to test_torchvision_ops_fx.py
[7, 6]) | ||
] | ||
) | ||
def test_roialign(op_alias, input, boxes, output_size, spatial_scale, sampling_ratio, aligned): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the fx test comment, this should be in a test_torchvision_models_dynamo.py
file, and it should mark to skip in the same way if torchvision is not installed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
created test_torchvision_ops_dynamo.py
…erent test files, other cleanup
Add aten and acc converters for the RoiAlign op.
This change is dependent on Migraphx PR 3482. It can't be tested or merged until that branch is merged. (Test locally using MigraphX branch
roialign_fix
As of date of PR creation, the aten converter test
test_roialign_d
is still incomplete.