diff --git a/angr/analyses/callee_cleanup_finder.py b/angr/analyses/callee_cleanup_finder.py index 661d2b77421..5bc47d67de6 100644 --- a/angr/analyses/callee_cleanup_finder.py +++ b/angr/analyses/callee_cleanup_finder.py @@ -27,13 +27,14 @@ def __init__(self, starts=None, hook_all=False): l.error("Function at %#x has a misaligned return?", addr) continue args = size // self.project.arch.bytes - cc = self.project.factory.cc_from_arg_kinds([False]*args) + cc = self.project.factory.cc() + prototype = cc.guess_prototype([0]*args) cc.CALLEE_CLEANUP = True sym = self.project.loader.find_symbol(addr) name = sym.name if sym is not None else None lib = self.project.loader.find_object_containing(addr) libname = lib.provides if lib is not None else None - self.project.hook(addr, SIM_PROCEDURES['stubs']['ReturnUnconstrained'](cc=cc, display_name=name, library_name=libname, is_stub=True)) + self.project.hook(addr, SIM_PROCEDURES['stubs']['ReturnUnconstrained'](cc=cc, prototype=prototype, display_name=name, library_name=libname, is_stub=True)) def analyze(self, addr): seen = set() diff --git a/angr/analyses/calling_convention.py b/angr/analyses/calling_convention.py index 3f633b51d8c..8e498ae9b9e 100644 --- a/angr/analyses/calling_convention.py +++ b/angr/analyses/calling_convention.py @@ -5,7 +5,8 @@ import networkx from archinfo.arch_arm import is_arm_arch -from ..calling_conventions import SimRegArg, SimStackArg, SimCC, DefaultCC +from ..calling_conventions import SimFunctionArgument, SimRegArg, SimStackArg, SimCC, DefaultCC +from ..sim_type import SimTypeInt, SimTypeFunction, SimType, SimTypeLongLong, SimTypeShort, SimTypeChar, SimTypeBottom from ..sim_variable import SimStackVariable, SimRegisterVariable from ..knowledge_plugins.key_definitions.atoms import Register, MemoryLocation, SpOffset from ..knowledge_plugins.key_definitions.constants import OP_BEFORE, OP_AFTER @@ -46,7 +47,7 @@ class UpdateArgumentsOption: class CallingConventionAnalysis(Analysis): """ - Analyze the calling convention of functions. + Analyze the calling convention of a function and guess a probable prototype. The calling convention of a function can be inferred at both its call sites and the function itself. At call sites, we consider all register and stack variables that are not alive after the function call as parameters to this @@ -72,6 +73,7 @@ def __init__(self, func: 'Function', cfg: Optional['CFGModel']=None, analyze_cal self.analyze_callsites = analyze_callsites self.cc: Optional[SimCC] = None + self.prototype: Optional[SimTypeFunction] = None if self._cfg is None and 'CFGFast' in self.kb.cfgs: self._cfg = self.kb.cfgs['CFGFast'] @@ -85,28 +87,37 @@ def _analyze(self): if self._function.is_simprocedure: self.cc = self._function.calling_convention + self.prototype = self._function.prototype if self.cc is None: callsite_facts = self._analyze_callsites(max_analyzing_callsites=1) cc = DefaultCC[self.project.arch.name](self.project.arch) - cc = self._adjust_cc(cc, callsite_facts, update_arguments=UpdateArgumentsOption.AlwaysUpdate) + prototype = self._adjust_prototype(self.prototype, callsite_facts, + update_arguments=UpdateArgumentsOption.AlwaysUpdate) self.cc = cc + self.prototype = prototype return if self._function.is_plt: - self.cc = self._analyze_plt() + r = self._analyze_plt() + if r is not None: + self.cc, self.prototype = r return - cc = self._analyze_function() - if self.analyze_callsites: - # only take the first 3 because running reaching definition analysis on all functions is costly - callsite_facts = self._analyze_callsites(max_analyzing_callsites=3) - cc = self._adjust_cc(cc, callsite_facts, update_arguments=UpdateArgumentsOption.UpdateWhenCCHasNoArgs) - - if cc is None: + r = self._analyze_function() + if r is None: l.warning('Cannot determine calling convention for %r.', self._function) - - self.cc = cc - - def _analyze_plt(self) -> Optional[SimCC]: + else: + # adjust prototype if needed + cc, prototype = r + if self.analyze_callsites: + # only take the first 3 because running reaching definition analysis on all functions is costly + callsite_facts = self._analyze_callsites(max_analyzing_callsites=3) + prototype = self._adjust_prototype(prototype, callsite_facts, + update_arguments=UpdateArgumentsOption.UpdateWhenCCHasNoArgs) + + self.cc = cc + self.prototype = prototype + + def _analyze_plt(self) -> Optional[Tuple[SimCC,SimTypeFunction]]: """ Get the calling convention for a PLT stub. @@ -135,18 +146,20 @@ def _analyze_plt(self) -> Optional[SimCC]: if real_func is not None: if real_func.is_simprocedure and self.project.is_hooked(real_func.addr): hooker = self.project.hooked_by(real_func.addr) - if hooker is not None and (not hooker.is_stub or real_func.calling_convention.func_ty is not None): - return real_func.calling_convention + if hooker is not None and not hooker.is_stub: + return real_func.calling_convention, real_func.prototype else: - return real_func.calling_convention + return real_func.calling_convention, real_func.prototype # determine the calling convention by analyzing its callsites callsite_facts = self._analyze_callsites(max_analyzing_callsites=1) cc = DefaultCC[self.project.arch.name](self.project.arch) - cc = self._adjust_cc(cc, callsite_facts, update_arguments=UpdateArgumentsOption.AlwaysUpdate) - return cc + prototype = SimTypeFunction([ ], None) + prototype = self._adjust_prototype(prototype, callsite_facts, + update_arguments=UpdateArgumentsOption.AlwaysUpdate) + return cc, prototype - def _analyze_function(self) -> Optional[SimCC]: + def _analyze_function(self) -> Optional[Tuple[SimCC,SimTypeFunction]]: """ Go over the variable information in variable manager for this function, and return all uninitialized register/stack variables. @@ -174,15 +187,13 @@ def _analyze_function(self) -> Optional[SimCC]: if cc is None: l.warning('_analyze_function(): Cannot find a calling convention for %r that fits the given arguments.', self._function) + return None else: # reorder args args = self._reorder_args(input_args, cc) - cc.args = args + prototype = SimTypeFunction([self._guess_arg_type(arg) for arg in args], SimTypeInt()) - # set return value - cc.ret_val = cc.return_val - - return cc + return cc, prototype def _analyze_callsites(self, max_analyzing_callsites: int=3) -> List[CallSiteFact]: # pylint:disable=no-self-use """ @@ -327,9 +338,9 @@ def _analyze_callsite_arguments(self, defs_by_stack_offset = dict((-d.atom.addr.offset, d) for d in all_stack_defs if isinstance(d.atom, MemoryLocation) and isinstance(d.atom.addr, SpOffset)) - arg_session = default_cc.arg_session + arg_session = default_cc.arg_session(SimTypeInt().with_arch(self.project.arch)) for _ in range(30): # at most 30 arguments - arg_loc = arg_session.next_arg(False) + arg_loc = default_cc.next_arg(arg_session, SimTypeInt().with_arch(self.project.arch)) if isinstance(arg_loc, SimRegArg): reg_offset = self.project.arch.registers[arg_loc.reg_name][0] # is it initialized? @@ -347,28 +358,25 @@ def _analyze_callsite_arguments(self, else: break - @staticmethod - def _adjust_cc(cc: SimCC, facts: List[CallSiteFact], - update_arguments: int=UpdateArgumentsOption.DoNotUpdate): + def _adjust_prototype(self, proto: Optional[SimTypeFunction], facts: List[CallSiteFact], + update_arguments: int=UpdateArgumentsOption.DoNotUpdate) -> Optional[SimTypeFunction]: - if cc is None: - return cc + if proto is None: + return None # is the return value used anywhere? if facts and all(fact.return_value_used is False for fact in facts): - cc.ret_val = None - else: - cc.ret_val = cc.RETURN_VAL + proto.returnty = None if update_arguments == UpdateArgumentsOption.AlwaysUpdate or ( update_arguments == UpdateArgumentsOption.UpdateWhenCCHasNoArgs and - not cc.args + not proto.args ): if len(set(len(fact.args) for fact in facts)) == 1: fact = next(iter(facts)) - cc.args = fact.args + proto.args = [self._guess_arg_type(arg) for arg in fact.args] - return cc + return proto def _args_from_vars(self, variables: List, var_manager): """ @@ -516,5 +524,18 @@ def _reorder_args(self, args, cc): return reg_args + args + stack_args + def _guess_arg_type(self, arg: SimFunctionArgument) -> SimType: + if arg.size == 4: + return SimTypeInt() + elif arg.size == 8: + return SimTypeLongLong() + elif arg.size == 2: + return SimTypeShort() + elif arg.size == 1: + return SimTypeChar() + else: + # Unsupported for now + return SimTypeBottom() + register_analysis(CallingConventionAnalysis, "CallingConvention") diff --git a/angr/analyses/complete_calling_conventions.py b/angr/analyses/complete_calling_conventions.py index c9cda9fe204..81de86b5cb9 100644 --- a/angr/analyses/complete_calling_conventions.py +++ b/angr/analyses/complete_calling_conventions.py @@ -65,8 +65,10 @@ def _analyze(self): cc_analysis = self.project.analyses.CallingConvention(func, cfg=self._cfg, analyze_callsites=self._analyze_callsites) if cc_analysis.cc is not None: - _l.info("Determined calling convention for %r.", func) + _l.info("Determined calling convention and prototype for %r.", func) func.calling_convention = cc_analysis.cc + func.prototype = cc_analysis.prototype + func.is_prototype_guessed = True else: _l.info("Cannot determine calling convention for %r.", func) diff --git a/angr/analyses/decompiler/callsite_maker.py b/angr/analyses/decompiler/callsite_maker.py index 90e1732157c..2dfc9200842 100644 --- a/angr/analyses/decompiler/callsite_maker.py +++ b/angr/analyses/decompiler/callsite_maker.py @@ -1,4 +1,5 @@ from typing import Optional, List, Tuple, Any, Set, TYPE_CHECKING +import copy import logging import archinfo @@ -6,8 +7,8 @@ from ...procedures.stubs.format_parser import FormatParser, FormatSpecifier from ...errors import SimMemoryMissingError -from ...sim_type import SimTypeBottom, SimTypePointer, SimTypeChar -from ...calling_conventions import SimRegArg, SimStackArg +from ...sim_type import SimTypeBottom, SimTypePointer, SimTypeChar, SimTypeInt +from ...calling_conventions import SimRegArg, SimStackArg, SimCC from ...knowledge_plugins.key_definitions.constants import OP_BEFORE from ...knowledge_plugins.key_definitions.definition import Definition from .. import Analysis, register_analysis @@ -65,30 +66,28 @@ def _analyze(self): args = [ ] arg_locs = None - if func.calling_convention is None: + if cc is None: l.warning('%s has an unknown calling convention.', repr(func)) else: stackarg_sp_diff = func.calling_convention.STACKARG_SP_DIFF - if func.prototype is not None: + if prototype is not None: # Make arguments - arg_locs = func.calling_convention.arg_locs() - if func.prototype.variadic: + arg_locs = cc.arg_locs(prototype) + if prototype.variadic: # determine the number of variadic arguments - variadic_args = self._determine_variadic_arguments(func, func.calling_convention, last_stmt) + variadic_args = self._determine_variadic_arguments(func, cc, last_stmt) if variadic_args: - arg_sizes = [arg.size // self.project.arch.byte_width for arg in func.prototype.args] + \ - ([self.project.arch.bytes] * variadic_args) - is_fp = [False] * len(arg_sizes) - arg_locs = func.calling_convention.arg_locs(is_fp=is_fp, sizes=arg_sizes) - else: - if func.calling_convention.args is not None: - arg_locs = func.calling_convention.arg_locs() + callsite_ty = copy.copy(prototype) + callsite_ty.args = list(callsite_ty.args) + for i in range(variadic_args): + callsite_ty.args.append(SimTypeInt().with_arch(self.project.arch)) + arg_locs = cc.arg_locs(callsite_ty) if arg_locs is not None: for arg_loc in arg_locs: if type(arg_loc) is SimRegArg: size = arg_loc.size - offset = arg_loc._fix_offset(None, size, arch=self.project.arch) + offset = arg_loc.check_offset(cc.arch) _, the_arg = self._resolve_register_argument(last_stmt, arg_loc) @@ -194,7 +193,7 @@ def _find_variable_from_definition(self, def_): def _resolve_register_argument(self, call_stmt, arg_loc) -> Tuple: size = arg_loc.size - offset = arg_loc._fix_offset(None, size, arch=self.project.arch) + offset = arg_loc.check_offset(self.project.arch) if self._reaching_definitions is not None: # Find its definition @@ -310,9 +309,7 @@ def _determine_variadic_arguments_for_format_strings(self, func, cc: 'SimCC', ca fmt_str = None min_arg_count = (max(potential_fmt_args) + 1) - arg_locs = cc.arg_locs(is_fp=[False] * min_arg_count, - sizes=[self.project.arch.bytes] * min_arg_count - ) + arg_locs = cc.arg_locs(SimCC.guess_prototype([0]*min_arg_count, proto)) for fmt_arg_idx in potential_fmt_args: arg_loc = arg_locs[fmt_arg_idx] diff --git a/angr/analyses/decompiler/clinic.py b/angr/analyses/decompiler/clinic.py index 447f46e7941..79d89ff6a5b 100644 --- a/angr/analyses/decompiler/clinic.py +++ b/angr/analyses/decompiler/clinic.py @@ -431,8 +431,8 @@ def _updatedict_handler(node): @timethis def _make_argument_list(self) -> List[SimVariable]: - if self.function.calling_convention is not None: - args: List[SimFunctionArgument] = self.function.calling_convention.args + if self.function.calling_convention is not None and self.function.prototype is not None: + args: List[SimFunctionArgument] = self.function.calling_convention.arg_locs(self.function.prototype) arg_vars: List[SimVariable] = [ ] if args: for idx, arg in enumerate(args): @@ -512,9 +512,10 @@ def _make_returns(self, ail_graph: networkx.DiGraph) -> networkx.DiGraph: def _handle_Return(stmt_idx: int, stmt: ailment.Stmt.Return, block: Optional[ailment.Block]): # pylint:disable=unused-argument if block is not None \ and not stmt.ret_exprs \ - and self.function.calling_convention.ret_val is not None: + and self.function.prototype is not None \ + and type(self.function.prototype.returnty) is not SimTypeBottom: new_stmt = stmt.copy() - ret_val = self.function.calling_convention.ret_val + ret_val = self.function.calling_convention.return_val(self.function.prototype.returnty) if isinstance(ret_val, SimRegArg): reg = self.project.arch.registers[ret_val.reg_name] new_stmt.ret_exprs.append(ailment.Expr.Register( @@ -546,7 +547,7 @@ def _handler(block): @timethis def _make_function_prototype(self, arg_list: List[SimVariable], variable_kb): - if self.function.prototype is not None: + if self.function.prototype is not None and not self.function.is_prototype_guessed: # do not overwrite an existing function prototype # if you want to re-generate the prototype, clear the existing one first return @@ -574,12 +575,11 @@ def _make_function_prototype(self, arg_list: List[SimVariable], variable_kb): func_args.append(func_arg) - if self.function.calling_convention is not None and self.function.calling_convention.ret_val is None: - returnty = SimTypeBottom(label="void") - else: - returnty = SimTypeInt() + # TODO: need a new method of determining whether a function returns void + returnty = SimTypeInt() self.function.prototype = SimTypeFunction(func_args, returnty).with_arch(self.project.arch) + self.function.is_prototype_guessed = False @timethis def _recover_and_link_variables(self, ail_graph, arg_list): diff --git a/angr/analyses/identifier/custom_callable.py b/angr/analyses/identifier/custom_callable.py index b33a33ababa..24ac2ed5dd8 100644 --- a/angr/analyses/identifier/custom_callable.py +++ b/angr/analyses/identifier/custom_callable.py @@ -57,24 +57,31 @@ def set_base_state(self, state): self._base_state = state def __call__(self, *args): - self.perform_call(*args) + prototype = self._cc.guess_prototype(args) + self.perform_call(*args, prototype=prototype) if self.result_state is not None: - return self.result_state.solver.simplify(self._cc.get_return_val(self.result_state, stack_base=self.result_state.regs.sp - self._cc.STACKARG_SP_DIFF)) + loc = self._cc.return_val(prototype.returnty) + return self.result_state.solver.simplify(loc.get_value(self.result_state, stack_base=self.result_state.regs.sp - self._cc.STACKARG_SP_DIFF)) return None def get_base_state(self, *args): + prototype = self._cc.guess_prototype(args) self._base_state.ip = self._addr state = self._project.factory.call_state(self._addr, *args, + prototype=prototype, cc=self._cc, base_state=self._base_state, ret_addr=self._deadend_addr, toc=self._toc) return state - def perform_call(self, *args): + def perform_call(self, *args, prototype=None): + if prototype is None: + prototype = self._cc.guess_prototype(args) self._base_state.ip = self._addr state = self._project.factory.call_state(self._addr, *args, cc=self._cc, + prototype=prototype, base_state=self._base_state, ret_addr=self._deadend_addr, toc=self._toc) diff --git a/angr/analyses/identifier/functions/memcpy.py b/angr/analyses/identifier/functions/memcpy.py index 4aff475ef91..c865a13a207 100644 --- a/angr/analyses/identifier/functions/memcpy.py +++ b/angr/analyses/identifier/functions/memcpy.py @@ -62,8 +62,8 @@ def pre_test(self, func, runner): s = runner.get_base_call_state(func, test) s.memory.store(0x2000, "ABC\x00\x00\x00\x00\x00") inttype = SimTypeInt(runner.project.arch.bits, False) - func_ty = SimTypeFunction([inttype] * 3, inttype) - cc = runner.project.factory.cc(func_ty=func_ty) + prototype = SimTypeFunction([inttype] * 3, inttype) + cc = runner.project.factory.cc(prototype=prototype) call = IdentifierCallable(runner.project, func.startpoint.addr, concrete_only=True, cc=cc, base_state=s, max_steps=20) _ = call(*[0x2003, 0x2000, 5]) @@ -73,8 +73,8 @@ def pre_test(self, func, runner): s = runner.get_base_call_state(func, test) s.memory.store(0x2000, "\x00\x00\x00\x00\x00CBA") inttype = SimTypeInt(runner.project.arch.bits, False) - func_ty = SimTypeFunction([inttype] * 3, inttype) - cc = runner.project.factory.cc(func_ty=func_ty) + prototype = SimTypeFunction([inttype] * 3, inttype) + cc = runner.project.factory.cc(prototype=prototype) call = IdentifierCallable(runner.project, func.startpoint.addr, concrete_only=True, cc=cc, base_state=s, max_steps=20) _ = call(*[0x2000, 0x2003, 5]) diff --git a/angr/analyses/identifier/runner.py b/angr/analyses/identifier/runner.py index 07144ab796c..9978abf4fd8 100644 --- a/angr/analyses/identifier/runner.py +++ b/angr/analyses/identifier/runner.py @@ -205,9 +205,7 @@ def get_base_call_state(self, function, test_data, initial_state=None, concrete_ raise Exception("Expected int/bytes got %s" % type(i)) mapped_input.append(i) - inttype = SimTypeInt(self.project.arch.bits, False) - func_ty = SimTypeFunction([inttype] * len(mapped_input), inttype) - cc = self.project.factory.cc(func_ty=func_ty) + cc = self.project.factory.cc() call = IdentifierCallable(self.project, function.startpoint.addr, concrete_only=True, cc=cc, base_state=s, max_steps=test_data.max_steps) return call.get_base_state(*mapped_input) @@ -238,9 +236,7 @@ def test(self, function, test_data, concrete_rand=False, custom_offs=None): raise Exception("Expected int/str got %s" % type(i)) mapped_input.append(i) - inttype = SimTypeInt(self.project.arch.bits, False) - func_ty = SimTypeFunction([inttype] * len(mapped_input), inttype) - cc = self.project.factory.cc(func_ty=func_ty) + cc = self.project.factory.cc() try: call = IdentifierCallable(self.project, function.startpoint.addr, concrete_only=True, cc=cc, base_state=s, max_steps=test_data.max_steps) @@ -336,9 +332,7 @@ def get_out_state(self, function, test_data, initial_state=None, concrete_rand=F raise Exception("Expected int/bytes got %s" % type(i)) mapped_input.append(i) - inttype = SimTypeInt(self.project.arch.bits, False) - func_ty = SimTypeFunction([inttype] * len(mapped_input), inttype) - cc = self.project.factory.cc(func_ty=func_ty) + cc = self.project.factory.cc() try: call = IdentifierCallable(self.project, function.startpoint.addr, concrete_only=True, cc=cc, base_state=s, max_steps=test_data.max_steps) diff --git a/angr/analyses/reaching_definitions/engine_ail.py b/angr/analyses/reaching_definitions/engine_ail.py index f9a26d43d94..3a31d6eb03b 100644 --- a/angr/analyses/reaching_definitions/engine_ail.py +++ b/angr/analyses/reaching_definitions/engine_ail.py @@ -173,7 +173,7 @@ def _handle_Call_base(self, stmt: ailment.Stmt.Call, is_expr: bool=False): # getting used expressions from stmt.args used_exprs = stmt.args elif stmt.calling_convention is not None and ( - stmt.calling_convention.func_ty is not None or stmt.calling_convention.args is not None): + stmt.calling_convention.prototype is not None or stmt.calling_convention.args is not None): # getting used expressions from the function prototype, its arguments, and the calling convention used_exprs = [ ] for arg_loc in stmt.calling_convention.arg_locs(): @@ -243,11 +243,12 @@ def _handle_Call_base(self, stmt: ailment.Stmt.Call, is_expr: bool=False): def _ail_handle_Return(self, stmt: ailment.Stmt.Return): # pylint:disable=unused-argument codeloc = self._codeloc() - size = self.project.arch.bits // 8 cc = None + prototype = None if self.state.analysis.subject.type == SubjectType.Function: cc = self.state.analysis.subject.content.calling_convention + prototype = self.state.analysis.subject.content.prototype # import ipdb; ipdb.set_trace() if cc is None: @@ -280,9 +281,14 @@ def _ail_handle_Return(self, stmt: ailment.Stmt.Return): # pylint:disable=unuse # consume registers that are potentially useful # return value - if cc is not None and cc.ret_val is not None: - if isinstance(cc.ret_val, SimRegArg): - offset = cc.ret_val._fix_offset(None, size, arch=self.project.arch) + if cc is not None and prototype is not None and prototype.returnty is not None: + ret_val = cc.return_val(prototype.returnty) + if isinstance(ret_val, SimRegArg): + if ret_val.clear_entire_reg: + offset, size = cc.arch.registers[ret_val.reg_name] + else: + offset = cc.arch.registers[ret_val.reg_name][0] + ret_val.reg_offset + size = ret_val.size self.state.add_use(Register(offset, size), codeloc) else: l.error("Cannot handle CC with non-register return value location") diff --git a/angr/analyses/reaching_definitions/engine_vex.py b/angr/analyses/reaching_definitions/engine_vex.py index 562d81a3932..9c34b3ffd29 100644 --- a/angr/analyses/reaching_definitions/engine_vex.py +++ b/angr/analyses/reaching_definitions/engine_vex.py @@ -1076,6 +1076,7 @@ def _handle_function_core(self, func_addr: Optional[MultiValues], **kwargs) -> b def _handle_function_cc(self, func_addr: Optional[MultiValues]): _cc = None + proto = None func_addr_int: Optional[Union[int,Undefined]] = None if func_addr is not None and self.functions is not None: func_addr_v = func_addr.one_value() @@ -1083,6 +1084,7 @@ def _handle_function_cc(self, func_addr: Optional[MultiValues]): func_addr_int = func_addr_v._model_concrete.value if self.functions.contains_addr(func_addr_int): _cc = self.functions[func_addr_int].calling_convention + proto = self.functions[func_addr_int].prototype cc: SimCC = _cc or DEFAULT_CC.get(self.arch.name, None)(self.arch) @@ -1090,9 +1092,9 @@ def _handle_function_cc(self, func_addr: Optional[MultiValues]): # - add uses for arguments # - kill return value registers # - caller-saving registers - if cc.args: + if proto and proto.args: code_loc = self._codeloc() - for arg in cc.args: + for arg in cc.arg_locs(proto): if isinstance(arg, SimRegArg): reg_offset, reg_size = self.arch.registers[arg.reg_name] atom = Register(reg_offset, reg_size) diff --git a/angr/analyses/stack_pointer_tracker.py b/angr/analyses/stack_pointer_tracker.py index 2ddc91510d1..7c2ee1d40f5 100644 --- a/angr/analyses/stack_pointer_tracker.py +++ b/angr/analyses/stack_pointer_tracker.py @@ -504,7 +504,6 @@ def resolve_stmt(stmt): callees = self._find_callees(node) if callees: callee_cleanups = [callee for callee in callees if callee.calling_convention is not None and - callee.calling_convention.args is not None and callee.calling_convention.CALLEE_CLEANUP] if callee_cleanups: # found callee clean-up cases... diff --git a/angr/analyses/variable_recovery/engine_ail.py b/angr/analyses/variable_recovery/engine_ail.py index bb06bceb4a7..6a6b08c0468 100644 --- a/angr/analyses/variable_recovery/engine_ail.py +++ b/angr/analyses/variable_recovery/engine_ail.py @@ -7,7 +7,7 @@ from ...storage.memory_mixins.paged_memory.pages.multi_values import MultiValues from ...calling_conventions import SimRegArg -from ...sim_type import SimTypeFunction +from ...sim_type import SimTypeFunction, SimTypeBottom from ...engines.light import SimEngineLightAILMixin from ...errors import SimMemoryMissingError from ..typehoon import typeconsts, typevars @@ -87,8 +87,12 @@ def _ail_handle_Call(self, stmt: ailment.Stmt.Call, is_expr=False) -> Optional[R ret_reg_offset = ret_expr.reg_offset else: if stmt.calling_convention is not None: - # return value - ret_expr: SimRegArg = stmt.calling_convention.RETURN_VAL + if stmt.prototype is None: + ret_expr: SimRegArg = stmt.calling_convention.RETURN_VAL + elif stmt.prototype.returnty is None or type(stmt.prototype.returnty) is SimTypeBottom: + ret_expr = None + else: + ret_expr: SimRegArg = stmt.calling_convention.return_val(stmt.prototype.returnty) else: l.debug("Unknown calling convention for function %s. Fall back to default calling convention.", target) ret_expr: SimRegArg = self.project.factory.cc().RETURN_VAL @@ -101,8 +105,8 @@ def _ail_handle_Call(self, stmt: ailment.Stmt.Call, is_expr=False) -> Optional[R # discover the prototype prototype: Optional[SimTypeFunction] = None - if stmt.calling_convention is not None: - prototype = stmt.calling_convention.func_ty + if stmt.prototype is not None: + prototype = stmt.prototype elif isinstance(stmt.target, ailment.Expr.Const): func_addr = stmt.target.value if func_addr in self.kb.functions: diff --git a/angr/analyses/variable_recovery/variable_recovery_fast.py b/angr/analyses/variable_recovery/variable_recovery_fast.py index 44133215eee..b0fa1298e0d 100644 --- a/angr/analyses/variable_recovery/variable_recovery_fast.py +++ b/angr/analyses/variable_recovery/variable_recovery_fast.py @@ -428,11 +428,6 @@ def _process_block(self, state, block): # pylint:disable=no-self-use sp_v = next(iter(next(iter(sp.values.values())))) adjusted = False - cc = self._node_to_cc[block.addr] - if cc is not None and cc.sp_delta is not None: - sp_v += cc.sp_delta - adjusted = True - l.debug('Adjusting stack pointer at end of block %#x with offset %+#x.', block.addr, cc.sp_delta) if not adjusted: # make a guess diff --git a/angr/callable.py b/angr/callable.py index e2c5c131d2d..8205944a8f9 100644 --- a/angr/callable.py +++ b/angr/callable.py @@ -1,7 +1,6 @@ - import pycparser -from .calling_conventions import DEFAULT_CC +from .calling_conventions import DEFAULT_CC, SimCC class Callable(object): @@ -15,13 +14,15 @@ class Callable(object): Otherwise, you can get the resulting simulation manager at callable.result_path_group. """ - def __init__(self, project, addr, concrete_only=False, perform_merge=True, base_state=None, toc=None, cc=None): + def __init__(self, project, addr, prototype=None, concrete_only=False, perform_merge=True, base_state=None, toc=None, + cc=None): """ :param project: The project to operate on :param addr: The address of the function to use The following parameters are optional: + :param prototype: The signature of the calls you would like to make. This really shouldn't be optional. :param concrete_only: Throw an exception if the execution splits into multiple paths :param perform_merge: Merge all result states into one at the end (only relevant if concrete_only=False) :param base_state: The state from which to do these runs @@ -37,6 +38,7 @@ def __init__(self, project, addr, concrete_only=False, perform_merge=True, base_ self._toc = toc self._cc = cc if cc is not None else DEFAULT_CC[project.arch.name](project.arch) self._deadend_addr = project.simos.return_deadend + self._func_ty = prototype self.result_path_group = None self.result_state = None @@ -49,14 +51,19 @@ def set_base_state(self, state): self._base_state = state def __call__(self, *args): - self.perform_call(*args) - if self.result_state is not None: - return self.result_state.solver.simplify(self._cc.get_return_val(self.result_state, stack_base=self.result_state.regs.sp - self._cc.STACKARG_SP_DIFF)) + prototype = SimCC.guess_prototype(args, self._func_ty).with_arch(self._project.arch) + self.perform_call(*args, prototype=prototype) + if self.result_state is not None and prototype.returnty is not None: + loc = self._cc.return_val(prototype.returnty) + val = loc.get_value(self.result_state, stack_base=self.result_state.regs.sp - self._cc.STACKARG_SP_DIFF) + return self.result_state.solver.simplify(val) else: return None - def perform_call(self, *args): + def perform_call(self, *args, prototype=None): + prototype = SimCC.guess_prototype(args, prototype or self._func_ty).with_arch(self._project.arch) state = self._project.factory.call_state(self._addr, *args, + prototype=prototype, cc=self._cc, base_state=self._base_state, ret_addr=self._deadend_addr, diff --git a/angr/calling_conventions.py b/angr/calling_conventions.py index 9e404272604..0c1967ff258 100644 --- a/angr/calling_conventions.py +++ b/angr/calling_conventions.py @@ -1,75 +1,193 @@ import logging +from typing import Union, Optional, List, Dict, Type +from collections import defaultdict import claripy import archinfo from archinfo import RegisterName -from typing import Union, Optional, List, Dict, Type -from .sim_type import SimType -from .sim_type import SimTypeChar -from .sim_type import SimTypePointer -from .sim_type import SimTypeFixedSizeArray -from .sim_type import SimTypeArray -from .sim_type import SimTypeString -from .sim_type import SimTypeFunction -from .sim_type import SimTypeFloat -from .sim_type import SimTypeDouble -from .sim_type import SimTypeReg -from .sim_type import SimTypeInt -from .sim_type import SimTypeBottom -from .sim_type import SimStruct -from .sim_type import SimUnion -from .sim_type import parse_file -from .sim_type import SimTypeTop +from .sim_type import SimType, SimTypeChar, SimTypePointer, SimTypeFixedSizeArray, SimTypeArray, SimTypeString, \ + SimTypeFunction, SimTypeFloat, SimTypeDouble, SimTypeReg, SimStruct, SimStructValue, SimTypeInt, SimTypeNum, \ + SimUnion, SimTypeBottom, parse_signature from .state_plugins.sim_action_object import SimActionObject l = logging.getLogger(name=__name__) from .engines.soot.engine import SootMixin -# TODO: This file contains explicit and implicit byte size assumptions all over. A good attempt to fix them was made. -# If your architecture hails from the astral plane, and you're reading this, start fixing here. - class PointerWrapper: - def __init__(self, value): + def __init__(self, value, buffer=False): self.value = value + self.buffer = buffer class AllocHelper: - def __init__(self, ptrsize, reverse_result): + def __init__(self, ptrsize): self.base = claripy.BVS('alloc_base', ptrsize) self.ptr = self.base - self.reverse_result = reverse_result self.stores = {} - def dump(self, val, state, endness='Iend_BE'): - self.stores[self.ptr.cache_key] = (val, endness) + def alloc(self, size): out = self.ptr - self.ptr += val.length // state.arch.byte_width - return out.reversed if self.reverse_result else out + self.ptr += size + return out + + def dump(self, val, state, loc=None): + if loc is None: + loc = self.stack_loc(val, state.arch) + self.stores[self.ptr.cache_key] = (val, loc) + return self.alloc(self.calc_size(val, state.arch)) def translate(self, val, base): - return val.replace(self.base, base) + if type(val) is SimStructValue: + return SimStructValue(val.struct, { + field: self.translate(subval, base) for field, subval in val.fields.items() + }) + if isinstance(val, claripy.Bits): + return val.replace(self.base, base) + if type(val) is list: + return [self.translate(subval, base) for subval in val] + raise TypeError(type(val)) def apply(self, state, base): - for ptr, (val, endness) in self.stores.items(): - state.memory.store(self.translate(ptr.ast, base), self.translate(val, base), endness=endness) + for ptr, (val, loc) in self.stores.items(): + translated_val = self.translate(val, base) + translated_ptr = self.translate(ptr.ast, base) + loc.set_value(state, translated_val, stack_base=translated_ptr) def size(self): val = self.translate(self.ptr, claripy.BVV(0, len(self.ptr))) assert val.op == 'BVV' return abs(val.args[0]) + @classmethod + def calc_size(cls, val, arch): + if type(val) is SimStructValue: + return val.struct.size // arch.byte_width + if isinstance(val, claripy.Bits): + return len(val) // arch.byte_width + if type(val) is list: + # TODO real strides + if len(val) == 0: + return 0 + return cls.calc_size(val[0], arch) * len(val) + raise TypeError(type(val)) + + @classmethod + def stack_loc(cls, val, arch, offset=0): + if isinstance(val, claripy.Bits): + return SimStackArg(offset, len(val) // arch.byte_width) + if type(val) is list: + # TODO real strides + if len(val) == 0: + return SimArrayArg([]) + stride = cls.calc_size(val[0], arch) + return SimArrayArg([cls.stack_loc(subval, arch, offset + i * stride) for i, subval in enumerate(val)]) + if type(val) is SimStructValue: + return SimStructArg(val.struct, { + field: cls.stack_loc(subval, arch, offset + val.struct.offsets[field]) + for field, subval in val.fields.items() + }) + raise TypeError(type(val)) + +def refine_locs_with_struct_type(arch, locs, arg_type, offset=0): + # CONTRACT FOR USING THIS METHOD: locs must be a list of locs which are all wordsize + # ADDITIONAL NUANCE: this will not respect the need for big-endian integers to be stored at the end of words. + # that's why this is named with_struct_type, because it will blindly trust the offsets given to it. + if isinstance(arg_type, (SimTypeReg, SimTypeNum, SimTypeFloat)): + seen_bytes = 0 + pieces = [] + while seen_bytes < arg_type.size // arch.byte_width: + start_offset = offset + seen_bytes + chunk = start_offset // arch.bytes + chunk_offset = start_offset % arch.bytes + chunk_remaining = arch.bytes - chunk_offset + type_remaining = arg_type.size // arch.byte_width - seen_bytes + use_bytes = min(chunk_remaining, type_remaining) + pieces.append(locs[chunk].refine(size=use_bytes, offset=chunk_offset)) + seen_bytes += use_bytes + + if len(pieces) == 1: + piece = pieces[0] + else: + piece = SimComboArg(pieces) + if isinstance(arg_type, SimTypeFloat): + piece.is_fp = True + return piece + if isinstance(arg_type, SimTypeFixedSizeArray): + # TODO explicit stride + locs = [ + refine_locs_with_struct_type(locs, arg_type.elem_type, offset + i * arg_type.size // arch.byte_width) + for i in range(arg_type.length) + ] + return SimArrayArg(locs) + if isinstance(arg_type, SimStruct): + locs = { + field: refine_locs_with_struct_type(locs, field_ty, offset + arg_type.offsets[field]) for field, field_ty in arg_type.fields.items() + } + return SimStructArg(arg_type, locs) + raise TypeError("I don't know how to lay out a %s" % arg_type) + +class SerializableIterator: + def __iter__(self): + return self + + def __next__(self): + raise NotImplementedError + + def getstate(self): + raise NotImplementedError + + def setstate(self, state): + raise NotImplementedError + +class SerializableListIterator(SerializableIterator): + def __init__(self, lst): + self._lst = lst + self._index = 0 + + def __next__(self): + if self._index >= len(self._lst): + raise StopIteration + result = self._lst[self._index] + self._index += 1 + return result + + def getstate(self): + return self._index + + def setstate(self, state): + self._index = state + +class SerializableCounter(SerializableIterator): + def __init__(self, start, stride, mapping=lambda x: x): + self._next = start + self._stride = stride + self._mapping = mapping + + def __next__(self): + result = self._mapping(self._next) + self._next += self._stride + return result + + def getstate(self): + return self._next + + def setstate(self, state): + self._next = state + class SimFunctionArgument: """ Represent a generic function argument. :ivar int size: The size of the argument, in number of bytes. + :ivar bool is_fp: Whether loads from this location should return a floating point bitvector """ - def __init__(self, size): + def __init__(self, size, is_fp=False): self.size = size + self.is_fp = is_fp def __ne__(self, other): return not self == other @@ -77,12 +195,23 @@ def __ne__(self, other): def __hash__(self): return hash(('function_argument', self.size)) - def check_value(self, value): + def check_value_set(self, value, arch): if not isinstance(value, claripy.ast.Base) and self.size is None: raise TypeError("Only claripy objects may be stored through SimFunctionArgument when size is not provided") - # TODO: this also looks byte related, change to ARCH.byte_width - if self.size is not None and isinstance(value, claripy.ast.Base) and self.size*8 < value.length: + if self.size is not None and isinstance(value, claripy.ast.Base) and self.size*arch.byte_width < value.length: raise TypeError("%s doesn't fit in an argument of size %d" % (value, self.size)) + if isinstance(value, int): + value = claripy.BVV(value, self.size * arch.byte_width) + if isinstance(value, float): + if self.size not in (4, 8): + raise ValueError("What do I do with a float %d bytes long" % self.size) + value = claripy.FPV(value, claripy.FSORT_FLOAT if self.size == 4 else claripy.FSORT_DOUBLE) + return value.raw_to_bv() + + def check_value_get(self, value): + if self.is_fp: + return value.raw_to_fp() + return value def set_value(self, state, value, **kwargs): raise NotImplementedError @@ -90,6 +219,8 @@ def set_value(self, state, value, **kwargs): def get_value(self, state, **kwargs): raise NotImplementedError + def refine(self, size, arch=None, offset=None, is_fp=None): + raise NotImplementedError class SimRegArg(SimFunctionArgument): """ @@ -98,50 +229,49 @@ class SimRegArg(SimFunctionArgument): :ivar string reg_name: The name of the represented register. :ivar int size: The size of the register, in number of bytes. """ - def __init__(self, reg_name: RegisterName, size: int, alt_offsets=None): - SimFunctionArgument.__init__(self, size) + def __init__(self, reg_name: RegisterName, size: int, reg_offset=0, is_fp=False, clear_entire_reg=False): + super().__init__(size, is_fp) self.reg_name = reg_name - self.alt_offsets = {} if alt_offsets is None else alt_offsets + self.reg_offset = reg_offset + self.clear_entire_reg = clear_entire_reg def __repr__(self): return "<%s>" % self.reg_name def __eq__(self, other): - return type(other) is SimRegArg and self.reg_name == other.reg_name + return type(other) is SimRegArg and self.reg_name == other.reg_name and self.reg_offset == other.reg_offset def __hash__(self): - return hash((self.size, self.reg_name, tuple(self.alt_offsets))) + return hash((self.size, self.reg_name, self.reg_offset)) + + def check_offset(self, arch): + return arch.registers[self.reg_name][0] + self.reg_offset + + def set_value(self, state, value, **kwargs): # pylint: disable=unused-argument,arguments-differ + value = self.check_value_set(value, state.arch) + offset = self.check_offset(state.arch) + if self.clear_entire_reg: + state.registers.store(self.reg_name, 0) + state.registers.store(offset, value, size=self.size) + + def get_value(self, state, **kwargs): # pylint: disable=unused-argument,arguments-differ + offset = self.check_offset(state.arch) + return self.check_value_get(state.registers.load(offset, size=self.size)) + + def refine(self, size, arch=None, offset=None, is_fp=None): + passed_offset_none = offset is None + if offset is None: + if arch is None: + raise ValueError("Need to specify either offset or arch in order to refine a register argument") + if arch.register_endness == 'Iend_LE': + offset = 0 + else: + offset = self.size - size + if is_fp is None: is_fp = self.is_fp + return SimRegArg(self.reg_name, size, self.reg_offset + offset, is_fp, clear_entire_reg=passed_offset_none) - def _fix_offset(self, state, size, arch=None): - """ - This is a hack to deal with small values being stored at offsets into large registers unpredictably - """ - if state is not None: - arch = state.arch - - if arch is None: - raise ValueError('Either "state" or "arch" must be specified.') - - offset = arch.registers[self.reg_name][0] - if size in self.alt_offsets: - return offset + self.alt_offsets[size] - elif size < self.size and arch.register_endness == 'Iend_BE': - return offset + (self.size - size) - return offset - - def set_value(self, state, value, endness=None, size=None, **kwargs): # pylint: disable=unused-argument,arguments-differ - self.check_value(value) - if endness is None: endness = state.arch.register_endness - if isinstance(value, int): value = claripy.BVV(value, self.size*state.arch.byte_width) - if size is None: size = min(self.size, value.length // state.arch.byte_width) - offset = self._fix_offset(state, size) - state.registers.store(offset, value, endness=endness, size=size) - - def get_value(self, state, endness=None, size=None, **kwargs): # pylint: disable=unused-argument,arguments-differ - if endness is None: endness = state.arch.register_endness - if size is None: size = self.size - offset = self._fix_offset(state, size) - return state.registers.load(offset, endness=endness, size=size) + def sse_extend(self): + return SimRegArg(self.reg_name, self.size, self.reg_offset + self.size, is_fp=self.is_fp) class SimStackArg(SimFunctionArgument): @@ -151,8 +281,8 @@ class SimStackArg(SimFunctionArgument): :var int stack_offset: The position of the argument relative to the stack pointer after the function prelude. :ivar int size: The size of the argument, in number of bytes. """ - def __init__(self, stack_offset, size): - SimFunctionArgument.__init__(self, size) + def __init__(self, stack_offset, size, is_fp=False): + SimFunctionArgument.__init__(self, size, is_fp) self.stack_offset = stack_offset def __repr__(self): @@ -164,22 +294,34 @@ def __eq__(self, other): def __hash__(self): return hash((self.size, self.stack_offset)) - def set_value(self, state, value, endness=None, stack_base=None): # pylint: disable=arguments-differ - self.check_value(value) - if endness is None: endness = state.arch.memory_endness + def set_value(self, state, value, stack_base=None, **kwargs): # pylint: disable=arguments-differ + value = self.check_value_set(value, state.arch) if stack_base is None: stack_base = state.regs.sp - if isinstance(value, int): value = claripy.BVV(value, self.size*state.arch.byte_width) - state.memory.store(stack_base + self.stack_offset, value, endness=endness, size=value.length//state.arch.byte_width) + state.memory.store(stack_base + self.stack_offset, value, endness=state.arch.memory_endness) - def get_value(self, state, endness=None, stack_base=None, size=None): # pylint: disable=arguments-differ - if endness is None: endness = state.arch.memory_endness + def get_value(self, state, stack_base=None, **kwargs): # pylint: disable=arguments-differ if stack_base is None: stack_base = state.regs.sp - return state.memory.load(stack_base + self.stack_offset, endness=endness, size=size or self.size) + value = state.memory.load(stack_base + self.stack_offset, endness=state.arch.memory_endness, size=self.size) + return self.check_value_get(value) + + def refine(self, size, arch=None, offset=None, is_fp=None): + if offset is None: + if arch is None: + raise ValueError("Need to specify either offset or arch in order to refine a stack argument") + if arch.register_endness == 'Iend_LE': + offset = 0 + else: + offset = self.size - size + if is_fp is None: is_fp = self.is_fp + return SimStackArg(self.stack_offset + offset, size, is_fp) class SimComboArg(SimFunctionArgument): - def __init__(self, locations): - super().__init__(sum(x.size for x in locations)) + """ + An argument which spans multiple storage locations. Locations should be given least-significant first. + """ + def __init__(self, locations, is_fp=False): + super().__init__(sum(x.size for x in locations), is_fp=is_fp) self.locations = locations def __repr__(self): @@ -188,28 +330,62 @@ def __repr__(self): def __eq__(self, other): return type(other) is SimComboArg and all(a == b for a, b in zip(self.locations, other.locations)) - def set_value(self, state, value, endness=None, **kwargs): # pylint:disable=arguments-differ - # TODO: This code needs to be reworked for variable byte width and the Third Endness - self.check_value(value) - if endness is None: endness = state.arch.memory_endness - if isinstance(value, int): - value = claripy.BVV(value, self.size*state.arch.byte_width) - elif isinstance(value, float): - if self.size not in (4, 8): - raise ValueError("What do I do with a float %d bytes long" % self.size) - value = claripy.FPV(value, claripy.FSORT_FLOAT if self.size == 4 else claripy.FSORT_DOUBLE) + def set_value(self, state, value, **kwargs): # pylint:disable=arguments-differ + value = self.check_value_set(value, state.arch) cur = 0 - # TODO: I have no idea if this reversed is only supposed to be applied in LE situations - for loc in reversed(self.locations): - loc.set_value(state, value[cur*state.arch.byte_width + loc.size*state.arch.byte_width - 1:cur*state.arch.byte_width], endness=endness, **kwargs) - cur += loc.size + for loc in self.locations: + size_bits = loc.size * state.arch.byte_width + loc.set_value(state, value[cur + size_bits - 1:cur], **kwargs) + cur += size_bits - def get_value(self, state, endness=None, **kwargs): # pylint:disable=arguments-differ - if endness is None: endness = state.arch.memory_endness + def get_value(self, state, **kwargs): # pylint:disable=arguments-differ vals = [] for loc in reversed(self.locations): - vals.append(loc.get_value(state, endness, **kwargs)) - return state.solver.Concat(*vals) + vals.append(loc.get_value(state, **kwargs)) + return self.check_value_get(state.solver.Concat(*vals)) + +class SimStructArg(SimFunctionArgument): + def __init__(self, struct, locs): + super().__init__(sum(loc.size for loc in locs)) + self.struct = struct + self.locs = locs + + def get_value(self, state, **kwargs): + return SimStructValue(self.struct, { + field: getter.get_value(state, **kwargs) for field, getter in self.locs.items() + }) + + def set_value(self, state, value, **kwargs): + for field, setter in self.locs.items(): + setter.set_value(state, value[field], **kwargs) + +class SimArrayArg(SimFunctionArgument): + def __init__(self, locs): + super().__init__(sum(loc.size for loc in locs)) + self.locs = locs + + def get_value(self, state, **kwargs): + return [getter.get_value(state, **kwargs) for getter in self.locs] + + def set_value(self, state, value, **kwargs): + if len(value) != len(self.locs): + raise TypeError("Expected %d elements, got %d" % (len(self.locs), len(value))) + for subvalue, setter in zip(value, self.locs): + setter.set_value(state, subvalue, **kwargs) + +class SimReferenceArgument(SimFunctionArgument): + def __init__(self, ptr_loc, main_loc): + super().__init__(ptr_loc.size) # ??? + self.ptr_loc = ptr_loc + self.main_loc = main_loc + + def get_value(self, state, **kwargs): + ptr_val = self.ptr_loc.get_value(state, **kwargs) + return self.main_loc.get_value(state, stack_base=ptr_val, **kwargs) + + def set_value(self, state, value, **kwargs): + ptr_val = self.ptr_loc.get_value(state, **kwargs) + self.main_loc.set_value(state, value, stack_base=ptr_val, **kwargs) class ArgSession: @@ -217,65 +393,35 @@ class ArgSession: A class to keep track of the state accumulated in laying parameters out into memory """ - __slots__ = ('cc', 'real_args', 'fp_iter', 'int_iter', 'both_iter', ) + __slots__ = ('cc', 'fp_iter', 'int_iter', 'both_iter', ) def __init__(self, cc): self.cc = cc - self.real_args = None - self.fp_iter = None - self.int_iter = None - self.both_iter = None - - # these iters should only be used if real_args are not set or real_args are intentionally ignored (e.g., when - # variadic arguments are used). self.fp_iter = cc.fp_args self.int_iter = cc.int_args - self.both_iter = cc.both_args + self.both_iter = cc.memory_args - if cc.args is not None: - self.real_args = iter(cc.args) + def getstate(self): + return (self.fp_iter.getstate(), self.int_iter.getstate(), self.both_iter.getstate()) - # TODO: use safer errors than TypeError and ValueError - def next_arg(self, is_fp, size=None, ignore_real_args=False): - if self.real_args is not None and not ignore_real_args: - try: - arg = next(self.real_args) - if is_fp and self.cc.is_fp_arg(arg) is False: - raise TypeError("Can't put a float here - concrete arg positions are specified") - if not is_fp and self.cc.is_fp_arg(arg) is True: - raise TypeError("Can't put an int here - concrete arg positions are specified") - except StopIteration: - raise TypeError("Accessed too many arguments - concrete number are specified") - else: - try: - if is_fp: - arg = next(self.fp_iter) - else: - arg = next(self.int_iter) - except StopIteration: - try: - arg = next(self.both_iter) - except StopIteration: - raise TypeError("Accessed too many arguments - exhausted all positions?") + def setstate(self, state): + fp, int_, both = state + self.fp_iter.setstate(fp) + self.int_iter.setstate(int_) + self.both_iter.setstate(both) - if size is not None and size > arg.size: - arg = self.upsize_arg(arg, is_fp, size) - return arg +class UsercallArgSession: + __slots__ = ('cc', 'real_args', ) - def upsize_arg(self, arg, is_fp, size): - if not is_fp: - raise ValueError("You can't fit a integral value of size %d into an argument of size %d!" % (size, arg.size)) - if not isinstance(arg, SimStackArg): - raise ValueError("I don't know how to handle this? please report to @rhelmot") + def __init__(self, cc): + self.cc = cc + self.real_args = SerializableListIterator(self.cc.arg_locs) - arg_size = arg.size - locations = [arg] - while arg_size < size: - next_arg = self.next_arg(is_fp, size=None) - arg_size += next_arg.size - locations.append(next_arg) + def getstate(self): + return self.real_args.getstate() - return SimComboArg(locations) + def setstate(self, state): + self.real_args.setstate(state) class SimCC: @@ -286,67 +432,12 @@ class SimCC: this may be overridden with the `stack_base` parameter to each individual method. This is the base class for all calling conventions. - - An instance of this class allows it to be tweaked to the way a specific function should be called. """ - def __init__(self, - arch: archinfo.Arch, - args: Optional[List[SimFunctionArgument]]=None, - ret_val: Optional[SimFunctionArgument]=None, - sp_delta: Optional[int]=None, - func_ty: Optional[Union[SimTypeFunction, str]]=None): + def __init__(self, arch: archinfo.Arch): """ :param arch: The Archinfo arch for this CC - :param args: A list of SimFunctionArguments describing where the arguments go - :param ret_val: A SimFunctionArgument describing where the return value goes - :param sp_delta: The amount the stack pointer changes over the course of this function - CURRENTLY UNUSED - :param func_ty: A SimTypeFunction for the function itself, or a string that can be parsed into a - SimTypeFunction instance. - - Example func_ty strings: - >>> "int func(char*, int)" - >>> "int f(int, int, int*);" - Function names are ignored. - """ - if func_ty is not None: - if isinstance(func_ty, str): - if not func_ty.endswith(";"): - func_ty += ";" # Make pycparser happy - parsed = parse_file(func_ty) - parsed_decl = parsed[0] - if not parsed_decl: - raise ValueError('Cannot parse the provided function prototype.') - _, func_ty = next(iter(parsed_decl.items())) - - if not isinstance(func_ty, SimTypeFunction): - raise TypeError("Function prototype must be a SimTypeFunction instance or a string that can be parsed " - "into a SimTypeFunction instance.") - self.arch = arch - self.args = args - self.ret_val = ret_val - self.sp_delta = sp_delta - self.func_ty: Optional[SimTypeFunction] = func_ty if func_ty is None else func_ty.with_arch(arch) - - @classmethod - def from_arg_kinds(cls, arch, fp_args, ret_fp=False, sizes=None, sp_delta=None, func_ty=None): - """ - Get an instance of the class that will extract floating-point/integral args correctly. - - :param arch: The Archinfo arch for this CC - :param fp_args: A list, with one entry for each argument the function can take. True if the argument is fp, - false if it is integral. - :param ret_fp: True if the return value for the function is fp. - :param sizes: Optional: A list, with one entry for each argument the function can take. Each entry is the - size of the corresponding argument in bytes. - :param sp_delta: The amount the stack pointer changes over the course of this function - CURRENTLY UNUSED - :parmm func_ty: A SimType for the function itself - """ - basic = cls(arch, sp_delta=sp_delta, func_ty=func_ty) - basic.args = basic.arg_locs(fp_args, sizes) - basic.ret_val = basic.fp_return_val if ret_fp else basic.return_val - return basic # # Here are all the things a subclass needs to specify! @@ -374,41 +465,34 @@ def from_arg_kinds(cls, arch, fp_args, ret_fp=False, sizes=None, sp_delta=None, @property def int_args(self): """ - Iterate through all the possible arg positions that can only be used to store integer or pointer values - Does not take into account customizations. + Iterate through all the possible arg positions that can only be used to store integer or pointer values. Returns an iterator of SimFunctionArguments """ if self.ARG_REGS is None: raise NotImplementedError() - for reg in self.ARG_REGS: # pylint: disable=not-an-iterable - yield SimRegArg(reg, self.arch.bytes) + return SerializableListIterator([SimRegArg(reg, self.arch.bytes) for reg in self.ARG_REGS]) @property - def both_args(self): + def memory_args(self): """ - Iterate through all the possible arg positions that can be used to store any kind of argument - Does not take into account customizations. + Iterate through all the possible arg positions that can be used to store any kind of argument. Returns an iterator of SimFunctionArguments """ - turtle = self.STACKARG_SP_BUFF + self.STACKARG_SP_DIFF - while True: - yield SimStackArg(turtle, self.arch.bytes) - turtle += self.arch.bytes + start = self.STACKARG_SP_BUFF + self.STACKARG_SP_DIFF + return SerializableCounter(start, self.arch.bytes, lambda offset: SimStackArg(offset, self.arch.bytes)) @property def fp_args(self): """ - Iterate through all the possible arg positions that can only be used to store floating point values - Does not take into account customizations. + Iterate through all the possible arg positions that can only be used to store floating point values. Returns an iterator of SimFunctionArguments """ if self.FP_ARG_REGS is None: raise NotImplementedError() - for reg in self.FP_ARG_REGS: # pylint: disable=not-an-iterable - yield SimRegArg(reg, self.arch.registers[reg][1]) + return SerializableListIterator([SimRegArg(reg, self.arch.bytes) for reg in self.FP_ARG_REGS]) def is_fp_arg(self, arg): """ @@ -426,8 +510,8 @@ def is_fp_arg(self, arg): return None ArgSession = ArgSession # import this from global scope so SimCC subclasses can subclass it if they like - @property - def arg_session(self): + + def arg_session(self, ret_ty: Optional[SimType]): """ Return an arg session. @@ -435,8 +519,16 @@ def arg_session(self): laid out into memory. The default behavior is that there are a finite list of int-only and fp-only argument slots, and an infinite number of generic slots, and when an argument of a given type is requested, the most slot available is used. If you need different behavior, subclass ArgSession. + + You need to provide the return type of the function in order to kick off an arg layout session. """ - return self.ArgSession(self) + session = self.ArgSession(self) + if self.return_in_implicit_outparam(ret_ty): + self.next_arg(session, SimTypePointer(SimTypeBottom())) + return session + + def return_in_implicit_outparam(self, ty): + return False def stack_space(self, args): """ @@ -453,26 +545,28 @@ def stack_space(self, args): out += self.STACKARG_SP_BUFF return out - @property - def return_val(self): + def return_val(self, ty, perspective_returned=False): """ - The location the return value is stored. + The location the return value is stored, based on its type. """ - # pylint: disable=unsubscriptable-object - if self.ret_val is not None: - return self.ret_val - - if self.func_ty is not None and \ - self.func_ty.returnty is not None and \ - self.OVERFLOW_RETURN_VAL is not None and \ - self.func_ty.returnty.size not in (None, NotImplemented) and \ - self.func_ty.returnty.size > self.RETURN_VAL.size * self.arch.byte_width: - return SimComboArg([self.RETURN_VAL, self.OVERFLOW_RETURN_VAL]) - return self.RETURN_VAL + if ty._arch is None: + ty = ty.with_arch(self.arch) + if isinstance(ty, (SimStruct, SimUnion, SimTypeFixedSizeArray)): + raise TypeError(f"{self} doesn't know how to return aggregate types. Consider overriding return_val to " + "implement its ABI logic") + if self.return_in_implicit_outparam(ty): + if perspective_returned: + ptr_loc = self.RETURN_VAL + else: + ptr_loc = self.next_arg(self.ArgSession(self), SimTypePointer(SimTypeBottom())) + return SimReferenceArgument(ptr_loc, SimStackArg(0, ty.size // self.arch.byte_width, is_fp=isinstance(ty, SimTypeFloat))) - @property - def fp_return_val(self): - return self.FP_RETURN_VAL if self.ret_val is None else self.ret_val + if isinstance(ty, SimTypeFloat): + return self.FP_RETURN_VAL.refine(size=ty.size // self.arch.byte_width, arch=self.arch, is_fp=True) + + if ty.size > self.RETURN_VAL.size * self.arch.byte_width: + return SimComboArg([self.RETURN_VAL, self.OVERFLOW_RETURN_VAL]) + return self.RETURN_VAL.refine(size=ty.size // self.arch.byte_width, arch=self.arch, is_fp=False) @property def return_addr(self): @@ -481,6 +575,36 @@ def return_addr(self): """ return self.RETURN_ADDR + def next_arg(self, session, arg_type): + if isinstance(arg_type, (SimStruct, SimUnion, SimTypeFixedSizeArray)): + raise TypeError(f"{self} doesn't know how to store aggregate types. Consider overriding next_arg to " + "implement its ABI logic") + is_fp = isinstance(arg_type, SimTypeFloat) + size = arg_type.size // self.arch.byte_width + try: + if is_fp: + arg = next(session.fp_iter) + else: + arg = next(session.int_iter) + except StopIteration: + try: + arg = next(session.both_iter) + except StopIteration: + raise TypeError("Accessed too many arguments - exhausted all positions?") + + if size > arg.size: + if isinstance(arg, SimStackArg): + arg_size = arg.size + locations = [arg] + while arg_size < size: + next_arg = next(session.both_iter) + arg_size += next_arg.size + locations.append(next_arg) + return SimComboArg(locations, is_fp=is_fp) + raise ValueError(f"{self} doesn't know how to store large types. Consider overriding" + " next_arg to implement its ABI logic") + return arg.refine(size, is_fp=is_fp, arch=self.arch) + # # Useful functions! # @@ -491,120 +615,62 @@ def is_fp_value(val): (isinstance(val, claripy.ast.Base) and val.op.startswith('fp')) or \ (isinstance(val, claripy.ast.Base) and val.op == 'Reverse' and val.args[0].op.startswith('fp')) - def arg_locs(self, is_fp=None, sizes=None): + @staticmethod + def guess_prototype(args, prototype=None): """ - Pass this a list of whether each parameter is floating-point or not, and get back a list of - SimFunctionArguments. Optionally, pass a list of argument sizes (in bytes) as well. + Come up with a plausible SimTypeFunction for the given args (as would be passed to e.g. setup_callsite). - If you've customized this CC, this will sanity-check the provided locations with the given list. + You can pass a variadic function prototype in the `base_type` parameter and all its arguments will be used, + only guessing types for the variadic arguments. """ - session = self.arg_session - ignore_real_args = False - if self.func_ty is None and self.args is None: - # No function prototype is provided, no args is provided. `is_fp` must be provided. - if is_fp is None: - raise ValueError('"is_fp" must be provided when no function prototype is available.') - ignore_real_args = True - else: - # let's rely on the func_ty or self.args for the number of arguments and whether each argument is FP or not - if self.func_ty is not None: - args = [ a.with_arch(self.arch) for a in self.func_ty.args ] - else: - args = self.args - # FIXME: Hack: Replacing structs with primitive types since we don't yet support passing structs as - # FIXME: arguments. - args = [ SimTypeInt().with_arch(self.arch) if isinstance(a, (SimStruct, SimUnion)) else a for a in args ] - # FIXME: Hack: Replacing long ints with shorter types since we don't yet support passing integers longer - # FIXME: than GPR sizes. - args = [ SimTypeInt().with_arch(self.arch) if not isinstance(a, SimTypeBottom) and a.size > self.arch.bytes - else a for a in args ] - if is_fp is None: - is_fp = [ isinstance(arg, (SimTypeFloat, SimTypeDouble)) or self.is_fp_arg(arg) for arg in args ] - else: - ignore_real_args = True - if sizes is None: - # initialize sizes from args - sizes = [ ] - for a in args: - if isinstance(a, SimType): - if isinstance(a, SimTypeFixedSizeArray) or a.size is NotImplemented: - sizes.append(self.arch.bytes) - else: - sizes.append(a.size // self.arch.byte_width) # SimType.size is in bits - elif isinstance(a, SimFunctionArgument): - sizes.append(a.size) # SimFunctionArgument.size is in bytes - else: - # fallback to use self.arch.bytes - sizes.append(self.arch.bytes) + if type(prototype) is str: + prototype = parse_signature(prototype) + elif prototype is None: + l.warning("Guessing call prototype. Please specify prototype.") + + charp = SimTypePointer(SimTypeChar()) + result = prototype if prototype is not None else SimTypeFunction([], charp) + for arg in args[len(result.args):]: + if type(arg) in (int, bytes, PointerWrapper): + result.args.append(charp) + elif type(arg) is float: + result.args.append(SimTypeDouble()) + elif isinstance(arg, claripy.ast.BV): + result.args.append(SimTypeNum(len(arg), False)) + elif isinstance(arg, claripy.ast.FP): + if arg.sort == claripy.FSORT_FLOAT: + result.args.append(SimTypeFloat()) + elif arg.sort == claripy.FSORT_DOUBLE: + result.args.append(SimTypeDouble()) + else: + raise TypeError("WHAT kind of floating point is this") else: - ignore_real_args = True - - if sizes is None: sizes = [self.arch.bytes] * len(is_fp) - return [session.next_arg(ifp, size=sz, ignore_real_args=ignore_real_args) for ifp, sz in zip(is_fp, sizes)] - - def arg(self, state, index, stack_base=None): - """ - Returns a bitvector expression representing the nth argument of a function. - - `stack_base` is an optional pointer to the top of the stack at the function start. If it is not - specified, use the current stack pointer. + raise TypeError("Cannot guess FFI type for %s" % type(arg)) - WARNING: this assumes that none of the arguments are floating-point and they're all single-word-sized, unless - you've customized this CC. - """ - session = self.arg_session - if self.args is None or index >= len(self.args): - # self.args may not be provided, or args is incorrect or includes variadic arguments that we must get the - # proper argument according to the default calling convention - arg_loc = [session.next_arg(False, ignore_real_args=True) for _ in range(index + 1)][-1] - else: - arg_loc = self.args[index] - - return arg_loc.get_value(state, stack_base=stack_base) - - def get_args(self, state, is_fp=None, sizes=None, stack_base=None): - """ - `is_fp` should be a list of booleans specifying whether each corresponding argument is floating-point - - True for fp and False for int. For a shorthand to assume that all the parameters are int, pass the number of - parameters as an int. - - If you've customized this CC, you may omit this parameter entirely. If it is provided, it is used for - sanity-checking. - - `sizes` is an optional list of argument sizes, in bytes. Be careful about using this if you've made explicit - the arg locations, since it might decide to combine two locations into one if an arg is too big. + return result - `stack_base` is an optional pointer to the top of the stack at the function start. If it is not - specified, use the current stack pointer. - - Returns a list of bitvector expressions representing the arguments of a function. - """ - if sizes is None and self.func_ty is not None: - sizes = [arg.size for arg in self.func_ty.args] - if is_fp is None: - if self.args is None: - if self.func_ty is None: - raise ValueError("You must either customize this CC or pass a value to is_fp!") - arg_locs = self.arg_locs([False]*len(self.func_ty.args)) - else: - arg_locs = self.args - - elif type(is_fp) is int: - if self.args is not None and len(self.args) != is_fp: - raise ValueError("Bad number of args requested: got %d, expected %d" % (is_fp, len(self.args))) - arg_locs = self.arg_locs([False]*is_fp, sizes) - else: - arg_locs = self.arg_locs(is_fp, sizes) + def arg_locs(self, prototype): + if prototype._arch is None: + prototype = prototype.with_arch(self.arch) + session = self.arg_session(prototype.returnty) + return [self.next_arg(session, arg_ty) for arg_ty in prototype.args] + def get_args(self, state, prototype, stack_base=None): + arg_locs = self.arg_locs(prototype) return [loc.get_value(state, stack_base=stack_base) for loc in arg_locs] - def setup_callsite(self, state, ret_addr, args, stack_base=None, alloc_base=None, grow_like_stack=True): + def set_return_val(self, state, val, ty, stack_base=None, perspective_returned=False): + loc = self.return_val(ty, perspective_returned=perspective_returned) + loc.set_value(state, val, stack_base=stack_base) + + def setup_callsite(self, state, ret_addr, args, prototype, stack_base=None, alloc_base=None, grow_like_stack=True): """ This function performs the actions of the caller getting ready to jump into a function. :param state: The SimState to operate on :param ret_addr: The address to return to when the called function finishes :param args: The list of arguments that that the called function will see + :param prototype: The signature of the call you're making. Should include variadic args concretely. :param stack_base: An optional pointer to use as the top of the stack, circa the function entry point :param alloc_base: An optional pointer to use as the place to put excess argument data :param grow_like_stack: When allocating data at alloc_base, whether to allocate at decreasing addresses @@ -612,8 +678,7 @@ def setup_callsite(self, state, ret_addr, args, stack_base=None, alloc_base=None The idea here is that you can provide almost any kind of python type in `args` and it'll be translated to a binary format to be placed into simulated memory. Lists (representing arrays) must be entirely elements of the same type and size, while tuples (representing structs) can be elements of any type and size. - If you'd like there to be a pointer to a given value, wrap the value in a `PointerWrapper`. Any value - that can't fit in a register will be automatically put in a PointerWrapper. + If you'd like there to be a pointer to a given value, wrap the value in a `PointerWrapper`. If stack_base is not provided, the current stack pointer will be used, and it will be updated. If alloc_base is not provided, the stack base will be used and grow_like_stack will implicitly be True. @@ -626,11 +691,11 @@ def setup_callsite(self, state, ret_addr, args, stack_base=None, alloc_base=None # STEP 0: clerical work - if isinstance(self, SimCCSoot): - SootMixin.setup_callsite(state, args, ret_addr) - return - - allocator = AllocHelper(self.arch.bits, self.arch.memory_endness == 'Iend_LE') + allocator = AllocHelper(self.arch.bits) + if type(prototype) is str: + prototype = parse_signature(prototype, arch=self.arch) + elif prototype._arch is None: + prototype = prototype.with_arch(self.arch) # # STEP 1: convert all values into serialized form @@ -639,26 +704,25 @@ def setup_callsite(self, state, ret_addr, args, stack_base=None, alloc_base=None # This is also where we compute arg locations (arg_locs) # - if self.func_ty is not None: - vals = [self._standardize_value(arg, ty, state, allocator.dump) for arg, ty in zip(args, self.func_ty.args)] - else: - vals = [self._standardize_value(arg, None, state, allocator.dump) for arg in args] - - arg_session = self.arg_session - arg_locs = [None]*len(args) - for i, (arg, val) in enumerate(zip(args, vals)): - if self.is_fp_value(arg) or \ - (self.func_ty is not None and isinstance(self.func_ty.args[i], SimTypeFloat)): - arg_locs[i] = arg_session.next_arg(is_fp=True, size=val.length // state.arch.byte_width) + vals = [self._standardize_value(arg, ty, state, allocator.dump) for arg, ty in zip(args, prototype.args)] + arg_locs = self.arg_locs(prototype) + + # step 1.5, gotta handle the SimReferenceArguments correctly + for i, (loc, val) in enumerate(zip(arg_locs, vals)): + if not isinstance(loc, SimReferenceArgument): continue - if val.length > state.arch.bits or (self.func_ty is None and isinstance(arg, (bytes, str, list, tuple))): - vals[i] = allocator.dump(val, state) - elif val.length < state.arch.bits: - if self.arch.memory_endness == 'Iend_LE': - vals[i] = val.concat(claripy.BVV(0, state.arch.bits - val.length)) - else: - vals[i] = claripy.BVV(0, state.arch.bits - val.length).concat(val) - arg_locs[i] = arg_session.next_arg(is_fp=False, size=vals[i].length // state.arch.byte_width) + dumped = allocator.dump(val, state, loc=val.main_loc) + vals[i] = dumped + arg_locs[i] = val.ptr_loc + + # step 1.75 allocate implicit outparam stuff + if self.return_in_implicit_outparam(prototype.returnty): + loc = self.return_val(prototype.returnty) + assert isinstance(loc, SimReferenceArgument) + # hack: because the allocator gives us a pointer that needs to be translated, we need to shove it into + # the args list so it'll be translated and stored once everything is laid out + vals.append(allocator.alloc(loc.main_loc.size)) + arg_locs.append(loc.ptr_loc) # # STEP 2: decide on memory storage locations @@ -702,103 +766,45 @@ def setup_callsite(self, state, ret_addr, args, stack_base=None, alloc_base=None allocator.apply(state, alloc_base) for loc, val in zip(arg_locs, vals): - if val.length > loc.size * state.arch.byte_width: - raise ValueError("Can't fit value {} into location {}".format(repr(val), repr(loc))) - loc.set_value(state, val, endness='Iend_BE', stack_base=stack_base) + loc.set_value(state, val, stack_base=stack_base) self.return_addr.set_value(state, ret_addr, stack_base=stack_base) - def teardown_callsite(self, state, return_val=None, arg_types=None, force_callee_cleanup=False): + def teardown_callsite(self, state, return_val=None, prototype=None, force_callee_cleanup=False): """ This function performs the actions of the callee as it's getting ready to return. It returns the address to return to. :param state: The state to mutate :param return_val: The value to return - :param arg_types: The fp-ness of each of the args. Used to calculate sizes to clean up + :param prototype: The prototype of the given function :param force_callee_cleanup: If we should clean up the stack allocation for the arguments even if it's not the callee's job to do so TODO: support the stack_base parameter from setup_callsite...? Does that make sense in this context? Maybe it could make sense by saying that you pass it in as something like the "saved base pointer" value? """ - if return_val is not None: - self.set_return_val(state, return_val) + if return_val is not None and not isinstance(prototype.returnty, SimTypeBottom): + self.set_return_val(state, return_val, prototype.returnty) + # ummmmmmmm hack + loc = self.return_val(prototype.returnty) + if isinstance(loc, SimReferenceArgument): + self.RETURN_VAL.set_value(state, loc.ptr_loc.get_value(state)) ret_addr = self.return_addr.get_value(state) if state.arch.sp_offset is not None: if force_callee_cleanup or self.CALLEE_CLEANUP: - if arg_types is not None: - session = self.arg_session - state.regs.sp += self.stack_space([session.next_arg(x) for x in arg_types]) - elif self.args is not None: - state.regs.sp += self.stack_space(self.args) + session = self.arg_session(prototype.returnty) + if self.return_in_implicit_outparam(prototype.returnty): + extra = [self.return_val(prototype.returnty).ptr_loc] else: - l.warning("Can't perform callee cleanup when I have no idea how many arguments there are! Assuming 0") - state.regs.sp += self.STACKARG_SP_DIFF + extra = [] + state.regs.sp += self.stack_space(extra + [self.next_arg(session, x) for x in prototype.args]) else: state.regs.sp += self.STACKARG_SP_DIFF return ret_addr - def set_func_type_with_arch(self, func_ty: Optional[SimTypeFunction]) -> None: - """ - Set self.func_ty to another function type and set its arch to self.arch. - - :param func_ty: The SimTypeFunction type to set. - :return: None - """ - if func_ty is None or self.arch is None: - self.func_ty = func_ty - else: - self.func_ty = func_ty.with_arch(self.arch) - - # pylint: disable=unused-argument - def get_return_val(self, state, is_fp=None, size=None, stack_base=None): - """ - Get the return value out of the given state - """ - ty = self.func_ty.returnty if self.func_ty is not None else None - if self.ret_val is not None: - loc = self.ret_val - elif is_fp is not None: - loc = self.FP_RETURN_VAL if is_fp else self.RETURN_VAL - elif ty is not None: - loc = self.FP_RETURN_VAL if isinstance(ty, SimTypeFloat) else self.RETURN_VAL - else: - loc = self.RETURN_VAL - - if loc is None: - raise NotImplementedError("This SimCC doesn't know how to get this value - should be implemented") - - val = loc.get_value(state, stack_base=stack_base, size=None if ty is None else ty.size//state.arch.byte_width) - if self.is_fp_arg(loc) or self.is_fp_value(val) or isinstance(ty, SimTypeFloat): - val = val.raw_to_fp() - return val - - def set_return_val(self, state, val, is_fp=None, size=None, stack_base=None): - """ - Set the return value into the given state - """ - ty = self.func_ty.returnty if self.func_ty is not None else None - try: - betterval = self._standardize_value(val, ty, state, None) - except AttributeError as ex: - raise ValueError("Can't fit value %s into a return value" % repr(val)) from ex - - if self.ret_val is not None: - loc = self.ret_val - elif is_fp is not None: - loc = self.fp_return_val if is_fp else self.return_val - elif ty is not None: - loc = self.fp_return_val if isinstance(ty, SimTypeFloat) else self.return_val - else: - loc = self.fp_return_val if self.is_fp_value(val) else self.return_val - - if loc is None: - raise NotImplementedError("This SimCC doesn't know how to store this value - should be implemented") - loc.set_value(state, betterval, endness='Iend_BE', stack_base=stack_base) - # # Helper functions @@ -806,174 +812,132 @@ def set_return_val(self, state, val, is_fp=None, size=None, stack_base=None): @staticmethod def _standardize_value(arg, ty, state, alloc): - check = ty is not None - if check: - ty = ty.with_arch(state.arch) - if isinstance(arg, SimActionObject): return SimCC._standardize_value(arg.ast, ty, state, alloc) elif isinstance(arg, PointerWrapper): - if check and not isinstance(ty, SimTypePointer): + if not isinstance(ty, SimTypePointer): raise TypeError("Type mismatch: expected %s, got pointer-wrapper" % ty.name) - real_value = SimCC._standardize_value(arg.value, ty.pts_to if check else None, state, alloc) + if arg.buffer: + if isinstance(arg.value, claripy.Bits): + real_value = arg.value.chop(state.arch.byte_width) + elif type(arg.value) in (bytes, str): + real_value = claripy.BVV(arg.value).chop(8) + else: + raise TypeError("PointerWrapper(buffer=True) can only be used with a bitvector or a bytestring") + else: + child_type = SimTypeArray(ty.pts_to) if type(arg.value) in (str, bytes, list) else ty.pts_to + real_value = SimCC._standardize_value(arg.value, child_type, state, alloc) return alloc(real_value, state) elif isinstance(arg, (str, bytes)): if type(arg) is str: arg = arg.encode() arg += b'\0' - ref = False - if check: - if isinstance(ty, SimTypePointer) and \ - isinstance(ty.pts_to, SimTypeChar): - ref = True - elif isinstance(ty, SimTypeFixedSizeArray) and \ - isinstance(ty.elem_type, SimTypeChar): - ref = False + if isinstance(ty, SimTypePointer) and \ + isinstance(ty.pts_to, SimTypeChar): + ref = True + elif isinstance(ty, SimTypeFixedSizeArray) and \ + isinstance(ty.elem_type, SimTypeChar): + ref = False + if len(arg) > ty.length: + raise TypeError("String %s is too long for %s" % (repr(arg), ty)) + arg = arg.ljust(ty.length, b'\0') + elif isinstance(ty, SimTypeArray) and \ + isinstance(ty.elem_type, SimTypeChar): + ref = True + if ty.length is not None: if len(arg) > ty.length: - raise TypeError("String %s is too long for %s" % (repr(arg), ty.name)) + raise TypeError("String %s is too long for %s" % (repr(arg), ty)) arg = arg.ljust(ty.length, b'\0') - elif isinstance(ty, SimTypeArray) and \ - isinstance(ty.elem_type, SimTypeChar): - ref = True - if ty.length is not None: - if len(arg) > ty.length: - raise TypeError("String %s is too long for %s" % (repr(arg), ty.name)) - arg = arg.ljust(ty.length, b'\0') - elif isinstance(ty, SimTypeString): - ref = False - if len(arg) > ty.length + 1: - raise TypeError("String %s is too long for %s" % (repr(arg), ty.name)) - arg = arg.ljust(ty.length + 1, b'\0') - else: - raise TypeError("Type mismatch: Expected %s, got char*" % ty.name) + elif isinstance(ty, SimTypeString): + ref = False + if len(arg) > ty.length + 1: + raise TypeError("String %s is too long for %s" % (repr(arg), ty)) + arg = arg.ljust(ty.length + 1, b'\0') + else: + raise TypeError("Type mismatch: Expected %s, got char*" % ty.name) val = SimCC._standardize_value(list(arg), SimTypeFixedSizeArray(SimTypeChar(), len(arg)), state, alloc) if ref: val = alloc(val, state) return val elif isinstance(arg, list): - ref = False - subty = None - if check: - if isinstance(ty, SimTypePointer): - ref = True - subty = ty.pts_to - elif isinstance(ty, SimTypeFixedSizeArray): - ref = False - subty = ty.elem_type + if isinstance(ty, SimTypePointer): + ref = True + subty = ty.pts_to + elif isinstance(ty, SimTypeFixedSizeArray): + ref = False + subty = ty.elem_type + if len(arg) != ty.length: + raise TypeError("Array %s is the wrong length for %s" % (repr(arg), ty)) + elif isinstance(ty, SimTypeArray): + ref = True + subty = ty.elem_type + if ty.length is not None: if len(arg) != ty.length: - raise TypeError("Array %s is the wrong length for %s" % (repr(arg), ty.name)) - elif isinstance(ty, SimTypeArray): - ref = True - subty = ty.elem_type - if ty.length is not None: - if len(arg) != ty.length: - raise TypeError("Array %s is the wrong length for %s" % (repr(arg), ty.name)) - else: - raise TypeError("Type mismatch: Expected %s, got char*" % ty.name) + raise TypeError("Array %s is the wrong length for %s" % (repr(arg), ty)) else: - types = list(map(type, arg)) - if types[1:] != types[:-1]: - raise TypeError("All elements of list must be of same type") + raise TypeError("Type mismatch: Expected %s, got char*" % ty.name) - val = claripy.Concat(*[SimCC._standardize_value(sarg, subty, state, alloc) for sarg in arg]) + val = [SimCC._standardize_value(sarg, subty, state, alloc) for sarg in arg] if ref: - val = alloc(val, state) + val = alloc(claripy.Concat(*val), state) return val - elif isinstance(arg, tuple): - if check: - if not isinstance(ty, SimStruct): - raise TypeError("Type mismatch: Expected %s, got tuple (i.e. struct)" % ty.name) + elif isinstance(arg, (tuple, dict, SimStructValue)): + if not isinstance(ty, SimStruct): + raise TypeError("Type mismatch: Expected %s, got %s (i.e. struct)" % (ty.name, type(arg))) + if type(arg) is not SimStructValue: if len(arg) != len(ty.fields): raise TypeError("Wrong number of fields in struct, expected %d got %d" % (len(ty.fields), len(arg))) - return claripy.Concat(*[SimCC._standardize_value(sarg, sty, state, alloc) - for sarg, sty - in zip(arg, ty.fields.values())]) - else: - return claripy.Concat(*[SimCC._standardize_value(sarg, None, state, alloc) for sarg in arg]) + arg = SimStructValue(ty, arg) + return SimStructValue(ty, [SimCC._standardize_value(arg[field], ty.fields[field], state, alloc) for field in ty.fields]) elif isinstance(arg, int): - if check and isinstance(ty, SimTypeFloat): + if isinstance(ty, SimTypeFloat): return SimCC._standardize_value(float(arg), ty, state, alloc) - val = state.solver.BVV(arg, ty.size if check else state.arch.bits) - if state.arch.memory_endness == 'Iend_LE': - val = val.reversed + val = state.solver.BVV(arg, ty.size) return val elif isinstance(arg, float): - sort = claripy.FSORT_FLOAT - if check: - if isinstance(ty, SimTypeDouble): - sort = claripy.FSORT_DOUBLE - elif isinstance(ty, SimTypeFloat): - pass - else: - raise TypeError("Type mismatch: expectd %s, got float" % ty.name) + if isinstance(ty, SimTypeDouble): + sort = claripy.FSORT_DOUBLE + elif isinstance(ty, SimTypeFloat): + sort = claripy.FSORT_FLOAT else: - sort = claripy.FSORT_DOUBLE if state.arch.bits == 64 else claripy.FSORT_FLOAT + raise TypeError("Type mismatch: expected %s, got float" % ty) - val = claripy.fpToIEEEBV(claripy.FPV(arg, sort)) - if state.arch.memory_endness == 'Iend_LE': - val = val.reversed # pylint: disable=no-member - return val + return claripy.FPV(arg, sort) elif isinstance(arg, claripy.ast.FP): - val = claripy.fpToIEEEBV(arg) - if state.arch.memory_endness == 'Iend_LE': - val = val.reversed # pylint: disable=no-member - return val - - elif isinstance(arg, claripy.ast.Base): - endswap = False - bypass_sizecheck = False - if check: - if isinstance(ty, SimTypePointer): - # we have been passed an AST as a pointer argument. is this supposed to be the pointer or the - # content of the pointer? - # in the future (a breaking change) we should perhaps say it ALWAYS has to be the pointer itself - # but for now use the heuristic that if it's the right size for the pointer it is the pointer - endswap = True - elif isinstance(ty, SimTypeReg): - # definitely endswap. - # TODO: should we maybe pad the value to the type size here? - endswap = True - bypass_sizecheck = True - else: - # if we know nothing about the type assume it's supposed to be an int if it looks like an int - endswap = True - - # yikes - if endswap and state.arch.memory_endness == 'Iend_LE' and (bypass_sizecheck or arg.length == state.arch.bits): - arg = arg.reversed - return arg + if isinstance(ty, SimTypeFloat): + if len(arg) != ty.size: + raise TypeError("Type mismatch: expected %s, got %s" % (ty, arg.sort)) + return arg + if isinstance(ty, (SimTypeReg, SimTypeNum)): + return arg.val_to_bv(ty.size, ty.signed) + raise TypeError("Type mismatch: expected %s, got %s" % (ty, arg.sort)) + + elif isinstance(arg, claripy.ast.BV): + if isinstance(ty, (SimTypeReg, SimTypeNum)): + if len(arg) != ty.size: + raise TypeError("Type mismatch: expected %s, got %d bits" % (ty, len(arg))) + return arg + if isinstance(ty, (SimTypeFloat)): + raise TypeError("It's unclear how to coerce a bitvector to %s. " + "Do you want .raw_to_fp or .val_to_fp, and signed or unsigned?") + raise TypeError("Type mismatch: expected %s, got bitvector" % ty) else: raise TypeError("I don't know how to serialize %s." % repr(arg)) def __repr__(self): - return "<{}: {}->{}, sp_delta={}>".format(self.__class__.__name__, - self.args, - self.ret_val, - self.sp_delta) + return "<{}>".format(self.__class__.__name__) def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - - def _compare_args(args0, args1): - if args0 is None and args1 is None: - return True - if args0 is None or args1 is None: - return False - return set(args0) == set(args1) - - return _compare_args(self.args, other.args) and \ - self.ret_val == other.ret_val and \ - self.sp_delta == other.sp_delta + return isinstance(other, self.__class__) @classmethod def _match(cls, arch, args: List, sp_delta): @@ -985,7 +949,7 @@ def _match(cls, arch, args: List, sp_delta): sample_inst = cls(arch) all_fp_args = list(sample_inst.fp_args) all_int_args = list(sample_inst.int_args) - both_iter = sample_inst.both_args + both_iter = sample_inst.memory_args some_both_args = [next(both_iter) for _ in range(len(args))] new_args = [ ] @@ -1020,58 +984,64 @@ def find_cc(arch: 'archinfo.Arch', args: List[SimFunctionArgument], sp_delta: in possible_cc_classes = CC[arch.name] for cc_cls in possible_cc_classes: if cc_cls._match(arch, args, sp_delta): - return cc_cls(arch, args=args, sp_delta=sp_delta) + return cc_cls(arch) return None - def get_arg_info(self, state, is_fp=None, sizes=None): + def get_arg_info(self, state, prototype): """ This is just a simple wrapper that collects the information from various locations - is_fp and sizes are passed to self.arg_locs and self.get_args + prototype is as passed to self.arg_locs and self.get_args :param angr.SimState state: The state to evaluate and extract the values from :return: A list of tuples, where the nth tuple is (type, name, location, value) of the nth argument """ - argument_locations = self.arg_locs(is_fp=is_fp, sizes=sizes) - argument_values = self.get_args(state, is_fp=is_fp, sizes=sizes) + argument_locations = self.arg_locs(prototype) + argument_values = self.get_args(state, prototype) - if self.func_ty: - argument_types = self.func_ty.args - argument_names = self.func_ty.arg_names if self.func_ty.arg_names else ['unknown'] * len(self.func_ty.args) - else: - argument_types = [SimTypeTop] * len(argument_locations) - argument_names = ['unknown'] * len(argument_locations) + argument_types = prototype.args + argument_names = prototype.arg_names if prototype.arg_names else ['unknown'] * len(prototype.args) return list(zip(argument_types, argument_names, argument_locations, argument_values)) class SimLyingRegArg(SimRegArg): """ A register that LIES about the types it holds """ - def __init__(self, name): - # TODO: This looks byte-related. Make sure to use Arch.byte_width + def __init__(self, name, size=8): super().__init__(name, 8) + self._real_size = size - def get_value(self, state, size=None, endness=None, **kwargs): # pylint:disable=arguments-differ + def get_value(self, state, **kwargs): # pylint:disable=arguments-differ #val = super(SimLyingRegArg, self).get_value(state, **kwargs) - val = getattr(state.regs, self.reg_name) - if endness and endness != state.arch.register_endness: - val = val.reversed - if size == 4: + val = state.registers.load(self.reg_name).raw_to_fp() + if self._real_size == 4: val = claripy.fpToFP(claripy.fp.RM.RM_NearestTiesEven, val.raw_to_fp(), claripy.FSORT_FLOAT) return val - def set_value(self, state, val, size=None, endness=None, **kwargs): # pylint:disable=arguments-differ - if size == 4: - if state.arch.register_endness == 'IEnd_LE' and endness == 'IEnd_BE': - # pylint: disable=no-member - val = claripy.fpToFP(claripy.fp.RM.RM_NearestTiesEven, val.reversed.raw_to_fp(), claripy.FSORT_DOUBLE).reversed - else: - val = claripy.fpToFP(claripy.fp.RM.RM_NearestTiesEven, val.raw_to_fp(), claripy.FSORT_DOUBLE) - if endness and endness != state.arch.register_endness: - val = val.reversed - setattr(state.regs, self.reg_name, val) + def set_value(self, state, val, **kwargs): # pylint:disable=arguments-differ + val = self.check_value_set(val, state.arch) + if self._real_size == 4: + val = claripy.fpToFP(claripy.fp.RM.RM_NearestTiesEven, val.raw_to_fp(), claripy.FSORT_DOUBLE) + state.registers.store(self.reg_name, val) #super(SimLyingRegArg, self).set_value(state, val, endness=endness, **kwargs) + def refine(self, size, arch=None, offset=None, is_fp=None): + return SimLyingRegArg(self.reg_name, size) + + +class SimCCUsercall(SimCC): + def __init__(self, arch, arg_locs, ret_loc): + super().__init__(arch) + self.arg_locs = arg_locs + self.ret_loc = ret_loc + + ArgSession = UsercallArgSession + def next_arg(self, session, arg_type): + return next(session.real_args) + + def return_val(self, ty, **kwargs): + return self.ret_loc + class SimCCCdecl(SimCC): ARG_REGS = [] # All arguments are passed in stack FP_ARG_REGS = [] @@ -1083,7 +1053,47 @@ class SimCCCdecl(SimCC): RETURN_ADDR = SimStackArg(0, 4) ARCH = archinfo.ArchX86 -class SimCCStdcall(SimCCCdecl): + def next_arg(self, session, arg_type): + locs_size = 0 + byte_size = arg_type.size // self.arch.byte_width + locs = [] + while locs_size < byte_size: + locs.append(next(session.both_iter)) + locs_size += locs[-1].size + + return refine_locs_with_struct_type(self.arch, locs, arg_type) + + STRUCT_RETURN_THRESHOLD = 32 + + def return_val(self, ty, perspective_returned=False): + if ty._arch is None: + ty = ty.with_arch(self.arch) + if not isinstance(ty, SimStruct): + return super().return_val(ty, perspective_returned) + + if ty.size > self.STRUCT_RETURN_THRESHOLD: + # TODO this code is duplicated a ton of places. how should it be a function? + byte_size = ty.size // self.arch.byte_width + referenced_locs = [SimStackArg(offset, self.arch.bytes) for offset in range(0, byte_size, self.arch.bytes)] + referenced_loc = refine_locs_with_struct_type(self.arch, referenced_locs, ty) + if perspective_returned: + ptr_loc = self.RETURN_VAL + else: + ptr_loc = SimStackArg(0, 4) + reference_loc = SimReferenceArgument(ptr_loc, referenced_loc) + return reference_loc + + return refine_locs_with_struct_type(self.arch, [self.RETURN_VAL, self.OVERFLOW_RETURN_VAL], ty) + + def return_in_implicit_outparam(self, ty): + if isinstance(ty, SimTypeBottom): + return False + return isinstance(ty, SimStruct) and ty.size > self.STRUCT_RETURN_THRESHOLD + +class SimCCMicrosoftCdecl(SimCCCdecl): + STRUCT_RETURN_THRESHOLD = 64 + +class SimCCStdcall(SimCCMicrosoftCdecl): CALLEE_CLEANUP = True class SimCCMicrosoftFastcall(SimCC): @@ -1093,6 +1103,13 @@ class SimCCMicrosoftFastcall(SimCC): RETURN_ADDR = SimStackArg(0, 4) ARCH = archinfo.ArchX86 +class MicrosoftAMD64ArgSession: + def __init__(self, cc): + self.cc = cc + self.int_iter = cc.int_args + self.fp_iter = cc.fp_args + self.both_iter = cc.both_args + class SimCCMicrosoftAMD64(SimCC): ARG_REGS = ['rcx', 'rdx', 'r8', 'r9'] FP_ARG_REGS = ['xmm0', 'xmm1', 'xmm2', 'xmm3'] @@ -1105,6 +1122,32 @@ class SimCCMicrosoftAMD64(SimCC): ARCH = archinfo.ArchAMD64 STACK_ALIGNMENT = 16 + ArgSession = MicrosoftAMD64ArgSession + def next_arg(self, session, arg_type): + try: + int_loc = next(session.int_iter) + fp_loc = next(session.fp_iter) + except StopIteration: + int_loc = fp_loc = next(session.both_iter) + + byte_size = arg_type.size // self.arch.byte_width + + if isinstance(arg_type, SimTypeFloat): + return fp_loc.refine(size=byte_size, is_fp=True, arch=self.arch) + + if byte_size <= int_loc.size: + return int_loc.refine(size=byte_size, is_fp=False, arch=self.arch) + + referenced_locs = [SimStackArg(offset, self.arch.bytes) for offset in range(0, byte_size, self.arch.bytes)] + referenced_loc = refine_locs_with_struct_type(self.arch, referenced_locs, arg_type) + reference_loc = SimReferenceArgument(int_loc, referenced_loc) + return reference_loc + + def return_in_implicit_outparam(self, ty): + if isinstance(ty, SimTypeBottom): + return False + return not isinstance(ty, SimTypeFloat) and ty.size > 64 + class SimCCSyscall(SimCC): """ @@ -1146,10 +1189,10 @@ def linux_syscall_update_error_reg(self, state, expr): self.ERROR_REG.set_value(state, error_reg_val) return expr - def set_return_val(self, state, val, is_fp=None, size=None, stack_base=None): + def set_return_val(self, state, val, ty, **kwargs): if self.ERROR_REG is not None: val = self.linux_syscall_update_error_reg(state, val) - super().set_return_val(state, val, is_fp=is_fp, size=size, stack_base=stack_base) + super().set_return_val(state, val, ty, **kwargs) class SimCCX86LinuxSyscall(SimCCSyscall): @@ -1195,17 +1238,11 @@ class SimCCSystemVAMD64(SimCC): RETURN_ADDR = SimStackArg(0, 8) RETURN_VAL = SimRegArg('rax', 8) OVERFLOW_RETURN_VAL = SimRegArg('rdx', 8) - FP_RETURN_VAL = SimRegArg('xmm0', 32) + FP_RETURN_VAL = SimRegArg('xmm0', 128) + OVERFLOW_FP_RETURN_VAL = SimRegArg('xmm1', 128) ARCH = archinfo.ArchAMD64 STACK_ALIGNMENT = 16 - def __init__(self, arch, args=None, ret_val=None, sp_delta=None, func_ty=None): - super().__init__(arch, args, ret_val, sp_delta, func_ty) - - # Remove the ret address on stack - if self.args is not None: - self.args = [ i for i in self.args if not (isinstance(i, SimStackArg) and i.stack_offset == 0x0) ] - @classmethod def _match(cls, arch, args, sp_delta): if cls.ARCH is not None and not isinstance(arch, cls.ARCH): @@ -1216,7 +1253,7 @@ def _match(cls, arch, args, sp_delta): sample_inst = cls(arch) all_fp_args = list(sample_inst.fp_args) all_int_args = list(sample_inst.int_args) - both_iter = sample_inst.both_args + both_iter = sample_inst.memory_args some_both_args = [next(both_iter) for _ in range(len(args))] for arg in args: @@ -1227,6 +1264,152 @@ def _match(cls, arch, args, sp_delta): return True + # https://raw.githubusercontent.com/wiki/hjl-tools/x86-psABI/x86-64-psABI-1.0.pdf + # section 3.2.3 + def next_arg(self, session, arg_type): + state = session.getstate() + classification = self._classify(arg_type) + try: + mapped_classes = [] + for cls in classification: + if cls == 'SSEUP': + mapped_classes.append(mapped_classes[-1].sse_extend(self.arch.bytes)) + elif cls == 'NO_CLASS': + raise NotImplementedError("Bug. Report to @rhelmot") + elif cls == 'MEMORY': + mapped_classes.append(next(session.both_iter)) + elif cls == 'INTEGER': + mapped_classes.append(next(session.int_iter)) + elif cls == 'SSE': + mapped_classes.append(next(session.fp_iter)) + else: + raise NotImplementedError("Bug. Report to @rhelmot") + except StopIteration: + session.setstate(state) + mapped_classes = [next(session.both_iter) for _ in classification] + + return refine_locs_with_struct_type(self.arch, mapped_classes, arg_type) + + def return_val(self, ty: Optional[SimType], perspective_returned=False): + if ty is None: + return None + if ty._arch is None: + ty = ty.with_arch(self.arch) + classification = self._classify(ty) + if any(cls == 'MEMORY' for cls in classification): + assert all(cls == 'MEMORY' for cls in classification) + byte_size = ty.size // self.arch.byte_width + referenced_locs = [SimStackArg(offset, self.arch.bytes) for offset in range(0, byte_size, self.arch.bytes)] + referenced_loc = refine_locs_with_struct_type(self.arch, referenced_locs, ty) + if perspective_returned: + ptr_loc = self.RETURN_VAL + else: + ptr_loc = SimRegArg('rdi', 8) + reference_loc = SimReferenceArgument(ptr_loc, referenced_loc) + return reference_loc + else: + mapped_classes = [] + int_iter = iter([self.RETURN_VAL, self.OVERFLOW_RETURN_VAL]) + fp_iter = iter([self.FP_RETURN_VAL, self.OVERFLOW_FP_RETURN_VAL]) + for cls in classification: + if cls == 'SSEUP': + mapped_classes.append(mapped_classes[-1].sse_extend(self.arch.bytes)) + elif cls == 'NO_CLASS': + raise NotImplementedError("Bug. Report to @rhelmot") + elif cls == 'INTEGER': + mapped_classes.append(next(int_iter)) + elif cls == 'SSE': + mapped_classes.append(next(fp_iter)) + else: + raise NotImplementedError("Bug. Report to @rhelmot") + + return refine_locs_with_struct_type(self.arch, mapped_classes, ty) + + def return_in_implicit_outparam(self, ty): + if isinstance(ty, SimTypeBottom): + return False + # :P + return isinstance(self.return_val(ty), SimReferenceArgument) + + def _classify(self, ty, chunksize=None): + if chunksize is None: + chunksize = self.arch.bytes + nchunks = (ty.size // self.arch.byte_width + chunksize - 1) // chunksize + if isinstance(ty, (SimTypeInt, SimTypeChar, SimTypePointer, SimTypeNum)): + return ['INTEGER'] * nchunks + elif isinstance(ty, (SimTypeFloat,)): + return ['SSE'] + ['SSEUP'] * (nchunks - 1) + elif isinstance(ty, (SimStruct, SimTypeFixedSizeArray, SimUnion)): + if ty.size > 512: + return ['MEMORY'] * nchunks + flattened = self._flatten(ty) + if flattened is None: + return ['MEMORY'] * nchunks + result = ['NO_CLASS'] * nchunks + for offset, subty_list in flattened.items(): + for subty in subty_list: + # is the smaller chunk size necessary? Genuinely unsure + subresult = self._classify(subty, chunksize=1) + idx_start = offset // chunksize + idx_end = (offset + (subty.size // self.arch.byte_width) - 1) // chunksize + for i, idx in enumerate(range(idx_start, idx_end + 1)): + subclass = subresult[i*chunksize] + result[idx] = self._combine_classes(result[idx], subclass) + if any(subresult == 'MEMORY' for subresult in result): + return ['MEMORY'] * nchunks + if nchunks > 2 and (result[0] != 'SSE' or any(subresult != 'SSEUP' for subresult in result[1:])): + return ['MEMORY'] * nchunks + for i in range(1, nchunks): + if result[i] == 'SSEUP' and result[i-1] not in ('SSE', 'SSEUP'): + result[i] = 'SSE' + return result + else: + raise NotImplementedError("Ummmmm... not sure what goes here. report bug to @rhelmot") + + def _flatten(self, ty): + result = defaultdict(list) + if isinstance(ty, SimStruct): + if ty.packed: + return None + for field, subty in ty.fields.items(): + offset = ty.offsets[field] + subresult = self._flatten(subty) + if subresult is None: + return None + for suboffset, subsubty in subresult.items(): + result[offset + suboffset].append(subsubty) + elif isinstance(ty, SimTypeFixedSizeArray): + subresult = self._flatten(ty.elem_type) + if subresult is None: + return None + for suboffset, subsubty in subresult.items(): + for idx in range(ty.length): + # TODO I think we need an explicit stride field on array types + result[idx * ty.elem_type.size // self.arch.byte_width + suboffset].append(subsubty) + elif isinstance(ty, SimUnion): + for field, subty in ty.members.items(): + subresult = self._flatten(subty) + if subresult is None: + return None + for suboffset, subsubty in subresult.items(): + result[suboffset].append(subsubty) + else: + result[0].append(ty) + return result + + def _combine_classes(self, cls1, cls2): + if cls1 == cls2: + return cls1 + if cls1 == 'NO_CLASS': + return cls2 + if cls2 == 'NO_CLASS': + return cls1 + if cls1 == 'MEMORY' or cls2 == 'MEMORY': + return 'MEMORY' + if cls1 == 'INTEGER' or cls2 == 'INTEGER': + return 'INTEGER' + return 'SSE' + class SimCCAMD64LinuxSyscall(SimCCSyscall): ARG_REGS = ['rdi', 'rsi', 'rdx', 'r10', 'r8', 'r9'] @@ -1454,6 +1637,14 @@ class SimCCSoot(SimCC): ARCH = archinfo.ArchSoot ARG_REGS = [] + def setup_callsite(self, state, ret_addr, args, prototype, stack_base=None, alloc_base=None, grow_like_stack=True): + SootMixin.setup_callsite(state, args, ret_addr) + + @staticmethod + def guess_prototype(args, prototype=None): + # uhhhhhhhhhhhhhhhh + return None + class SimCCUnknown(SimCC): """ @@ -1465,7 +1656,7 @@ def _match(arch, args, sp_delta): # pylint: disable=unused-argument return True def __repr__(self): - return f"" + return f"" class SimCCS390X(SimCC): diff --git a/angr/engines/soot/engine.py b/angr/engines/soot/engine.py index 0e0f29ec75e..132bc13bbf3 100644 --- a/angr/engines/soot/engine.py +++ b/angr/engines/soot/engine.py @@ -9,6 +9,7 @@ from ...errors import SimEngineError, SimTranslationError from cle import CLEError from ...state_plugins.inspect import BP_AFTER, BP_BEFORE +from ...sim_type import SimTypeNum, SimTypeFunction, parse_type from ..engine import SuccessorsMixin from ..procedure import ProcedureMixin from .exceptions import BlockTerminationNotice, IncorrectLocationException @@ -350,7 +351,9 @@ def prepare_native_return_state(native_state): if ret_var is not None: # get return symbol from native state native_cc = javavm_simos.get_native_cc() - ret_symbol = native_cc.get_return_val(native_state).to_claripy() + ret_symbol = native_cc.return_val( + javavm_simos.get_native_type(ret_var.type) + ).get_value(native_state).to_claripy() # convert value to java type if ret_var.type in ArchSoot.primitive_types: # return value has a primitive type @@ -399,7 +402,14 @@ def _setup_native_callsite(cls, state, native_addr, java_method, args, ret_addr, # add to args final_args = [jni_env, ref] + args + # Step 3: generate C prototype from java_method + voidp = parse_type('void*') + arg_types = [voidp, voidp] + [state.project.simos.get_native_type(ty) for ty in java_method.params] + ret_type = state.project.simos.get_native_type(java_method.ret) + prototype = SimTypeFunction(args=arg_types, returnty=ret_type) + # Step 3: create native invoke state return state.project.simos.state_call(native_addr, *final_args, base_state=state, + prototype=prototype, ret_type=java_method.ret) diff --git a/angr/exploration_techniques/director.py b/angr/exploration_techniques/director.py index b3cf63cf78e..861c08e88fe 100644 --- a/angr/exploration_techniques/director.py +++ b/angr/exploration_techniques/director.py @@ -271,12 +271,12 @@ def _check_arguments(self, arch, state): # TODO: add calling convention detection to individual functions, and use that instead of the # TODO: default calling convention of the platform - cc = DEFAULT_CC[arch.name](arch) # type: s_cc.SimCC + cc = DEFAULT_CC[arch.name](arch) + real_args = cc.get_args(state, cc.guess_prototype([0]*len(self.arguments))) - for i, expected_arg in enumerate(self.arguments): + for i, (expected_arg, real_arg) in enumerate(zip(self.arguments, real_args)): if expected_arg is None: continue - real_arg = cc.arg(state, i) expected_arg_type, expected_arg_value = expected_arg r = self._compare_arguments(state, expected_arg_type, expected_arg_value, real_arg) diff --git a/angr/factory.py b/angr/factory.py index f2defc716f9..2d7d887b86c 100644 --- a/angr/factory.py +++ b/angr/factory.py @@ -186,12 +186,13 @@ def simgr(self, *args, **kwargs): """ return self.simulation_manager(*args, **kwargs) - def callable(self, addr, concrete_only=False, perform_merge=True, base_state=None, toc=None, cc=None): + def callable(self, addr, prototype=None, concrete_only=False, perform_merge=True, base_state=None, toc=None, cc=None): """ A Callable is a representation of a function in the binary that can be interacted with like a native python function. :param addr: The address of the function to use + :param prototype: The prototype of the call to use, as a string or a SimTypeFunction :param concrete_only: Throw an exception if the execution splits into multiple states :param perform_merge: Merge all result states into one at the end (only relevant if concrete_only=False) :param base_state: The state from which to do these runs @@ -203,26 +204,16 @@ def callable(self, addr, concrete_only=False, perform_merge=True, base_state=Non """ return Callable(self.project, addr=addr, + prototype=prototype, concrete_only=concrete_only, perform_merge=perform_merge, base_state=base_state, toc=toc, cc=cc) - def cc(self, args=None, ret_val=None, sp_delta=None, func_ty=None): + def cc(self): """ - Return a SimCC (calling convention) parametrized for this project and, optionally, a given function. - - :param args: A list of argument storage locations, as SimFunctionArguments. - :param ret_val: The return value storage location, as a SimFunctionArgument. - :param sp_delta: Does this even matter?? - :param func_ty: The prototype for the given function, as a SimType or a C-style function declaration that - can be parsed into a SimTypeFunction instance. - - Example func_ty strings: - >>> "int func(char*, int)" - >>> "int f(int, int, int*);" - Function names are ignored. + Return a SimCC (calling convention) parametrized for this project. Relevant subclasses of SimFunctionArgument are SimRegArg and SimStackArg, and shortcuts to them can be found on this `cc` object. @@ -230,38 +221,7 @@ def cc(self, args=None, ret_val=None, sp_delta=None, func_ty=None): For stack arguments, offsets are relative to the stack pointer on function entry. """ - return self._default_cc(arch=self.project.arch, - args=args, - ret_val=ret_val, - sp_delta=sp_delta, - func_ty=func_ty) - - def cc_from_arg_kinds(self, fp_args, ret_fp=None, sizes=None, sp_delta=None, func_ty=None): - """ - Get a SimCC (calling convention) that will extract floating-point/integral args correctly. - - :param arch: The Archinfo arch for this CC - :param fp_args: A list, with one entry for each argument the function can take. True if the argument is fp, - false if it is integral. - :param ret_fp: True if the return value for the function is fp. - :param sizes: Optional: A list, with one entry for each argument the function can take. Each entry is the - size of the corresponding argument in bytes. - :param sp_delta: The amount the stack pointer changes over the course of this function - CURRENTLY UNUSED - :param func_ty: A SimType for the function itself or a C-style function declaration that can be parsed into - a SimTypeFunction instance. - - Example func_ty strings: - >>> "int func(char*, int)" - >>> "int f(int, int, int*);" - Function names are ignored. - - """ - return self._default_cc.from_arg_kinds(arch=self.project.arch, - fp_args=fp_args, - ret_fp=ret_fp, - sizes=sizes, - sp_delta=sp_delta, - func_ty=func_ty) + return self._default_cc(arch=self.project.arch) #pylint: disable=unused-argument, no-self-use, function-redefined @overload diff --git a/angr/knowledge_plugins/functions/function.py b/angr/knowledge_plugins/functions/function.py index b6344ab21fd..d31ff07865a 100644 --- a/angr/knowledge_plugins/functions/function.py +++ b/angr/knowledge_plugins/functions/function.py @@ -39,10 +39,10 @@ class Function(Serializable): 'is_syscall', '_project', 'is_plt', 'addr', 'is_simprocedure', '_name', 'is_default_name', 'from_signature', 'binary_name', '_argument_registers', '_argument_stack_variables', - 'bp_on_stack', 'retaddr_on_stack', 'sp_delta', '_cc', '_prototype', '_returning', + 'bp_on_stack', 'retaddr_on_stack', 'sp_delta', 'calling_convention', 'prototype', '_returning', 'prepared_registers', 'prepared_stack_variables', 'registers_read_afterwards', 'startpoint', '_addr_to_block_node', '_block_sizes', '_block_cache', '_local_blocks', - '_local_block_addrs', 'info', 'tags', 'alignment', + '_local_block_addrs', 'info', 'tags', 'alignment', 'is_prototype_guessed', ) def __init__(self, function_manager, addr, name=None, syscall=None, is_simprocedure=None, binary_name=None, @@ -94,9 +94,10 @@ def __init__(self, function_manager, addr, name=None, syscall=None, is_simproced self.retaddr_on_stack = False self.sp_delta = 0 # Calling convention - self._cc = None # type: Optional[SimCC] + self.calling_convention = None # type: Optional[SimCC] # Function prototype - self._prototype = None # type: Optional[SimTypeFunction] + self.prototype = None # type: Optional[SimTypeFunction] + self.is_prototype_guessed: bool = True # Whether this function returns or not. `None` means it's not determined yet self._returning = None self.prepared_registers = set() @@ -192,12 +193,6 @@ def __init__(self, function_manager, addr, name=None, syscall=None, is_simproced if self.project.arch.name in DEFAULT_CC: cc = DEFAULT_CC[arch.name](arch) - # update cc.args according to num_args - # TODO: Handle non-traditional arguments like fp - if cc is not None and not cc.args and simproc.num_args: - args = cc.arg_locs(is_fp=[False] * simproc.num_args) # arg_locs() uses cc.args - cc.args = args - self.calling_convention = cc else: self.calling_convention = None @@ -334,65 +329,6 @@ def code_constants(self): # TODO: remove link register values return [const.value for block in self.blocks for const in block.vex.constants] - @property - def calling_convention(self) -> Optional[SimCC]: - """ - Get the calling convention of this function. - - :return: The calling convention of this function. - """ - return self._cc - - @calling_convention.setter - def calling_convention(self, v): - """ - Set the calling convention of this function. If the new cc has a function prototype, we will clear - self._prototype. Otherwise, if self.prototype is set, we will use it to update the function prototype of the new - cc, and then clear self._prototype. A warning message will be generated in either case. - - :param Optional[SimCC] v: The new calling convention. - :return: None - """ - self._cc = v - - if self._cc is not None: - if self._cc.func_ty is None and self._prototype is not None: - l.warning("The new calling convention for %r does not have a prototype associated. Using the existing " - "function prototype to update the new calling convention. The existing function prototype " - "will be removed.", self) - self._cc.set_func_type_with_arch(self._prototype) - self._prototype = None - elif self._cc.func_ty is not None and self._prototype is not None: - l.warning("The new calling convention for %r already has a prototype associated. The existing function " - "prototype will be removed.", self) - self._prototype = None - - @property - def prototype(self) -> Optional[SimTypeFunction]: - """ - Get the prototype of this function. We prioritize the function prototype that is set in self.calling_convention. - - :return: The function prototype. - """ - if self._cc: - return self._cc.func_ty - else: - return self._prototype - - @prototype.setter - def prototype(self, proto): - """ - Set a new prototype to this function. If a calling convention is already set to this function, the new prototype - will be set to this calling convention instead. - - :param Optional[SimTypeFunction] proto: The new prototype. - :return: None - """ - if self._cc: - self._cc.set_func_type_with_arch(proto) - else: - self._prototype = proto - @classmethod def _get_cmsg(cls): return function_pb2.Function() @@ -1436,11 +1372,7 @@ def find_declaration(self): return proto = library.get_prototype(self.name) - self.prototype = proto - if self.calling_convention is not None: - self.calling_convention.args = self.calling_convention.arg_locs( - is_fp=[isinstance(arg, (SimTypeFloat, SimTypeDouble)) for arg in proto.args]) - self.calling_convention.set_func_type_with_arch(proto) + self.prototype = proto.with_arch(self.project.arch) @staticmethod def _addr_to_funcloc(addr): @@ -1471,12 +1403,13 @@ def apply_definition(self, definition, calling_convention=None): """ if not definition.endswith(";"): definition += ";" - func_def = parse_defns(definition) + func_def = parse_defns(definition, arch=self.project.arch) if len(func_def.keys()) > 1: raise Exception("Too many definitions: %s " % list(func_def.keys())) name, ty = func_def.popitem() # type: str, SimTypeFunction self.name = name + self.prototype = ty # setup the calling convention # If a SimCC object is passed assume that this is sane and just use it if isinstance(calling_convention, SimCC): @@ -1484,11 +1417,11 @@ def apply_definition(self, definition, calling_convention=None): # If it is a subclass of SimCC we can instantiate it elif isinstance(calling_convention, type) and issubclass(calling_convention, SimCC): - self.calling_convention = calling_convention(self.project.arch, func_ty=ty) + self.calling_convention = calling_convention(self.project.arch) # If none is specified default to something elif calling_convention is None: - self.calling_convention = self.project.factory.cc(func_ty=ty) + self.calling_convention = self.project.factory.cc() else: raise TypeError("calling_convention has to be one of: [SimCC, type(SimCC), None]") diff --git a/angr/procedures/definitions/__init__.py b/angr/procedures/definitions/__init__.py index 37fb6d8e670..1437fff2d57 100644 --- a/angr/procedures/definitions/__init__.py +++ b/angr/procedures/definitions/__init__.py @@ -174,25 +174,24 @@ def add_alias(self, name, *alt_names): new_procedure = copy.deepcopy(old_procedure) new_procedure.display_name = alt self.procedures[alt] = new_procedure + if name in self.prototypes: + self.prototypes[alt] = self.prototypes[name] + if name in self.non_returning: + self.non_returning.add(alt) + def _apply_metadata(self, proc, arch): if proc.cc is None and arch.name in self.default_ccs: proc.cc = self.default_ccs[arch.name](arch) - if proc.cc.func_ty is not None: - # Use inspect to extract the parameters from the run python function - proc.cc.func_ty.arg_names = inspect.getfullargspec(proc.run).args[1:] if proc.cc is None and arch.name in self.fallback_cc: proc.cc = self.fallback_cc[arch.name](arch) if proc.display_name in self.prototypes: - proc.cc.func_ty = self.prototypes[proc.display_name].with_arch(arch) - if proc.cc.func_ty.arg_names is None: + proc.prototype = self.prototypes[proc.display_name].with_arch(arch) + if proc.prototype.arg_names is None: # Use inspect to extract the parameters from the run python function - proc.cc.func_ty.arg_names = inspect.getfullargspec(proc.run).args[1:] - proc.cc.args = proc.cc.arg_locs( - is_fp=[isinstance(arg, (SimTypeFloat, SimTypeDouble)) for arg in proc.cc.func_ty.args]) + proc.prototype.arg_names = inspect.getfullargspec(proc.run).args[1:] if not proc.ARGS_MISMATCH: - proc.cc.num_args = len(proc.cc.func_ty.args) - proc.num_args = len(proc.cc.func_ty.args) + proc.num_args = len(proc.prototype.args) if proc.display_name in self.non_returning: proc.returns = False proc.library_name = self.name @@ -295,13 +294,13 @@ def _try_demangle(name): return name @staticmethod - def _proto_from_demangled_name(name: str) -> Optional[SimCC]: + def _proto_from_demangled_name(name: str) -> Optional[SimTypeFunction]: """ Attempt to extract arguments and calling convention information for a C++ function whose name was mangled according to the Itanium C++ ABI symbol mangling language. :param name: The demangled function name. - :return: A calling convention or None if a calling convention cannot be found. + :return: A prototype or None if a prototype cannot be found. """ try: @@ -342,12 +341,12 @@ def get_stub(self, name, arch): # try to determine a prototype from the function name if possible if demangled_name != name: # itanium-mangled function name - stub.cc.set_func_type_with_arch(self._proto_from_demangled_name(demangled_name)) - stub.cc.args = stub.cc.arg_locs( - is_fp=[isinstance(arg, (SimTypeFloat, SimTypeDouble)) for arg in stub.cc.func_ty.args]) - if stub.cc.func_ty is not None and not stub.ARGS_MISMATCH: - stub.cc.num_args = len(stub.cc.func_ty.args) - stub.num_args = len(stub.cc.func_ty.args) + stub.prototype = self._proto_from_demangled_name(demangled_name) + if stub.prototype is not None: + stub.prototype = stub.prototype.with_arch(arch) + if not stub.ARGS_MISMATCH: + stub.cc.num_args = len(stub.prototype.args) + stub.num_args = len(stub.prototype.args) return stub def get_prototype(self, name: str, arch=None) -> Optional[SimTypeFunction]: @@ -516,13 +515,11 @@ def _apply_numerical_metadata(self, proc, number, arch, abi): proc.abi = abi if abi in self.default_cc_mapping: cc = self.default_cc_mapping[abi](arch) - if proc.cc is not None: - cc.set_func_type_with_arch(proc.cc.func_ty) proc.cc = cc # a bit of a hack. name = proc.display_name - if self.syscall_prototypes[abi].get(name, None) is not None and proc.cc is not None: - proc.cc.func_ty = self.syscall_prototypes[abi][name].with_arch(arch) + if self.syscall_prototypes[abi].get(name, None) is not None: + proc.prototype = self.syscall_prototypes[abi][name].with_arch(arch) # pylint: disable=arguments-differ def get(self, number, arch, abi_list=()): diff --git a/angr/procedures/definitions/glibc.py b/angr/procedures/definitions/glibc.py index 44bdd614fb6..12c0439eeff 100644 --- a/angr/procedures/definitions/glibc.py +++ b/angr/procedures/definitions/glibc.py @@ -21,12 +21,8 @@ libc.add_all_from_dict(P['posix']) libc.add_all_from_dict(P['glibc']) libc.add_all_from_dict(P['uclibc']) # gotta do this since there's no distinguishing different libcs without analysis. there should be no naming conflicts in the functions. -libc.add_alias('abort', '__assert_fail', '__stack_chk_fail') -libc.add_alias('memcpy', 'memmove', 'bcopy') -libc.add_alias('getc', '_IO_getc') -libc.add_alias('putc', '_IO_putc') libc.set_non_returning('exit_group', 'exit', 'abort', 'pthread_exit', '__assert_fail', - 'longjmp', 'siglongjmp', '__longjmp_chk', '__siglongjmp_chk') + 'longjmp', 'siglongjmp', '__longjmp_chk', '__siglongjmp_chk') # @@ -3087,6 +3083,11 @@ _l.debug("Libc provides %d function prototypes, and has %d unsupported function prototypes.", proto_count, unsupported_count) +libc.add_alias('abort', '__assert_fail', '__stack_chk_fail') +libc.add_alias('memcpy', 'memmove', 'bcopy') +libc.add_alias('getc', '_IO_getc') +libc.add_alias('putc', '_IO_putc') + # # function prototypes in strings diff --git a/angr/procedures/glibc/__libc_init.py b/angr/procedures/glibc/__libc_init.py index 7464efb92bb..04716826c06 100644 --- a/angr/procedures/glibc/__libc_init.py +++ b/angr/procedures/glibc/__libc_init.py @@ -27,7 +27,8 @@ def run(self, raw_args, unused, slingshot, structors): self.argv = self.state.memory.load(raw_args + 1 * offset, readlen, endness=endness) self.envp= self.state.memory.load(raw_args + (1 + argc_val + 1) * offset, readlen, endness=endness) # TODO: __cxa_atexit calls for various at-exit needs - self.call(self.main, (self.argc, self.argv, self.envp), 'after_slingshot') + self.call(self.main, (self.argc, self.argv, self.envp), 'after_slingshot', + prototype='int main(int arch, char **argv, char **envp)') def after_slingshot(self, raw_args, unused, slingshot, structors, exit_addr=0): self.exit(0) diff --git a/angr/procedures/glibc/__libc_start_main.py b/angr/procedures/glibc/__libc_start_main.py index 3475ceae9b3..af1ea8ebcb7 100644 --- a/angr/procedures/glibc/__libc_start_main.py +++ b/angr/procedures/glibc/__libc_start_main.py @@ -125,7 +125,6 @@ def envp(self): def run(self, main, argc, argv, init, fini): # TODO: handle symbolic and static modes - # TODO: add argument types self._initialize_ctype_table() self._initialize_errno() @@ -135,10 +134,12 @@ def run(self, main, argc, argv, init, fini): # TODO: __cxa_atexit calls for various at-exit needs - self.call(self.init, (self.argc, self.argv, self.envp), 'after_init') + self.call(self.init, (self.argc[31:0], self.argv, self.envp), 'after_init', + prototype = 'int main(int argc, char **argv, char **envp)') def after_init(self, main, argc, argv, init, fini, exit_addr=0): - self.call(self.main, (self.argc, self.argv, self.envp), 'after_main') + self.call(self.main, (self.argc[31:0], self.argv, self.envp), 'after_main', + prototype='int main(int argc, char **argv, char **envp)') def after_main(self, main, argc, argv, init, fini, exit_addr=0): self.exit(0) @@ -161,7 +162,8 @@ def static_exits(self, blocks): break cc = angr.DEFAULT_CC[self.arch.name](self.arch) - args = [ cc.arg(state, _) for _ in range(5) ] + ty = angr.sim_type.parse_signature('void x(void*, void*, void*, void*, void*)').with_arch(self.arch) + args = cc.get_args(state, ty) main, _, _, init, fini = self._extract_args(blank_state, *args) all_exits = [ diff --git a/angr/procedures/java_jni/__init__.py b/angr/procedures/java_jni/__init__.py index ef09901b3d1..ead488e0111 100644 --- a/angr/procedures/java_jni/__init__.py +++ b/angr/procedures/java_jni/__init__.py @@ -32,9 +32,10 @@ def execute(self, state, successors=None, arguments=None, ret_to=None): if not self.return_ty: raise ValueError("Classes implementing JNISimProcedure's must set the return type.") elif self.return_ty != 'void': - func_ty = SimTypeFunction(args=[], + prototype = SimTypeFunction(args=self.prototype.args, returnty=state.project.simos.get_native_type(self.return_ty)) - self.cc = DefaultCC[state.arch.name](state.arch, func_ty=func_ty) + self.cc = DefaultCC[state.arch.name](state.arch) + self.prototype = prototype super(JNISimProcedure, self).execute(state, successors, arguments, ret_to) # diff --git a/angr/procedures/java_jni/method_calls.py b/angr/procedures/java_jni/method_calls.py index 849625b9150..7bd84052de7 100644 --- a/angr/procedures/java_jni/method_calls.py +++ b/angr/procedures/java_jni/method_calls.py @@ -62,7 +62,7 @@ def _invoke(self, method_id, obj=None, dynamic_dispatch=True, args_in_array=None self.call(invoke_addr, java_args, "return_from_invocation", cc=SimCCSoot(ArchSoot())) def _get_arg_values(self, no_of_args): - return [ self.arg(self.num_args+idx).to_claripy() for idx in range(no_of_args) ] + return [self.va_arg('void*') for _ in range(no_of_args)] def _get_arg_values_from_array(self, array, no_of_args): return self._load_from_native_memory(addr=array, data_size=self.arch.bytes, diff --git a/angr/procedures/libc/access.py b/angr/procedures/libc/access.py index 2d8acbde11d..18ddeea060d 100644 --- a/angr/procedures/libc/access.py +++ b/angr/procedures/libc/access.py @@ -10,6 +10,6 @@ class access(angr.SimProcedure): def run(self, path, mode): - ret = self.state.solver.BVS('access', self.arch.bits) + ret = self.state.solver.BVS('access', self.arch.sizeof['int']) self.state.add_constraints(self.state.solver.Or(ret == 0, ret == -1)) return ret diff --git a/angr/procedures/libc/atoi.py b/angr/procedures/libc/atoi.py index 38b59f0a7de..b779be4f1aa 100644 --- a/angr/procedures/libc/atoi.py +++ b/angr/procedures/libc/atoi.py @@ -8,4 +8,6 @@ class atoi(angr.SimProcedure): #pylint:disable=arguments-differ def run(self, s): strtol = angr.SIM_PROCEDURES['libc']['strtol'] - return strtol.strtol_inner(s, self.state, self.state.memory, 10, True)[1] + val = strtol.strtol_inner(s, self.state, self.state.memory, 10, True)[1] + val = val[self.arch.sizeof['int'] - 1:0] + return val diff --git a/angr/procedures/libc/feof.py b/angr/procedures/libc/feof.py index a20865b1f4a..4b1f211eaf3 100644 --- a/angr/procedures/libc/feof.py +++ b/angr/procedures/libc/feof.py @@ -16,6 +16,6 @@ def run(self, file_ptr): simfd = self.state.posix.get_fd(fileno) if simfd is None: return None - return self.state.solver.If(simfd.eof(), self.state.solver.BVV(1, self.state.arch.bits), 0) + return self.state.solver.If(simfd.eof(), self.state.solver.BVV(1, self.arch.sizeof['int']), 0) feof_unlocked = feof diff --git a/angr/procedures/libc/fflush.py b/angr/procedures/libc/fflush.py index e4e6025136a..2dc5033a43a 100644 --- a/angr/procedures/libc/fflush.py +++ b/angr/procedures/libc/fflush.py @@ -7,6 +7,6 @@ class fflush(angr.SimProcedure): #pylint:disable=arguments-differ,unused-argument def run(self, fd): - return self.state.solver.BVV(0, self.state.arch.bits) + return 0 fflush_unlocked = fflush diff --git a/angr/procedures/libc/fgetc.py b/angr/procedures/libc/fgetc.py index f124ffbed69..501b024be19 100644 --- a/angr/procedures/libc/fgetc.py +++ b/angr/procedures/libc/fgetc.py @@ -17,7 +17,7 @@ def run(self, stream, simfd=None): return -1 data, real_length, = simfd.read_data(1) - return self.state.solver.If(real_length == 0, -1, data.zero_extend(self.state.arch.bits - 8)) + return self.state.solver.If(real_length == 0, -1, data.zero_extend(self.arch.sizeof['int'] - 8)) getc = fgetc fgetc_unlocked = fgetc diff --git a/angr/procedures/libc/fgets.py b/angr/procedures/libc/fgets.py index 7ba28411c9e..acc35ebfce9 100644 --- a/angr/procedures/libc/fgets.py +++ b/angr/procedures/libc/fgets.py @@ -12,6 +12,8 @@ class fgets(angr.SimProcedure): #pylint:disable=arguments-differ def run(self, dst, size, file_ptr): + size = size.zero_extend(self.arch.bits - self.arch.sizeof['int']) + # let's get the memory back for the file we're interested in and find the newline fd_offset = io_file_data_for_arch(self.state.arch)['fd'] fd = self.state.mem[file_ptr + fd_offset:].int.resolved diff --git a/angr/procedures/libc/fprintf.py b/angr/procedures/libc/fprintf.py index b4a8a030ccf..9258ca668d7 100644 --- a/angr/procedures/libc/fprintf.py +++ b/angr/procedures/libc/fprintf.py @@ -20,8 +20,8 @@ def run(self, file_ptr, fmt): # pylint:disable=unused-argument return -1 # The format str is at index 1 - fmt_str = self._parse(1) - out_str = fmt_str.replace(2, self.arg) + fmt_str = self._parse(fmt) + out_str = fmt_str.replace(self.va_arg) simfd.write_data(out_str, out_str.size() // 8) diff --git a/angr/procedures/libc/fscanf.py b/angr/procedures/libc/fscanf.py index 56248408840..7883c362d80 100644 --- a/angr/procedures/libc/fscanf.py +++ b/angr/procedures/libc/fscanf.py @@ -14,6 +14,6 @@ def run(self, file_ptr, fmt): # pylint:disable=unused-argument if simfd is None: return -1 - fmt_str = self._parse(1) - items = fmt_str.interpret(2, self.arg, simfd=simfd) + fmt_str = self._parse(fmt) + items = fmt_str.interpret(self.va_arg, simfd=simfd) return items diff --git a/angr/procedures/libc/fseek.py b/angr/procedures/libc/fseek.py index dc6a38832ef..f5ec847ebef 100644 --- a/angr/procedures/libc/fseek.py +++ b/angr/procedures/libc/fseek.py @@ -29,6 +29,6 @@ def run(self, file_ptr, offset, whence): simfd = self.state.posix.get_fd(fd) if simfd is None: return -1 - return self.state.solver.If(simfd.seek(offset, whence), self.state.solver.BVV(0, self.state.arch.bits), -1) + return self.state.solver.If(simfd.seek(offset, whence), self.state.solver.BVV(0, self.arch.sizeof['int']), -1) fseeko = fseek diff --git a/angr/procedures/libc/memcmp.py b/angr/procedures/libc/memcmp.py index 27a3144a3d7..3d5cf8377e7 100644 --- a/angr/procedures/libc/memcmp.py +++ b/angr/procedures/libc/memcmp.py @@ -19,10 +19,11 @@ def run(self, s1_addr, s2_addr, n): l.debug("Definite size %s and conditional size: %s", definite_size, conditional_size) + int_bits = self.arch.sizeof['int'] if definite_size > 0: s1_part = self.state.memory.load(s1_addr, definite_size, endness='Iend_BE') s2_part = self.state.memory.load(s2_addr, definite_size, endness='Iend_BE') - cases = [ [s1_part == s2_part, self.state.solver.BVV(0, self.state.arch.bits)], [self.state.solver.ULT(s1_part, s2_part), self.state.solver.BVV(-1, self.state.arch.bits)], [self.state.solver.UGT(s1_part, s2_part), self.state.solver.BVV(1, self.state.arch.bits) ] ] + cases = [ [s1_part == s2_part, self.state.solver.BVV(0, int_bits)], [self.state.solver.ULT(s1_part, s2_part), self.state.solver.BVV(-1, int_bits)], [self.state.solver.UGT(s1_part, s2_part), self.state.solver.BVV(1, int_bits) ] ] definite_answer = self.state.solver.ite_cases(cases, 2) constraint = self.state.solver.Or(*[c for c,_ in cases]) self.state.add_constraints(constraint) @@ -31,7 +32,7 @@ def run(self, s1_addr, s2_addr, n): l.debug("Created constraint: %s", constraint) l.debug("... crom cases: %s", cases) else: - definite_answer = self.state.solver.BVV(0, self.state.arch.bits) + definite_answer = self.state.solver.BVV(0, int_bits) if not self.state.solver.symbolic(definite_answer) and self.state.solver.eval(definite_answer) != 0: return definite_answer @@ -44,7 +45,7 @@ def run(self, s1_addr, s2_addr, n): for byte, bit in zip(range(conditional_size), range(conditional_size*8, 0, -8)): s1_part = s1_all[conditional_size*8-1 : bit-8] s2_part = s2_all[conditional_size*8-1 : bit-8] - cases = [ [s1_part == s2_part, self.state.solver.BVV(0, self.state.arch.bits)], [self.state.solver.ULT(s1_part, s2_part), self.state.solver.BVV(-1, self.state.arch.bits)], [self.state.solver.UGT(s1_part, s2_part), self.state.solver.BVV(1, self.state.arch.bits) ] ] + cases = [ [s1_part == s2_part, self.state.solver.BVV(0, int_bits)], [self.state.solver.ULT(s1_part, s2_part), self.state.solver.BVV(-1, int_bits)], [self.state.solver.UGT(s1_part, s2_part), self.state.solver.BVV(1, int_bits) ] ] conditional_rets[byte+1] = self.state.solver.ite_cases(cases, 0) self.state.add_constraints(self.state.solver.Or(*[c for c,_ in cases])) diff --git a/angr/procedures/libc/printf.py b/angr/procedures/libc/printf.py index 95a8ee10ac8..04910e497d2 100644 --- a/angr/procedures/libc/printf.py +++ b/angr/procedures/libc/printf.py @@ -5,27 +5,27 @@ l = logging.getLogger(name=__name__) class printf(FormatParser): - def run(self): + def run(self, fmt): stdout = self.state.posix.get_fd(1) if stdout is None: return -1 # The format str is at index 0 - fmt_str = self._parse(0) - out_str = fmt_str.replace(1, self.arg) + fmt_str = self._parse(fmt) + out_str = fmt_str.replace(self.va_arg) stdout.write_data(out_str, out_str.size() // 8) return out_str.size() // 8 class __printf_chk(FormatParser): - def run(self): + def run(self, _, fmt): stdout = self.state.posix.get_fd(1) if stdout is None: return -1 # The format str is at index 1 - fmt_str = self._parse(1) - out_str = fmt_str.replace(2, self.arg) + fmt_str = self._parse(fmt) + out_str = fmt_str.replace(self.va_arg) stdout.write_data(out_str, out_str.size() // 8) return out_str.size() // 8 diff --git a/angr/procedures/libc/puts.py b/angr/procedures/libc/puts.py index 53c8237977d..fbc412162ff 100644 --- a/angr/procedures/libc/puts.py +++ b/angr/procedures/libc/puts.py @@ -17,4 +17,4 @@ def run(self, string): length = self.inline_call(strlen, string).ret_expr out = stdout.write(string, length) stdout.write_data(self.state.solver.BVV(b'\n')) - return out + 1 + return (out + 1)[31:0] diff --git a/angr/procedures/libc/rand.py b/angr/procedures/libc/rand.py index 18319ed009d..f00e4b9e068 100644 --- a/angr/procedures/libc/rand.py +++ b/angr/procedures/libc/rand.py @@ -3,4 +3,4 @@ class rand(angr.SimProcedure): def run(self): rval = self.state.solver.BVS('rand', 31, key=('api', 'rand')) - return rval.zero_extend(self.state.arch.bits - 31) + return rval.zero_extend(self.arch.sizeof['int'] - 31) diff --git a/angr/procedures/libc/scanf.py b/angr/procedures/libc/scanf.py index cca79747128..35d463ce5d1 100644 --- a/angr/procedures/libc/scanf.py +++ b/angr/procedures/libc/scanf.py @@ -8,12 +8,12 @@ class scanf(ScanfFormatParser): #pylint:disable=arguments-differ,unused-argument def run(self, fmt): - fmt_str = self._parse(0) + fmt_str = self._parse(fmt) # we're reading from stdin so the region is the file's content simfd = self.state.posix.get_fd(0) if simfd is None: return -1 - items = fmt_str.interpret(1, self.arg, simfd=simfd) + items = fmt_str.interpret(self.va_arg, simfd=simfd) return items diff --git a/angr/procedures/libc/snprintf.py b/angr/procedures/libc/snprintf.py index 45c4ef63bd0..f5209e464fd 100644 --- a/angr/procedures/libc/snprintf.py +++ b/angr/procedures/libc/snprintf.py @@ -15,16 +15,14 @@ def run(self, dst_ptr, size, fmt): # pylint:disable=arguments-differ,unused-arg if self.state.solver.eval(size) == 0: return size - # The format str is at index 2 - fmt_str = self._parse(2) - out_str = fmt_str.replace(3, self.arg) + fmt_str = self._parse(fmt) + out_str = fmt_str.replace(self.va_arg) self.state.memory.store(dst_ptr, out_str) # place the terminating null byte - self.state.memory.store(dst_ptr + (out_str.size() // 8), self.state.solver.BVV(0, 8)) + self.state.memory.store(dst_ptr + (out_str.size() // self.arch.byte_width), self.state.solver.BVV(0, 8)) - # size_t has size arch.bits - return self.state.solver.BVV(out_str.size()//8, self.state.arch.bits) + return out_str.size()//self.arch.byte_width ###################################### # __snprintf_chk @@ -35,12 +33,11 @@ class __snprintf_chk(FormatParser): def run(self, dst_ptr, maxlen, size, fmt): # pylint:disable=arguments-differ,unused-argument # The format str is at index 4 - fmt_str = self._parse(4) - out_str = fmt_str.replace(5, self.arg) + fmt_str = self._parse(fmt) + out_str = fmt_str.replace(self.va_arg) self.state.memory.store(dst_ptr, out_str) # place the terminating null byte - self.state.memory.store(dst_ptr + (out_str.size() // 8), self.state.solver.BVV(0, 8)) + self.state.memory.store(dst_ptr + (out_str.size() // self.arch.byte_width), self.state.solver.BVV(0, 8)) - # size_t has size arch.bits - return self.state.solver.BVV(out_str.size()//8, self.state.arch.bits) + return out_str.size()//self.arch.byte_width diff --git a/angr/procedures/libc/sprintf.py b/angr/procedures/libc/sprintf.py index 7a3bceeba9d..919387ce0b4 100644 --- a/angr/procedures/libc/sprintf.py +++ b/angr/procedures/libc/sprintf.py @@ -14,12 +14,11 @@ class sprintf(FormatParser): def run(self, dst_ptr, fmt): # pylint:disable=unused-argument # The format str is at index 1 - fmt_str = self._parse(1) - out_str = fmt_str.replace(2, self.arg) + fmt_str = self._parse(fmt) + out_str = fmt_str.replace(self.va_arg) self.state.memory.store(dst_ptr, out_str) # place the terminating null byte - self.state.memory.store(dst_ptr + (out_str.size() // 8), self.state.solver.BVV(0, 8)) + self.state.memory.store(dst_ptr + (out_str.size() // self.arch.byte_width), self.state.solver.BVV(0, self.arch.byte_width)) - # size_t has size arch.bits - return self.state.solver.BVV(out_str.size()//8, self.state.arch.bits) + return out_str.size()//self.arch.byte_width diff --git a/angr/procedures/libc/sscanf.py b/angr/procedures/libc/sscanf.py index 7675f195ae6..1b754b6a72e 100644 --- a/angr/procedures/libc/sscanf.py +++ b/angr/procedures/libc/sscanf.py @@ -7,6 +7,6 @@ class sscanf(ScanfFormatParser): #pylint:disable=arguments-differ,unused-argument def run(self, data, fmt): - fmt_str = self._parse(1) - items = fmt_str.interpret(2, self.arg, addr=data) + fmt_str = self._parse(fmt) + items = fmt_str.interpret(self.va_arg, addr=data) return items diff --git a/angr/procedures/libc/strlen.py b/angr/procedures/libc/strlen.py index 76a53ff3f1e..e95b6baf9d3 100644 --- a/angr/procedures/libc/strlen.py +++ b/angr/procedures/libc/strlen.py @@ -33,7 +33,8 @@ def run(self, s, wchar=False, maxlen=None): # Make sure to convert s to ValueSet addr_desc: AbstractAddressDescriptor = self.state.memory._normalize_address(s) - length = self.state.solver.ESI(self.state.arch.bits) + # size_t + length = self.state.solver.ESI(self.arch.bits) for s_aw in self.state.memory._concretize_address_descriptor(addr_desc, None): s_ptr = s_aw.to_valueset(self.state) diff --git a/angr/procedures/libc/strncmp.py b/angr/procedures/libc/strncmp.py index 0492a57d5fa..2301890bae5 100644 --- a/angr/procedures/libc/strncmp.py +++ b/angr/procedures/libc/strncmp.py @@ -18,7 +18,7 @@ def run(self, a_addr, b_addr, limit, a_len=None, b_len=None, wchar=False, ignore match_constraints = [ ] variables = a_len.variables | b_len.variables | limit.variables - ret_expr = self.state.solver.Unconstrained("strncmp_ret", self.state.arch.bits, key=('api', 'strncmp')) + ret_expr = self.state.solver.Unconstrained("strncmp_ret", 32, key=('api', 'strncmp')) # determine the maximum number of bytes to compare concrete_run = False @@ -50,19 +50,19 @@ def run(self, a_addr, b_addr, limit, a_len=None, b_len=None, wchar=False, ignore if self.state.solver.single_valued(limit) and self.state.solver.eval(limit) == 0: # limit is 0 l.debug("returning equal for 0-limit") - return self.state.solver.BVV(0, self.state.arch.bits) + return self.state.solver.BVV(0, 32) elif self.state.solver.single_valued(a_len) and self.state.solver.single_valued(b_len) and \ self.state.solver.eval(a_len) == self.state.solver.eval(b_len) == 0: # two empty strings l.debug("returning equal for two empty strings") - return self.state.solver.BVV(0, self.state.arch.bits) + return self.state.solver.BVV(0, 32) else: # all other cases fall into this branch l.debug("returning non-equal for comparison of an empty string and a non-empty string") if a_strlen.max_null_index == 0: - return self.state.solver.BVV(-1, self.state.arch.bits) + return self.state.solver.BVV(-1, 32) else: - return self.state.solver.BVV(1, self.state.arch.bits) + return self.state.solver.BVV(1, 32) # the bytes max_byte_len = maxlen * char_size @@ -96,9 +96,9 @@ def run(self, a_addr, b_addr, limit, a_len=None, b_len=None, wchar=False, ignore if a_conc != b_conc: l.debug("... found mis-matching concrete bytes 0x%x and 0x%x", a_conc, b_conc) if a_conc < b_conc: - return self.state.solver.BVV(-1, self.state.arch.bits) + return self.state.solver.BVV(-1, 32) else: - return self.state.solver.BVV(1, self.state.arch.bits) + return self.state.solver.BVV(1, 32) else: if self.state.mode == 'static': @@ -131,14 +131,14 @@ def run(self, a_addr, b_addr, limit, a_len=None, b_len=None, wchar=False, ignore if concrete_run: l.debug("concrete run made it to the end!") - return self.state.solver.BVV(0, self.state.arch.bits) + return self.state.solver.BVV(0, 32) if self.state.mode == 'static': ret_expr = self.state.solver.ESI(8) for expr in return_values: ret_expr = ret_expr.union(expr) - ret_expr = ret_expr.sign_extend(self.state.arch.bits - 8) + ret_expr = ret_expr.sign_extend(24) else: # make the constraints diff --git a/angr/procedures/libc/system.py b/angr/procedures/libc/system.py index c2252d7c0e5..c29feab3c0e 100644 --- a/angr/procedures/libc/system.py +++ b/angr/procedures/libc/system.py @@ -7,4 +7,4 @@ class system(angr.SimProcedure): #pylint:disable=arguments-differ,unused-argument def run(self, cmd): retcode = self.state.solver.Unconstrained('system_returncode', 8, key=('api', 'system')) - return retcode.zero_extend(self.state.arch.bits - 8) + return retcode.zero_extend(self.arch.sizeof['int'] - 8) diff --git a/angr/procedures/linux_kernel/futex.py b/angr/procedures/linux_kernel/futex.py index ff61a96f0ac..aa5e8468d30 100644 --- a/angr/procedures/linux_kernel/futex.py +++ b/angr/procedures/linux_kernel/futex.py @@ -16,4 +16,4 @@ def run(self, uaddr, futex_op, val, timeout, uaddr2, val3): return 0 else: l.debug('futex(futex_op=%d)', op) - return self.state.solver.Unconstrained("futex", self.state.arch.bits, key=('api', 'futex')) + return self.state.solver.Unconstrained("futex", self.arch.sizeof['int'], key=('api', 'futex')) diff --git a/angr/procedures/linux_kernel/getrlimit.py b/angr/procedures/linux_kernel/getrlimit.py index 6f2d13dcfff..5f2904b7d99 100644 --- a/angr/procedures/linux_kernel/getrlimit.py +++ b/angr/procedures/linux_kernel/getrlimit.py @@ -18,7 +18,7 @@ def run(self, resource, rlim): return 0 else: l.debug('running getrlimit(other)') - return self.state.solver.Unconstrained("rlimit", self.state.arch.bits, key=('api', 'rlimit', 'other')) + return self.state.solver.Unconstrained("rlimit", self.arch.sizeof['int'], key=('api', 'rlimit', 'other')) class ugetrlimit(getrlimit): pass diff --git a/angr/procedures/linux_kernel/munmap.py b/angr/procedures/linux_kernel/munmap.py index f6a4b04c527..e196efba686 100644 --- a/angr/procedures/linux_kernel/munmap.py +++ b/angr/procedures/linux_kernel/munmap.py @@ -4,4 +4,4 @@ class munmap(angr.SimProcedure): def run(self, addr, length): #pylint:disable=arguments-differ,unused-argument # TODO: actually do something - return self.state.solver.BVV(0, self.state.arch.bits) + return 0 diff --git a/angr/procedures/linux_kernel/sigaction.py b/angr/procedures/linux_kernel/sigaction.py index e770dd8ce2d..534947bb6e3 100644 --- a/angr/procedures/linux_kernel/sigaction.py +++ b/angr/procedures/linux_kernel/sigaction.py @@ -3,7 +3,7 @@ class sigaction(angr.SimProcedure): def run(self, signum, act, oldact): #pylint:disable=arguments-differ,unused-argument # TODO: actually do something - return self.state.solver.BVV(0, self.state.arch.bits) + return self.state.solver.BVV(0, self.arch.sizeof['int']) class rt_sigaction(angr.SimProcedure): def run(self, signum, act, oldact, sigsetsize): #pylint:disable=arguments-differ,unused-argument @@ -11,4 +11,4 @@ def run(self, signum, act, oldact, sigsetsize): #pylint:disable=arguments-differ # ...hack if self.state.solver.is_true(signum == 33): return self.state.libc.ret_errno('EINVAL') - return self.state.solver.BVV(0, self.state.arch.bits) + return self.state.solver.BVV(0, self.arch.sizeof['int']) diff --git a/angr/procedures/linux_kernel/sigprocmask.py b/angr/procedures/linux_kernel/sigprocmask.py index 66caf1a0908..902b1b3d16b 100644 --- a/angr/procedures/linux_kernel/sigprocmask.py +++ b/angr/procedures/linux_kernel/sigprocmask.py @@ -8,10 +8,12 @@ def run(self, how, set_, oldset, sigsetsize): self.state.posix.sigprocmask(how, self.state.memory.load(set_, sigsetsize), sigsetsize, valid_ptr=set_!=0) # TODO: EFAULT - return self.state.solver.If(self.state.solver.And( - how != self.state.posix.SIG_BLOCK, - how != self.state.posix.SIG_UNBLOCK, - how != self.state.posix.SIG_SETMASK), - self.state.solver.BVV(self.state.posix.EINVAL, self.state.arch.bits), - self.state.solver.BVV(0, self.state.arch.bits), + return self.state.solver.If( + self.state.solver.And( + how != self.state.posix.SIG_BLOCK, + how != self.state.posix.SIG_UNBLOCK, + how != self.state.posix.SIG_SETMASK + ), + self.state.solver.BVV(self.state.posix.EINVAL, self.arch.sizeof['int']), + 0, ) diff --git a/angr/procedures/linux_kernel/tgkill.py b/angr/procedures/linux_kernel/tgkill.py index 53837467a1f..8e16304b770 100644 --- a/angr/procedures/linux_kernel/tgkill.py +++ b/angr/procedures/linux_kernel/tgkill.py @@ -2,6 +2,6 @@ class tgkill(angr.SimProcedure): - def run(self, addr, length): #pylint:disable=arguments-differ,unused-argument + def run(self, tgid, tid, sig): #pylint:disable=arguments-differ,unused-argument # TODO: actually do something - return self.state.solver.BVV(0, self.state.arch.bits) + return self.state.solver.BVV(0, self.arch.sizeof['int']) diff --git a/angr/procedures/linux_kernel/time.py b/angr/procedures/linux_kernel/time.py index 71a388dd3c4..1d4708d13ac 100644 --- a/angr/procedures/linux_kernel/time.py +++ b/angr/procedures/linux_kernel/time.py @@ -14,6 +14,7 @@ def last_time(self, v): self.state.globals[self.KEY] = v def run(self, pointer): + # TODO lord have mercy. how big is time_t? if angr.options.USE_SYSTEM_TIMES in self.state.options: ts = int(_time.time()) result = self.state.solver.BVV(ts, self.state.arch.bits) diff --git a/angr/procedures/linux_loader/sim_loader.py b/angr/procedures/linux_loader/sim_loader.py index 35dc5872ad0..6a9037dbd82 100644 --- a/angr/procedures/linux_loader/sim_loader.py +++ b/angr/procedures/linux_loader/sim_loader.py @@ -20,7 +20,8 @@ def run_initializer(self): else: addr = self.initializers[0] self.initializers = self.initializers[1:] - self.call(addr, (self.state.posix.argc, self.state.posix.argv, self.state.posix.environ), 'run_initializer') + self.call(addr, (self.state.posix.argc, self.state.posix.argv, self.state.posix.environ), 'run_initializer', + prototype = 'int main(int argc, char **argv, char **envp)') class IFuncResolver(angr.SimProcedure): NO_RET = True @@ -29,10 +30,10 @@ class IFuncResolver(angr.SimProcedure): # pylint: disable=arguments-differ,unused-argument def run(self, funcaddr=None, gotaddr=None, funcname=None): self.saved_regs = {reg.name: self.state.registers.load(reg.name) for reg in self.arch.register_list if reg.argument} - self.call(funcaddr, (), continue_at='after_call') + self.call(funcaddr, (), continue_at='after_call', cc=self.cc, prototype='void *x()') def after_call(self, funcaddr=None, gotaddr=None, funcname=None): - value = self.cc.return_val.get_value(self.state) + value = self.cc.return_val(angr.sim_type.SimTypePointer(angr.sim_type.SimTypeBottom())).get_value(self.state) for name, val in self.saved_regs.items(): self.state.registers.store(name, val) diff --git a/angr/procedures/msvcr/_initterm.py b/angr/procedures/msvcr/_initterm.py index 141478d574e..f2252040244 100644 --- a/angr/procedures/msvcr/_initterm.py +++ b/angr/procedures/msvcr/_initterm.py @@ -35,4 +35,4 @@ def do_callbacks(self, fp_a, fp_z): # pylint:disable=unused-argument else: callback_addr = self.callbacks.pop(0) l.debug("Calling %#x", callback_addr) - self.call(callback_addr, [], continue_at='do_callbacks') + self.call(callback_addr, [], continue_at='do_callbacks', prototype='void x()') diff --git a/angr/procedures/ntdll/exceptions.py b/angr/procedures/ntdll/exceptions.py index 44f2f7a2c7b..5f5348e2a99 100644 --- a/angr/procedures/ntdll/exceptions.py +++ b/angr/procedures/ntdll/exceptions.py @@ -51,6 +51,6 @@ def dispatch(self, record, context): self.cur_ptr = next_ptr # as far as I can tell it doesn't actually matter whether the callback is stdcall or cdecl - self.call(func_ptr, (record, cur_ptr, context, 0xBADF00D), 'dispatch') + self.call(func_ptr, (record, cur_ptr, context, 0xBADF00D), 'dispatch', prototype='void x(int, int, int, int)') # bonus! after we've done the call, mutate the state even harder so ebp is pointing to some fake args self.successors.successors[0].regs.ebp = self.successors.successors[0].regs.esp - 4 diff --git a/angr/procedures/posix/bind.py b/angr/procedures/posix/bind.py index 44250721b56..1d77a8a685c 100644 --- a/angr/procedures/posix/bind.py +++ b/angr/procedures/posix/bind.py @@ -10,4 +10,4 @@ class bind(angr.SimProcedure): #pylint:disable=arguments-differ def run(self, fd, addr_ptr, addr_len): #pylint:disable=unused-argument - return self.state.solver.Unconstrained('bind', self.state.arch.bits, key=('api', 'bind')) + return self.state.solver.Unconstrained('bind', self.arch.sizeof['int'], key=('api', 'bind')) diff --git a/angr/procedures/posix/fcntl.py b/angr/procedures/posix/fcntl.py index 33ef7dc5dcf..a2a220697ea 100644 --- a/angr/procedures/posix/fcntl.py +++ b/angr/procedures/posix/fcntl.py @@ -9,4 +9,4 @@ class fcntl(angr.SimProcedure): def run(self, fd, cmd): # this is a stupid stub that does not do anything besides returning an unconstrained variable. - return self.state.solver.BVS('fcntl_retval', self.state.arch.bits) + return self.state.solver.BVS('sys_fcntl', self.arch.sizeof['int'], key=('api', 'fcntl')) diff --git a/angr/procedures/posix/fileno.py b/angr/procedures/posix/fileno.py index f2e20459fb4..4c5622dfe17 100644 --- a/angr/procedures/posix/fileno.py +++ b/angr/procedures/posix/fileno.py @@ -15,7 +15,6 @@ def run(self, f): io_file_data = io_file_data_for_arch(self.state.arch) # Get the file descriptor from FILE struct - result = self.state.mem[f + io_file_data['fd']].int.resolved - return result.sign_extend(self.arch.bits - len(result)) + return self.state.mem[f + io_file_data['fd']].int.resolved fileno_unlocked = fileno diff --git a/angr/procedures/posix/fork.py b/angr/procedures/posix/fork.py index 7e7c7177030..614e9f25ffc 100644 --- a/angr/procedures/posix/fork.py +++ b/angr/procedures/posix/fork.py @@ -3,5 +3,5 @@ class fork(angr.SimProcedure): def run(self): return self.state.solver.If(self.state.solver.BoolS('fork_parent'), - self.state.solver.BVV(1338, self.state.arch.bits), - self.state.solver.BVV(0, self.state.arch.bits)) + self.state.solver.BVV(1338, self.arch.sizeof['int']), + self.state.solver.BVV(0, self.arch.sizeof['int'])) diff --git a/angr/procedures/posix/listen.py b/angr/procedures/posix/listen.py index 6b39768fd9a..d86ff6d1988 100644 --- a/angr/procedures/posix/listen.py +++ b/angr/procedures/posix/listen.py @@ -10,5 +10,4 @@ class listen(angr.SimProcedure): #pylint:disable=arguments-differ def run(self, sockfd, backlog): #pylint:disable=unused-argument - return self.state.solver.Unconstrained('listen', self.state.arch.bits, key=('api', 'listen')) - + return self.state.solver.Unconstrained('listen', self.arch.sizeof['int'], key=('api', 'listen')) diff --git a/angr/procedures/posix/pthread.py b/angr/procedures/posix/pthread.py index ece630fff3e..ee96604b289 100644 --- a/angr/procedures/posix/pthread.py +++ b/angr/procedures/posix/pthread.py @@ -11,8 +11,8 @@ class pthread_create(angr.SimProcedure): # pylint: disable=unused-argument,arguments-differ def run(self, thread, attr, start_routine, arg): - self.call(start_routine, (arg,), 'terminate_thread') - self.ret(self.state.solver.BVV(0, self.state.arch.bits)) + self.call(start_routine, (arg,), 'terminate_thread', prototype='void *start_routine(void*)') + return 0 def terminate_thread(self, thread, attr, start_routine, arg): self.exit(0) @@ -30,9 +30,8 @@ def static_exits(self, blocks): else: break - cc = angr.DEFAULT_CC[self.arch.name](self.arch) - callfunc = cc.arg(state, 2) - retaddr = state.memory.load(state.regs.sp, self.arch.bytes) + callfunc = self.cc.get_args(state, self.prototype)[2] + retaddr = state.memory.load(state.regs.sp, size=self.arch.bytes) all_exits = [ {'address': callfunc, 'jumpkind': 'Ijk_Call', 'namehint': 'thread_entry'}, @@ -86,7 +85,7 @@ def run(self, control, func): controlword |= 2 self.state.mem[control].char = controlword - self.call(func, (), 'retsite') + self.call(func, (), 'retsite', prototype='void x()') def retsite(self, control, func): return 0 diff --git a/angr/procedures/posix/readdir.py b/angr/procedures/posix/readdir.py index 4d589897f24..6309f393090 100644 --- a/angr/procedures/posix/readdir.py +++ b/angr/procedures/posix/readdir.py @@ -21,7 +21,7 @@ def run(self, dirp): # pylint: disable=arguments-differ malloc = angr.SIM_PROCEDURES['libc']['malloc'] pointer = self.inline_call(malloc, 19 + 256).ret_expr self._store_amd64(pointer) - return self.state.solver.If(self.condition, pointer, self.state.solver.BVV(0, self.state.arch.bits)) + return self.state.solver.If(self.condition, pointer, 0) def instrument(self): """ diff --git a/angr/procedures/posix/syslog.py b/angr/procedures/posix/syslog.py index 04f0b01933f..ae95809d375 100644 --- a/angr/procedures/posix/syslog.py +++ b/angr/procedures/posix/syslog.py @@ -6,9 +6,9 @@ l.setLevel('INFO') class syslog(FormatParser): - def run(self, priority): - fmt = self._parse(1) - formatted = fmt.replace(2, self.arg) + def run(self, priority, fmt): + fmt = self._parse(fmt) + formatted = fmt.replace(self.va_arg) if not formatted.symbolic: formatted = self.state.solver.eval(formatted, cast_to=bytes) - l.info("Syslog priority %s: %s", priority, formatted) \ No newline at end of file + l.info("Syslog priority %s: %s", priority, formatted) diff --git a/angr/procedures/stubs/ReturnUnconstrained.py b/angr/procedures/stubs/ReturnUnconstrained.py index 2b01da9a4c0..40fccc9e9e9 100644 --- a/angr/procedures/stubs/ReturnUnconstrained.py +++ b/angr/procedures/stubs/ReturnUnconstrained.py @@ -13,7 +13,12 @@ def run(self, *args, **kwargs): #pylint:disable=arguments-differ return_val = kwargs.pop('return_val', None) if return_val is None: - o = self.state.solver.Unconstrained("unconstrained_ret_%s" % self.display_name, self.state.arch.bits, key=('api', '?', self.display_name)) + size = self.prototype.returnty.size + # ummmmm do we really want to rely on this behavior? + if size is NotImplemented: + o = None + else: + o = self.state.solver.Unconstrained("unconstrained_ret_%s" % self.display_name, size, key=('api', '?', self.display_name)) else: o = return_val diff --git a/angr/procedures/stubs/caller.py b/angr/procedures/stubs/caller.py index 661223a0922..43760e05656 100644 --- a/angr/procedures/stubs/caller.py +++ b/angr/procedures/stubs/caller.py @@ -11,7 +11,7 @@ class Caller(angr.SimProcedure): """ def run(self, target_addr=None, target_cc=None): - self.call(target_addr, [ ], 'after_call', cc=target_cc) + self.call(target_addr, [ ], 'after_call', cc=target_cc, prototype='void x()') def after_call(self, target_addr=None, target_cc=None): pass diff --git a/angr/procedures/stubs/format_parser.py b/angr/procedures/stubs/format_parser.py index d6310b7b481..f190695047d 100644 --- a/angr/procedures/stubs/format_parser.py +++ b/angr/procedures/stubs/format_parser.py @@ -57,16 +57,14 @@ def _get_str_at(self, str_addr, max_length=None): return self.parser.state.memory.load(str_addr, max_length) - def replace(self, startpos, args): + def replace(self, va_arg): """ Implement printf - based on the stored format specifier information, format the values from the arg getter function `args` into a string. - :param startpos: The index of the first argument to be used by the first element of the format string - :param args: A function which, given an argument index, returns the integer argument to the current function at that index + :param va_arg: A function which takes a type and returns the next argument of that type :return: The result formatted string """ - argpos = startpos string = None for component in self.components: @@ -83,15 +81,15 @@ def replace(self, startpos, args): fmt_spec = component if fmt_spec.spec_type == b's': if fmt_spec.length_spec == b".*": - str_length = args(argpos) - argpos += 1 + str_length = va_arg('size_t') else: str_length = None - str_ptr = args(argpos) + str_ptr = va_arg('char*') string = self._add_to_string(string, self._get_str_at(str_ptr, max_length=str_length)) # integers, for most of these we'll end up concretizing values.. else: - i_val = args(argpos) + # ummmmmmm this is a cheap translation but I think it should work + i_val = va_arg('void*') c_val = int(self.parser.state.solver.eval(i_val)) c_val &= (1 << (fmt_spec.size * 8)) - 1 if fmt_spec.signed and (c_val & (1 << ((fmt_spec.size * 8) - 1))): @@ -117,23 +115,20 @@ def replace(self, startpos, args): string = self._add_to_string(string, self.parser.state.solver.BVV(s_val.encode())) - argpos += 1 - return string - def interpret(self, startpos, args, addr=None, simfd=None): + def interpret(self, va_arg, addr=None, simfd=None): """ implement scanf - extract formatted data from memory or a file according to the stored format specifiers and store them into the pointers extracted from `args`. - :param startpos: The index of the first argument corresponding to the first format element - :param args: A function which, given the index of an argument to the function, returns that argument + :param va_arg: A function which, given a type, returns the next argument of that type :param addr: The address in the memory to extract data from, or... :param simfd: A file descriptor to use for reading data from :return: The number of arguments parsed """ + num_args = 0 if simfd is not None and isinstance(simfd.read_storage, SimPackets): - argnum = startpos for component in self.components: if type(component) is bytes: sdata, _ = simfd.read_data(len(component), short_reads=False) @@ -148,13 +143,14 @@ def interpret(self, startpos, args, addr=None, simfd=None): sdata, slen = simfd.read_data(component.length_spec) for byte in sdata.chop(8): self.state.add_constraints(claripy.And(*[byte != char for char in self.SCANF_DELIMITERS])) - self.state.memory.store(args(argnum), sdata, size=slen) - self.state.memory.store(args(argnum) + slen, claripy.BVV(0, 8)) - argnum += 1 + ptr = va_arg('char*') + self.state.memory.store(ptr, sdata, size=slen) + self.state.memory.store(ptr + slen, claripy.BVV(0, 8)) + num_args += 1 elif component.spec_type == b'c': sdata, _ = simfd.read_data(1, short_reads=False) - self.state.memory.store(args(argnum), sdata) - argnum += 1 + self.state.memory.store(va_arg('char*'), sdata) + num_args += 1 else: bits = component.size * 8 if component.spec_type == b'x': @@ -166,7 +162,7 @@ def interpret(self, startpos, args, addr=None, simfd=None): # here's the variable representing the result of the parsing target_variable = self.state.solver.BVS('scanf_' + component.string.decode(), bits, - key=('api', 'scanf', argnum - startpos, component.string)) + key=('api', 'scanf', num_args, component.string)) negative = claripy.SLT(target_variable, 0) # how many digits does it take to represent this variable fully? @@ -214,10 +210,11 @@ def interpret(self, startpos, args, addr=None, simfd=None): self.state.add_constraints(digit == digit_ascii[7:0]) - self.state.memory.store(args(argnum), target_variable, endness=self.state.arch.memory_endness) - argnum += 1 + # again, a cheap hack + self.state.memory.store(va_arg('void*'), target_variable, endness=self.state.arch.memory_endness) + num_args += 1 - return argnum - startpos + return num_args if simfd is not None: region = simfd.read_storage @@ -226,8 +223,7 @@ def interpret(self, startpos, args, addr=None, simfd=None): region = self.parser.state.memory bits = self.parser.state.arch.bits - failed = self.parser.state.solver.BVV(0, bits) - argpos = startpos + failed = self.parser.state.solver.BVV(0, 32) position = addr for component in self.components: if isinstance(component, bytes): @@ -237,7 +233,7 @@ def interpret(self, startpos, args, addr=None, simfd=None): else: fmt_spec = component try: - dest = args(argpos) + dest = va_arg('void*') except SimProcedureArgumentError: dest = None if fmt_spec.spec_type == b's': @@ -297,13 +293,13 @@ def interpret(self, startpos, args, addr=None, simfd=None): i = self.parser.state.solver.Extract(fmt_spec.size*8-1, 0, i) self.parser.state.memory.store(dest, i, size=fmt_spec.size, endness=self.parser.state.arch.memory_endness) - argpos += 1 + num_args += 1 if simfd is not None: _, realsize = simfd.read_data(position - addr) self.state.add_constraints(realsize == position - addr) - return (argpos - startpos) - failed + return num_args - failed def __repr__(self): outstr = "" @@ -569,17 +565,15 @@ def _sim_strlen(self, str_addr): return self.inline_call(strlen, str_addr).ret_expr - def _parse(self, fmt_idx): + def _parse(self, fmtstr_ptr): """ Parse format strings. - :param fmt_idx: The index of the (pointer to the) format string in the arguments list. + :param fmt_idx: The pointer to the format string from the arguments list. :returns: A FormatString object which can be used for replacing the format specifiers with arguments or for scanning into arguments. """ - fmtstr_ptr = self.arg(fmt_idx) - if self.state.solver.symbolic(fmtstr_ptr): raise SimProcedureError("Symbolic pointer to (format) string :(") diff --git a/angr/procedures/stubs/syscall_stub.py b/angr/procedures/stubs/syscall_stub.py index 3fc5b58dfa2..619519c0cb1 100644 --- a/angr/procedures/stubs/syscall_stub.py +++ b/angr/procedures/stubs/syscall_stub.py @@ -7,7 +7,7 @@ #pylint:disable=redefined-builtin,arguments-differ class syscall(angr.SimProcedure): - def run(self, resolves=None): + def run(self, *args, resolves=None): self.resolves = resolves # pylint:disable=attribute-defined-outside-init diff --git a/angr/sim_procedure.py b/angr/sim_procedure.py index 372191ed4c1..f65e2491323 100644 --- a/angr/sim_procedure.py +++ b/angr/sim_procedure.py @@ -1,4 +1,5 @@ import inspect +import typing import copy import itertools import logging @@ -8,6 +9,7 @@ if TYPE_CHECKING: import angr + import archinfo l = logging.getLogger(name=__name__) symbolic_count = itertools.count() @@ -60,7 +62,7 @@ class SimProcedure: :ivar kwargs: Any extra keyword arguments used to construct the procedure; will be passed to ``run`` :ivar display_name: See the eponymous parameter :ivar library_name: See the eponymous parameter - :ivar abi: + :ivar abi: If this is a syscall simprocedure, which ABI are we using to map the syscall numbers? :ivar symbolic_return: See the eponymous parameter :ivar syscall_number: If this procedure is a syscall, the number will be populated here. :ivar returns: See eponymous parameter and NO_RET cvar @@ -79,6 +81,8 @@ class SimProcedure: :ivar state: The SimState we should be mutating to perform the procedure :ivar successors: The SimSuccessors associated with the current step :ivar arguments: The function arguments, deserialized from the state + :ivar arg_session: The ArgSession that was used to parse arguments out of the state, in case you need it for + varargs :ivar use_state_arguments: Whether we're using arguments extracted from the state or manually provided :ivar ret_to: The current return address @@ -86,19 +90,24 @@ class SimProcedure: :ivar call_ret_expr: The return value from having used ``self.call()`` :ivar inhibit_autoret: Whether we should avoid automatically adding an exit for returning once the run function ends + :ivar arg_session: The ArgSession object that was used to extract the runtime argument values. Useful for if + you want to extract variadic args. """ def __init__( - self, project=None, cc=None, symbolic_return=None, + self, project=None, cc=None, prototype=None, symbolic_return=None, returns=None, is_syscall=False, is_stub=False, num_args=None, display_name=None, library_name=None, is_function=None, **kwargs ): # WE'LL FIGURE IT OUT self.project = project # type: angr.Project - self.arch = project.arch if project is not None else None + self.arch = project.arch if project is not None else None # type: archinfo.arch.Arch self.addr = None self.cc = cc # type: angr.SimCC + if type(prototype) is str: + prototype = parse_signature(prototype) + self.prototype = prototype # type: angr.sim_type.SimTypeFunction self.canonical = self self.kwargs = kwargs @@ -124,6 +133,10 @@ def __init__( else: self.num_args = num_args + if self.prototype is None: + charp = SimTypePointer(SimTypeChar()) + self.prototype = SimTypeFunction([charp] * self.num_args, charp) + # runtime values self.state = None self.successors = None @@ -133,6 +146,7 @@ def __init__( self.ret_expr = None self.call_ret_expr = None self.inhibit_autoret = None + self.arg_session: typing.Union[None, ArgSession, int] = None def __repr__(self): return "" % self._describe_me() @@ -169,6 +183,8 @@ def execute(self, state, successors=None, arguments=None, ret_to=None): else: raise SimProcedureError('There is no default calling convention for architecture %s.' ' You must specify a calling convention.' % self.arch.name) + if self.prototype._arch is None: + self.prototype = self.prototype.with_arch(self.arch) inst = copy.copy(self) inst.state = state @@ -218,12 +234,14 @@ def execute(self, state, successors=None, arguments=None, ret_to=None): else: if arguments is None: inst.use_state_arguments = True - sim_args = [ inst.arg(_) for _ in range(inst.num_args) ] + inst.arg_session = inst.cc.arg_session(inst.prototype.returnty) + sim_args = [inst.cc.next_arg(inst.arg_session, ty).get_value(inst.state) for ty in inst.prototype.args] inst.arguments = sim_args else: inst.use_state_arguments = False sim_args = arguments[:inst.num_args] inst.arguments = arguments + inst.arg_session = 0 # run it l.debug("Executing %s%s%s%s%s with %s, %s", *(inst._describe_me() + (sim_args, inst.kwargs))) @@ -309,30 +327,26 @@ def _compute_ret_addr(self, expr): #pylint:disable=unused-argument,no-self-use raise SimProcedureError("the java-specific _compute_ret_addr() method was invoked on a non-Java SimProcedure.") def set_args(self, args): - arg_session = self.cc.arg_session - for arg in args: - if self.cc.is_fp_value(args): - arg_session.next_arg(True).set_value(self.state, arg) - else: - arg_session.next_arg(False).set_value(self.state, arg) + arg_session = self.cc.arg_session(self.prototype.returnty) + for arg, ty in zip(args, self.prototype.args): + self.cc.next_arg(arg_session, ty).set_value(self.state, arg) - def arg(self, i): - """ - Returns the ith argument. Raise a SimProcedureArgumentError if we don't have such an argument available. + def va_arg(self, ty, index=None): + if not self.use_state_arguments: + if index is not None: + return self.arguments[self.num_args + index] + + result = self.arguments[self.num_args + self.arg_session] + self.arg_session += 1 + return result - :param int i: The index of the argument to get - :return: The argument - :rtype: object - """ - if self.use_state_arguments: - r = self.cc.arg(self.state, i) - else: - if i >= len(self.arguments): - raise SimProcedureArgumentError("Argument %d does not exist." % i) - r = self.arguments[i] # pylint: disable=unsubscriptable-object - l.debug("returning argument") - return r + if index is not None: + raise Exception("you think you're so fucking smart? you implement this logic then") + + if type(ty) is str: + ty = parse_type(ty, arch=self.arch) + return self.cc.next_arg(self.arg_session, ty).get_value(self.state) # # Control Flow @@ -385,15 +399,7 @@ def ret(self, expr=None): if isinstance(self.addr, SootAddressDescriptor): ret_addr = self._compute_ret_addr(expr) #pylint:disable=assignment-from-no-return elif self.use_state_arguments: - if self.cc.args is not None: - arg_types = [isinstance(arg, (SimTypeFloat, SimTypeDouble)) for arg in self.cc.args] - else: - # fall back to using self.num_args - arg_types = [False] * self.num_args - ret_addr = self.cc.teardown_callsite( - self.state, - expr, - arg_types=arg_types) + ret_addr = self.cc.teardown_callsite(self.state, expr, prototype=self.prototype) if not self.should_add_successors: l.debug("Returning without setting exits due to 'internal' call.") @@ -411,7 +417,7 @@ def ret(self, expr=None): self.successors.add_successor(self.state, ret_addr, self.state.solver.true, 'Ijk_Ret') - def call(self, addr, args, continue_at, cc=None): + def call(self, addr, args, continue_at, cc=None, prototype=None): """ Add an exit representing calling another function via pointer. @@ -421,11 +427,13 @@ def call(self, addr, args, continue_at, cc=None): procedure will continue in the named method. :param cc: Optional: use this calling convention for calling the new function. Default is to use the current convention. + :param prototype: Optional: The prototype to use for the call. Will default to all-ints. """ self.inhibit_autoret = True if cc is None: cc = self.cc + prototype = cc.guess_prototype(args, prototype) call_state = self.state.copy() ret_addr = self.make_continuation(continue_at) @@ -435,7 +443,7 @@ def call(self, addr, args, continue_at, cc=None): saved_local_vars, self.state.regs.lr if self.state.arch.lr_offset is not None else None, ret_addr) - cc.setup_callsite(call_state, ret_addr, args) + cc.setup_callsite(call_state, ret_addr, args, prototype) call_state.callstack.top.procedure_data = simcallstack_entry # TODO: Move this to setup_callsite? @@ -453,7 +461,7 @@ def call(self, addr, args, continue_at, cc=None): if o.DO_RET_EMULATION in self.state.options: # we need to set up the call because the continuation will try to tear it down ret_state = self.state.copy() - cc.setup_callsite(ret_state, ret_addr, args) + cc.setup_callsite(ret_state, ret_addr, args, prototype) ret_state.callstack.top.procedure_data = simcallstack_entry guard = ret_state.solver.true if o.TRUE_RET_EMULATION_GUARD in ret_state.options else ret_state.solver.false self.successors.add_successor(ret_state, ret_addr, guard, 'Ijk_FakeRet') @@ -504,7 +512,7 @@ def argument_types(self): # pylint: disable=no-self-use @argument_types.setter def argument_types(self, v): # pylint: disable=unused-argument,no-self-use - l.critical("SimProcedure.argument_types is deprecated. specify the function signature in the cc") + l.critical("SimProcedure.argument_types is deprecated. specify the function signature in the prototype param") @property def return_type(self): # pylint: disable=no-self-use @@ -512,12 +520,12 @@ def return_type(self): # pylint: disable=no-self-use @return_type.setter def return_type(self, v): # pylint: disable=unused-argument,no-self-use - l.critical("SimProcedure.return_type is deprecated. specify the function signature in the cc") + l.critical("SimProcedure.return_type is deprecated. specify the function signature in the prototype param") from . import sim_options as o from angr.errors import SimProcedureError, SimProcedureArgumentError, SimShadowStackError -from angr.sim_type import SimTypePointer from angr.state_plugins.sim_action import SimActionExit -from angr.calling_conventions import DEFAULT_CC, SimTypeFloat, SimTypeDouble +from angr.calling_conventions import DEFAULT_CC, SimTypeFloat, SimTypeFunction, SimTypePointer, SimTypeChar, ArgSession from .state_plugins import BP_AFTER, BP_BEFORE, NO_OVERRIDE +from .sim_type import parse_signature, parse_type diff --git a/angr/sim_type.py b/angr/sim_type.py index b9390d552e9..37d41fcf0ad 100644 --- a/angr/sim_type.py +++ b/angr/sim_type.py @@ -119,6 +119,15 @@ def c_repr(self, name=None, full=0, memo=None, indent=0): def copy(self): raise NotImplementedError() + def extract_claripy(self, bits): + """ + Given a bitvector `bits` which was loaded from memory in a big-endian fashion, return a more appropriate or + structured representation of the data. + + A type must have an arch associated in order to use this method. + """ + raise NotImplementedError(f"extract_claripy is not implemented for {self}") + class TypeRef(SimType): """ A TypeRef is a reference to a type with a name. This allows for interactivity in type analysis, by storing a type @@ -1067,6 +1076,10 @@ def __init__(self, fields: Union[Dict[str,SimType], OrderedDict], name=None, pac self._arch_memo = {} + @property + def packed(self): + return self._pack + @property def offsets(self) -> Dict[str,int]: offsets = {} @@ -1209,8 +1222,15 @@ def __init__(self, struct, values=None): :param values: A mapping from struct fields to values """ self._struct = struct + # since the keys are specified, also support specifying the values as just a list + if values is not None and type(values) is not dict and hasattr(values, '__iter__'): + values = dict(zip(struct.fields.keys(), values)) self._values = defaultdict(lambda: None, values or ()) + @property + def struct(self): + return self._struct + def __indented_repr__(self, indent=0): fields = [] for name in self._struct.fields: @@ -1651,6 +1671,22 @@ def do_preprocess(defn, include_path=()): return ''.join(tok.value for tok in p.parser if tok.type not in p.ignore) +def parse_signature(defn, preprocess=True, predefined_types=None, arch=None): + """ + Parse a single function prototype and return its type + """ + try: + parsed = parse_file( + defn.strip(' \n\t;') + ';', + preprocess=preprocess, + predefined_types=predefined_types, + arch=arch + ) + return next(iter(parsed[0].values())) + except StopIteration as e: + raise ValueError("No declarations found") from e + + def parse_defns(defn, preprocess=True, predefined_types=None, arch=None): """ Parse a series of C definitions, returns a mapping from variable name to variable type object @@ -1688,7 +1724,7 @@ def parse_file(defn, preprocess=True, predefined_types=None, arch=None): if isinstance(piece, pycparser.c_ast.FuncDef): out[piece.decl.name] = _decl_to_type(piece.decl.type, extra_types, arch=arch) elif isinstance(piece, pycparser.c_ast.Decl): - ty = _decl_to_type(piece.type, extra_types) + ty = _decl_to_type(piece.type, extra_types, arch=arch) if piece.name is not None: out[piece.name] = ty @@ -1702,7 +1738,7 @@ def parse_file(defn, preprocess=True, predefined_types=None, arch=None): i.members = ty.members elif isinstance(piece, pycparser.c_ast.Typedef): - extra_types[piece.name] = copy.copy(_decl_to_type(piece.type, extra_types)) + extra_types[piece.name] = copy.copy(_decl_to_type(piece.type, extra_types, arch=arch)) extra_types[piece.name].label = piece.name return out, extra_types diff --git a/angr/simos/javavm.py b/angr/simos/javavm.py index b8fafc69ce7..a06b2cebca6 100644 --- a/angr/simos/javavm.py +++ b/angr/simos/javavm.py @@ -6,6 +6,7 @@ SootNullConstant) from claripy import BVS, BVV, StringS, StringV, FSORT_FLOAT, FSORT_DOUBLE, FPV, FPS from claripy.ast.fp import FP, fpToIEEEBV +from claripy.ast.bv import BV from ..calling_conventions import DEFAULT_CC, SimCCSoot from ..engines.soot import SootMixin @@ -17,7 +18,7 @@ from ..errors import AngrSimOSError from ..procedures.java_jni import jni_functions from ..sim_state import SimState -from ..sim_type import SimTypeFunction, SimTypeReg +from ..sim_type import SimTypeFunction, SimTypeNum from .simos import SimOS l = logging.getLogger('angr.simos.JavaVM') @@ -212,14 +213,29 @@ def state_call(self, addr, *args, **kwargs): else: state = state.copy() state.regs.ip = addr - cc.setup_callsite(state, ret_addr, args) + cc.setup_callsite(state, ret_addr, args, kwargs.pop('prototype', None)) return state else: # NATIVE CALLSITE + + # setup native return type + # TODO roll this into protytype + ret_type = kwargs.pop('ret_type') + native_ret_type = self.get_native_type(ret_type) + + # setup function prototype, so the SimCC know how to init the callsite + prototype = kwargs.pop('prototype', None) + if prototype is None: + arg_types = [self.get_native_type(arg.type) for arg in args] + prototype = SimTypeFunction(args=arg_types, returnty=native_ret_type) + native_cc = kwargs.pop('cc', None) + if native_cc is None: + native_cc = self.get_native_cc() + # setup native argument values native_arg_values = [] - for arg in args: + for arg, arg_ty in zip(args, prototype.args): if arg.type in ArchSoot.primitive_types or \ arg.type == "JNIEnv": # the value of primitive types and the JNIEnv pointer @@ -228,7 +244,9 @@ def state_call(self, addr, *args, **kwargs): if self.arch.bits == 32 and arg.type == "long": # On 32 bit architecture, long values (w/ 64 bit) are copied # as two 32 bit integer - # TODO is this correct? + # TODO I _think_ all this logic can go away as long as the cc knows how to store large values + # TODO this has been mostly implemented 11 Dec 2021 + # unfortunately no test cases hit this branch so I don't wanna touch it :( upper = native_arg_value.get_bytes(0, 4) lower = native_arg_value.get_bytes(4, 4) idx = args.index(arg) @@ -237,6 +255,9 @@ def state_call(self, addr, *args, **kwargs): + args[idx+1:] native_arg_values += [upper, lower] continue + if type(arg.value) is BV and len(arg.value) > arg_ty.size: + # hack??? all small primitives are passed around as 32bit but cc won't like that + native_arg_value = native_arg_value[arg_ty.size - 1:0] else: # argument has a relative type # => map Java reference to an opaque reference, which the native code @@ -244,20 +265,11 @@ def state_call(self, addr, *args, **kwargs): native_arg_value = state.jni_references.create_new_reference(obj=arg.value) native_arg_values += [native_arg_value] - # setup native return type - ret_type = kwargs.pop('ret_type') - native_ret_type = self.get_native_type(ret_type) - - # setup function prototype, so the SimCC know how to init the callsite - arg_types = [self.get_native_type(arg.type) for arg in args] - prototype = SimTypeFunction(args=arg_types, returnty=native_ret_type) - native_cc = self.get_native_cc(func_ty=prototype) - # setup native invoke state return self.native_simos.state_call(addr, *native_arg_values, base_state=state, ret_addr=self.native_return_hook_addr, - cc=native_cc, **kwargs) + cc=native_cc, prototype=prototype, **kwargs) # # MISC @@ -269,9 +281,9 @@ def get_default_value_by_type(type_, state): Java specify defaults values for primitive and reference types. This method returns the default value for a given type. - :param str type_: Name of type. - :param str state: Current SimState. - :return: Default value for this type. + :param str type_: Name of type. + :param SimState state: Current SimState. + :return: Default value for this type. """ if options.ZERO_FILL_UNCONSTRAINED_MEMORY not in state.options: return SimJavaVM._get_default_symbolic_value_by_type(type_, state) @@ -420,7 +432,10 @@ def get_native_type(self, java_type): else: # if it's not a primitive type, we treat it as a reference jni_type_size = self.native_simos.arch.bits - return SimTypeReg(size=jni_type_size) + return SimTypeNum(size=jni_type_size) + + def get_method_native_type(self, method): + return SimTypeFunction @property def native_arch(self): @@ -429,12 +444,12 @@ def native_arch(self): """ return self.native_simos.arch - def get_native_cc(self, func_ty=None): + def get_native_cc(self): """ :return: SimCC object for the native simos. """ native_cc_cls = DEFAULT_CC[self.native_simos.arch.name] - return native_cc_cls(self.native_simos.arch, func_ty=func_ty) + return native_cc_cls(self.native_simos.arch) def prepare_native_return_state(native_state): """ diff --git a/angr/simos/linux.py b/angr/simos/linux.py index 5ae027fd8e1..8a0430db1aa 100644 --- a/angr/simos/linux.py +++ b/angr/simos/linux.py @@ -257,9 +257,9 @@ def state_entry(self, args=None, env=None, argc=None, **kwargs): # Prepare argc if argc is None: - argc = claripy.BVV(len(args), state.arch.bits) + argc = claripy.BVV(len(args), 32) elif type(argc) is int: # pylint: disable=unidiomatic-typecheck - argc = claripy.BVV(argc, state.arch.bits) + argc = claripy.BVV(argc, 32) # Make string table for args/env/auxv table = StringTableSpec() @@ -373,6 +373,14 @@ def set_entry_register_values(self, state): else: _l.error('What the ass kind of default value is %s?', val) + if state.arch.name == 'PPC64': + # store argc at the top of the stack if the program is statically linked, otherwise 0 + # see sysdeps/powerpc/powerpc64/dl-machine.h, _dl_start_user + #stack_top = state.posix.argc.sign_extend(32) if state.project.loader.linux_loader_object is None else 0 + # UMMMMMM actually nvm we're going to lie about it + stack_top = state.posix.argc.sign_extend(32) + state.mem[state.regs.sp].qword = stack_top + def state_full_init(self, **kwargs): kwargs['addr'] = self._loader_addr return super(SimLinux, self).state_full_init(**kwargs) diff --git a/angr/simos/simos.py b/angr/simos/simos.py index 13ee143d691..dde5484e462 100644 --- a/angr/simos/simos.py +++ b/angr/simos/simos.py @@ -58,7 +58,12 @@ def irelative_resolver(resolver_addr): base_state = self.state_blank(addr=0, add_options={o.SYMBOL_FILL_UNCONSTRAINED_MEMORY, o.SYMBOL_FILL_UNCONSTRAINED_REGISTERS}) - resolver = self.project.factory.callable(resolver_addr, concrete_only=True, base_state=base_state) + prototype = 'void *x(long)' if isinstance(self.arch, ArchS390X) else 'void *x(void)' + resolver = self.project.factory.callable( + resolver_addr, + concrete_only=True, + base_state=base_state, + prototype=prototype) try: if isinstance(self.arch, ArchS390X): # On s390x ifunc resolvers expect hwcaps. @@ -246,6 +251,7 @@ def state_call(self, addr, *args, **kwargs): stack_base = kwargs.pop('stack_base', None) alloc_base = kwargs.pop('alloc_base', None) grow_like_stack = kwargs.pop('grow_like_stack', True) + prototype = angr.calling_conventions.SimCC.guess_prototype(args, kwargs.pop('prototype', None)).with_arch(self.arch) if state is None: if stack_base is not None: @@ -254,7 +260,7 @@ def state_call(self, addr, *args, **kwargs): else: state = state.copy() state.regs.ip = addr - cc.setup_callsite(state, ret_addr, args, stack_base, alloc_base, grow_like_stack) + cc.setup_callsite(state, ret_addr, args, prototype, stack_base, alloc_base, grow_like_stack) if state.arch.name == 'PPC64' and toc is not None: state.regs.r2 = toc diff --git a/angr/simos/userland.py b/angr/simos/userland.py index 997199d1dca..35e0faa6dff 100644 --- a/angr/simos/userland.py +++ b/angr/simos/userland.py @@ -81,8 +81,6 @@ def syscall(self, state, allow_unsupported=True): raise AngrUnsupportedSyscallError("Got a symbolic syscall number") proc = self.syscall_from_number(num, allow_unsupported=allow_unsupported, abi=abi) - if proc.cc is not None: - cc.func_ty = proc.cc.func_ty proc.cc = cc return proc diff --git a/tests/test_callable.py b/tests/test_callable.py index 6b5c3b79b23..701c9c33f5c 100644 --- a/tests/test_callable.py +++ b/tests/test_callable.py @@ -38,8 +38,7 @@ def run_fauxware(arch): p = angr.Project(os.path.join(location, 'tests', arch, 'fauxware')) charstar = SimTypePointer(SimTypeChar()) prototype = SimTypeFunction((charstar, charstar), SimTypeInt(False)) - cc = p.factory.cc(func_ty=prototype) - authenticate = p.factory.callable(addr, toc=0x10018E80 if arch == 'ppc64' else None, concrete_only=True, cc=cc) + authenticate = p.factory.callable(addr, toc=0x10018E80 if arch == 'ppc64' else None, concrete_only=True, prototype=prototype) nose.tools.assert_equal(authenticate("asdf", "SOSNEAKY")._model_concrete.value, 1) nose.tools.assert_raises(AngrCallableMultistateError, authenticate, "asdf", "NOSNEAKY") @@ -47,8 +46,7 @@ def run_fauxware(arch): def run_callable_c_fauxware(arch): addr = addresses_fauxware[arch] p = angr.Project(os.path.join(location, 'tests', arch, 'fauxware')) - cc = p.factory.cc(func_ty="int f(char*, char*)") - authenticate = p.factory.callable(addr, toc=0x10018E80 if arch == 'ppc64' else None, concrete_only=True, cc=cc) + authenticate = p.factory.callable(addr, toc=0x10018E80 if arch == 'ppc64' else None, concrete_only=True, prototype="int f(char*, char*)") retval = authenticate.call_c('("asdf", "SOSNEAKY")') nose.tools.assert_equal(retval._model_concrete.value, 1) nose.tools.assert_raises(AngrCallableMultistateError, authenticate, "asdf", "NOSNEAKY") @@ -59,8 +57,7 @@ def run_manysum(arch): p = angr.Project(os.path.join(location, 'tests', arch, 'manysum')) inttype = SimTypeInt() prototype = SimTypeFunction([inttype]*11, inttype) - cc = p.factory.cc(func_ty=prototype) - sumlots = p.factory.callable(addr, cc=cc) + sumlots = p.factory.callable(addr, prototype=prototype) result = sumlots(1,2,3,4,5,6,7,8,9,10,11) nose.tools.assert_false(result.symbolic) nose.tools.assert_equal(result._model_concrete.value, sum(range(12))) @@ -69,8 +66,7 @@ def run_manysum(arch): def run_callable_c_manysum(arch): addr = addresses_manysum[arch] p = angr.Project(os.path.join(location, 'tests', arch, 'manysum')) - cc = p.factory.cc(func_ty="int f(int, int, int, int, int, int, int, int, int, int, int)") - sumlots = p.factory.callable(addr, cc=cc) + sumlots = p.factory.callable(addr, prototype="int f(int, int, int, int, int, int, int, int, int, int, int)") result = sumlots.call_c("(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)") nose.tools.assert_false(result.symbolic) nose.tools.assert_equal(result._model_concrete.value, sum(range(12))) @@ -85,11 +81,10 @@ def run_manyfloatsum(arch): p = angr.Project(os.path.join(location, 'tests', arch, 'manyfloatsum')) for function in ('sum_floats', 'sum_combo', 'sum_segregated', 'sum_doubles', 'sum_combo_doubles', 'sum_segregated_doubles'): - cc = p.factory.cc(func_ty=type_cache[function]) - args = list(range(len(cc.func_ty.args))) + args = list(range(len(type_cache[function].args))) answer = float(sum(args)) addr = p.loader.main_object.get_symbol(function).rebased_addr - my_callable = p.factory.callable(addr, cc=cc) + my_callable = p.factory.callable(addr, prototype=type_cache[function]) result = my_callable(*args) nose.tools.assert_false(result.symbolic) result_concrete = result.args[0] @@ -104,10 +99,9 @@ def run_manyfloatsum_symbolic(arch): p = angr.Project(os.path.join(location, 'tests', arch, 'manyfloatsum')) function = 'sum_doubles' - cc = p.factory.cc(func_ty=type_cache[function]) args = [claripy.FPS('arg_%d' % i, claripy.FSORT_DOUBLE) for i in range(len(type_cache[function].args))] addr = p.loader.main_object.get_symbol(function).rebased_addr - my_callable = p.factory.callable(addr, cc=cc) + my_callable = p.factory.callable(addr, prototype=type_cache[function]) result = my_callable(*args) nose.tools.assert_true(result.symbolic) @@ -156,12 +150,12 @@ def test_callable_c_manyfloatsum(): def test_setup_callsite(): p = angr.load_shellcode(b'b', arch=archinfo.ArchX86()) - s = p.factory.call_state(0, "hello", stack_base=0x1234, alloc_base=0x5678, grow_like_stack=False) + s = p.factory.call_state(0, "hello", prototype='void x(char*)', stack_base=0x1234, alloc_base=0x5678, grow_like_stack=False) assert (s.regs.sp == 0x1234).is_true() assert (s.mem[0x1234 + 4].long.resolved == 0x5678).is_true() assert (s.memory.load(0x5678, 5) == b'hello').is_true() - s = p.factory.call_state(0, "hello", stack_base=0x1234) + s = p.factory.call_state(0, "hello", prototype='void x(char*)', stack_base=0x1234) assert (s.regs.sp == 0x1234).is_true() assert (s.mem[0x1234 + 4].long.resolved == 0x1234 + 8).is_true() assert (s.memory.load(0x1234 + 8, 5) == b'hello').is_true() @@ -169,10 +163,10 @@ def test_setup_callsite(): if __name__ == "__main__": - print('testing manyfloatsum with symbolic arguments') - for func, march in test_manyfloatsum_symbolic(): - print('* testing ' + march) - func(march) + #print('testing manyfloatsum with symbolic arguments') + #for func, march in test_manyfloatsum_symbolic(): + # print('* testing ' + march) + # func(march) print('testing manyfloatsum') for func, march in test_manyfloatsum(): print('* testing ' + march) diff --git a/tests/test_calling_convention_analysis.py b/tests/test_calling_convention_analysis.py index 51c66415035..73f22bc7f6b 100644 --- a/tests/test_calling_convention_analysis.py +++ b/tests/test_calling_convention_analysis.py @@ -46,19 +46,10 @@ def test_fauxware(): args = { 'i386': [ - ('authenticate', SimCCCdecl( - archinfo.arch_from_id('i386'), - args=[SimStackArg(4, 4), SimStackArg(8, 4)], sp_delta=4, ret_val=SimRegArg('eax', 4), - ) - ), + ('authenticate', SimCCCdecl( archinfo.arch_from_id('i386'), ) ), ], 'x86_64': [ - ('authenticate', SimCCSystemVAMD64( - amd64, - args=[SimRegArg('rdi', 8), SimRegArg('rsi', 8)], - sp_delta=8, - ret_val=SimRegArg('rax', 8), - ) + ('authenticate', SimCCSystemVAMD64( amd64, ) ), ], } @@ -107,7 +98,8 @@ def check_args(func_name, args, expected_arg_strs): def _a(funcs, func_name): - return funcs[func_name].calling_convention.args + func = funcs[func_name] + return func.calling_convention.arg_locs(func.prototype) def test_x8664_dir_gcc_O0(): @@ -189,10 +181,11 @@ def test_x8664_void(): if func.name in groundtruth: r = groundtruth[func.name] if r is None: - assert func.calling_convention.ret_val is None + assert func.prototype.returnty is None else: - assert isinstance(func.calling_convention.ret_val, SimRegArg) - assert func.calling_convention.ret_val.reg_name == r + ret_val = func.calling_convention.return_val(func.prototype.returnty) + assert isinstance(ret_val, SimRegArg) + assert ret_val.reg_name == r def test_x86_saved_regs(): @@ -208,27 +201,30 @@ def test_x86_saved_regs(): proj.analyses.VariableRecoveryFast(func) cca = proj.analyses.CallingConvention(func) cc = cca.cc + prototype = cca.prototype assert cc is not None, "Calling convention analysis failed to determine the calling convention of function " \ "0x80494f0." assert isinstance(cc, SimCCCdecl) - assert len(cc.args) == 3 - assert cc.args[0] == SimStackArg(4, 4) - assert cc.args[1] == SimStackArg(8, 4) - assert cc.args[2] == SimStackArg(12, 4) + assert len(prototype.args) == 3 + arg_locs = cc.arg_locs(prototype) + assert arg_locs[0] == SimStackArg(4, 4) + assert arg_locs[1] == SimStackArg(8, 4) + assert arg_locs[2] == SimStackArg(12, 4) func_exit = cfg.functions[0x804a1a9] # exit proj.analyses.VariableRecoveryFast(func_exit) cca = proj.analyses.CallingConvention(func_exit) cc = cca.cc + prototype = cca.prototype assert func_exit.returning is False assert cc is not None, "Calling convention analysis failed to determine the calling convention of function " \ "0x804a1a9." assert isinstance(cc, SimCCCdecl) - assert len(cc.args) == 1 - assert cc.args[0] == SimStackArg(4, 4) + assert len(prototype.args) == 1 + assert cc.arg_locs(prototype)[0] == SimStackArg(4, 4) def test_callsite_inference_amd64(): @@ -241,7 +237,7 @@ def test_callsite_inference_amd64(): func = cfg.functions.function(name='mosquitto_publish', plt=True) cca = proj.analyses.CallingConvention(func) - assert len(cca.cc.args) == 6 + assert len(cca.prototype.args) == 6 def run_all(): diff --git a/tests/test_callsite_maker.py b/tests/test_callsite_maker.py index 67f45381e51..37f6ecf0ffd 100644 --- a/tests/test_callsite_maker.py +++ b/tests/test_callsite_maker.py @@ -23,6 +23,7 @@ def test_callsite_maker(): cc_analysis = project.analyses.CallingConvention(func) if cc_analysis.cc is not None: func.calling_convention = cc_analysis.cc + func.prototype = cc_analysis.prototype new_cc_found = True main_func = cfg.kb.functions['main'] diff --git a/tests/test_cc.py b/tests/test_cc.py deleted file mode 100644 index 9d2b38d3f61..00000000000 --- a/tests/test_cc.py +++ /dev/null @@ -1,46 +0,0 @@ -import nose -from angr import SimState, SIM_PROCEDURES - -FAKE_ADDR = 0x100000 - -def test_calling_conventions(): - - # - # SimProcedures - # - - from angr.calling_conventions import SimCCCdecl, SimCCMicrosoftFastcall - - args = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 1000, 100000, 1000000, 2000000, 14, 15, 16 ] - arches = [ - ('X86', SimCCCdecl), - ('X86', SimCCMicrosoftFastcall), - ('AMD64', None), - ('ARMEL', None), - ('MIPS32', None), - ('PPC32', None), - ('PPC64', None), - ] - - # x86, cdecl - for arch, cc in arches: - s = SimState(arch=arch) - for reg, val, _, _ in s.arch.default_register_values: - s.registers.store(reg, val) - - if cc is not None: - manyargs = SIM_PROCEDURES['testing']['manyargs'](cc=cc(s.arch)).execute(s) - else: - manyargs = SIM_PROCEDURES['testing']['manyargs']().execute(s) - - # Simulate a call - if s.arch.call_pushes_ret: - s.regs.sp = s.regs.sp + s.arch.stack_change - manyargs.set_args(args) - - - for index, arg in enumerate(args): - nose.tools.assert_true(s.solver.is_true(manyargs.arg(index) == arg)) - -if __name__ == '__main__': - test_calling_conventions() diff --git a/tests/test_decompiler.py b/tests/test_decompiler.py index 67ff437a836..d58086e5e0a 100644 --- a/tests/test_decompiler.py +++ b/tests/test_decompiler.py @@ -362,7 +362,9 @@ def test_decompiling_1after909_verify_password(): f = cfg.functions['verify_password'] # recover calling convention p.analyses.VariableRecoveryFast(f) - f.calling_convention = p.analyses.CallingConvention(f).cc + cca = p.analyses.CallingConvention(f) + f.calling_convention = cca.cc + f.prototype = cca.prototype dec = p.analyses.Decompiler(f, cfg=cfg.model) if dec.codegen is None: print("Failed to decompile function %r." % f) @@ -485,6 +487,7 @@ def test_decompiling_strings_local_strlen(): _ = p.analyses.VariableRecoveryFast(func) cca = p.analyses.CallingConvention(func, cfg=cfg.model) func.calling_convention = cca.cc + func.prototype = cca.prototype dec = p.analyses.Decompiler(func, cfg=cfg.model) assert dec.codegen is not None, "Failed to decompile function %r." % func @@ -506,6 +509,7 @@ def test_decompiling_strings_local_strcat(): _ = p.analyses.VariableRecoveryFast(func) cca = p.analyses.CallingConvention(func, cfg=cfg.model) func.calling_convention = cca.cc + func.prototype = cca.prototype dec = p.analyses.Decompiler(func, cfg=cfg.model) assert dec.codegen is not None, "Failed to decompile function %r." % func @@ -527,6 +531,7 @@ def test_decompiling_strings_local_strcat_with_local_strlen(): _ = p.analyses.VariableRecoveryFast(func_strlen) cca = p.analyses.CallingConvention(func_strlen, cfg=cfg.model) func_strlen.calling_convention = cca.cc + func_strlen.prototype = cca.prototype p.analyses.Decompiler(func_strlen, cfg=cfg.model) func = cfg.functions['local_strcat'] @@ -534,6 +539,7 @@ def test_decompiling_strings_local_strcat_with_local_strlen(): _ = p.analyses.VariableRecoveryFast(func) cca = p.analyses.CallingConvention(func, cfg=cfg.model) func.calling_convention = cca.cc + func.prototype = cca.prototype dec = p.analyses.Decompiler(func, cfg=cfg.model) assert dec.codegen is not None, "Failed to decompile function %r." % func diff --git a/tests/test_function.py b/tests/test_function.py index 5f94ddfbf8f..36cc6144911 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -22,6 +22,7 @@ def test_function_serialization(): f = angr.knowledge_plugins.Function.parse(s) nose.tools.assert_equal(func_main.addr, f.addr) nose.tools.assert_equal(func_main.name, f.name) + nose.tools.assert_equal(func_main.is_prototype_guessed, f.is_prototype_guessed) def test_function_definition_application(): p = angr.Project(os.path.join(test_location, 'x86_64', 'fauxware'), auto_load_libs=False) diff --git a/tests/test_java.py b/tests/test_java.py index c5de5f8651a..940622c312d 100644 --- a/tests/test_java.py +++ b/tests/test_java.py @@ -248,7 +248,7 @@ def test_jni_primitive_datatypes(): run_method(project=project, method="MixedJava.test_short", - assert_locals={'s3': 0x1000, 's0': 11, 's5': 0xfffff000, 's9': 0}) + assert_locals={'s3': 0x1000, 's5': 0xfffff000, 's0': 11, 's9': 0}) run_method(project=project, method="MixedJava.test_int", diff --git a/tests/test_prototypes.py b/tests/test_prototypes.py index 24537f0f7b8..e781edbea9b 100644 --- a/tests/test_prototypes.py +++ b/tests/test_prototypes.py @@ -15,10 +15,7 @@ def test_function_prototype(): func = angr.knowledge_plugins.Function(proj.kb.functions, 0x100000, name='strcmp') func.prototype = angr.SIM_LIBRARIES['libc.so.6'].prototypes[func.name] - func.calling_convention = angr.calling_conventions.DEFAULT_CC[proj.arch.name]( - proj.arch, - func_ty=func.prototype, - ) + func.calling_convention = angr.calling_conventions.DEFAULT_CC[proj.arch.name](proj.arch) def test_find_prototype(): @@ -29,12 +26,9 @@ def test_find_prototype(): func = cfg.kb.functions.function(name='strcmp', plt=False) func.calling_convention = angr.calling_conventions.DEFAULT_CC[proj.arch.name](proj.arch) - # Calling SimCC.arg_locs() should fail when the function prototype is not provided. - nose.tools.assert_raises(ValueError, func.calling_convention.arg_locs) - func.find_declaration() - arg_locs = func.calling_convention.arg_locs() # now it won't fail + arg_locs = func.calling_convention.arg_locs(func.prototype) nose.tools.assert_equal(len(arg_locs), 2) nose.tools.assert_equal(arg_locs[0].reg_name, 'rdi') diff --git a/tests/test_sim_procedure.py b/tests/test_sim_procedure.py index 32367af1812..f47ae539f11 100644 --- a/tests/test_sim_procedure.py +++ b/tests/test_sim_procedure.py @@ -7,29 +7,50 @@ BIN_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'binaries') def test_ret_float(): - p = angr.load_shellcode(b'X', arch='i386') - class F1(angr.SimProcedure): def run(self): return 12.5 - p.hook(0x1000, F1(cc=p.factory.cc(func_ty=angr.sim_type.parse_file('float (x)();')[0]['x']))) - p.hook(0x2000, F1(cc=p.factory.cc(func_ty=angr.sim_type.parse_file('double (x)();')[0]['x']))) + p = angr.load_shellcode(b'X', arch='i386') + + p.hook(0x1000, F1(prototype='float (x)();')) + p.hook(0x2000, F1(prototype='double (x)();')) - s = p.factory.call_state(addr=0x1000, ret_addr=0) + s = p.factory.call_state(addr=0x1000, ret_addr=0, prototype='float(x)()') succ = s.step() nose.tools.assert_equal(len(succ.successors), 1) s2 = succ.flat_successors[0] nose.tools.assert_false(s2.regs.st0.symbolic) - nose.tools.assert_equal(s2.solver.eval(s2.regs.st0.get_bytes(4, 4).raw_to_fp()), 12.5) + nose.tools.assert_equal(s2.solver.eval(s2.regs.st0.raw_to_fp()), 12.5) - s = p.factory.call_state(addr=0x2000, ret_addr=0) + s = p.factory.call_state(addr=0x2000, ret_addr=0, prototype='double(x)()') succ = s.step() nose.tools.assert_equal(len(succ.successors), 1) s2 = succ.flat_successors[0] nose.tools.assert_false(s2.regs.st0.symbolic) nose.tools.assert_equal(s2.solver.eval(s2.regs.st0.raw_to_fp()), 12.5) + p = angr.load_shellcode(b'X', arch='amd64') + + p.hook(0x1000, F1(prototype='float (x)();')) + p.hook(0x2000, F1(prototype='double (x)();')) + + s = p.factory.call_state(addr=0x1000, ret_addr=0, prototype='float(x)()') + succ = s.step() + nose.tools.assert_equal(len(succ.successors), 1) + s2 = succ.flat_successors[0] + res = s2.registers.load('xmm0', 4).raw_to_fp() + nose.tools.assert_false(res.symbolic) + nose.tools.assert_equal(s2.solver.eval(res), 12.5) + + s = p.factory.call_state(addr=0x2000, ret_addr=0, prototype='double(x)()') + succ = s.step() + nose.tools.assert_equal(len(succ.successors), 1) + s2 = succ.flat_successors[0] + res = s2.registers.load('xmm0', 8).raw_to_fp() + nose.tools.assert_false(res.symbolic) + nose.tools.assert_equal(s2.solver.eval(res), 12.5) + def test_syscall_and_simprocedure(): bin_path = os.path.join(BIN_PATH, 'tests', 'cgc', 'CADET_00002') proj = angr.Project(bin_path, auto_load_libs=False) diff --git a/tests/test_stack_alignment.py b/tests/test_stack_alignment.py index b5af8828b63..fced45d6fb1 100644 --- a/tests/test_stack_alignment.py +++ b/tests/test_stack_alignment.py @@ -22,7 +22,7 @@ def test_alignment(): st.regs.sp = -1 # setup callsite with one argument (0x1337), "returning" to 0 - cc.setup_callsite(st, 0, [0x1337]) + cc.setup_callsite(st, 0, [0x1337], 'void foo(int x)') # ensure stack alignment is correct nose.tools.assert_true(st.solver.is_true(((st.regs.sp + cc.STACKARG_SP_DIFF) % cc.STACK_ALIGNMENT == 0)), @@ -37,7 +37,7 @@ def test_sys_v_abi_compliance(): st.regs.sp = -1 # setup callsite with one argument (0x1337), "returning" to 0 - cc.setup_callsite(st, 0, [0x1337]) + cc.setup_callsite(st, 0, [0x1337], 'void foo(int x)') # (rsp+8) must be aligned to 16 as required by System V ABI. # ref: https://raw.githubusercontent.com/wiki/hjl-tools/x86-psABI/x86-64-psABI-1.0.pdf , page 18t @@ -48,7 +48,7 @@ def test_initial_allocation(): # not strictly about alignment but it's about stack initialization so whatever p = Project(os.path.join(os.path.dirname(__file__), '../../binaries/tests/x86_64/true'), auto_load_libs=False) s = p.factory.entry_state(add_options={o.STRICT_PAGE_ACCESS}) - s.memory.load(s.regs.sp - 0x10000, 4) + s.memory.load(s.regs.sp - 0x10000, size=4) if __name__ == "__main__": test_alignment() diff --git a/tests/test_string.py b/tests/test_string.py index ee53e9c4480..8fa9a0e568b 100644 --- a/tests/test_string.py +++ b/tests/test_string.py @@ -329,7 +329,7 @@ def test_memcpy(): s = SimState(arch="AMD64", mode="symbolic", remove_options=angr.options.simplification) s.memory._maximum_symbolic_size = 0x2000000 size = s.solver.BVV(0x1000000, 64) - data = s.solver.BVS('giant', 8*0x1000000) + data = s.solver.BVS('giant', 8*0x1_000_000) dst_addr = s.solver.BVV(0x2000000, 64) src_addr = s.solver.BVV(0x4000000, 64) s.memory.store(src_addr, data) @@ -920,7 +920,7 @@ def test_strcmp(): s.memory.store(b_addr, b"heck\x00") r = strcmp(s, arguments=[a_addr, b_addr]) - nose.tools.assert_equal(s.solver.eval_upto(r, 2), [0xffffffffffffffff]) + nose.tools.assert_equal(s.solver.eval_upto(r, 2), [0xffffffff]) l.info("empty a, empty b") s = SimState(arch="AMD64", mode="symbolic") diff --git a/tests/test_stub_procedure_args.py b/tests/test_stub_procedure_args.py index 019098276d7..8c676b8b275 100644 --- a/tests/test_stub_procedure_args.py +++ b/tests/test_stub_procedure_args.py @@ -21,8 +21,8 @@ def test_stub_procedure_args(): stub = lib.get_stub('____a_random_stdcall_function__', archinfo.ArchX86()) stub.cc = SimCCStdcall(archinfo.ArchX86()) lib._apply_metadata(stub, archinfo.ArchX86()) - assert len(stub.cc.args) == 3 - assert all(isinstance(arg, SimStackArg) for arg in stub.cc.args) + assert len(stub.prototype.args) == 3 + assert all(isinstance(arg, SimStackArg) for arg in stub.cc.arg_locs(stub.prototype)) proj = angr.Project(os.path.join(binaries_base, "i386", "all"), auto_load_libs=False) state = proj.factory.blank_state() diff --git a/tests/test_unicorn.py b/tests/test_unicorn.py index 336b06caa98..9e9a1a8a998 100644 --- a/tests/test_unicorn.py +++ b/tests/test_unicorn.py @@ -46,7 +46,7 @@ def test_stops(): nose.tools.assert_equal(p_normal_angr.history.bbl_addrs.hardcopy, p_normal.history.bbl_addrs.hardcopy) # test STOP_STOPPOINT on an address that is not a basic block start - s_stoppoints = p.factory.call_state(p.loader.find_symbol("main").rebased_addr, 1, [], add_options=so.unicorn) + s_stoppoints = p.factory.call_state(p.loader.find_symbol("main").rebased_addr, 1, angr.PointerWrapper([]), add_options=so.unicorn) # this address is right before/after the bb for the stop_normal() function ends # we should not stop there, since that code is never hit @@ -181,11 +181,10 @@ def test_fp(): p = angr.Project(os.path.join(test_location, 'binaries', 'tests', 'i386', 'manyfloatsum'), auto_load_libs=False) for function in ('sum_floats', 'sum_combo', 'sum_segregated', 'sum_doubles', 'sum_combo_doubles', 'sum_segregated_doubles'): - cc = p.factory.cc(func_ty=type_cache[function]) - args = list(range(len(cc.func_ty.args))) + args = list(range(len(type_cache[function].args))) answer = float(sum(args)) addr = p.loader.find_symbol(function).rebased_addr - my_callable = p.factory.callable(addr, cc=cc) + my_callable = p.factory.callable(addr, prototype=type_cache[function]) my_callable.set_base_state(p.factory.blank_state(add_options=so.unicorn)) result = my_callable(*args) nose.tools.assert_false(result.symbolic) @@ -255,7 +254,7 @@ def test_inspect(): def main_state(argc, add_options=None): add_options = add_options or so.unicorn main_addr = p.loader.find_symbol("main").rebased_addr - return p.factory.call_state(main_addr, argc, [], add_options=add_options) + return p.factory.call_state(main_addr, argc, angr.PointerWrapper([]), add_options=add_options) # test breaking on specific addresses s_break_addr = main_state(1) @@ -299,7 +298,7 @@ def test_explore(): def main_state(argc, add_options=None): add_options = add_options or so.unicorn main_addr = p.loader.find_symbol("main").rebased_addr - return p.factory.call_state(main_addr, argc, [], add_options=add_options) + return p.factory.call_state(main_addr, argc, angr.PointerWrapper([]), add_options=add_options) addr = 0x08048479 s_explore = main_state(1) @@ -321,7 +320,7 @@ def test_single_step(): def main_state(argc, add_options=None): add_options = add_options or so.unicorn main_addr = p.loader.find_symbol("main").rebased_addr - return p.factory.call_state(main_addr, argc, [], add_options=add_options) + return p.factory.call_state(main_addr, argc, angr.PointerWrapper([]), add_options=add_options) s_main = main_state(1) @@ -364,4 +363,4 @@ def main_state(argc, add_options=None): fo = ft[0] fa = ft[1:] print('...', fa) - fo(*fa) \ No newline at end of file + fo(*fa)