Skip to content

Commit

Permalink
Update lintrunner libraries (pytorch#3963)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#3963

Reviewed By: kirklandsign

Differential Revision: D58479009

fbshipit-source-id: 31624ace3ccb6ca8c6de6497f8b239eea8d31ad4
  • Loading branch information
mergennachin authored and facebook-github-bot committed Jun 13, 2024
1 parent 22b063d commit 2d35b30
Show file tree
Hide file tree
Showing 13 changed files with 35 additions and 32 deletions.
2 changes: 1 addition & 1 deletion backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def dbg_node(node):
logger.info(" node.meta = ")
for k, v in node.meta.items():
logger.info(f" '{k}' = {v}")
if type([]) == type(v):
if isinstance(v, list):
for i in v:
logger.info(f" {i} ")

Expand Down
6 changes: 5 additions & 1 deletion backends/qualcomm/passes/convert_to_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def _convert(self, graph_module: torch.fx.GraphModule):
for _, src_partitions in partitions.items():
for src_partition in src_partitions:
op_cnt = Counter(
[n.target for n in src_partition.nodes if type(n.target) == edge_op]
[
n.target
for n in src_partition.nodes
if isinstance(n.target, edge_op)
]
)
if self.linear in op_cnt:
continue
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/passes/fold_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _fold(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:

# collecting quant nodes to be removed
for i in range(1, len(n.args)):
if type(n.args[i]) == torch.fx.node.Node:
if isinstance(n.args[i], torch.fx.node.Node):
to_be_removed.append(n.args[i])
# could be a commonly shared attribute between q & dq
if n.args[i].target == exir_ops.edge.aten._to_copy.default:
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/passes/insert_io_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def _ceate_args(self, target: torch.fx.node.Target, quant_attrs: Dict):
if name == "out_dtype":
continue
value = quant_attrs[name]
if type(arg_schema.type) == torch.tensor and type(value) in [int, float]:
if isinstance(arg_schema.type, torch.tensor) and (
isinstance(value, int) or isinstance(value, float)
):
value = torch.tensor(value)
ret.append(value)
return ret
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/passes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_quant_attrs(
attr_n = quant_node.args[i]

value = attr_n
if type(attr_n) == torch.fx.node.Node:
if isinstance(attr_n, torch.fx.node.Node):
# could be a commonly shared attribute between q & dq
if attr_n.target == exir_ops.edge.aten._to_copy.default:
value = get_parameter(attr_n.args[0], edge_program)
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _get_quant_config(self, op: str | OpOverload) -> Optional[QuantizationConfig
1. is one of use_per_channel_weight_quant_ops
2. int8 / int16 config
"""
if type(op) == str:
if isinstance(op, str):
return

if op in self.use_per_channel_weight_quant_ops:
Expand Down
2 changes: 1 addition & 1 deletion examples/qualcomm/scripts/dummy_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def create_device_inputs(example_inputs, use_kv_cache):
input_list = ""
if use_kv_cache:
for i, d in enumerate(inputs[0]):
if type(d) == list:
if isinstance(d, list):
d = torch.stack(d)
d.numpy().tofile(f"{args.artifact}/input_0_0.raw")
input_list = f"input_0_{i}.raw "
Expand Down
19 changes: 9 additions & 10 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,17 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
NOTE: When symbool and symfloat are supported bool and float lists will be stored boxed.
"""
elem_type = type(val_type)

if elem_type == torch.BoolType:
if isinstance(val_type, torch.BoolType):
return EValue(BoolList(typing.cast(List[bool], val)))

if elem_type == torch.IntType:
if isinstance(val_type, torch.IntType):
return self._emit_int_list(val)

if elem_type == torch.FloatType:
if isinstance(val_type, torch.FloatType):
return EValue(DoubleList(typing.cast(List[float], val)))

if elem_type == torch.TensorType:
if isinstance(val_type, torch.TensorType):
values = []
for v in val:
assert isinstance(v, _AbstractValue)
Expand All @@ -308,10 +307,10 @@ def _emit_list(self, val: List[_Argument], val_type: _SchemaType) -> EValue:
values.append(v.id)
return EValue(TensorList(values))

if elem_type == torch.OptionalType:
if isinstance(val_type, torch.OptionalType):
# refine further
actual_type = typing.cast(torch.OptionalType, val_type).getElementType()
if type(actual_type) == torch.TensorType:
actual_type = val_type.getElementType()
if isinstance(actual_type, torch.TensorType):
vals = []
for v in val:
if v is None:
Expand Down Expand Up @@ -437,9 +436,9 @@ def _constant_to_evalue( # noqa: C901
val_type = torch.ListType(
self._get_list_tuple_jit_type(val) # pyre-ignore
)
if type(val_type) == torch.OptionalType:
if isinstance(val_type, torch.OptionalType):
val_type = val_type.getElementType()
assert type(val_type) == torch.ListType
assert isinstance(val_type, torch.ListType)
return self._emit_list(
typing.cast(List[_Argument], val),
typing.cast(_SchemaType, val_type.getElementType()),
Expand Down
2 changes: 1 addition & 1 deletion exir/passes/remove_mixed_type_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901
)[1]

def try_coerce(value: PyTree, arg: torch.Argument) -> PyTree:
if type(arg.type) != torch.TensorType:
if not isinstance(arg.type, torch.TensorType):
return value

if isinstance(value, ProxyValue):
Expand Down
2 changes: 1 addition & 1 deletion exir/passes/scalar_to_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def try_coerce(value, arg):
return (
torch.tensor(value)
if isinstance(value, (float, int, bool))
and type(arg.type) == torch.TensorType
and isinstance(arg.type, torch.TensorType)
else value
)

Expand Down
6 changes: 3 additions & 3 deletions exir/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __torch_dispatch__( # noqa: C901

# Kind of a hacky way to test if an op is in-place or not
if func.__name__[-1] == "_" and func.__name__[0] != "_":
if type(args[0]) == PythonTensor:
if isinstance(args[0], PythonTensor):
args[0].proxy = proxy_out

if not torch.fx.traceback.has_preserved_node_meta():
Expand All @@ -361,13 +361,13 @@ def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue:
if e is None:
e = torch.empty(())

if type(e) == torch.Tensor:
if isinstance(e, torch.Tensor):
return PythonTensor(e, proxy)

# Inplace and out-variant ops may return one of their arguments, which is already
# a PythonTensor. In this case, we need to update the PythonTensor's associated
# proxy to the newly created proxy.
if type(e) == PythonTensor:
if isinstance(e, PythonTensor):
e.update_proxy(proxy)
return e

Expand Down
12 changes: 6 additions & 6 deletions requirements-lintrunner.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@ lintrunner==0.11.0
lintrunner-adapters==0.11.0

# Flake 8 and its dependencies
flake8==6.0.0
flake8==6.1.0
flake8-breakpoint==1.1.0
flake8-bugbear==23.6.5
flake8-comprehensions==3.12.0
flake8-bugbear==23.9.16
flake8-comprehensions==3.14.0
flake8-pyi==23.5.0
mccabe==0.7.0
pycodestyle==2.10.0
pycodestyle==2.11.1
torchfix==0.5.0

# UFMT
black==24.2.0
ufmt==2.5.1
black==24.4.2
ufmt==2.6.0
usort==1.0.5

# Other linters
Expand Down
6 changes: 2 additions & 4 deletions sdk/bundled_program/test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def assertIOListEqual(
) -> None:
self.assertEqual(len(tl1), len(tl2))
for t1, t2 in zip(tl1, tl2):
if type(t1) == torch.Tensor:
assert type(t1) == type(t2)
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[bool, float, int, Tensor]`.
if isinstance(t1, torch.Tensor):
assert isinstance(t2, torch.Tensor)
self.assertTensorEqual(t1, t2)
else:
self.assertTrue(t1 == t2)
Expand Down

0 comments on commit 2d35b30

Please sign in to comment.