Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update Pytorch/api_doc.rst #58

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions sources/pytorch/api_doc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,155 @@ PyTorch-NPU 除了提供了 PyTorch 官方算子实现之外,也提供了大
>>> y = torch_npu.npu_anchor_response_flags(x, [60, 60], [2, 2], 9)
>>> y.shape
torch.Size([32400])

.. py:function:: npu_apply_adam(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, out = (var, m, v))
:module: torch_npu

adam结果计数。

:param Scalar beta1_power: beta1的幂
:param Scalar beta2_power: beta2的幂
:param Scalar lr: 学习率
:param Scalar beta1: 一阶矩估计值的指数衰减率
:param Scalar beta2: 二阶矩估计值的指数衰减率
:param Scalar epsilon: 添加到分母中以提高数值稳定性的项数
:param Tensor grad: 梯度
:param Bool use_locking: 设置为True时使用lock进行更新操作
:param Bool use_nesterov: 设置为True时采用nesterov更新
:param Tensor var: 待优化变量。
:param Tensor m: 变量平均值。
:param Tensor v: 变量方差。

.. py:function:: npu_batch_nms(self, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size, change_coordinate_frame=False, transpose_box=False) -> (Tensor, Tensor, Tensor, Tensor)

:module: torch_npu

根据batch分类计算输入框评分,通过评分排序,删除评分高于阈值(iou_threshold)的框,支持多批多类处理。通过NonMaxSuppression(nms)操作可有效删除冗余的输入框,提高检测精度。NonMaxSuppression:抑制不是极大值的元素,搜索局部的极大值,常用于计算机视觉任务中的检测类模型。

:param Tensor self: 必填值,输入框的tensor,包含batch大小,数据类型Float16,输入示例:[batch_size, num_anchors, q, 4],其中q=1或q=num_classes
:param Tensor scores: 必填值,输入tensor,数据类型Float16,输入示例:[batch_size, num_anchors, num_classes]
:param Float32 score_threshold: 必填值,指定评分过滤器的iou_threshold,用于筛选框,去除得分较低的框,数据类型Float32
:param Float32 iou_threshold: 必填值,指定nms的iou_threshold,用于设定阈值,去除高于阈值的的框,数据类型Float32
:param Int max_size_per_class: 必填值,指定每个类别的最大可选的框数,数据类型Int
:param Int max_total_size: 必填值,指定每个batch最大可选的框数,数据类型Int
:param Bool change_coordinate_frame: 可选值, 是否正则化输出框坐标矩阵,数据类型Bool(默认False)
:param Bool transpose_box: 可选值,确定是否在此op之前插入转置,数据类型Bool。True表示boxes使用4,N排布。 False表示boxes使用过N,4排布

输出说明:
:param Tensor nmsed_boxes: shape为(batch, max_total_size, 4)的3D张量,指定每批次输出的nms框,数据类型Float16
:param Tensor nmsed_scores: shape为(batch, max_total_size)的2D张量,指定每批次输出的nms分数,数据类型Float16
:param Tensor nmsed_classes: shape为(batch, max_total_size)的2D张量,指定每批次输出的nms类,数据类型Float16
:param Tensor nmsed_num: shape为(batch)的1D张量,指定nmsed_boxes的有效数量,数据类型Int32

:rtype: Tensor

示例:

.. code-block:: python
:linenos:

>>> boxes = torch.randn(8, 2, 4, 4, dtype = torch.float32).to("npu")
>>> scores = torch.randn(3, 2, 4, dtype = torch.float32).to("npu")
>>> nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = torch_npu.npu_batch_nms(boxes, scores, 0.3, 0.5, 3, 4)
>>> nmsed_boxes
>>> nmsed_scores
>>> nmsed_classes
>>> nmsed_num

.. py:function:: npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, step_size=None, adam_mode=0, *, out=(var,m,v))

:module: torch_npu

adam结果计数

