Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling of decompilation tasks #391

Merged
merged 17 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def decompile(bv: BinaryView, function: Function):
"""Decompile the target mlil_function."""
decompiler = Decompiler.from_raw(bv)
options = Options.from_gui()
task = decompiler.decompile(function, options)
task, code = decompiler.decompile(function, task_options=options)
show_html_report(
f"decompile {task.name}", DecoratedCode.generate_html_from_code(task.code, task.options.getstring("code-generator.style_plugin"))
f"decompile {task.name}",
DecoratedCode.generate_html_from_code(code, options.getstring("code-generator.style_plugin")),
)


Expand Down
48 changes: 25 additions & 23 deletions decompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
"""Main decompiler Interface."""
from __future__ import annotations

from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from typing import Collection, Optional

from decompiler.backend.codegenerator import CodeGenerator
from decompiler.frontend import BinaryninjaFrontend, Frontend
Expand Down Expand Up @@ -40,33 +41,34 @@ def from_raw(cls, data, frontend: Frontend = BinaryninjaFrontend) -> Decompiler:
"""Create a decompiler instance from existing frontend instance (e.g. a binaryninja view)."""
return cls(frontend.from_raw(data))

def decompile(self, function: str, task_options: Optional[Options] = None) -> DecompilerTask:
"""Decompile the target function."""
# Sanity check to ensure task_options is populated
def decompile_all(self, function_ids: Collection[object] | None = None, task_options: Options | None = None) -> Result:
rihi marked this conversation as resolved.
Show resolved Hide resolved
if function_ids is None: # decompile all functions when none are specified
function_ids = self._frontend.get_all_function_names()
if task_options is None:
task_options = Decompiler.create_options()
# Start decompiling
pipeline = DecompilerPipeline.from_strings(task_options.getlist("pipeline.cfg_stages"), task_options.getlist("pipeline.ast_stages"))
task = self._frontend.create_task(function, task_options)
pipeline.run(task)
task.code = self._backend.generate([task])
return task

def decompile_all(self, task_options: Optional[Options] = None) -> str:
"""Decompile all functions in the binary"""
tasks = list()
# Sanity check to ensure task_options is populated
if task_options is None:
task_options = Decompiler.create_options()
# Start decompiling

pipeline = DecompilerPipeline.from_strings(task_options.getlist("pipeline.cfg_stages"), task_options.getlist("pipeline.ast_stages"))
functions = self._frontend.get_all_function_names()
for function in functions:
task = self._frontend.create_task(function, task_options)
pipeline.run(task)

tasks = []
for func_id in function_ids:
task = DecompilerTask(str(func_id), func_id, task_options)
tasks.append(task)

self._frontend.lift(task)
pipeline.run(task)

code = self._backend.generate(tasks)
return code

return Decompiler.Result(tasks, code)

def decompile(self, function_id: object, task_options: Options | None = None) -> tuple[DecompilerTask, str]:
rihi marked this conversation as resolved.
Show resolved Hide resolved
result = self.decompile_all([function_id], task_options)
return result.tasks[0], result.code

@dataclass
class Result:
tasks: list[DecompilerTask]
code: str


"""When invoked as a script, run the commandline interface."""
Expand Down
29 changes: 20 additions & 9 deletions decompiler/backend/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,23 @@ def generate(self, tasks: Iterable[DecompilerTask], run_cleanup: bool = True):

def generate_function(self, task: DecompilerTask) -> str:
"""Generate C-Code for the function described in the given DecompilerTask."""
return self.TEMPLATE.substitute(
return_type=task.function_return_type,
name=task.name,
parameters=", ".join(
map(lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]), task.function_parameters)
),
local_declarations=LocalDeclarationGenerator.from_task(task) if not task.failed else "",
function_body=CodeVisitor(task).visit(task.syntax_tree.root) if not task.failed else task.failure_message,
)
if task.failed:
return self.generate_failure_message(task)
else:
return self.TEMPLATE.substitute(
return_type=task.function_return_type,
name=task.name,
parameters=", ".join(
map(lambda param: CExpressionGenerator.format_variables_declaration(param.type, [param.name]), task.function_parameters)
),
local_declarations=LocalDeclarationGenerator.from_task(task),
function_body=CodeVisitor(task).visit(task.syntax_tree.root),
)

@staticmethod
def generate_failure_message(task: DecompilerTask):
"""Returns the message to be shown for a failed task."""
msg = f"Failed to decompile {task.name}"
if origin := task.failure_origin: # checks if the string is empty (should never be None when this method is called)
msg += f" due to error during {origin}."
return msg
170 changes: 72 additions & 98 deletions decompiler/frontend/binaryninja/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
from __future__ import annotations

import logging
from typing import List, Optional, Tuple, Union

from binaryninja import BinaryView, Function, load
import binaryninja
from binaryninja import BinaryView
from binaryninja.types import SymbolType
from decompiler.structures.graphs.cfg import ControlFlowGraph
from decompiler.structures.pseudo.complextypes import ComplexTypeMap
from decompiler.structures.pseudo.expressions import Variable
from decompiler.structures.pseudo.typing import Type
from decompiler.task import DecompilerTask
from decompiler.util.options import Options

