diff --git a/plai/dialect/torch_dialect.py b/plai/dialect/torch_dialect.py index 052ee4d..cd1961f 100644 --- a/plai/dialect/torch_dialect.py +++ b/plai/dialect/torch_dialect.py @@ -17,8 +17,7 @@ def from_torch(args: list, attrs: dict, loc: Location = None): @classmethod def register_torch_overload(cls, register: Callable[[str, Optional[Callable]], None]): - name = cls.get_op_name('::') - register(name, cls.from_torch) + register(cls.get_op_name('::'), cls.from_torch) convertion_function_dict = {}