Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update pytorch api_doc.rst #63

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 208 additions & 2 deletions sources/pytorch/api_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ PyTorch-NPU 除了提供了 PyTorch 官方算子实现之外,也提供了大
:param Int adam_mode: 选择adam模式。0表示“adam”, 1表示“mbert_adam”, 默认值为0

关键字参数:
out (Tensor,可选) - 输出张量。
out (Tensor,可选) - 输出张量。

示例:

Expand Down Expand Up @@ -309,4 +309,210 @@ PyTorch-NPU 除了提供了 PyTorch 官方算子实现之外,也提供了大
>>> output = torch_npu.npu_bounding_box_decode(rois, deltas, 0, 0, 0, 0, 1, 1, 1, 1, (10, 10), 0.1)
>>> output
tensor([[2.5000, 6.5000, 9.0000, 9.0000],
[9.0000, 9.0000, 9.0000, 9.0000]], device='npu:0')
[9.0000, 9.0000, 9.0000, 9.0000]], device='npu:0')

.. py:function:: npu_broadcast(self, size) -> Tensor
:module: torch_npu

返回self张量的新视图,其单维度扩展,结果连续。

:param Tensor self: 输入张量。
:param ListInt size: 对应扩展尺寸。

:rtype: Tensor

示例:

.. code-block:: python
:linenos:

>>> x = torch.tensor([[1], [2], [3]]).npu()
>>> x.shape
torch.Size([3, 1])
>>> x.npu_broadcast(3, 4)
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]], device='npu:0')

.. py:function:: npu_ciou(Tensor self, Tensor gtboxes, bool trans=False, bool is_cross=True, int mode=0, bool atan_sub_flag=False) -> Tensor
:module: torch_npu

应用基于NPU的CIoU操作。在DIoU的基础上增加了penalty item,并propose CIoU。

:param Tensor boxes1: 格式为xywh、shape为(4, n)的预测检测框。
:param Tensor boxes2: 相应的gt检测框,shape为(4, n)。
:param Bool trans: 是否有偏移。
:param Bool is_cross: box1和box2之间是否有交叉操作。
:param Int mode: 选择CIoU的计算方式。0表示IoU,1表示IoF。
:param Bool atan_sub_flag:是否将正向的第二个值传递给反向。

:rtype: Tensor

约束说明:
到目前为止,CIoU向后只支持当前版本中的trans==True、is_cross==False、mode==0('iou')。如果需要反向传播,确保参数正确。

示例:

.. code-block:: python
:linenos:

>>> box1 = torch.randn(4, 32).npu()
>>> box1.requires_grad = True
>>> box2 = torch.randn(4, 32).npu()
>>> box2.requires_grad = True
>>> diou = torch_npu.contrib.function.npu_ciou(box1, box2)
>>> l = ciou.sum()
>>> l.backward()

.. py:function:: npu_clear_float_status(self) -> Tensor
:module: torch_npu

在每个核中设置地址0x40000的值为0。

:param Tensor self: 数据类型为float32的张量。

:rtype: Tensor

示例:

.. code-block:: python
:linenos:

>>> x = torch.rand(2).npu()
>>> torch_npu.npu_clear_float_status(x)
tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='npu:0')

.. py:function:: npu_confusion_transpose(self, perm, shape, transpose_first) -> Tensor
:module: torch_npu

混淆reshape和transpose运算。

:param Tensor self: 数据类型:float16、float32、int8、int16、int32、int64、uint8、uint16、uint32、uint64。
:param ListInt perm: self张量的维度排列。
:param ListInt shape: 输入shape。
:param Bool transpose_first: 如果值为True,首先执行transpose,否则先执行reshape。

:rtype: Tensor

示例:

.. code-block:: python
:linenos:

>>> x = torch.rand(2, 3, 4, 6).npu()
>>> x.shape
torch.Size([2, 3, 4, 6])
>>> y = torch_npu.npu_confusion_transpose(x, (0, 2, 1, 3), (2, 4, 18), True)
>>> y.shape
torch.Size([2, 4, 18])
>>> y2 = torch_npu.npu_confusion_transpose(x, (0, 2, 1), (2, 12, 6), False)
>>> y2.shape
torch.Size([2, 6, 12])

.. py:function:: npu_conv2d(input, weight, bias, stride, padding, dilation, groups) -> Tensor
:module: torch_npu

在由多个输入平面组成的输入图像上应用一个2D卷积。

:param Tensor input: shape的输入张量,值为 (minibatch, in_channels, iH, iW)。
:param Tensor weight: shape过滤器,值为 (out_channels, in_channels/groups, kH, kW)。
:param Tensor bias: shape偏差 (out_channels)。
:param ListInt stride: 卷积核步长。
:param ListInt padding: 输入两侧的隐式填充。
:param ListInt dilation: 内核元素间距。
:param Int groups: 对输入进行分组。In_channels可被组数整除。

:rtype: Tensor

.. py:function:: npu_conv3d(input, weight, bias, stride, padding, dilation, groups) -> Tensor
:module: torch_npu

在由多个输入平面组成的输入图像上应用一个3D卷积。

:param Tensor input: shape的输入张量,值为 (minibatch, in_channels, iT, iH, iW)。
:param Tensor weight: shape过滤器,值为 (out_channels, in_channels/groups, kT, kH, kW)。
:param Tensor bias: shape偏差 (out_channels)。
:param ListInt stride: 卷积核步长。
:param ListInt padding: 输入两侧的隐式填充。
:param ListInt dilation: 内核元素间距。
:param Int groups: 对输入进行分组。In_channels可被组数整除。

