Skip to content

Commit

Permalink
Add 'Type hints', 'documentation', 'example code'
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuizhao committed Feb 5, 2024
1 parent 406e46a commit bd3c34d
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 27 deletions.
11 changes: 9 additions & 2 deletions python/oneflow/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
"""
import warnings
import oneflow as flow
from oneflow.framework.tensor import Tensor


def relu6(input, inplace=False):
r"""relu6(input, inplace=False) -> Tensor
def relu6(input: Tensor, inplace=False) -> Tensor:
r"""relu6(input: Tensor, inplace=False) -> Tensor
Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`.
Expand All @@ -27,3 +28,9 @@ def relu6(input, inplace=False):
if inplace:
warnings.warn("relu6 do not support inplace now")
return flow._C.hardtanh(input, min_val=0.0, max_val=6.0)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
9 changes: 8 additions & 1 deletion python/oneflow/nn/functional/affine_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from typing import List

import oneflow as flow
from oneflow.framework.tensor import Tensor


def affine_grid(theta, size: List[int], align_corners: bool = False):
def affine_grid(theta: Tensor, size: List[int], align_corners: bool = False) -> Tensor:
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/1.10/generated/torch.nn.functional.affine_grid.html.
Expand Down Expand Up @@ -66,3 +67,9 @@ def affine_grid(theta, size: List[int], align_corners: bool = False):
"""
y = flow._C.affine_grid(theta, size=size, align_corners=align_corners)
return y


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
13 changes: 10 additions & 3 deletions python/oneflow/nn/functional/grid_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
"""

import oneflow as flow
from oneflow.framework.tensor import Tensor


def grid_sample(
input,
grid,
input: Tensor,
grid: Tensor,
mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = False,
):
) -> Tensor:
r"""The interface is consistent with PyTorch.
The documentation is referenced from:
https://pytorch.org/docs/1.10/generated/torch.nn.functional.grid_sample.html.
Expand Down Expand Up @@ -137,3 +138,9 @@ def grid_sample(
align_corners=align_corners,
)
return y


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
25 changes: 16 additions & 9 deletions python/oneflow/nn/functional/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,17 @@
from typing import Optional, Tuple, Union

import oneflow as flow
from oneflow.framework.tensor import Tensor


def interpolate(
input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
):
input: Tensor,
size: Optional[Union[int, Tuple[int, ...]]] = None,
scale_factor: Optional[Union[float, Tuple[float, ...]]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
) -> Tensor:
r"""The interface is consistent with PyTorch.
The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate.
Expand Down Expand Up @@ -279,12 +280,12 @@ def interpolate(


def upsample(
input,
input: Tensor,
size: Optional[Union[int, Tuple[int, ...]]] = None,
scale_factor: Optional[Union[float, Tuple[float, ...]]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
) -> Tensor:
r"""
Upsamples a given multi-channel 1D (temporal), 2D (spatial) or 3D (volumetric) data.
Expand All @@ -298,3 +299,9 @@ def upsample(
mode=mode,
align_corners=align_corners,
)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
11 changes: 10 additions & 1 deletion python/oneflow/nn/functional/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional

import oneflow as flow
from oneflow.framework.tensor import Tensor


def linear(input, weight, bias=None):
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
r"""
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.
Expand Down Expand Up @@ -46,3 +49,9 @@ def linear(input, weight, bias=None):
if bias is not None:
res += bias
return res


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
19 changes: 17 additions & 2 deletions python/oneflow/nn/functional/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Tuple, Union
import oneflow as flow
from oneflow.framework.tensor import Tensor

_shape_t = Union[int, Tuple[int], flow._oneflow_internal.Size]


def group_norm(
input: Tensor,
Expand All @@ -24,7 +27,7 @@ def group_norm(
bias: Tensor = None,
eps: float = 1e-05,
num_channels: int = None,
):
) -> Tensor:
r"""Apply Group Normalization for last certain number of dimensions.
See :class:`~oneflow.nn.GroupNorm` for details.
Expand Down Expand Up @@ -54,7 +57,13 @@ def group_norm(
return res


def layer_norm(input, normalized_shape: tuple, weight=None, bias=None, eps=1e-05):
def layer_norm(
input: Tensor,
normalized_shape: _shape_t,
weight: Tensor = None,
bias: Tensor = None,
eps: float = 1e-05,
) -> Tensor:
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)
normalized_shape = tuple(normalized_shape)
Expand Down Expand Up @@ -118,3 +127,9 @@ def layer_norm(input, normalized_shape: tuple, weight=None, bias=None, eps=1e-05
epsilon=eps,
)
return res


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
24 changes: 16 additions & 8 deletions python/oneflow/nn/functional/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Optional
import oneflow as flow
from oneflow.framework.tensor import Tensor


def embedding(
input,
weight,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
):
input: Tensor,
weight: Tensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
) -> Tensor:
r"""A simple lookup table that looks up embeddings in a fixed dictionary and size.
This module is often used to retrieve word embeddings using indices.
Expand Down Expand Up @@ -87,3 +89,9 @@ def embedding(
return flow._C.gather(weight, input, axis=0)
else:
return flow._C.embedding(weight, input, padding_idx, scale_grad_by_freq)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
9 changes: 8 additions & 1 deletion python/oneflow/nn/functional/sparse_softmax_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
limitations under the License.
"""
import oneflow as flow
from oneflow.framework.tensor import Tensor


def sparse_softmax_cross_entropy(labels, logits):
def sparse_softmax_cross_entropy(labels: Tensor, logits: Tensor) -> Tensor:
r"""The interface is consistent with TensorFlow.
The documentation is referenced from:
https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits
Expand Down Expand Up @@ -68,3 +69,9 @@ def sparse_softmax_cross_entropy(labels, logits):
tensor([ 2.9751e-01, 1.1448e+00, -1.4305e-06], dtype=oneflow.float32)
"""
return flow._C.sparse_softmax_cross_entropy(logits, labels)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)

0 comments on commit bd3c34d

Please sign in to comment.