Skip to content

Commit

Permalink
handle function calls on architectures that save return addresses in …
Browse files Browse the repository at this point in the history
…a register
  • Loading branch information
Kyle-Kyle committed Feb 14, 2024
1 parent a86e4b9 commit 4b30ec2
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 58 deletions.
4 changes: 4 additions & 0 deletions angrop/chain_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _build_reg_setting_chain(self, gadgets, modifiable_memory_range, register_di
badbytes=self.badbytes)

# iterate through the stack values that need to be in the chain
# HACK: handle jump register separately because of angrop's broken
# assumptions on x86's ret behavior
if gadgets[-1].transit_type == 'jmp_reg':
stack_change += arch_bytes
for i in range(stack_change // bytes_per_pop):
sym_word = test_symbolic_state.memory.load(sp + bytes_per_pop*i, bytes_per_pop,
endness=self.project.arch.memory_endness)
Expand Down
79 changes: 37 additions & 42 deletions angrop/chain_builder/func_caller.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging

import angr
from angr.calling_conventions import SimRegArg, SimStackArg

from .builder import Builder
from .. import rop_utils
from ..errors import RopException
from ..rop_gadget import RopGadget

Expand All @@ -14,13 +16,19 @@ class FuncCaller(Builder):
calling convention
"""

def _func_call(self, func_gadget, cc, args, extra_regs=None, modifiable_memory_range=None, preserve_regs=None,
use_partial_controllers=False, needs_return=True):
def _func_call(self, func_gadget, cc, args, extra_regs=None, preserve_regs=None, needs_return=True):
"""
func_gadget: the address of the function to invoke
cc: calling convention
args: the arguments to the function
extra_regs: what extra registers to set besides the function arguments, useful for invoking system calls
preserve_res: what registers preserve
needs_return: whether we need to cleanup stack after the function invocation, setting this to False will result in a shorter chain
"""
assert type(args) in [list, tuple], "function arguments must be a list or tuple!"

preserve_regs = set(preserve_regs) if preserve_regs else set()
arch_bytes = self.project.arch.bytes
registers = {} if extra_regs is None else extra_regs
if preserve_regs is None:
preserve_regs = []

# distinguish register and stack arguments
register_arguments = args
Expand All @@ -30,55 +38,42 @@ def _func_call(self, func_gadget, cc, args, extra_regs=None, modifiable_memory_r
stack_arguments = args[len(cc.ARG_REGS):]

# set register arguments
registers = {} if extra_regs is None else extra_regs
for arg, reg in zip(register_arguments, cc.ARG_REGS):
registers[reg] = arg
for reg in preserve_regs:
registers.pop(reg, None)
chain = self.chain_builder.set_regs(modifiable_memory_range=modifiable_memory_range,
use_partial_controllers=use_partial_controllers,
**registers)
chain = self.chain_builder.set_regs(**registers)

# invoke the function
chain.add_gadget(func_gadget)
for _ in range(func_gadget.stack_change//arch_bytes-1):
chain.add_value(self._get_fill_val())

# we are done here if there is no stack arguments
if not stack_arguments:
# we are done here if we don't need to return
if not needs_return:
return chain

# handle stack arguments:
# 1. we need to pop the arguments after use
# 2. push the stack arguments

# step 1: find a stack cleaner (a gadget that can pop all the stack args)
# with the smallest stack change
stack_cleaner = None
if needs_return:
for g in self.chain_builder.gadgets:
# just pop plz
if g.mem_reads or g.mem_writes or g.mem_changes:
continue
# at least can pop all the args
if g.stack_change < arch_bytes * (len(stack_arguments)+1):
continue

if stack_cleaner is None or g.stack_change < stack_cleaner.stack_change:
stack_cleaner = g

if stack_cleaner is None:
raise RopException(f"Fail to find a stack cleaner that can pop {len(stack_arguments)} words!")

# in case we can't find a stack_cleaner and we don't need to return
if stack_cleaner is None:
stack_cleaner = RopGadget(self._get_fill_val())
stack_cleaner.stack_change = arch_bytes * (len(stack_arguments)+1)

chain.add_gadget(stack_cleaner)
stack_arguments += [self._get_fill_val()]*(stack_cleaner.stack_change//arch_bytes - len(stack_arguments)-1)
for arg in stack_arguments:
chain.add_value(arg)

# now we need to cleanly finish the calling convention
# 1. handle stack arguments
# 2. handle function return address to maintain the control flow
if stack_arguments:
cleaner = self.chain_builder.shift((len(stack_arguments)+1)*arch_bytes) # +1 for itself
chain.add_gadget(cleaner._gadgets[0])
for arg in stack_arguments:
chain.add_value(arg)

# handle return address
if not isinstance(cc.RETURN_ADDR, (SimStackArg, SimRegArg)):
raise RopException(f"What is the calling convention {cc} I'm dealing with?")
if isinstance(cc.RETURN_ADDR, SimRegArg):
# now we know this function will return to a specific register
# so we need to set the return address before invoking the function
reg_name = cc.RETURN_ADDR.reg_name
shifter = self.chain_builder._shifter.shift(self.project.arch.bytes)
next_ip = rop_utils.cast_rop_value(shifter._gadgets[0].addr, self.project)
pre_chain = self.chain_builder.set_regs(**{reg_name: next_ip})
chain = pre_chain + chain
return chain

def func_call(self, address, args, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions angrop/chain_builder/reg_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def verify(self, chain, preserve_regs, registers):
return False
return True

def _maybe_fix_jump_chain(self, chain, preserve_regs):
all_changed_regs = set()
for g in chain._gadgets[:-1]:
all_changed_regs.update(g.changed_regs)
jump_reg = chain._gadgets[-1].jump_reg
if jump_reg in all_changed_regs:
return chain
shifter = self.chain_builder._shifter.shift(self.project.arch.bytes)
next_ip = rop_utils.cast_rop_value(shifter._gadgets[0].addr, self.project)
new = self.run(preserve_regs=preserve_regs, **{jump_reg: next_ip})
return new + chain

def run(self, modifiable_memory_range=None, use_partial_controllers=False, preserve_regs=None, **registers):
if len(registers) == 0:
return RopChain(self.project, None, badbytes=self.badbytes)
Expand Down Expand Up @@ -94,6 +106,8 @@ def run(self, modifiable_memory_range=None, use_partial_controllers=False, pres
chain = self._build_reg_setting_chain(gadgets, modifiable_memory_range,
registers, stack_change)
chain._concretize_chain_values()
if chain._gadgets[-1].transit_type == 'jmp_reg':
chain = self._maybe_fix_jump_chain(chain, preserve_regs)
if self.verify(chain, preserve_regs, registers):
#self._chain_cache[reg_tuple].append(gadgets)
return chain
Expand Down
4 changes: 2 additions & 2 deletions angrop/chain_builder/shifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def _filter_gadgets(self, gadgets):
"""
filter gadgets having the same effect
"""
# we don't like gadgets with any memory accesses
gadgets = [x for x in gadgets if x.num_mem_access == 0]
# we don't like gadgets with any memory accesses or jump gadgets
gadgets = [x for x in gadgets if x.num_mem_access == 0 and x.transit_type != 'jmp_reg']

# now do the standard filtering
gadgets = set(gadgets)
Expand Down
9 changes: 4 additions & 5 deletions angrop/rop_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ def payload_code(self, constraints=None, print_instructions=True):
payload = ""
payload += 'chain = b""\n'

gadget_dict = {g.addr:g for g in self._gadgets}
concrete_vals = self._concretize_chain_values(constraints)
for value, rebased in concrete_vals:

instruction_code = ""
if print_instructions:
value_in_gadget = value
if value_in_gadget in gadget_dict:
asmstring = rop_utils.gadget_to_asmstring(self._p, gadget_dict[value_in_gadget])
if print_instructions and rebased:
seg = self._p.loader.find_segment_containing(value)
if seg and seg.is_executable:
asmstring = rop_utils.addr_to_asmstring(self._p, value)
if asmstring != "":
instruction_code = "\t# " + asmstring

Expand Down
2 changes: 2 additions & 0 deletions angrop/rop_gadget.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def reg_set_same_effect(self, other):
return False
if self.reg_dependencies != other.reg_dependencies:
return False
if self.transit_type != other.transit_type:
return False
return True

def reg_set_better_than(self, other):
Expand Down
6 changes: 4 additions & 2 deletions angrop/rop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
from .errors import RegNotFoundException, RopException
from .rop_value import RopValue

def addr_to_asmstring(project, addr):
block = project.factory.block(addr)
return "; ".join(["%s %s" %(i.mnemonic, i.op_str) for i in block.capstone.insns])

def gadget_to_asmstring(project, gadget):
if not gadget.block_length:
return ""
block = project.factory.block(gadget.addr)
return "; ".join(["%s %s" %(i.mnemonic, i.op_str) for i in block.capstone.insns])
return addr_to_asmstring(project, gadget.addr)

def get_ast_dependency(ast):
"""
Expand Down
17 changes: 10 additions & 7 deletions tests/test_chainbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,23 @@ def test_i386_func_call():
def test_arm_func_call():
cache_path = os.path.join(CACHE_DIR, "armel_glibc_2.31")
proj = angr.Project(os.path.join(BIN_DIR, "tests", "armel", "libc-2.31.so"), auto_load_libs=False)
proj.hook_symbol('puts', angr.SIM_PROCEDURES['libc']['puts']())
rop = proj.analyses.ROP(fast_mode=False, only_check_near_rets=False, is_thumb=True)
if os.path.exists(cache_path):
rop.load_gadgets(cache_path)
else:
rop.find_gadgets()
rop.save_gadgets(cache_path)

chain = rop.func_call("realloc", [0xcafebabe, 0xa])
print(chain)
#import IPython; IPython.embed()
#chain = rop.set_regs(r4=0x4141412c, r5=0x42424242)
#state = chain.exec()
#assert state.regs.r4.concrete_value == 0x4141412c
#assert state.regs.r5.concrete_value == 0x42424242
proj.hook_symbol('write', angr.SIM_PROCEDURES['posix']['write']())
chain = rop.func_call("write", [1, 0x4E15F0, 9])
state = chain.exec()
assert state.posix.dumps(1) == b'malloc.c\x00'

proj.hook_symbol('write', angr.SIM_PROCEDURES['posix']['write']())
chain = rop.func_call("puts", [0x4E15F0])
state = chain.exec()
assert state.posix.dumps(1) == b'malloc.c\n'

def test_i386_syscall():
cache_path = os.path.join(CACHE_DIR, "bronze_ropchain")
Expand Down

0 comments on commit 4b30ec2

Please sign in to comment.