diff --git a/decompiler/frontend/binaryninja/handlers/assignments.py b/decompiler/frontend/binaryninja/handlers/assignments.py index b34efc670..f68a77f53 100644 --- a/decompiler/frontend/binaryninja/handlers/assignments.py +++ b/decompiler/frontend/binaryninja/handlers/assignments.py @@ -18,7 +18,7 @@ RegisterPair, UnaryOperation, ) -from decompiler.structures.pseudo.complextypes import Struct, Union +from decompiler.structures.pseudo.complextypes import Class, Struct, Union from decompiler.structures.pseudo.operations import MemberAccess @@ -67,9 +67,8 @@ def lift_set_field(self, assignment: mediumlevelil.MediumLevelILSetVarField, is_ """ # case 1 (struct), avoid set field of named integers: dest_type = self._lifter.lift(assignment.dest.type) - if isinstance(assignment.dest.type, binaryninja.NamedTypeReferenceType) and not ( - isinstance(dest_type, Pointer) and isinstance(dest_type.type, Integer) - ): + if isinstance(assignment.dest.type, binaryninja.NamedTypeReferenceType) and ( + isinstance(dest_type, Struct) or isinstance(dest_type, Class)): # otherwise get_member_by_offset not available struct_variable = self._lifter.lift(assignment.dest, is_aliased=True, parent=assignment) destination = MemberAccess( offset=assignment.offset, @@ -216,8 +215,14 @@ def lift_store_struct(self, instruction: mediumlevelil.MediumLevelILStoreStruct, """Lift a MLIL_STORE_STRUCT_SSA instruction to pseudo (e.g. object->field = x).""" vartype = self._lifter.lift(instruction.dest.expr_type) struct_variable = self._lifter.lift(instruction.dest, is_aliased=True, parent=instruction) + member = vartype.type.get_member_by_offset(instruction.offset) + if member is not None: + name = member.name + else: + name = f"__offset_{instruction.offset}" + name.replace("-", "minus_") struct_member_access = MemberAccess( - member_name=vartype.type.members.get(instruction.offset), + member_name=name, offset=instruction.offset, operands=[struct_variable], vartype=vartype, diff --git a/decompiler/frontend/binaryninja/handlers/types.py b/decompiler/frontend/binaryninja/handlers/types.py index 353a5922a..eeee02724 100644 --- a/decompiler/frontend/binaryninja/handlers/types.py +++ b/decompiler/frontend/binaryninja/handlers/types.py @@ -22,7 +22,7 @@ ) from decompiler.frontend.lifter import Handler from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType, Variable -from decompiler.structures.pseudo.complextypes import ComplexTypeMember, ComplexTypeName, Enum, Struct +from decompiler.structures.pseudo.complextypes import Class, ComplexTypeMember, ComplexTypeName, Enum, Struct from decompiler.structures.pseudo.complextypes import Union as Union_ @@ -75,39 +75,60 @@ def lift_named_type_reference_type(self, custom: NamedTypeReferenceType, **kwarg def lift_enum(self, binja_enum: EnumerationType, name: str = None, **kwargs) -> Enum: """Lift enum type.""" - enum_name = name if name else self._get_data_type_name(binja_enum, keyword="enum") + type_id = hash(binja_enum) + enum_name = self._get_data_type_name(binja_enum, keyword="enum", provided_name=name) enum = Enum(binja_enum.width * self.BYTE_SIZE, enum_name, {}) for member in binja_enum.members: enum.add_member(self._lifter.lift(member)) - self._lifter.complex_types.add(enum) + self._lifter.complex_types.add(enum, type_id) return enum def lift_enum_member(self, enum_member: EnumerationMember, **kwargs) -> ComplexTypeMember: """Lift enum member type.""" return ComplexTypeMember(size=0, name=enum_member.name, offset=-1, type=Integer(32), value=int(enum_member.value)) - def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Union[Struct, ComplexTypeName]: + def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Union[Struct, Union_, Class, ComplexTypeName]: + type_id = hash(struct) + cached_type = self._lifter.complex_types.retrieve_by_id(type_id) + if cached_type is not None: + return cached_type + """Lift struct or union type.""" if struct.type == StructureVariant.StructStructureType: - type_name = name if name else self._get_data_type_name(struct, keyword="struct") - lifted_struct = Struct(struct.width * self.BYTE_SIZE, type_name, {}) + keyword, type, members = "struct", Struct, {} elif struct.type == StructureVariant.UnionStructureType: - type_name = name if name else self._get_data_type_name(struct, keyword="union") - lifted_struct = Union_(struct.width * self.BYTE_SIZE, type_name, []) + keyword, type, members = "union", Union_, [] + elif struct.type == StructureVariant.ClassStructureType: + keyword, type, members = "class", Class, {} else: raise RuntimeError(f"Unknown struct type {struct.type.name}") + + type_name = self._get_data_type_name(struct, keyword=keyword, provided_name=name) + lifted_struct = type(struct.width * self.BYTE_SIZE, type_name, members) + + self._lifter.complex_types.add(lifted_struct, type_id) for member in struct.members: lifted_struct.add_member(self.lift_struct_member(member, type_name)) - self._lifter.complex_types.add(lifted_struct) return lifted_struct @abstractmethod - def _get_data_type_name(self, complex_type: Union[StructureType, EnumerationType], keyword: str) -> str: - """Parse out the name of complex type.""" - string = complex_type.get_string() - if keyword in string: - return complex_type.get_string().split(keyword)[1] - return string + def _get_data_type_name(self, complex_type: Union[StructureType, EnumerationType], keyword: str, provided_name:str) -> str: + """Parse out the name of complex type. Empty and duplicate names are changed. + Calling this function has the side effect of incrementing a counter in the UniqueNameProvider.""" + if provided_name: + name = provided_name + else: + type_string = complex_type.get_string() + if keyword in type_string: + name = complex_type.get_string().split(keyword)[1] + else: + name = type_string + + if name.strip() == "": + name = f"__anonymous_{keyword}" + name = self._lifter.unique_name_provider.get_unique_name(name) + + return name def lift_struct_member(self, member: StructureMember, parent_struct_name: str = None) -> ComplexTypeMember: """Lift struct or union member.""" @@ -117,7 +138,7 @@ def lift_struct_member(self, member: StructureMember, parent_struct_name: str = else: # if member is an embedded struct/union, the name is already available member_type = self._lifter.lift(member.type, name=member.name) - return ComplexTypeMember(0, name=member.name, offset=member.offset, type=member_type) + return ComplexTypeMember(member_type.size, name=member.name, offset=member.offset, type=member_type) @abstractmethod def _get_member_pointer_on_the_parent_struct(self, member: StructureMember, parent_struct_name: str) -> ComplexTypeMember: diff --git a/decompiler/frontend/binaryninja/handlers/unary.py b/decompiler/frontend/binaryninja/handlers/unary.py index 180aecfd0..824f23f7c 100644 --- a/decompiler/frontend/binaryninja/handlers/unary.py +++ b/decompiler/frontend/binaryninja/handlers/unary.py @@ -99,7 +99,12 @@ def _lift_load_struct(self, instruction: mediumlevelil.MediumLevelILLoadStruct, struct_variable = self._lifter.lift(instruction.src) struct_ptr: Pointer = self._lifter.lift(instruction.src.expr_type) struct_member = struct_ptr.type.get_member_by_offset(instruction.offset) - return MemberAccess(vartype=struct_ptr, operands=[struct_variable], offset=struct_member.offset, member_name=struct_member.name) + if struct_member is not None: + name = struct_member.name + else: + name = f"__offset_{instruction.offset}" + name.replace("-", "minus_") + return MemberAccess(vartype=struct_ptr, operands=[struct_variable], offset=instruction.offset, member_name=name) def _lift_ftrunc(self, instruction: mediumlevelil.MediumLevelILFtrunc, **kwargs) -> UnaryOperation: """Lift a MLIL_FTRUNC operation.""" diff --git a/decompiler/frontend/binaryninja/lifter.py b/decompiler/frontend/binaryninja/lifter.py index e42761763..244df6ed4 100644 --- a/decompiler/frontend/binaryninja/lifter.py +++ b/decompiler/frontend/binaryninja/lifter.py @@ -6,7 +6,7 @@ from decompiler.frontend.lifter import ObserverLifter from decompiler.structures.pseudo import DataflowObject, Tag, UnknownExpression, UnknownType -from ...structures.pseudo.complextypes import ComplexTypeMap +from ...structures.pseudo.complextypes import ComplexTypeMap, UniqueNameProvider from .handlers import HANDLERS @@ -17,6 +17,7 @@ def __init__(self, no_bit_masks: bool = True, bv: BinaryView = None): self.no_bit_masks = no_bit_masks self.bv: BinaryView = bv self.complex_types: ComplexTypeMap = ComplexTypeMap() + self.unique_name_provider: UniqueNameProvider = UniqueNameProvider() for handler in HANDLERS: handler(self).register() diff --git a/decompiler/structures/pseudo/complextypes.py b/decompiler/structures/pseudo/complextypes.py index b32528b4a..764143fa2 100644 --- a/decompiler/structures/pseudo/complextypes.py +++ b/decompiler/structures/pseudo/complextypes.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from decompiler.structures.pseudo.typing import Type @@ -28,6 +28,10 @@ def copy(self, **kwargs) -> Type: def declaration(self) -> str: raise NotImplementedError + @property + def complex_type_name(self): + return ComplexTypeName(0, self.name) + @dataclass(frozen=True, order=True) class ComplexTypeMember(ComplexType): @@ -54,16 +58,16 @@ def declaration(self) -> str: @dataclass(frozen=True, order=True) -class Struct(ComplexType): +class _BaseStruct(ComplexType): """Class representing a struct type.""" members: Dict[int, ComplexTypeMember] = field(compare=False) - type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.STRUCT + type_specifier: ComplexTypeSpecifier def add_member(self, member: ComplexTypeMember): self.members[member.offset] = member - def get_member_by_offset(self, offset: int) -> ComplexTypeMember: + def get_member_by_offset(self, offset: int) -> Optional[ComplexTypeMember]: return self.members.get(offset) def declaration(self) -> str: @@ -71,6 +75,16 @@ def declaration(self) -> str: return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}" +@dataclass(frozen=True, order=True) +class Struct(_BaseStruct): + type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.STRUCT + + +@dataclass(frozen=True, order=True) +class Class(_BaseStruct): + type_specifier: ComplexTypeSpecifier = ComplexTypeSpecifier.CLASS + + @dataclass(frozen=True, order=True) class Union(ComplexType): members: List[ComplexTypeMember] = field(compare=False) @@ -98,8 +112,9 @@ class Enum(ComplexType): def add_member(self, member: ComplexTypeMember): self.members[member.value] = member - def get_name_by_value(self, value: int) -> str: - return self.members.get(value).name + def get_name_by_value(self, value: int) -> Optional[str]: + member = self.members.get(value) + return member.name if member is not None else None def declaration(self) -> str: members = ",\n\t".join(f"{x.name} = {x.value}" for x in self.members.values()) @@ -117,19 +132,46 @@ def __str__(self) -> str: return self.name +class UniqueNameProvider: + """ The purpose of this class is to provide unique names for types, as duplicate names can potentially be encountered in the lifting stage (especially anonymous structs, etc.) + This class keeps track of all the names already used. If duplicates are found, they are renamed by appending suffixes with incrementing numbers. + E.g. `classname`, `classname__2`, `classname__3`, ... + """ + + def __init__(self): + self._name_to_count: Dict[str, int] = {} + + def get_unique_name(self, name: str) -> str: + """ This method returns the input name if it was unique so far. + Otherwise it returns the name with an added incrementing suffix. + In any case, the name occurence of the name is counted. + """ + if name not in self._name_to_count: + self._name_to_count[name] = 1 + return name + else: + self._name_to_count[name] += 1 + return f"{name}__{self._name_to_count[name]}" + + class ComplexTypeMap: """A class in charge of storing complex custom/user defined types by their string representation""" def __init__(self): self._name_to_type_map: Dict[ComplexTypeName, ComplexType] = {} + self._id_to_type_map: Dict[int, ComplexType] = {} - def retrieve_by_name(self, typename: ComplexTypeName) -> ComplexType: + def retrieve_by_name(self, typename: ComplexTypeName) -> Optional[ComplexType]: """Get complex type by name; used to avoid recursion.""" return self._name_to_type_map.get(typename, None) - def add(self, complex_type: ComplexType): + def retrieve_by_id(self, id: int) -> Optional[ComplexType]: + return self._id_to_type_map.get(id, None) + + def add(self, complex_type: ComplexType, type_id: int): """Add complex type to the mapping.""" - self._name_to_type_map[ComplexTypeName(0, complex_type.name)] = complex_type + self._id_to_type_map[type_id] = complex_type + self._name_to_type_map[complex_type.complex_type_name] = complex_type def pretty_print(self): for t in self._name_to_type_map.values(): diff --git a/decompiler/structures/pseudo/expressions.py b/decompiler/structures/pseudo/expressions.py index 5c59afaf2..4f9774ac6 100644 --- a/decompiler/structures/pseudo/expressions.py +++ b/decompiler/structures/pseudo/expressions.py @@ -191,7 +191,10 @@ def __str__(self) -> str: Constants of type Enum are represented as strings (corresponding enumerator identifiers). """ if isinstance(self._type, Enum): - return self._type.get_name_by_value(self.value) + name = self._type.get_name_by_value(self.value) + if name is not None: + return name + # otherwise, i.e. if value is not found in Enum class, fall through if self._type.is_boolean: return "true" if self.value else "false" if isinstance(self.value, str): diff --git a/decompiler/structures/pseudo/instructions.py b/decompiler/structures/pseudo/instructions.py index 362aecb72..d7dfdc8fe 100644 --- a/decompiler/structures/pseudo/instructions.py +++ b/decompiler/structures/pseudo/instructions.py @@ -146,7 +146,7 @@ def writes_memory(self) -> Optional[int]: """Return the memory version generated by this assignment, if any.""" if isinstance(self.value, Call): return self.value.writes_memory - if isinstance(self.destination, UnaryOperation) and self.destination.operation == OperationType.dereference: + if isinstance(self.destination, UnaryOperation) and self.destination.operation in {OperationType.member_access, OperationType.dereference}: return self.destination.writes_memory for variable in self.definitions: if variable.is_aliased: diff --git a/tests/structures/pseudo/test_complextypes.py b/tests/structures/pseudo/test_complextypes.py index 3bad97d60..c5a7d5c13 100644 --- a/tests/structures/pseudo/test_complextypes.py +++ b/tests/structures/pseudo/test_complextypes.py @@ -190,7 +190,7 @@ def blue(): class TestComplexTypeMap: def test_declarations(self, complex_types: ComplexTypeMap, book: Struct, color: Enum, record_id: Union): assert complex_types.declarations() == f"{book.declaration()};\n{color.declaration()};\n{record_id.declaration()};" - complex_types.add(book) + complex_types.add(book, 0) assert complex_types.declarations() == f"{book.declaration()};\n{color.declaration()};\n{record_id.declaration()};" def test_retrieve_by_name(self, complex_types: ComplexTypeMap, book: Struct, color: Enum, record_id: Union): @@ -201,7 +201,7 @@ def test_retrieve_by_name(self, complex_types: ComplexTypeMap, book: Struct, col @pytest.fixture def complex_types(self, book: Struct, color: Enum, record_id: Union): complex_types = ComplexTypeMap() - complex_types.add(book) - complex_types.add(color) - complex_types.add(record_id) + complex_types.add(book, 0) + complex_types.add(color, 1) + complex_types.add(record_id, 2) return complex_types