Skip to content

Commit

Permalink
Handling of Named Type references (#99)
Browse files Browse the repository at this point in the history
* Handling of named references

Fix the handling of type cache and read bytes from memory

* review changes

* Add vector type lifting

* add remill compat header for vector type

Co-authored-by: AkshayK <[email protected]>
  • Loading branch information
kumarak and AkshayK authored Feb 16, 2021
1 parent 659ac45 commit 0036c42
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 59 deletions.
93 changes: 55 additions & 38 deletions lib/Lift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <llvm/IR/Module.h>
#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Utils.h>

#include <remill/BC/Compat/VectorType.h>
#include <remill/BC/Util.h>

#include <algorithm>
Expand Down Expand Up @@ -510,8 +512,10 @@ namespace {
static llvm::APInt ReadValueFromMemory(const uint64_t addr, const uint64_t size,
const remill::Arch *arch,
const Program &program) {
llvm::APInt result(size, 0);
for (auto i = 0u; i < (size / 8); ++i) {

// create an instance of precision integer of size*8 bits
llvm::APInt result(size * 8, 0);
for (auto i = 0u; i < size; ++i) {
auto byte_val = program.FindByte(addr + i).Value();
if (remill::IsError(byte_val)) {
LOG(ERROR) << "Unable to read value of byte at " << std::hex << addr + i
Expand All @@ -524,8 +528,8 @@ static llvm::APInt ReadValueFromMemory(const uint64_t addr, const uint64_t size,
}

// NOTE(artem): LLVM's APInt does not handle byteSwap()
// for size 8, leading to a segfault. Guard against it here.
if (arch->MemoryAccessIsLittleEndian() && size > 8) {
// for size 1, leading to a segfault. Guard against it here.
if (arch->MemoryAccessIsLittleEndian() && size > 1) {
result = result.byteSwap();
}

Expand All @@ -540,61 +544,74 @@ CreateConstFromMemory(const uint64_t addr, llvm::Type *type,
llvm::Constant *result{nullptr};
switch (type->getTypeID()) {
case llvm::Type::IntegerTyID: {
const auto size = dl.getTypeSizeInBits(type);
const auto size = dl.getTypeAllocSize(type);
auto val = ReadValueFromMemory(addr, size, arch, program);
result = llvm::ConstantInt::get(type, val);
} break;

case llvm::Type::PointerTyID: {
const auto pointer_type = llvm::dyn_cast<llvm::PointerType>(type);
const auto size = dl.getTypeAllocSize(type);
auto val = ReadValueFromMemory(addr, size, arch, program);
result = llvm::Constant::getIntegerValue(pointer_type, val);
} break;

case llvm::Type::StructTyID: {

// Take apart the structure type, recursing into each element
// so that we can create a constant structure
auto struct_type = llvm::dyn_cast<llvm::StructType>(type);

auto num_elms = struct_type->getNumElements();
auto elm_offset = 0;
const auto struct_type = llvm::dyn_cast<llvm::StructType>(type);
const auto layout = dl.getStructLayout(struct_type);
const auto num_elms = struct_type->getStructNumElements();
std::vector<llvm::Constant *> initializer_list;
initializer_list.reserve(num_elms);

std::vector<llvm::Constant *> const_list;

for (std::uint64_t i = 0U; i < num_elms; ++i) {
auto elm_type = struct_type->getElementType(i);
auto elm_size = dl.getTypeSizeInBits(elm_type);

auto const_elm =
CreateConstFromMemory(addr + elm_offset, elm_type, arch,
program, module);

const_list.push_back(const_elm);
elm_offset += elm_size / 8;
for (auto i = 0u; i < num_elms; ++i) {
const auto elm_type = struct_type->getStructElementType(i);
const auto offset = layout->getElementOffset(i);
auto const_elm = CreateConstFromMemory(addr + offset, elm_type, arch,
program, module);
initializer_list.push_back(const_elm);
}

result = llvm::ConstantStruct::get(struct_type,
llvm::ArrayRef(const_list));
result = llvm::ConstantStruct::get(struct_type, initializer_list);
} break;

case llvm::Type::ArrayTyID: {

// Traverse through all the elements of array and create the initializer
const auto array_type = llvm::dyn_cast<llvm::ArrayType>(type);
const auto elm_type = type->getArrayElementType();
const auto elm_size = dl.getTypeSizeInBits(elm_type);
const auto elm_size = dl.getTypeAllocSize(elm_type);
const auto num_elms = type->getArrayNumElements();
std::string bytes(dl.getTypeSizeInBits(type) / 8, '\0');
std::vector<llvm::Constant *> initializer_list;
initializer_list.reserve(num_elms);

for (auto i = 0u; i < num_elms; ++i) {
const auto elm_offset = i * (elm_size / 8);
const auto src =
ReadValueFromMemory(addr + elm_offset, elm_size, arch, program)
.getRawData();
const auto dst = bytes.data() + elm_offset;
std::memcpy(dst, src, elm_size / 8);
}
if (elm_size == 8) {
result = llvm::ConstantDataArray::getString(module.getContext(), bytes,
/*AddNull=*/false);
} else {
result = llvm::ConstantDataArray::getRaw(bytes, num_elms, elm_type);
const auto elm_offset = i * elm_size;
auto const_elm = CreateConstFromMemory(addr + elm_offset, elm_type,
arch, program, module);
initializer_list.push_back(const_elm);
}
result = llvm::ConstantArray::get(array_type, initializer_list);
} break;

case llvm::GetFixedVectorTypeId(): {
const auto vec_type = llvm::dyn_cast<llvm::FixedVectorType>(type);
const auto num_elms = vec_type->getNumElements();
const auto elm_type = vec_type->getElementType();
const auto elm_size = dl.getTypeAllocSize(elm_type);
std::vector<llvm::Constant *> initializer_list;
initializer_list.reserve(num_elms);

for (auto i = 0u; i < num_elms; ++i) {
const auto elm_offset = i * elm_size;
auto const_elm = CreateConstFromMemory(addr + elm_offset, elm_type,
arch, program, module);
initializer_list.push_back(const_elm);
}
result = llvm::ConstantVector::get(initializer_list);
}break;

default:
LOG(FATAL) << "Unhandled LLVM Type: " << remill::LLVMThingToString(type);
break;
Expand Down
107 changes: 86 additions & 21 deletions python/anvill/binja.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,66 @@ def _convert_bn_llil_type(
ret = IntegerType(reg_size_bytes, True)
return ret

def _cache_key(tinfo: bn.types.Type) -> str:
""" Convert bn Type instance to cache key"""
return str(tinfo)

def _convert_named_type_reference(
bv, tinfo: bn.types.Type, cache) -> Type:
""" Convert named type references into a `Type` instance"""
if tinfo.type_class != bn.TypeClass.NamedTypeReferenceClass:
return

named_tinfo = tinfo.named_type_reference
if (named_tinfo.type_class
== bn.NamedTypeReferenceClass.StructNamedTypeClass):
# Get the bn struct type and recursively recover the elements
ref_type = bv.get_type_by_name(named_tinfo.name);
struct_type = ref_type.structure
ret = StructureType()
cache[_cache_key(struct_type)] = ret
for elem in struct_type.members:
ret.add_element_type(_convert_bn_type(bv, elem.type, cache))
return ret

elif (named_tinfo.type_class
== bn.NamedTypeReferenceClass.UnionNamedTypeClass):
# Get the union type and recover the member elements
ref_type = bv.get_type_by_name(named_tinfo.name);
struct_type = ref_type.structure
ret = UnionType()
cache[_cache_key(struct_type)] = ret
for elem in struct_type.union.members:
ret.add_element_type(_convert_bn_type(bv, elem.type, cache))
return ret

elif (named_tinfo.type_class
== bn.NamedTypeReferenceClass.TypedefNamedTypeClass):
ref_type = bv.get_type_by_name(named_tinfo.name);
ret = TypedefType()
cache[_cache_key(ref_type)] = ret
ret.set_underlying_type(_convert_bn_type(bv, ref_type, cache))
return ret

def _convert_bn_type(tinfo: bn.types.Type, cache):
elif (named_tinfo.type_class
== bn.NamedTypeReferenceClass.EnumNamedTypeClass):
# Set the underlying type int of size width
ref_type = bv.get_type_by_name(named_tinfo.name);
ret = EnumType()
cache[_cache_key(ref_type)] = ret
ret.set_underlying_type(IntegerType(ref_type.width, False))
return ret

else:
DEBUG("WARNING: Unknown named type {} not handled".format(named_tinfo))


def _convert_bn_type(bv, tinfo: bn.types.Type, cache):
"""Convert an bn `Type` instance into a `Type` instance."""
if str(tinfo) in cache:
return cache[str(tinfo)]

cache_key = _cache_key(tinfo)
if cache_key in cache:
return cache[cache_key]

# Void type.
if tinfo.type_class == bn.TypeClass.VoidTypeClass:
Expand All @@ -186,17 +241,17 @@ def _convert_bn_type(tinfo: bn.types.Type, cache):
# Pointer, array, or function.
elif tinfo.type_class == bn.TypeClass.PointerTypeClass:
ret = PointerType()
cache[str(tinfo)] = ret
ret.set_element_type(_convert_bn_type(tinfo.element_type, cache))
cache[cache_key] = ret
ret.set_element_type(_convert_bn_type(bv, tinfo.element_type, cache))
return ret

elif tinfo.type_class == bn.TypeClass.FunctionTypeClass:
ret = FunctionType()
cache[str(tinfo)] = ret
ret.set_return_type(_convert_bn_type(tinfo.return_value, cache))
cache[cache_key] = ret
ret.set_return_type(_convert_bn_type(bv, tinfo.return_value, cache))

for var in tinfo.parameters:
ret.add_parameter_type(_convert_bn_type(var.type, cache))
ret.add_parameter_type(_convert_bn_type(bv, var.type, cache))

if tinfo.has_variable_arguments:
ret.set_is_variadic()
Expand All @@ -205,19 +260,26 @@ def _convert_bn_type(tinfo: bn.types.Type, cache):

elif tinfo.type_class == bn.TypeClass.ArrayTypeClass:
ret = ArrayType()
cache[str(tinfo)] = ret
ret.set_element_type(_convert_bn_type(tinfo.element_type, cache))
cache[cache_key] = ret
ret.set_element_type(_convert_bn_type(bv, tinfo.element_type, cache))
ret.set_num_elements(tinfo.count)
return ret

elif tinfo.type_class == bn.TypeClass.StructureTypeClass:
ret = StructureType()
cache[str(tinfo)] = ret
cache[cache_key] = ret

for elem in tinfo.structure.members:
ret.add_element_type(_convert_bn_type(bv, elem.type, cache))

return ret

elif tinfo.type_class == bn.TypeClass.EnumerationTypeClass:
# The underlying type of enum will be an Interger of size
# tinfo.width
ret = EnumType()
cache[str(tinfo)] = ret
cache[cache_key] = ret
ret.set_underlying_type(IntegerType(tinfo.width, False))
return ret

elif tinfo.type_class == bn.TypeClass.BoolTypeClass:
Expand All @@ -238,16 +300,18 @@ def _convert_bn_type(tinfo: bn.types.Type, cache):
width = tinfo.width
return FloatingPointType(width)

elif tinfo.type_class == bn.TypeClass.NamedTypeReferenceClass:
ret = _convert_named_type_reference(bv, tinfo, cache)
return ret

elif tinfo.type_class in [
bn.TypeClass.VarArgsTypeClass,
bn.TypeClass.ValueTypeClass,
bn.TypeClass.NamedTypeReferenceClass,
bn.TypeClass.WideCharTypeClass,
]:
err_type_class = {
bn.TypeClass.VarArgsTypeClass : "VarArgsTypeClass",
bn.TypeClass.ValueTypeClass : "ValueTypeClass",
bn.TypeClass.NamedTypeReferenceClass : "NamedTypeReferenceClass",
bn.TypeClass.WideCharTypeClass : "WideCharTypeClass",
}
DEBUG("WARNING: Unhandled type class {}".format(err_type_class[tinfo.type_class]))
Expand All @@ -256,7 +320,7 @@ def _convert_bn_type(tinfo: bn.types.Type, cache):
raise UnhandledTypeException("Unhandled type: {}".format(str(tinfo)), tinfo)


def get_type(ty):
def get_type(bv, ty):
"""Type class that gives access to type sizes, printings, etc."""

if isinstance(ty, Type):
Expand All @@ -266,7 +330,7 @@ def get_type(ty):
return ty.type()

elif isinstance(ty, bn.Type):
return _convert_bn_type(ty, {})
return _convert_bn_type(bv, ty, {})

if not ty:
return VoidType()
Expand Down Expand Up @@ -416,7 +480,7 @@ def _extract_types_mlil(
):
reg_name = bv.arch.get_reg_name(item_or_list.storage)
results.append(
(reg_name, _convert_bn_type(item_or_list.type, {}), None)
(reg_name, _convert_bn_type(bv, item_or_list.type, {}), None)
)
return results

Expand Down Expand Up @@ -525,7 +589,8 @@ def get_variable_impl(self, address):

arch = self._arch
bn_var = self._bv.get_data_var_at(address)
var_type = get_type(bn_var.type)
var_type = get_type(self._bv, bn_var.type)

# fall back onto an array of bytes type for variables
# of an unknown (void) type.
if isinstance(var_type, VoidType):
Expand All @@ -550,15 +615,15 @@ def get_function_impl(self, address):
"No function defined at or containing address {:x}".format(address)
)

func_type = get_type(bn_func.function_type)
func_type = get_type(self._bv, bn_func.function_type)
calling_conv = CallingConvention(arch, bn_func)

index = 0
param_list = []
for var in bn_func.parameter_vars:
source_type = var.source_type
var_type = var.type
arg_type = get_type(var_type)
arg_type = get_type(self._bv, var_type)

if source_type == bn.VariableSourceType.RegisterVariableSourceType:
if (
Expand Down Expand Up @@ -590,7 +655,7 @@ def get_function_impl(self, address):
index += 1

ret_list = []
retTy = get_type(bn_func.return_type)
retTy = get_type(self._bv, bn_func.return_type)
if not isinstance(retTy, VoidType):
for reg in calling_conv.return_regs:
loc = Location()
Expand Down

0 comments on commit 0036c42

Please sign in to comment.