-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upsample #115
Upsample #115
Changes from 12 commits
fa50d1d
19b048a
8f519c0
99bbe85
600c754
aa5ebec
43d58d8
5d53f8b
5bc36b3
9921747
22ab3e5
b4c1d47
07d06f6
ed9da8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3645,3 +3645,20 @@ class SizeAverageMatcher(BaseMatcher): | |
def generate_code(self, kwargs): | ||
process_reduce_and_size_average(kwargs) | ||
return GenericMatcher.generate_code(self, kwargs) | ||
|
||
|
||
class UtilsCppExtensionMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
new_kwargs = {} | ||
for k in kwargs.keys(): | ||
if "name" in k: | ||
continue | ||
new_kwargs[k] = kwargs[k] | ||
return GenericMatcher.generate_code(self, new_kwargs) | ||
|
||
|
||
class TensorIsSpareMatcher(BaseMatcher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以起个公用一些的名字,这样其他人也可以复用了 就叫Attribute2Func吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
def get_paddle_class_attribute_nodes(self, node): | ||
self.parse_func(node) | ||
code = "{}()".format(self.paddle_api) | ||
return ast.parse(code).body[0].value |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.Tensor.is_sparse") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
a = torch.tensor([[ 0.9254, -0.6213]]) | ||
result = a.is_sparse | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
test_case_1() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个需要删掉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.nn.Upsample") | ||
|
||
|
||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]], | ||
|
||
[[ 0.1024, -0.4482, 0.4137], | ||
[ 0.9385, 0.4565, 0.7702], | ||
[ 0.4135, -0.2587, 0.0482]]]]) | ||
m = torch.nn.Upsample(scale_factor=2, mode='nearest') | ||
result = m(input) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_2(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]], | ||
|
||
[[ 0.1024, -0.4482, 0.4137], | ||
[ 0.9385, 0.4565, 0.7702], | ||
[ 0.4135, -0.2587, 0.0482]]]]) | ||
m = torch.nn.Upsample(scale_factor=2, mode='bilinear') | ||
result = m(input) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]], | ||
|
||
[[ 0.1024, -0.4482, 0.4137], | ||
[ 0.9385, 0.4565, 0.7702], | ||
[ 0.4135, -0.2587, 0.0482]]]]) | ||
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',align_corners=True) | ||
result = m(input) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_4(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]], | ||
|
||
[[ 0.1024, -0.4482, 0.4137], | ||
[ 0.9385, 0.4565, 0.7702], | ||
[ 0.4135, -0.2587, 0.0482]]]]) | ||
m = torch.nn.Upsample(size=(2,2)) | ||
result = m(input) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_5(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]], | ||
|
||
[[ 0.1024, -0.4482, 0.4137], | ||
[ 0.9385, 0.4565, 0.7702], | ||
[ 0.4135, -0.2587, 0.0482]]]]) | ||
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',align_corners=False) | ||
result = m(input) | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_6(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
import torch | ||
input = torch.tensor([[[[ 1.1524, 0.4714, 0.2857], | ||
[-1.2533, -0.9829, -1.0981], | ||
[ 0.1507, -1.1431, -2.0361]], | ||
|
||
[[ 0.1024, -0.4482, 0.4137], | ||
[ 0.9385, 0.4565, 0.7702], | ||
[ 0.4135, -0.2587, 0.0482]]]]) | ||
m = torch.nn.Upsample(scale_factor=2, mode='bilinear',recompute_scale_factor=True) | ||
result = m(input) | ||
""" | ||
) | ||
obj.run(pytorch_code, unsupport=True, reason="paddle unsupport") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个原因要写具体一点 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.utils.cpp_extension.CUDAExtension") | ||
|
||
|
||
# The cuda compile not supports | ||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
from torch.utils.cpp_extension import CUDAExtension | ||
|
||
CUDAExtension( | ||
name='cuda_extension', | ||
sources=['extension.cpp', 'extension_kernel.cu'], | ||
extra_compile_args={'cxx': ['-g'], | ||
'nvcc': ['-O2']}) | ||
result = True | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import textwrap | ||
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.utils.cpp_extension.CppExtension") | ||
|
||
|
||
# The cpp compile not supports | ||
def test_case_1(): | ||
pytorch_code = textwrap.dedent( | ||
""" | ||
from torch.utils.cpp_extension import CppExtension | ||
|
||
CppExtension( | ||
name='cuda_extension', | ||
sources=['extension.cpp'], | ||
extra_compile_args=['-g']) | ||
result = True | ||
""" | ||
) | ||
obj.run(pytorch_code, ["result"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个直接用genericamatcher,如果你要删掉某个参数,可设置kwargs_change: name: ""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done