Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redunant copy method from type classes #429

Merged
merged 20 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions decompiler/frontend/binaryninja/handlers/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def lift_call(self, call: mediumlevelil.MediumLevelILCall, ssa: bool = False, **
Call(
dest := self._lifter.lift(call.dest, parent=call),
[self._lifter.lift(parameter, parent=call) for parameter in call.params],
vartype=dest.type.copy(),
vartype=dest.type,
writes_memory=call.output_dest_memory if ssa else None,
meta_data={"param_names": self._lift_call_parameter_names(call), "is_tailcall": isinstance(call, Tailcall)},
),
Expand All @@ -52,7 +52,7 @@ def lift_syscall(self, call: mediumlevelil.MediumLevelILSyscall, ssa: bool = Fal
Call(
dest := ImportedFunctionSymbol("Syscall", value=-1),
[self._lifter.lift(parameter, parent=call) for parameter in call.params],
vartype=dest.type.copy(),
vartype=dest.type,
writes_memory=call.output_dest_memory if ssa else None,
meta_data={"param_names": self._lift_syscall_parameter_names(call)},
),
Expand Down
2 changes: 1 addition & 1 deletion decompiler/frontend/binaryninja/handlers/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def lift_branch(self, branch: mediumlevelil.MediumLevelILIf, **kwargs) -> Branch
"""Lift a branch instruction by lifting its condition."""
condition = self._lifter.lift(branch.condition, parent=branch)
if not isinstance(condition, Condition):
condition = Condition(OperationType.not_equal, [condition, Constant(0, condition.type.copy())])
condition = Condition(OperationType.not_equal, [condition, Constant(0, condition.type)])
return Branch(condition)

def lift_branch_indirect(self, branch: mediumlevelil.MediumLevelILJumpTo, **kwargs) -> IndirectBranch:
Expand Down
31 changes: 18 additions & 13 deletions decompiler/frontend/binaryninja/handlers/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import abstractmethod
from typing import Optional, Union
from typing import Union

from binaryninja import BinaryView, StructureVariant
from binaryninja.types import (
Expand All @@ -22,9 +22,10 @@
)
from decompiler.frontend.lifter import Handler
from decompiler.structures.pseudo import ArrayType as PseudoArrayType
from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType, Variable
from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType
from decompiler.structures.pseudo.complextypes import Class, ComplexTypeMember, ComplexTypeName, Enum, Struct
from decompiler.structures.pseudo.complextypes import Union as Union_
from decompiler.util.frozen_dict import FrozenDict


class TypeHandler(Handler):
Expand Down Expand Up @@ -95,21 +96,25 @@ def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Unio
return cached_type

"""Lift struct or union type."""
if struct.type == StructureVariant.StructStructureType:
keyword, type, members = "struct", Struct, {}
elif struct.type == StructureVariant.UnionStructureType:
keyword, type, members = "union", Union_, []
elif struct.type == StructureVariant.ClassStructureType:
keyword, type, members = "class", Class, {}
else:
raise RuntimeError(f"Unknown struct type {struct.type.name}")
match struct.type:
case StructureVariant.StructStructureType:
keyword, type, to_members = "struct", Struct, lambda members: FrozenDict({m.offset: m for m in members})
case StructureVariant.UnionStructureType:
keyword, type, to_members = "union", Union_, lambda members: members
case StructureVariant.ClassStructureType:
keyword, type, to_members = "class", Class, lambda members: members
case _:
raise RuntimeError(f"Unknown struct type {struct.type.name}")

type_name = self._get_data_type_name(struct, keyword=keyword, provided_name=name)
lifted_struct = type(struct.width * self.BYTE_SIZE, type_name, members)

self._lifter.complex_types.add(lifted_struct, type_id)
members: list[ComplexTypeMember] = []
for member in struct.members:
lifted_struct.add_member(self.lift_struct_member(member, type_name))
members.append(self.lift_struct_member(member, type_name))
rihi marked this conversation as resolved.
Show resolved Hide resolved

lifted_struct = type(struct.width * self.BYTE_SIZE, type_name, to_members(members))

self._lifter.complex_types.add(lifted_struct, type_id)
return lifted_struct

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion decompiler/frontend/binaryninja/handlers/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def lift_address_of_field(self, operation: mediumlevelil.MediumLevelILAddressOfF
OperationType.plus,
[
UnaryOperation(OperationType.address, [operand := self._lifter.lift(operation.src, parent=operation)]),
Constant(operation.offset, vartype=operand.type.copy()),
Constant(operation.offset, vartype=operand.type),
],
vartype=self._lifter.lift(operation.expr_type),
)
Expand Down
2 changes: 1 addition & 1 deletion decompiler/pipeline/preprocessing/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _set_variables_type(self, variables: List[Variable]) -> None:
"""Harmonize the variable type of the given non-empty list of variables."""
group_type = variables[0].type
for variable in variables:
variable._type = group_type.copy()
variable._type = group_type

def _set_variables_aliased(self, variables: List) -> None:
"""Set all variables in the given list as aliased."""
Expand Down
2 changes: 1 addition & 1 deletion decompiler/pipeline/preprocessing/missing_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _insert_label_zero_for_aliased_if_missing(self, var_name: str, variable: Var
"""
first_copy = self.get_smallest_label_copy(var_name)
if first_copy.ssa_label > 0 and first_copy.is_aliased:
first_copy = variable.copy(vartype=first_copy.type.copy(), is_aliased=True, ssa_label=0)
first_copy = variable.copy(vartype=first_copy.type, is_aliased=True, ssa_label=0)
self._sorted_copies_of[var_name].insert(0, first_copy)

def get_smallest_label_copy(self, variable: Union[str, Variable]):
Expand Down
20 changes: 4 additions & 16 deletions decompiler/structures/pseudo/complextypes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import copy
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from decompiler.structures.pseudo.typing import Type
from decompiler.util.frozen_dict import FrozenDict


class ComplexTypeSpecifier(Enum):
Expand All @@ -14,17 +14,14 @@ class ComplexTypeSpecifier(Enum):
CLASS = "class"


@dataclass(frozen=True, order=True)
@dataclass(frozen=True, order=True, slots=True)
class ComplexType(Type):
size = 0
name: str

def __str__(self):
return self.name

def copy(self, **kwargs) -> Type:
return copy.deepcopy(self)

def declaration(self) -> str:
raise NotImplementedError

Expand Down Expand Up @@ -61,12 +58,9 @@ def declaration(self) -> str:
class _BaseStruct(ComplexType):
"""Class representing a struct type."""

members: Dict[int, ComplexTypeMember] = field(compare=False)
members: FrozenDict[int, ComplexTypeMember] = field(compare=False)
type_specifier: ComplexTypeSpecifier

def add_member(self, member: ComplexTypeMember):
self.members[member.offset] = member

def get_member_by_offset(self, offset: int) -> Optional[ComplexTypeMember]:
return self.members.get(offset)

Expand Down Expand Up @@ -99,9 +93,6 @@ class Union(ComplexType):
members: List[ComplexTypeMember] = field(compare=False)
type_specifier = ComplexTypeSpecifier.UNION

def add_member(self, member: ComplexTypeMember):
self.members.append(member)

def declaration(self) -> str:
members = ";\n\t".join(x.declaration() for x in self.members) + ";"
return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}"
Expand All @@ -127,9 +118,6 @@ class Enum(ComplexType):
members: Dict[int, ComplexTypeMember] = field(compare=False)
type_specifier = ComplexTypeSpecifier.ENUM

def add_member(self, member: ComplexTypeMember):
self.members[member.value] = member

def get_name_by_value(self, value: int) -> Optional[str]:
member = self.members.get(value)
return member.name if member is not None else None
Expand Down
16 changes: 8 additions & 8 deletions decompiler/structures/pseudo/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def pointee(self) -> Optional[Constant]:

def copy(self) -> Constant:
"""Generate a Constant with the same value and type."""
return Constant(self.value, self._type.copy(), self._pointee.copy() if self._pointee else None, self.tags)
return Constant(self.value, self._type, self._pointee.copy() if self._pointee else None, self.tags)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Expression."""
Expand Down Expand Up @@ -289,7 +289,7 @@ def __repr__(self):
raise ValueError(f"Unknown symbol type {type(self.value)}")

def copy(self) -> Symbol:
return Symbol(self.name, self.value, self._type.copy(), self.tags)
return Symbol(self.name, self.value, self._type, self.tags)


class FunctionSymbol(Symbol):
Expand All @@ -302,7 +302,7 @@ def __hash__(self):
return super().__hash__()

def copy(self) -> FunctionSymbol:
return FunctionSymbol(self.name, self.value, self._type.copy(), self.tags)
return FunctionSymbol(self.name, self.value, self._type, self.tags)


class ImportedFunctionSymbol(FunctionSymbol):
Expand All @@ -315,7 +315,7 @@ def __hash__(self):
return super().__hash__()

def copy(self) -> ImportedFunctionSymbol:
return ImportedFunctionSymbol(self._name, self.value, self._type.copy(), self.tags)
return ImportedFunctionSymbol(self._name, self.value, self._type, self.tags)


class IntrinsicSymbol(FunctionSymbol):
Expand Down Expand Up @@ -410,7 +410,7 @@ def copy(
"""Provide a copy of the current Variable."""
return self.__class__(
self._name[:] if name is None else name,
self._type.copy() if vartype is None else vartype,
self._type if vartype is None else vartype,
self.ssa_label if ssa_label is None else ssa_label,
self.is_aliased if is_aliased is None else is_aliased,
self.ssa_name if ssa_name is None else ssa_name,
Expand Down Expand Up @@ -467,7 +467,7 @@ def copy(

return self.__class__(
self._name[:] if name is None else name,
self._type.copy() if vartype is None else vartype,
self._type if vartype is None else vartype,
self.initial_value.copy() if initial_value is None else initial_value.copy(),
self.ssa_label if ssa_label is None else ssa_label,
self.is_aliased if is_aliased is None else is_aliased,
Expand Down Expand Up @@ -552,7 +552,7 @@ def substitute(self, replacee: Variable, replacement: Variable) -> None:

def copy(self) -> RegisterPair:
"""Return a copy of the current register pair."""
return RegisterPair(self._high.copy(), self._low.copy(), self._type.copy(), self.tags)
return RegisterPair(self._high.copy(), self._low.copy(), self._type, self.tags)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Expression."""
Expand Down Expand Up @@ -580,7 +580,7 @@ def __str__(self) -> str:

def copy(self) -> ConstantComposition:
"""Generate a copy of the UnknownExpression with the same message."""
return ConstantComposition([x.copy() for x in self.value], self._type.copy())
return ConstantComposition([x.copy() for x in self.value], self._type)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Expression."""
Expand Down
8 changes: 4 additions & 4 deletions decompiler/structures/pseudo/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def copy(self) -> UnaryOperation:
return UnaryOperation(
self._operation,
[operand.copy() for operand in self._operands],
self._type.copy(),
self._type,
writes_memory=self._writes_memory,
contraction=self.contraction,
array_info=ArrayInfo(self.array_info.base, self.array_info.index, self.array_info.confidence) if self.array_info else None,
Expand Down Expand Up @@ -457,7 +457,7 @@ def copy(self) -> MemberAccess:
self.member_offset,
self.member_name,
[operand.copy() for operand in self._operands],
self._type.copy(),
self._type,
writes_memory=self.writes_memory,
)

Expand Down Expand Up @@ -499,7 +499,7 @@ def right(self) -> Expression:

def copy(self) -> BinaryOperation:
"""Generate a deep copy of the current binary operation."""
return self.__class__(self._operation, [operand.copy() for operand in self._operands], self._type.copy(), self.tags)
return self.__class__(self._operation, [operand.copy() for operand in self._operands], self._type, self.tags)

def accept(self, visitor: DataflowObjectVisitorInterface[T]) -> T:
"""Invoke the appropriate visitor for this Operation."""
Expand Down Expand Up @@ -583,7 +583,7 @@ def copy(self) -> Call:
return Call(
self._function,
[operand.copy() for operand in self._operands],
self._type.copy(),
self._type,
self._writes_memory,
self._meta_data.copy() if self._meta_data is not None else None,
self.tags,
Expand Down
47 changes: 15 additions & 32 deletions decompiler/structures/pseudo/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from typing import Tuple
from typing import Tuple, TypeVar

_T = TypeVar("_T", bound="Type")


@dataclass(frozen=True, order=True)
Expand All @@ -18,34 +20,22 @@ def is_boolean(self) -> bool:
"""Check whether the given value is a boolean."""
return str(self) == "bool"

def copy(self, **kwargs) -> Type:
"""Generate a copy of the current type."""
return replace(self, **kwargs)

def resize(self, new_size: int) -> Type:
def resize(self: _T, new_size: int) -> _T:
"""Create an object of the type with a different size."""
return self.copy(size=new_size)
return replace(self, size=new_size)

@abstractmethod
def __str__(self) -> str:
"""Every type should provide a c-like string representation."""

def __add__(self, other) -> Type:
blattm marked this conversation as resolved.
Show resolved Hide resolved
"""Add two types to generate one type of bigger size."""
return self.copy(size=self.size + other.size)

def __hash__(self) -> int:
"""Return a hash value for the given type."""
return hash(repr(self))


@dataclass(frozen=True, order=True)
class UnknownType(Type):
"""Represent an unknown type, mostly utilized for testing purposes."""

def __init__(self, size: int = 0):
"""Create a type with size 0."""
super().__init__(size)
object.__setattr__(self, "size", size)
blattm marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
"""Return the representation of the unknown type."""
Expand Down Expand Up @@ -137,10 +127,6 @@ class Float(Type):

SIZE_TYPES = {8: "quarter", 16: "half", 32: "float", 64: "double", 80: "long double", 128: "quadruple", 256: "octuple"}

def __init__(self, size: int):
"""Create a new float type with the given size."""
super().__init__(size)

@classmethod
def float(cls) -> Float:
"""Return a float type (IEEE 754)."""
Expand All @@ -167,16 +153,17 @@ def __init__(self, basetype: Type, size: int = 32):
object.__setattr__(self, "type", basetype)
object.__setattr__(self, "size", size)

def resize(self, new_size: int) -> Pointer:
# Needs custom implementation, because construction parameter 'basetype' differs in name to field 'type'.
# This causes dataclasses.replace to not work
return Pointer(self.type, new_size)

def __str__(self) -> str:
"""Return a nice string representation."""
if isinstance(self.type, Pointer):
return f"{self.type}*"
return f"{self.type} *"

def copy(self, **kwargs) -> Pointer:
"""Generate a copy of the current pointer."""
return Pointer(self.type.copy(), self.size)


@dataclass(frozen=True, order=True)
class ArrayType(Type):
Expand All @@ -191,14 +178,14 @@ def __init__(self, basetype: Type, elements: int):
object.__setattr__(self, "size", basetype.size * elements)
object.__setattr__(self, "elements", elements)

def resize(self, new_size: int) -> ArrayType:
# Overridden because of backwards compatibility... ArrayType was not able to be resized.
return self

def __str__(self) -> str:
"""Return a nice string representation."""
return f"{self.type} [{self.elements}]"

def copy(self, **kwargs) -> Pointer:
"""Generate a copy of the current pointer."""
return ArrayType(self.type.copy(), self.elements)


@dataclass(frozen=True, order=True)
class CustomType(Type):
Expand Down Expand Up @@ -235,10 +222,6 @@ def __str__(self) -> str:
"""Return the given string representation."""
return self.text

def copy(self, **kwargs) -> CustomType:
"""Generate a copy of the current custom type."""
return CustomType(self.text, self.size)


@dataclass(frozen=True, order=True)
class FunctionTypeDef(Type):
Expand Down
Loading
Loading