Expand All @@ -20,72 +16,6 @@
from .tagging import CompilerIdiomsTagging


class FunctionObject:
"""Wrapper class for dealing with Binaryninja Functions"""

def __init__(self, function: Function):
self._function = function
self._lifter = BinaryninjaLifter()
self._name = self._lifter.lift(self._function.symbol).name

@classmethod
def get(cls, bv: BinaryView, identifier: Union[str, Function]) -> FunctionObject:
"""Get a function object from the given identifier."""
if isinstance(identifier, Function):
return cls(identifier)
if isinstance(identifier, str):
return cls.from_string(bv, identifier)
raise ValueError(f"Could not parse function identifier of type {type(identifier)}.")

@classmethod
def from_string(cls, bv: BinaryView, function_name: str) -> FunctionObject:
"""Given a function identifier, locate Function object in BinaryView"""
if (function := cls._resolve_by_identifier_name(bv, function_name)) is not None:
return cls(function)
if (function := cls._resolve_by_address(bv, function_name)) is not None:
return cls(function)
raise RuntimeError(f"Frontend could not resolve function '{function_name}'")

@property
def function(self) -> Function:
"""Function object"""
return self._function

@property
def name(self) -> str:
"""Name of function object"""
return self._name

@property
def return_type(self) -> Type:
"""Lifted return type of function"""
return self._lifter.lift(self._function.type.return_value)

@property
def params(self) -> List[Variable]:
"""Lifted function parameters"""
return [self._lifter.lift(param) for param in self._function.type.parameters]

@staticmethod
def _resolve_by_identifier_name(bv: BinaryView, function_name: str) -> Optional[Function]:
"""
Iterate BinaryView.functions and compare matching names.

note: we take this approach since bv.get_functions_by_name() may return wrong functions.
"""
return next(filter(lambda f: f.name == function_name, bv.functions), None)

@staticmethod
def _resolve_by_address(bv: BinaryView, hex_str: str) -> Optional[Function]:
"""Get Function object by hex address or 'sub_<address>'"""
try:
hex_address = hex_str[4:] if hex_str.startswith("sub_") else hex_str
address = int(hex_address, 16)
return bv.get_function_at(address)
except ValueError:
logging.info(f"{hex_str} does not contain hex value")


class BinaryninjaFrontend(Frontend):
"""Frontend implementation for binaryninja."""

Expand Down Expand Up @@ -114,7 +44,7 @@ def __init__(self, bv: BinaryView):
def from_path(cls, path: str, options: Options):
"""Create a frontend object by invoking binaryninja on the given sample."""
file_options = {"analysis.limits.maxFunctionSize": options.getint("binaryninja.max_function_size")}
if (bv := load(path, options=file_options)) is not None:
if (bv := binaryninja.load(path, options=file_options)) is not None:
return cls(bv)
raise RuntimeError("Failed to create binary view")

Expand All @@ -123,30 +53,28 @@ def from_raw(cls, view: BinaryView):
"""Create a binaryninja frontend instance based on an initialized binary view."""
return cls(view)

def create_task(self, function_identifier: Union[str, Function], options: Options) -> DecompilerTask:
"""Create a task from the given function identifier."""
function = FunctionObject.get(self._bv, function_identifier)
tagging = CompilerIdiomsTagging(self._bv, function.function.start, options)
tagging.run()
def lift(self, task: DecompilerTask):
ebehner marked this conversation as resolved.
Show resolved Hide resolved
if task.failed:
ebehner marked this conversation as resolved.
Show resolved Hide resolved
return

try:
cfg, complex_types = self._extract_cfg(function.function, options)
task = DecompilerTask(
function.name,
cfg,
function_return_type=function.return_type,
function_parameters=function.params,
options=options,
complex_types=complex_types,
)
function = self._get_binninja_function(task.function_identifier)
lifter, parser = self._create_lifter_parser(task.options)

task.function_return_type = lifter.lift(function.return_type)
task.function_parameters = [lifter.lift(param_type) for param_type in function.type.parameters]

tagging = CompilerIdiomsTagging(self._bv, function.start, task.options)
tagging.run()
rihi marked this conversation as resolved.
Show resolved Hide resolved

task.cfg = parser.parse(function)
task.complex_types = parser.complex_types
except Exception as e:
task = DecompilerTask(
function.name, None, function_return_type=function.return_type, function_parameters=function.params, options=options
)
task.fail(origin="CFG creation")
logging.error(f"Failed to decompile {task.name}, error during CFG creation: {e}")
if options.getboolean("pipeline.debug", fallback=False):
task.fail("Function lifting")
logging.exception(f"Failed to decompile {task.name}, error during function lifting")

if task.options.getboolean("pipeline.debug", fallback=False):
raise e
return task

def get_all_function_names(self):
"""Returns the entire list of all function names in the binary. Ignores blacklisted functions and imported functions."""
Expand All @@ -159,9 +87,55 @@ def get_all_function_names(self):
functions.append(function.name)
return functions