:rtype: Tensor

.. py:function:: npu_conv_transpose2d(input, weight, bias, padding, output_padding, stride, dilation, groups) -> Tensor
:module: torch_npu

在由多个输入平面组成的输入图像上应用一个2D转置卷积算子,有时这个过程也被称为“反卷积”。

:param Tensor input: shape的输入张量,值为 (minibatch, in_channels, iH, iW)。
:param Tensor weight: shape过滤器,值为 (in_channels, out_channels/groups, kH, kW)。
:param Tensor bias: shape偏差 (out_channels)。
:param ListInt padding: (dilation * (kernel_size - 1) - padding) 用零来填充输入每个维度的两侧。
:param ListInt output_padding: 添加到输出shape每个维度一侧的附加尺寸。
:param ListInt stride: 卷积核步长。
:param ListInt dilation: 内核元素间距。
:param Int groups: 对输入进行分组。In_channels可被组数整除。

:rtype: Tensor

.. py:function:: npu_convolution(input, weight, bias, stride, padding, dilation, groups) -> Tensor
:module: torch_npu

在由多个输入平面组成的输入图像上应用一个2D或3D卷积。

:param Tensor input: shape的输入张量,值为 (minibatch, in_channels, iH, iW) 或 (minibatch, in_channels, iT, iH, iW)。
:param Tensor weight: shape过滤器,值为 (out_channels, in_channels/groups, kH, kW) 或 (out_channels, in_channels/groups, kT, kH, kW)。
:param Tensor bias: shape偏差 (out_channels)。
:param ListInt stride: 卷积核步长。
:param ListInt padding: 输入两侧的隐式填充。
:param ListInt dilation: 内核元素间距。
:param Int groups: 对输入进行分组。In_channels可被组数整除。

:rtype: Tensor

.. py:function:: npu_convolution_transpose(input, weight, bias, padding, output_padding, stride, dilation, groups) -> Tensor
:module: torch_npu

在由多个输入平面组成的输入图像上应用一个2D或3D转置卷积算子,有时这个过程也被称为“反卷积”。

:param Tensor input: shape的输入张量,值为 (minibatch, in_channels, iH, iW) 或 (minibatch, in_channels, iT, iH, iW)。
:param Tensor weight: shape过滤器,值为 (in_channels, out_channels/groups, kH, kW) 或 (in_channels, out_channels/groups, kT, kH, kW)。
:param Tensor bias: shape偏差 (out_channels)。
:param ListInt padding: (dilation * (kernel_size - 1) - padding) 用零来填充输入每个维度的两侧。
:param ListInt output_padding: 添加到输出shape每个维度一侧的附加尺寸。
:param ListInt stride: 卷积核步长。
:param ListInt dilation: 内核元素间距。
:param Int groups: 对输入进行分组。In_channels可被组数整除。

:rtype: Tensor

.. py:function:: npu_deformable_conv2d(self, weight, offset, bias, kernel_size, stride, padding, dilation=[1,1,1,1], groups=1, deformable_groups=1, modulated=True) -> (Tensor, Tensor)
:module: torch_npu

使用预期输入计算变形卷积输出(deformed convolution output)。

:param Tensor self: 输入图像的4D张量。格式为“NHWC”,数据按以下顺序存储:[batch, in_height, in_width, in_channels]。
:param Tensor weight: 可学习过滤器的4D张量。数据类型需与self相同。格式为“HWCN”,数据按以下顺序存储:[filter_height, filter_width, in_channels / groups, out_channels]。
:param Tensor offset: x-y坐标偏移和掩码的4D张量。格式为“NHWC”,数据按以下顺序存储:[batch, out_height, out_width, deformable_groups * filter_height * filter_width * 3]。bias (Tensor,可选) - 过滤器输出附加偏置(additive bias)的1D张量,数据按[out_channels]的顺序存储。
:param ListInt[2] kernel_size: 内核大小,2个整数的元组/列表。
:param ListInt stride: 4个整数的列表,表示每个输入维度的滑动窗口步长。维度顺序根据self的数据格式解释。N维和C维必须设置为1。
:param ListInt padding: 4个整数的列表,表示要添加到输入每侧(顶部、底部、左侧、右侧)的像素数。

:param ListInt dilation: 4个整数的列表,表示输入每个维度的膨胀系数(dilation factor)。维度顺序根据self的数据格式解释。N维和C维必须设置为1。
:param Int groups: int32类型单整数,表示从输入通道到输出通道的阻塞连接数。In_channels和out_channels需都可被“groups”数整除。
:param Int deformable_groups: int32类型单整数,表示可变形组分区的数量。In_channels需可被“deformable_groups”数整除。
:param Bool transpose_first: 默认值为True, 指定DeformableConv2D版本。True表示v2版本, False表示v1版本,目前仅支持v2。

:rtype: (Tensor, Tensor)

示例:

.. code-block:: python
:linenos:

>>> x = torch.rand(16, 32, 32, 32).npu()
>>> weight = torch.rand(32, 32, 5, 5).npu()
>>> offset = torch.rand(16, 75, 32, 32).npu()
>>> output, _ = torch_npu.npu_deformable_conv2d(x, weight, offset, None, kernel_size=[5, 5], stride = [1, 1, 1, 1], padding = [2, 2, 2, 2])
>>> output.shape
torch.Size([16, 32, 32, 32])
Loading