diff --git a/decompiler/frontend/binaryninja/handlers/types.py b/decompiler/frontend/binaryninja/handlers/types.py index 5b0305b1..929162a7 100644 --- a/decompiler/frontend/binaryninja/handlers/types.py +++ b/decompiler/frontend/binaryninja/handlers/types.py @@ -1,6 +1,6 @@ import logging from abc import abstractmethod -from typing import Union +from typing import Optional, Union from binaryninja import BinaryView, StructureVariant from binaryninja.types import ( @@ -22,7 +22,7 @@ ) from decompiler.frontend.lifter import Handler from decompiler.structures.pseudo import ArrayType as PseudoArrayType -from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType +from decompiler.structures.pseudo import CustomType, Float, FunctionTypeDef, Integer, Pointer, UnknownType, Variable from decompiler.structures.pseudo.complextypes import Class, ComplexTypeMember, ComplexTypeName, Enum, Struct from decompiler.structures.pseudo.complextypes import Union as Union_ @@ -101,25 +101,21 @@ def lift_struct(self, struct: StructureType, name: str = None, **kwargs) -> Unio return cached_type """Lift struct or union type.""" - match struct.type: - case StructureVariant.StructStructureType: - keyword, type, to_members = "struct", Struct, lambda members: {m.offset: m for m in members} - case StructureVariant.UnionStructureType: - keyword, type, to_members = "union", Union_, lambda members: members - case StructureVariant.ClassStructureType: - keyword, type, to_members = "class", Class, lambda members: members - case _: - raise RuntimeError(f"Unknown struct type {struct.type.name}") + if struct.type == StructureVariant.StructStructureType: + keyword, type, members = "struct", Struct, {} + elif struct.type == StructureVariant.UnionStructureType: + keyword, type, members = "union", Union_, [] + elif struct.type == StructureVariant.ClassStructureType: + keyword, type, members = "class", Class, {} + else: + raise RuntimeError(f"Unknown struct type {struct.type.name}") type_name = self._get_data_type_name(struct, keyword=keyword, provided_name=name) - - members: list[ComplexTypeMember] = [] - for member in struct.members: - members.append(self.lift_struct_member(member, type_name)) - - lifted_struct = type(struct.width * self.BYTE_SIZE, type_name, to_members(members)) + lifted_struct = type(struct.width * self.BYTE_SIZE, type_name, members) self._lifter.complex_types.add(lifted_struct, type_id) + for member in struct.members: + lifted_struct.add_member(self.lift_struct_member(member, type_name)) return lifted_struct @abstractmethod diff --git a/decompiler/structures/pseudo/complextypes.py b/decompiler/structures/pseudo/complextypes.py index 61e86353..3c304053 100644 --- a/decompiler/structures/pseudo/complextypes.py +++ b/decompiler/structures/pseudo/complextypes.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional from decompiler.structures.pseudo.typing import Type -from pydot import frozendict class ComplexTypeSpecifier(Enum): @@ -58,9 +57,12 @@ def declaration(self) -> str: class _BaseStruct(ComplexType): """Class representing a struct type.""" - members: frozendict[int, ComplexTypeMember] = field(compare=False) + members: Dict[int, ComplexTypeMember] = field(compare=False) type_specifier: ComplexTypeSpecifier + def add_member(self, member: ComplexTypeMember): + self.members[member.offset] = member + def get_member_by_offset(self, offset: int) -> Optional[ComplexTypeMember]: return self.members.get(offset) @@ -93,6 +95,9 @@ class Union(ComplexType): members: List[ComplexTypeMember] = field(compare=False) type_specifier = ComplexTypeSpecifier.UNION + def add_member(self, member: ComplexTypeMember): + self.members.append(member) + def declaration(self) -> str: members = ";\n\t".join(x.declaration() for x in self.members) + ";" return f"{self.type_specifier.value} {self.name} {{\n\t{members}\n}}" @@ -118,6 +123,9 @@ class Enum(ComplexType): members: Dict[int, ComplexTypeMember] = field(compare=False) type_specifier = ComplexTypeSpecifier.ENUM + def add_member(self, member: ComplexTypeMember): + self.members[member.value] = member + def get_name_by_value(self, value: int) -> Optional[str]: member = self.members.get(value) return member.name if member is not None else None diff --git a/tests/structures/pseudo/test_complextypes.py b/tests/structures/pseudo/test_complextypes.py index 3fc0b970..7b09e82b 100644 --- a/tests/structures/pseudo/test_complextypes.py +++ b/tests/structures/pseudo/test_complextypes.py @@ -11,7 +11,6 @@ Union, UniqueNameProvider, ) -from pydot import frozendict class TestStruct: @@ -61,17 +60,11 @@ def test_get_complex_type_name(self, book): class TestClass: - def test_declaration(self, record_id: Union): - m = ComplexTypeMember(size=64, name="id", offset=12, type=record_id) - class_book = Struct( - name="Book", - members=frozendict({ - 0: ComplexTypeMember(size=32, name="title", offset=0, type=Pointer(Integer.char())), - 4: ComplexTypeMember(size=32, name="num_pages", offset=4, type=Integer.int32_t()), - 8: ComplexTypeMember(size=32, name="author", offset=8, type=Pointer(Integer.char())), - 12: m - }), - size=96, + def test_declaration(self, class_book: Struct, record_id: Union): + assert class_book.declaration() == "class ClassBook {\n\tchar * title;\n\tint num_pages;\n\tchar * author;\n}" + # nest complex type + class_book.add_member( + m := ComplexTypeMember(size=64, name="id", offset=12, type=record_id), ) result = f"class ClassBook {{\n\tchar * title;\n\tint num_pages;\n\tchar * author;\n\t{m.declaration()};\n}}" assert class_book.declaration() == result