Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redunant copy method from type classes #429

Merged
merged 20 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions decompiler/frontend/binaryninja/handlers/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def lift_call(self, call: mediumlevelil.MediumLevelILCall, ssa: bool = False, **
Call(
dest := self._lifter.lift(call.dest, parent=call),
[self._lifter.lift(parameter, parent=call) for parameter in call.params],
vartype=dest.type.copy(),
vartype=dest.type,
writes_memory=call.output_dest_memory if ssa else None,
meta_data={"param_names": self._lift_call_parameter_names(call), "is_tailcall": isinstance(call, Tailcall)},
),
Expand All @@ -52,7 +52,7 @@ def lift_syscall(self, call: mediumlevelil.MediumLevelILSyscall, ssa: bool = Fal
Call(
dest := ImportedFunctionSymbol("Syscall", value=-1),
[self._lifter.lift(parameter, parent=call) for parameter in call.params],
vartype=dest.type.copy(),
vartype=dest.type,
writes_memory=call.output_dest_memory if ssa else None,
meta_data={"param_names": self._lift_syscall_parameter_names(call)},
),
Expand Down
2 changes: 1 addition & 1 deletion decompiler/frontend/binaryninja/handlers/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def lift_branch(self, branch: mediumlevelil.MediumLevelILIf, **kwargs) -> Branch
"""Lift a branch instruction by lifting its condition."""
condition = self._lifter.lift(branch.condition, parent=branch)
if not isinstance(condition, Condition):
condition = Condition(OperationType.not_equal, [condition, Constant(0, condition.type.copy())])
condition = Condition(OperationType.not_equal, [condition, Constant(0, condition.type)])
return Branch(condition)

def lift_branch_indirect(self, branch: mediumlevelil.MediumLevelILJumpTo, **kwargs) -> IndirectBranch:
Expand Down
2 changes: 1 addition & 1 deletion decompiler/frontend/binaryninja/handlers/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def lift_address_of_field(self, operation: mediumlevelil.MediumLevelILAddressOfF
OperationType.plus,
[
UnaryOperation(OperationType.address, [operand := self._lifter.lift(operation.src, parent=operation)]),
Constant(operation.offset, vartype=operand.type.copy()),
Constant(operation.offset, vartype=operand.type),
],
vartype=self._lifter.lift(operation.expr_type),
)
Expand Down
2 changes: 1 addition & 1 deletion decompiler/pipeline/preprocessing/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _set_variables_type(self, variables: List[Variable]) -> None:
"""Harmonize the variable type of the given non-empty list of variables."""
group_type = variables[0].type
for variable in variables:
variable._type = group_type.copy()
variable._type = group_type

def _set_variables_aliased(self, variables: List) -> None:
"""Set all variables in the given list as aliased."""
Expand Down
2 changes: 1 addition & 1 deletion decompiler/pipeline/preprocessing/missing_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _insert_label_zero_for_aliased_if_missing(self, var_name: str, variable: Var
"""
first_copy = self.get_smallest_label_copy(var_name)
if first_copy.ssa_label > 0 and first_copy.is_aliased:
first_copy = variable.copy(vartype=first_copy.type.copy(), is_aliased=True, ssa_label=0)
first_copy = variable.copy(vartype=first_copy.type, is_aliased=True, ssa_label=0)
self._sorted_copies_of[var_name].insert(0, first_copy)

def get_smallest_label_copy(self, variable: Union[str, Variable]):
Expand Down
20 changes: 14 additions & 6 deletions decompiler/structures/pseudo/complextypes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import copy
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from decompiler.structures.pseudo.typing import Type

Expand All @@ -14,17 +13,14 @@ class ComplexTypeSpecifier(Enum):
CLASS = "class"


@dataclass(frozen=True, order=True)
@dataclass(frozen=True, order=True, slots=True)
class ComplexType(Type):
size = 0
name: str

def __str__(self):
return self.name

def copy(self, **kwargs) -> Type:
return copy.deepcopy(self)

def declaration(self) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -83,6 +79,10 @@ def declaration(self) -> str:
members = ";\n\t".join(self.members[k].declaration() for k in sorted(self.members.keys())) + ";"
return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}"

def __hash__(self) -> int:
# Because dict is not hashable, we need our own hash implementation
return hash(repr(self))


@dataclass(frozen=True, order=True)
class Struct(_BaseStruct):
Expand Down Expand Up @@ -121,6 +121,10 @@ def get_member_name_by_type(self, _type: Type) -> str:
logging.warning(f"Cannot get member name for union {self}")
return "unknown_field"

def __hash__(self) -> int:
# Because list is not hashable, we need our own hash implementation
return hash(repr(self))


@dataclass(frozen=True, order=True)
class Enum(ComplexType):
Expand All @@ -138,6 +142,10 @@ def declaration(self) -> str:
members = ",\n\t".join(f"{x.name} = {x.value}" for x in self.members.values())
return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}"

def __hash__(self) -> int:
# Because dict is not hashable, we need our own hash implementation
return hash(repr(self))


@dataclass(frozen=True, order=True)
class ComplexTypeName(Type):
Expand Down
14 changes: 7 additions & 7 deletions decompiler/structures/pseudo/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def pointee(self) -> Optional[Constant]:

def copy(self) -> Constant:
"""Generate a Constant with the same value and type."""
return Constant(self.value, self._type.copy(), self._pointee.copy() if self._pointee else None, self.tags)
return Constant(self.value, self._type, self._pointee.copy() if self._pointee else None, self.tags)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Expression."""
Expand Down Expand Up @@ -289,7 +289,7 @@ def __repr__(self):
raise ValueError(f"Unknown symbol type {type(self.value)}")

def copy(self) -> Symbol:
return Symbol(self.name, self.value, self._type.copy(), self.tags)
return Symbol(self.name, self.value, self._type, self.tags)


class FunctionSymbol(Symbol):
Expand All @@ -302,7 +302,7 @@ def __hash__(self):
return super().__hash__()

def copy(self) -> FunctionSymbol:
return FunctionSymbol(self.name, self.value, self._type.copy(), self.tags)
return FunctionSymbol(self.name, self.value, self._type, self.tags)


class ImportedFunctionSymbol(FunctionSymbol):
Expand All @@ -315,7 +315,7 @@ def __hash__(self):
return super().__hash__()

def copy(self) -> ImportedFunctionSymbol:
return ImportedFunctionSymbol(self._name, self.value, self._type.copy(), self.tags)
return ImportedFunctionSymbol(self._name, self.value, self._type, self.tags)


class IntrinsicSymbol(FunctionSymbol):
Expand Down Expand Up @@ -410,7 +410,7 @@ def copy(
"""Provide a copy of the current Variable."""
return self.__class__(
self._name[:] if name is None else name,
self._type.copy() if vartype is None else vartype,
self._type if vartype is None else vartype,
self.ssa_label if ssa_label is None else ssa_label,
self.is_aliased if is_aliased is None else is_aliased,
self.ssa_name if ssa_name is None else ssa_name,
Expand Down Expand Up @@ -467,7 +467,7 @@ def copy(

return self.__class__(
self._name[:] if name is None else name,
self._type.copy() if vartype is None else vartype,
self._type if vartype is None else vartype,
self.initial_value.copy() if initial_value is None else initial_value.copy(),
self.ssa_label if ssa_label is None else ssa_label,
self.is_aliased if is_aliased is None else is_aliased,
Expand Down Expand Up @@ -552,7 +552,7 @@ def substitute(self, replacee: Variable, replacement: Variable) -> None:

def copy(self) -> RegisterPair:
"""Return a copy of the current register pair."""
return RegisterPair(self._high.copy(), self._low.copy(), self._type.copy(), self.tags)
return RegisterPair(self._high.copy(), self._low.copy(), self._type, self.tags)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Expression."""
Expand Down
8 changes: 4 additions & 4 deletions decompiler/structures/pseudo/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def copy(self) -> UnaryOperation:
return UnaryOperation(
self._operation,
[operand.copy() for operand in self._operands],
self._type.copy(),
self._type,
writes_memory=self._writes_memory,
contraction=self.contraction,
array_info=ArrayInfo(self.array_info.base, self.array_info.index, self.array_info.confidence) if self.array_info else None,
Expand Down Expand Up @@ -457,7 +457,7 @@ def copy(self) -> MemberAccess:
self.member_offset,
self.member_name,
[operand.copy() for operand in self._operands],
self._type.copy(),
self._type,
writes_memory=self.writes_memory,
)

Expand Down Expand Up @@ -499,7 +499,7 @@ def right(self) -> Expression:

def copy(self) -> BinaryOperation:
"""Generate a deep copy of the current binary operation."""
return self.__class__(self._operation, [operand.copy() for operand in self._operands], self._type.copy(), self.tags)
return self.__class__(self._operation, [operand.copy() for operand in self._operands], self._type, self.tags)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Operation."""
Expand Down Expand Up @@ -583,7 +583,7 @@ def copy(self) -> Call:
return Call(
self._function,
[operand.copy() for operand in self._operands],
self._type.copy(),
self._type,
self._writes_memory,
self._meta_data.copy() if self._meta_data is not None else None,
self.tags,
Expand Down
47 changes: 15 additions & 32 deletions decompiler/structures/pseudo/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import Tuple
from typing import Tuple, TypeVar

_T = TypeVar("_T", bound="Type")


@dataclass(frozen=True, order=True)
Expand All @@ -18,34 +20,22 @@ def is_boolean(self) -> bool:
"""Check whether the given value is a boolean."""
return str(self) == "bool"

def copy(self, **kwargs) -> Type:
"""Generate a copy of the current type."""
return replace(self, **kwargs)

def resize(self, new_size: int) -> Type:
def resize(self: _T, new_size: int) -> _T:
"""Create an object of the type with a different size."""
return self.copy(size=new_size)
return replace(self, size=new_size)

@abstractmethod
def __str__(self) -> str:
"""Every type should provide a c-like string representation."""

def __add__(self, other) -> Type:
blattm marked this conversation as resolved.
Show resolved Hide resolved
"""Add two types to generate one type of bigger size."""
return self.copy(size=self.size + other.size)

def __hash__(self) -> int:
"""Return a hash value for the given type."""
return hash(repr(self))


@dataclass(frozen=True, order=True)
class UnknownType(Type):
"""Represent an unknown type, mostly utilized for testing purposes."""

def __init__(self, size: int = 0):
"""Create a type with size 0."""
super().__init__(size)
object.__setattr__(self, "size", size)
blattm marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
"""Return the representation of the unknown type."""
Expand Down Expand Up @@ -137,10 +127,6 @@ class Float(Type):

SIZE_TYPES = {8: "quarter", 16: "half", 32: "float", 64: "double", 80: "long double", 128: "quadruple", 256: "octuple"}

def __init__(self, size: int):
"""Create a new float type with the given size."""
super().__init__(size)

@classmethod
def float(cls) -> Float:
"""Return a float type (IEEE 754)."""
Expand All @@ -167,16 +153,17 @@ def __init__(self, basetype: Type, size: int = 32):
object.__setattr__(self, "type", basetype)
object.__setattr__(self, "size", size)

def resize(self, new_size: int) -> Pointer:
# Needs custom implementation, because construction parameter 'basetype' differs in name to field 'type'.
# This causes dataclasses.replace to not work
return Pointer(self.type, new_size)

def __str__(self) -> str:
"""Return a nice string representation."""
if isinstance(self.type, Pointer):
return f"{self.type}*"
return f"{self.type} *"

def copy(self, **kwargs) -> Pointer:
"""Generate a copy of the current pointer."""
return Pointer(self.type.copy(), self.size)


@dataclass(frozen=True, order=True)
class ArrayType(Type):
Expand All @@ -191,14 +178,14 @@ def __init__(self, basetype: Type, elements: int):
object.__setattr__(self, "size", basetype.size * elements)
object.__setattr__(self, "elements", elements)

def resize(self, new_size: int) -> ArrayType:
# Overridden because of backwards compatibility... ArrayType was not able to be resized.
return self

def __str__(self) -> str:
"""Return a nice string representation."""
return f"{self.type} [{self.elements}]"

def copy(self, **kwargs) -> Pointer:
"""Generate a copy of the current pointer."""
return ArrayType(self.type.copy(), self.elements)


@dataclass(frozen=True, order=True)
class CustomType(Type):
Expand Down Expand Up @@ -235,10 +222,6 @@ def __str__(self) -> str:
"""Return the given string representation."""
return self.text

def copy(self, **kwargs) -> CustomType:
"""Generate a copy of the current custom type."""
return CustomType(self.text, self.size)


@dataclass(frozen=True, order=True)
class FunctionTypeDef(Type):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def branch(operation, *args):


def contract(_type: Type, var):
t = _type.copy()
t = _type
_field = UnaryOperation(OperationType.cast, [var], vartype=t, contraction=True)
return _field

Expand Down
18 changes: 9 additions & 9 deletions tests/pipeline/preprocessing/test-coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
[
(
{
"a": {0: [Variable("a", i32.copy()), Variable("a", u32.copy())], 1: [Variable("a", u32.copy())]},
"a": {0: [Variable("a", i32), Variable("a", u32)], 1: [Variable("a", u32)]},
"b": {42: [Variable("b")]},
"c": {0: [Variable("c", u32.copy())], 2: [Variable("c", i64.copy())]},
"c": {0: [Variable("c", u32)], 2: [Variable("c", i64)]},
},
{
"a": {0: [Variable("a", i32.copy()), Variable("a", i32.copy())], 1: [Variable("a", u32.copy())]},
"a": {0: [Variable("a", i32), Variable("a", i32)], 1: [Variable("a", u32)]},
"b": {42: [Variable("b")]},
"c": {0: [Variable("c", u32.copy())], 2: [Variable("c", i64.copy())]},
"c": {0: [Variable("c", u32)], 2: [Variable("c", i64)]},
},
)
],
Expand Down Expand Up @@ -72,16 +72,16 @@ def test_acceptance():
BasicBlock(
0,
instructions=[
Assignment(x01 := Variable("x", i32.copy(), ssa_label=0), Constant(0x1337, i32.copy())),
Assignment(x01 := Variable("x", i32, ssa_label=0), Constant(0x1337, i32)),
Assignment(
x10 := Variable("x", i32.copy(), is_aliased=True, ssa_label=1),
Call(FunctionSymbol("foo", 0x42), [x02 := Variable("x", u32.copy(), ssa_label=0)]),
x10 := Variable("x", i32, is_aliased=True, ssa_label=1),
Call(FunctionSymbol("foo", 0x42), [x02 := Variable("x", u32, ssa_label=0)]),
),
Return([x12 := Variable("x", i32.copy(), is_aliased=False, ssa_label=1)]),
Return([x12 := Variable("x", i32, is_aliased=False, ssa_label=1)]),
],
)
]
)
Coherence().run(DecompilerTask(name="test", function_identifier="", cfg=cfg))
assert {variable.type for variable in [x01, x02]} == {i32.copy()}
assert {variable.type for variable in [x01, x02]} == {i32}
assert {variable.is_aliased for variable in [x01, x02, x10, x12]} == {True}
Loading
Loading