:param Tensor var: float16或float32类型张量
:param Tensor m: 数据类型和shape与exp_avg相同
:param Tensor v: 数据类型和shape与exp_avg相同
:param Scalar lr: 数据类型与exp_avg相同
:param Scalar beta1: 数据类型与exp_avg相同
:param Scalar beta2: 数据类型与exp_avg相同
:param Scalar epsilon: 数据类型与exp_avg相同
:param Tensor grad: 数据类型和shape与exp_avg相同
:param Scalar max_grad_norm: 数据类型与exp_avg相同
:param Scalar global_grad_norm: 数据类型与exp_avg相同
:param Scalar weight_decay: 数据类型与exp_avg相同
:param Tensor step_size: 默认值为None - shape为(1, ),数据类型与exp_avg一致
:param Int adam_mode: 选择adam模式。0表示“adam”, 1表示“mbert_adam”, 默认值为0

关键字参数:
out (Tensor,可选) - 输出张量。

示例:

.. code-block:: python
:linenos:

>>> var_in = torch.rand(321538).uniform_(-32., 21.).npu()
>>> m_in = torch.zeros(321538).npu()
>>> v_in = torch.zeros(321538).npu()
>>> grad = torch.rand(321538).uniform_(-0.05, 0.03).npu()
>>> max_grad_norm = -1.
>>> beta1 = 0.9
>>> beta2 = 0.99
>>> weight_decay = 0.
>>> lr = 0.
>>> epsilon = 1e-06
>>> global_grad_norm = 0.
>>> var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay, out=(var_in, m_in, v_in))
>>> var_out
tensor([ 14.7733, -30.1218, -1.3647, ..., -16.6840, 7.1518, 8.4872], device='npu:0')

.. py:function:: npu_bmmV2(self, mat2, output_sizes) -> Tensor
:module: torch_npu

将矩阵“a”乘以矩阵“b”,生成“a*b”。支持FakeTensor模式

:param Tensor self: 2D或更高维度矩阵张量。数据类型:float16、float32、int32。格式:[ND, NHWC, FRACTAL_NZ]
:param Tensor mat2: 2D或更高维度矩阵张量。数据类型:float16、float32、int32。格式:[ND, NHWC, FRACTAL_NZ]
:param ListInt[] output_sizes: 输出的shape,用于matmul的反向传播

:rtype: Tensor

示例:

.. code-block:: python
:linenos:

>>> mat1 = torch.randn(10, 3, 4).npu()
>>> mat2 = torch.randn(10, 4, 5).npu()
>>> res = torch_npu.npu_bmmV2(mat1, mat2, [])
>>> res.shape
torch.Size([10, 3, 5])

.. py:function:: npu_bounding_box_decode(rois, deltas, means0, means1, means2, means3, stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip) -> Tensor
:module: torch_npu

根据rois和deltas生成标注框。自定义FasterRcnn算子

:param Tensor rois: 区域候选网络(RPN)生成的region of interests(ROI)。shape为(N,4)数据类型为float32或float16的2D张量。“N”表示ROI的数量, “4”表示“x0”、“x1”、“y0”和“y1”
:param Tensor deltas: RPN生成的ROI和真值框之间的绝对变化。shape为(N,4)数据类型为float32或float16的2D张量。“N”表示错误数,“4”表示“dx”、“dy”、“dw”和“dh”
:param Float means0: index
:param Float means1: index
:param Float means2: index
:param Float means33: index, 默认值为[0,0,0,0], "deltas" = "deltas" x "stds" + "means"
:param Float stds0: index
:param Float stds1: index
:param Float stds2: index
:param Float stds3: index, 默认值:[1.0,1.0,1.0,1.0], deltas" = "deltas" x "stds" + "means"
:param ListInt[2] max_shape: shape[h, w], 指定传输到网络的图像大小。用于确保转换后的bbox shape不超过“max_shape”
:param Float wh_ratio_clip: 当前水平的步长
:param Int num_base_anchors: “dw”和“dh”的值在(-wh_ratio_clip, wh_ratio_clip)范围内

:rtype: Tensor

示例:

.. code-block:: python
:linenos:

>>> rois = torch.tensor([[1., 2., 3., 4.], [3.,4., 5., 6.]], dtype = torch.float32).to("npu")
>>> deltas = torch.tensor([[5., 6., 7., 8.], [7.,8., 9., 6.]], dtype = torch.float32).to("npu")
>>> output = torch_npu.npu_bounding_box_decode(rois, deltas, 0, 0, 0, 0, 1, 1, 1, 1, (10, 10), 0.1)
>>> output
tensor([[2.5000, 6.5000, 9.0000, 9.0000],
[9.0000, 9.0000, 9.0000, 9.0000]], device='npu:0')