def _extract_cfg(self, function: Function, options: Options) -> Tuple[ControlFlowGraph, ComplexTypeMap]:
"""Extract a control flow graph utilizing the parser and fixing it afterwards."""
def _get_binninja_function(self, function_identifier: object) -> binaryninja.function.Function:
rihi marked this conversation as resolved.
Show resolved Hide resolved
function: binaryninja.function.Function | None
ebehner marked this conversation as resolved.
Show resolved Hide resolved
match function_identifier:
case str():
function = self._get_binninja_function_from_string(function_identifier)
case binaryninja.function.Function():
function = function_identifier
case _:
raise ValueError(f"BNinja frontend can't handle function identifier of type {type(function_identifier)}")

if function is None:
raise RuntimeError(f"BNinja frontend could not resolve function with identifier '{function_identifier}'")

if function.analysis_skipped:
raise RuntimeError(
f"BNinja skipped function analysis for function '{function.name}' with reason '{function.analysis_skip_reason.name}'"
)

return function

def _get_binninja_function_from_string(self, function_name: str) -> binaryninja.function.Function | None:
"""Given a function string identifier, locate Function object in BinaryView"""
if (function := self._resolve_by_identifier_name(function_name)) is not None:
return function
if (function := self._resolve_by_address(function_name)) is not None:
return function

return None

def _resolve_by_identifier_name(self, function_name: str) -> binaryninja.function.Function | None:
"""
Iterate BinaryView.functions and compare matching names.

note: we take this approach since bv.get_functions_by_name() may return wrong functions.
"""
return next(filter(lambda f: f.name == function_name, self._bv.functions), None)

def _resolve_by_address(self, hex_str: str) -> binaryninja.function.Function | None:
"""Get Function object by hex address or 'sub_<address>'"""
try:
hex_address = hex_str[4:] if hex_str.startswith("sub_") else hex_str
address = int(hex_address, 16)
return self._bv.get_function_at(address)
except ValueError:
logging.info(f"{hex_str} does not contain hex value")

def _create_lifter_parser(self, options: Options) -> tuple[BinaryninjaLifter, BinaryninjaParser]:
report_threshold = options.getint("lifter.report_threshold", fallback=3)
no_masks = options.getboolean("lifter.no_bit_masks", fallback=True)
parser = BinaryninjaParser(BinaryninjaLifter(no_masks, bv=function.view), report_threshold)
return parser.parse(function), parser.complex_types
lifter = BinaryninjaLifter(no_masks, bv=self._bv)
parser = BinaryninjaParser(lifter, report_threshold)
return lifter, parser
4 changes: 2 additions & 2 deletions decompiler/frontend/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def from_path(cls, path: str, options: Options) -> Frontend:
"""

@abstractmethod
def create_task(self, function_identifier: str, options: Options) -> DecompilerTask:
"""Create a task from the given function identifier."""
def lift(self, task: DecompilerTask):
"""Lift function data into task object."""

@abstractmethod
def get_all_function_names(self) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def run(self, task: DecompilerTask):
for_loop_names: List[str] = task.options.getlist("loop-name-generator.for_loop_variable_names", fallback=[])

if rename_while_loops:
WhileLoopVariableRenamer(task._ast).rename()
WhileLoopVariableRenamer(task.ast).rename()

if for_loop_names:
ForLoopVariableRenamer(task._ast, for_loop_names).rename()
ForLoopVariableRenamer(task.ast, for_loop_names).rename()
4 changes: 2 additions & 2 deletions decompiler/pipeline/controlflowanalysis/restructuring.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def run(self, task: DecompilerTask):
assert len(self.t_cfg) == 1, f"The Transition Graph can only have one node after the restructuring."
self.asforest.set_current_root(self.t_cfg.root.ast)
assert (roots := len(self.asforest.get_roots)) == 1, f"After the restructuring the forest should have one root, but it has {roots}!"
task._ast = AbstractSyntaxTree.from_asforest(self.asforest, self.asforest.current_root)
task._cfg = None
task.ast = AbstractSyntaxTree.from_asforest(self.asforest, self.asforest.current_root)
task.cfg = None

def restructure_cfg(self) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional

from decompiler.pipeline.stage import PipelineStage
from decompiler.structures.ast.ast_nodes import ConditionNode, LoopNode
Expand Down Expand Up @@ -75,9 +75,9 @@ class RenamingScheme(ABC):

def __init__(self, task: DecompilerTask) -> None:
"""Collets all needed variables for renaming + filters already renamed + function arguments out"""
collector = VariableCollector(task._ast.condition_map)
collector.visit_ast(task._ast)
self._params: List[Variable] = task._function_parameters
collector = VariableCollector(task.ast.condition_map)
collector.visit_ast(task.ast)
self._params: List[Variable] = task.function_parameters
self._loop_vars: List[Variable] = collector.get_loop_variables()
self._variables: List[Variable] = list(filter(self._filter_variables, collector.get_variables()))

Expand Down
Loading