Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Oct 14, 2023
1 parent 236255e commit 5c767f4
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 12 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repos:
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]
exclude: .cmake-format.py
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
hooks:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ select = [
]
ignore = [
"UP015",
"Q000"
"Q000",
"F405"
]
target-version = "py38"

Expand Down
4 changes: 2 additions & 2 deletions src/paddlefx/legacy_module/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def call_function(self, fn, args, kwargs):
res = self.output.create_node('call_function', fn, args, kwargs)
self.stack.push(res)
elif is_custom_call:
raise NotImplementedError(f"custom_call is not supported")
raise NotImplementedError("custom_call is not supported")
else:
raise NotImplementedError(f"call function {fn} is not supported")

Expand Down Expand Up @@ -237,7 +237,7 @@ def STORE_SUBSCR(self, inst):
self.output.create_node('call_method', "__setitem__", [root, idx, value], {})

def POP_TOP(self, inst: Instruction):
value = self.stack.pop()
self.stack.pop()

def STORE_FAST(self, inst: Instruction):
self.f_locals[inst.argval] = self.stack.pop()
Expand Down
4 changes: 2 additions & 2 deletions src/paddlefx/pyeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def wrapper(self: PyEval, inst: Instruction):
state = self.get_state()
try:
return inner_fn(self, inst)
except (BreakGraphError, NotImplementedError) as e:
except (BreakGraphError, NotImplementedError):
# TODO: remove NotImplementedError
logger.debug(
f"break_graph_if_unsupported triggered compile", exc_info=True
"break_graph_if_unsupported triggered compile", exc_info=True
)

if not isinstance(self, PyEval):
Expand Down
2 changes: 1 addition & 1 deletion src/paddlefx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def hashable(obj) -> bool:
try:
hash(obj)
return True
except TypeError as e:
except TypeError:
return False


Expand Down
12 changes: 6 additions & 6 deletions src/paddlefx/variables/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, tx: PyEvalBase, *args: VariableBase, **kwargs) -> VariableBas
return result
else: # basic layer
ot = type(args[0].var)
obj_cls = type(args[0])

target = ''
model = (
tx.f_locals['self'] if 'self' in tx.f_locals else globals()['self']
Expand All @@ -82,7 +82,7 @@ def __call__(self, tx: PyEvalBase, *args: VariableBase, **kwargs) -> VariableBas
elif fn.__module__.startswith("paddle"):
# TODO: support multiple ouputs and containers
ot = type(args[0].var)
obj_cls = type(args[0])

output = graph.call_function(fn, args, kwargs, ot)
return TensorVariable(None, node=output)
elif inspect.isbuiltin(fn):
Expand All @@ -95,7 +95,7 @@ def __call__(self, tx: PyEvalBase, *args: VariableBase, **kwargs) -> VariableBas
if isinstance(attr, types.MethodType):
# For method variables
ot = type(args[0].var)
obj_cls = type(args[0])

return CallableVariable(fn, tx=tx)
else:
# the attr could be callable function
Expand All @@ -110,17 +110,17 @@ def __call__(self, tx: PyEvalBase, *args: VariableBase, **kwargs) -> VariableBas
operator.iadd,
]:
ot = type(args[0].var)
obj_cls = type(args[0])

output = graph.call_function(fn, args, kwargs, ot)
return TensorVariable(None, node=output)
elif fn in [operator.gt, operator.lt, operator.ge, operator.le]:
ot = type(args[0].var)
obj_cls = type(args[0])

output = graph.call_function(fn, args, kwargs, ot)
return TensorVariable(None, node=output)
elif fn in [operator.is_, operator.is_not]:
ot = type(args[0].var)
obj_cls = type(args[0])

output = graph.call_function(fn, args, kwargs, ot)
return TensorVariable(None, node=output)
else:
Expand Down

0 comments on commit 5c767f4

Please sign in to comment.