Skip to content

Commit

Permalink
change register_torch_overload typehint.
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00-pl committed Jan 3, 2025
1 parent c568960 commit ca3f1a6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions plai/dialect/aten_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def from_torch_overload_dim(args: list, attrs: dict, loc: Location = None):
return Sum(args[0], args[1], args[2], loc)

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
register('torch::sum', cls.from_torch)
register('torch::sum.dim_IntList', cls.from_torch_overload_dim)

Expand Down Expand Up @@ -78,7 +78,7 @@ def from_torch_overload_dim(args: list, attrs: dict, loc: Location = None):
return Max(args[0], args[1], args[2], loc)

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
register(f'{cls.get_namespace()}::max.dim', cls.from_torch_overload_dim)


Expand Down Expand Up @@ -109,7 +109,7 @@ def from_torch(args: list, attrs: dict, loc: Location = None):
return Transpose(args[0], loc)

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
register(f'{cls.get_namespace()}::t', cls.from_torch)


Expand Down
12 changes: 6 additions & 6 deletions plai/dialect/torch_dialect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Callable, Optional
from typing import Callable

from plai.core import module
from plai.core.location import Location
Expand All @@ -16,15 +16,15 @@ def from_torch(args: list, attrs: dict, loc: Location = None):
pass

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
register(cls.get_op_name('::'), cls.from_torch)

convertion_function_dict = {}

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

def _register_torch_overload_inner(name, func):
def _register_torch_overload_inner(name: str, func: Callable) -> None:
assert name not in TorchNode.convertion_function_dict, f'Duplicate key: {name}'
TorchNode.convertion_function_dict[name] = func

Expand All @@ -40,7 +40,7 @@ def from_torch(args: list, attrs: dict, loc: Location = None):
return GetItem(args[0], args[1], loc)

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
name = '_operator.getitem'
register(name, cls.from_torch)

Expand All @@ -57,7 +57,7 @@ def from_torch(args: list, attrs: dict, loc: Location = None):
return Linear(args[0], args[1], args[2], loc)

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
name = f'{cls.get_namespace()}._C._nn.linear'
register(name, cls.from_torch)

Expand All @@ -71,6 +71,6 @@ def from_torch(args: list, attrs: dict, loc: Location = None):
return Relu(args[0], loc)

@classmethod
def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]):
def register_torch_overload(cls, register: Callable[[str, Callable], None]):
name = f'{cls.get_namespace()}.relu'
register(name, cls.from_torch)

0 comments on commit ca3f1a6

Please sign in to comment.