diff --git a/.clang-tidy b/.clang-tidy index d509f2c..1ce5810 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -11,7 +11,12 @@ Checks: "*,\ -llvm-include-order,\ -llvmlibc-*,\ -modernize-use-nodiscard,\ - -misc-non-private-member-variables-in-classes" + -misc-non-private-member-variables-in-classes,\ + -modernize-use-trailing-return-type,\ + -readability-convert-member-functions-to-static, \ + -misc-no-recursion, \ + -google-build-using-namespace, \ + -cppcoreguidelines-owning-memory" WarningsAsErrors: '' CheckOptions: - key: 'bugprone-argument-comment.StrictMode' @@ -47,15 +52,15 @@ CheckOptions: value: 'true' # These seem to be the most common identifier styles - key: 'readability-identifier-naming.AbstractClassCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.ClassCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.ClassConstantCase' - value: 'lower_case' + value: 'camelBack' - key: 'readability-identifier-naming.ClassMemberCase' - value: 'lower_case' + value: 'camelBack' - key: 'readability-identifier-naming.ClassMethodCase' - value: 'lower_case' + value: 'camelBack' - key: 'readability-identifier-naming.ConstantCase' value: 'lower_case' - key: 'readability-identifier-naming.ConstantMemberCase' @@ -71,9 +76,9 @@ CheckOptions: - key: 'readability-identifier-naming.ConstexprVariableCase' value: 'lower_case' - key: 'readability-identifier-naming.EnumCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.EnumConstantCase' - value: 'lower_case' + value: 'UPPER_CASE' - key: 'readability-identifier-naming.FunctionCase' value: 'lower_case' - key: 'readability-identifier-naming.GlobalConstantCase' @@ -101,7 +106,7 @@ CheckOptions: - key: 'readability-identifier-naming.MemberCase' value: 'lower_case' - key: 'readability-identifier-naming.MethodCase' - value: 'lower_case' + value: 'camelBack' - key: 'readability-identifier-naming.NamespaceCase' value: 'lower_case' - key: 'readability-identifier-naming.ParameterCase' @@ -113,13 +118,13 @@ CheckOptions: - key: 'readability-identifier-naming.PrivateMemberCase' value: 'lower_case' - key: 'readability-identifier-naming.PrivateMemberPrefix' - value: 'm_' + value: '' - key: 'readability-identifier-naming.PrivateMethodCase' value: 'lower_case' - key: 'readability-identifier-naming.ProtectedMemberCase' value: 'lower_case' - key: 'readability-identifier-naming.ProtectedMemberPrefix' - value: 'm_' + value: '' - key: 'readability-identifier-naming.ProtectedMethodCase' value: 'lower_case' - key: 'readability-identifier-naming.PublicMemberCase' @@ -133,23 +138,23 @@ CheckOptions: - key: 'readability-identifier-naming.StaticVariableCase' value: 'lower_case' - key: 'readability-identifier-naming.StructCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.TemplateParameterCase' value: 'CamelCase' - key: 'readability-identifier-naming.TemplateTemplateParameterCase' value: 'CamelCase' - key: 'readability-identifier-naming.TypeAliasCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.TypedefCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.TypeTemplateParameterCase' value: 'CamelCase' - key: 'readability-identifier-naming.UnionCase' - value: 'lower_case' + value: 'CamelCase' - key: 'readability-identifier-naming.ValueTemplateParameterCase' value: 'CamelCase' - key: 'readability-identifier-naming.VariableCase' value: 'lower_case' - key: 'readability-identifier-naming.VirtualMethodCase' - value: 'lower_case' + value: 'camelBack' ... diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d09aa74..48d6a87 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,5 @@ # # Copyright (c) 2024, Trail of Bits, Inc. -# All rights reserved. # # This source code is licensed in accordance with the terms specified in # the LICENSE file found in the root directory of this source tree. diff --git a/.github/workflows/devcontainer.yml b/.github/workflows/devcontainer.yml index b1dbe8d..018090d 100644 --- a/.github/workflows/devcontainer.yml +++ b/.github/workflows/devcontainer.yml @@ -1,6 +1,5 @@ # # Copyright (c) 2024, Trail of Bits, Inc. -# All rights reserved. # # This source code is licensed in accordance with the terms specified in # the LICENSE file found in the root directory of this source tree. diff --git a/.github/workflows/prerelease.yml b/.github/workflows/prerelease.yml index f445abc..c50c025 100644 --- a/.github/workflows/prerelease.yml +++ b/.github/workflows/prerelease.yml @@ -1,6 +1,5 @@ # # Copyright (c) 2024, Trail of Bits, Inc. -# All rights reserved. # # This source code is licensed in accordance with the terms specified in # the LICENSE file found in the root directory of this source tree. diff --git a/CMakeLists.txt b/CMakeLists.txt index 90a5948..fca8f81 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,7 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. cmake_minimum_required(VERSION 3.25) @@ -94,6 +95,17 @@ list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") include(AddMLIR) +find_package(Clang ${LLVM_PACKAGE_VERSION} CONFIG REQUIRED) +message(STATUS "Using ClangConfig.cmake in: ${Clang_DIR}") +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") + +find_package(VAST CONFIG REQUIRED) +message(STATUS "Using VASTConfig.cmake in: ${VAST_DIR}") + + +find_package(gap CONFIG REQUIRED) +message(STATUS "Using gapConfig.cmake in: ${gap_DIR}") + set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) @@ -252,4 +264,4 @@ if (PATCHESTRY_INSTALL) DESTINATION ${PATCHESTRY_CMAKE_INSTALL_DIR} ) -endif() \ No newline at end of file +endif() diff --git a/include/CMakeLists.txt b/include/CMakeLists.txt index 23f3625..999a428 100644 --- a/include/CMakeLists.txt +++ b/include/CMakeLists.txt @@ -1,5 +1,6 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_subdirectory(patchestry) diff --git a/include/patchestry/AST/ASTConsumer.hpp b/include/patchestry/AST/ASTConsumer.hpp new file mode 100644 index 0000000..1c44bf7 --- /dev/null +++ b/include/patchestry/AST/ASTConsumer.hpp @@ -0,0 +1,342 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace patchestry::ast { + using namespace patchestry::ghidra; + + using ASTTypeMap = std::unordered_map< std::string, clang::QualType >; + using ASTDeclMap = std::unordered_map< std::string, clang::Decl * >; + + class PcodeASTConsumer : public clang::ASTConsumer + { + public: + explicit PcodeASTConsumer( + clang::CompilerInstance &ci, Program &prog, std::string &outfile + ) + : program(prog) + , ci(ci) + , outfile(outfile) + , codegen(std::make_unique< CodeGenerator >(ci)) + , type_builder(std::make_unique< TypeBuilder >(ci.getASTContext())) {} + + void HandleTranslationUnit(clang::ASTContext &ctx) override; + + private: + void set_sema_context(clang::DeclContext *dc); + + void write_to_file(void); + + void create_globals(clang::ASTContext &ctx, VariableMap &serialized_variables); + + void create_functions( + clang::ASTContext &ctx, FunctionMap &serialized_functions, TypeMap &serialized_types + ); + + clang::QualType + create_function_prototype(clang::ASTContext &ctx, const FunctionPrototype &proto); + + std::vector< clang::ParmVarDecl * > create_default_paramaters( + clang::ASTContext &ctx, clang::FunctionDecl *func_decl, + const FunctionPrototype &proto + ); + + clang::FunctionDecl *create_function_declaration( + clang::ASTContext &ctx, const Function &function, bool is_definition = false + ); + + clang::FunctionDecl * + create_function_definition(clang::ASTContext &ctx, const Function &function); + + std::vector< clang::Stmt * > create_function_body( + clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function + ); + + void create_label_for_basic_blocks( + clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function + ); + + std::vector< clang::Stmt * > create_basic_block( + clang::ASTContext &ctx, const Function &function, const BasicBlock &block + ); + + std::pair< clang::Stmt *, bool > + create_operation(clang::ASTContext &ctx, const Function &function, const Operation &op); + + clang::DeclStmt *create_decl_stmt(clang::ASTContext &ctx, clang::Decl *decl); + + clang::Stmt *create_call_stmt(clang::ASTContext &ctx, const Operation &op); + + clang::Stmt *create_branch_stmt(clang::ASTContext &ctx, const Operation &branch); + + clang::Stmt *create_return_stmt( + clang::ASTContext &ctx, const Function &function, const Operation &ret_op + ); + + clang::QualType get_varnode_type(clang::ASTContext &ctx, const Varnode &vnode); + + clang::Stmt *create_varnode( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, + bool is_input = true + ); + + clang::Stmt *create_function( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, + bool is_input = true + ); + + clang::Stmt *create_local( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, + bool is_input = true + ); + + clang::Stmt *create_constant(clang::ASTContext &ctx, const Varnode &vnode); + + clang::Stmt *create_parameter( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, + bool is_input = true + ); + + clang::Stmt *create_global( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, + bool is_input = true + ); + + clang::Stmt *create_temporary( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, + bool is_input = true + ); + + // List of functions to generate AST node for Pcode operations + + // OP_DECLARE_LOCAL + std::pair< clang::Stmt *, bool > create_declare_local( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + // OP_DECLARE_PARAMETER + std::pair< clang::Stmt *, bool > create_declare_parameter( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > create_declare_temporary( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > + create_copy(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_load(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_store(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_branch(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_cbranch(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_branchind(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_call(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_callind(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_userdefined( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > + create_return(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_piece(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_subpiece(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_equal(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_int_notequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > + create_int_less(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_sless(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_int_lessequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > create_int_slessequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > + create_int_zext(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_sext(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_add(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_sub(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_carry(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_int_scarry( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > create_int_sborrow( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > + create_int_2comp(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_int_mult(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_int_div(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_int_rem(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_int_sdiv(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_int_srem(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_bool_negate( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > + create_bool_or(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_float_equal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > create_float_notequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > create_float_less( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > create_float_lessequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > + create_float_add(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_float_sub(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > create_float_mult( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > + create_float_div(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_float_neg(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_float_abs(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > create_float_sqrt( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > create_float_ceil( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > create_float_floor( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > create_float_round( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > + create_float_nan(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > + create_int2float(clang::ASTContext &ctx, const Function &function, const Operation &op); + std::pair< clang::Stmt *, bool > create_float2float( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + std::pair< clang::Stmt *, bool > + create_trunc(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_ptrsub(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_ptradd(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > + create_cast(clang::ASTContext &ctx, const Function &function, const Operation &op); + + std::pair< clang::Stmt *, bool > create_address_of( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template< clang::BinaryOperatorKind Kind > + std::pair< clang::Stmt *, bool > create_binary_operation( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template< clang::UnaryOperatorKind Kind > + std::pair< clang::Stmt *, bool > create_unary_operation( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + Program &get_program(void) const { return program.get(); } + + clang::Sema &get_sema(void) const { return ci.get().getSema(); } + + std::reference_wrapper< Program > program; + std::reference_wrapper< clang::CompilerInstance > ci; + + std::string outfile; + std::unique_ptr< CodeGenerator > codegen; + std::unique_ptr< TypeBuilder > type_builder; + + std::unordered_map< std::string, clang::FunctionDecl * > function_declarations; + + /* Map of basic block label decls and stmt for creating branch instructions */ + std::unordered_map< std::string, clang::LabelDecl * > basic_block_labels; + + std::unordered_map< std::string, clang::Stmt * > function_operation_stmts; + std::unordered_map< std::string, clang::VarDecl * > local_variable_declarations; + std::unordered_map< std::string, clang::VarDecl * > global_variable_declarations; + + std::unordered_map< std::string, std::vector< clang::Stmt * > > basic_block_stmts; + }; + +} // namespace patchestry::ast diff --git a/include/patchestry/AST/Codegen.hpp b/include/patchestry/AST/Codegen.hpp new file mode 100644 index 0000000..d8528c4 --- /dev/null +++ b/include/patchestry/AST/Codegen.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace llvm { + class raw_fd_ostream; +} + +namespace patchestry::ast { + class CodeGenerator + { + public: + explicit CodeGenerator(clang::CompilerInstance &ci) : opts(vast::cc::options(ci)) {} + + CodeGenerator(const CodeGenerator &) = delete; + CodeGenerator &operator=(const CodeGenerator &) = delete; + CodeGenerator(CodeGenerator &&) noexcept = delete; + CodeGenerator &operator=(CodeGenerator &&) noexcept = delete; + + virtual ~CodeGenerator() {} + + void generate_source_ir(clang::ASTContext &ctx, llvm::raw_fd_ostream &os); + + private: + vast::cc::action_options opts; + }; + +} // namespace patchestry::ast diff --git a/include/patchestry/AST/TypeBuilder.hpp b/include/patchestry/AST/TypeBuilder.hpp new file mode 100644 index 0000000..13c9b75 --- /dev/null +++ b/include/patchestry/AST/TypeBuilder.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace patchestry::ast { + using namespace patchestry::ghidra; + + using ASTTypeMap = std::unordered_map< std::string, clang::QualType >; + + class TypeBuilder + { + public: + explicit TypeBuilder(clang::ASTContext &ctx) : context(ctx), serialized_types({}) {} + + TypeBuilder &operator=(const TypeBuilder &) = delete; + TypeBuilder(const TypeBuilder &) = delete; + TypeBuilder &operator=(const TypeBuilder &&) = delete; + TypeBuilder(const TypeBuilder &&) = delete; + + virtual ~TypeBuilder() = default; + + ASTTypeMap &get_serialized_types(void) { return serialized_types; } + + void create_types(clang::ASTContext &ctx, TypeMap &lifted_types); + + private: + clang::QualType + create_type(clang::ASTContext &ctx, const std::shared_ptr< VarnodeType > &vnode_type); + + clang::QualType + create_typedef_type(clang::ASTContext &ctx, const TypedefType &typedef_type); + + clang::QualType + create_pointer_type(clang::ASTContext &ctx, const PointerType &pointer_type); + + clang::QualType create_array_type(clang::ASTContext &ctx, const ArrayType &array_type); + + clang::QualType + create_composite_type(clang::ASTContext &ctx, const VarnodeType &composite_type); + + clang::QualType + create_undefined_type(clang::ASTContext &ctx, const UndefinedType &undefined_type); + + void create_record_definition( + clang::ASTContext &ctx, const CompositeType &varnode, clang::Decl *prev_decl, + const ASTTypeMap &clang_types + ); + + clang::QualType create_enum_type(clang::ASTContext &ctx, const EnumType &enum_type); + + clang::ASTContext &get_context(void) { return context.get(); } + + std::unordered_map< std::string, clang::Decl * > missing_type_definition; + + std::reference_wrapper< clang::ASTContext > context; + ASTTypeMap serialized_types; + }; +} // namespace patchestry::ast diff --git a/include/patchestry/AST/Utils.hpp b/include/patchestry/AST/Utils.hpp new file mode 100644 index 0000000..7baaf30 --- /dev/null +++ b/include/patchestry/AST/Utils.hpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace patchestry::ast { + clang::SourceLocation source_location_from_key(clang::ASTContext &ctx, std::string key); + + clang::QualType get_type_for_size( + clang::ASTContext &ctx, unsigned bit_size, bool is_signed, bool is_integer + ); + + std::string label_name_from_key(std::string key); + +} // namespace patchestry::ast diff --git a/include/patchestry/CMakeLists.txt b/include/patchestry/CMakeLists.txt index a8adc4a..f8ed379 100644 --- a/include/patchestry/CMakeLists.txt +++ b/include/patchestry/CMakeLists.txt @@ -1,5 +1,6 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_subdirectory(Dialect) \ No newline at end of file diff --git a/include/patchestry/Dialect/CMakeLists.txt b/include/patchestry/Dialect/CMakeLists.txt index 52868d9..f89e135 100644 --- a/include/patchestry/Dialect/CMakeLists.txt +++ b/include/patchestry/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_subdirectory(Pcode) diff --git a/include/patchestry/Dialect/Pcode/CMakeLists.txt b/include/patchestry/Dialect/Pcode/CMakeLists.txt index 84164c8..5050db0 100644 --- a/include/patchestry/Dialect/Pcode/CMakeLists.txt +++ b/include/patchestry/Dialect/Pcode/CMakeLists.txt @@ -1,6 +1,7 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_mlir_dialect(Pcode pc) diff --git a/include/patchestry/Dialect/Pcode/Deserialize.hpp b/include/patchestry/Dialect/Pcode/Deserialize.hpp index a670053..2156882 100644 --- a/include/patchestry/Dialect/Pcode/Deserialize.hpp +++ b/include/patchestry/Dialect/Pcode/Deserialize.hpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. @@ -10,32 +9,41 @@ #include +#include #include #include -#include -namespace patchestry::pc -{ +namespace patchestry::pc { + + struct program; + struct function; + struct basic_block; + struct instruction; + struct pcode; + using json_arr = llvm::json::Array; using json_obj = llvm::json::Object; using json_val = llvm::json::Value; - struct deserializer { + struct deserializer + { mlir_builder bld; - explicit deserializer(mlir::ModuleOp mod) - : bld(mod) - { + explicit deserializer(mlir::ModuleOp mod) : bld(mod) { assert(mod->getNumRegions() > 0 && "Module has no regions."); auto ® = mod->getRegion(0); assert(reg.hasOneBlock() && "Region has unexpected blocks."); bld.setInsertionPointToStart(&*reg.begin()); } - void process(const json_obj &json); - void process_function(const json_obj &json); - void process_block(const json_obj &json); - void process_instruction(const json_obj &json); + void process(const program &prog); + void process_function(const function &func); + void process_block(const basic_block &block); + void process_instruction(const instruction &inst); + void process_pcode(const pcode &code); + + mlir_operation create_int_const(uint32_t offset, uint32_t size); + mlir_operation create_varnode(std::string type, uint32_t offset, uint32_t size); }; mlir::OwningOpRef< mlir::ModuleOp > deserialize(const json_obj &json, mcontext_t *mctx); diff --git a/include/patchestry/Dialect/Pcode/Json.hpp b/include/patchestry/Dialect/Pcode/Json.hpp new file mode 100644 index 0000000..561c5fd --- /dev/null +++ b/include/patchestry/Dialect/Pcode/Json.hpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include "llvm/Support/JSON.h" +#include +#include +#include + +namespace patchestry::pc { + + struct pcode + { + std::string mnemonic; + + struct + { + std::string type; + std::optional< int64_t > offset; + std::optional< int64_t > size; + } output; + + struct input + { + std::string type; + std::optional< int64_t > offset; + std::optional< int64_t > size; + }; + + std::vector< input > inputs; + }; + + struct instruction + { + std::string mnemonic; + std::string address; + std::vector< pcode > pcodes; + }; + + struct basic_block + { + std::string label; + std::vector< instruction > instructions; + }; + + struct function + { + std::string name; + std::vector< basic_block > basic_blocks; + }; + + struct program + { + std::string arch; + std::string os; + std::vector< function > functions; + }; + + class json_parser + { + public: + std::optional< program > parse_program(const llvm::json::Object &root); + + private: + // Function to parse Pcode + std::optional< pcode > parse_pcode(const llvm::json::Object &pcode_obj); + + // Function to parse Instructions + std::optional< instruction > parse_instruction(const llvm::json::Object &inst_obj); + + // Function to parse Basic Blocks + std::optional< basic_block > parse_basic_block(const llvm::json::Object &block_obj); + + // Function to parse Functions + std::optional< function > parse_function(const llvm::json::Object &func_obj); + }; + +} // namespace patchestry::pc diff --git a/include/patchestry/Dialect/Pcode/Pcode.hpp b/include/patchestry/Dialect/Pcode/Pcode.hpp new file mode 100644 index 0000000..215340c --- /dev/null +++ b/include/patchestry/Dialect/Pcode/Pcode.hpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "PcodeDef.h" + +namespace patchestry::pc { + +enum class PCodeMnemonic { +#define X(name) name, + PCODE_MNEMONIC_LIST +#undef X + UNKNOWN +}; + +enum class PCodeVarnodeType { +#define X(name) name##_, + PCODE_VARNODE_TYPE +#undef X + UNKNOWN +}; + +template +struct PCodeStringMapper { + std::array, N> mappings; + + //Convert enum to string + constexpr std::string_view to_string(EnumType val) const { + for (const auto& [pcode, str] : mappings) { + if (pcode == val) { + return str; + } + } + return "UNKNOWN"; + } + + constexpr EnumType from_string(std::string_view s) const { + for (const auto& [val, str] : mappings) { + if (str == s) { + return val; + } + } + return EnumType::UNKNOWN; + } +}; + +// Calculate the number of mnemonics +constexpr size_t NumPCodeMnemonics = []() constexpr { + size_t count = 0; +#define X(name) ++count; + PCODE_MNEMONIC_LIST +#undef X + return count; +}(); + +constexpr size_t NumVarNodeType = []() constexpr { + size_t count = 0; +#define X(name) ++count; + PCODE_VARNODE_TYPE +#undef X + return count; +}(); + +// Instantiate the EnumStringMapper for PCodeMnemonic +constexpr PCodeStringMapper PCodeMnemonicMapper{{ +#define X(name) std::pair{PCodeMnemonic::name, #name}, + PCODE_MNEMONIC_LIST +#undef X +}}; + +constexpr PCodeStringMapper PCodeVarNodeMapper{{ +#define X(name) std::pair{PCodeVarnodeType::name##_, #name}, + PCODE_VARNODE_TYPE +#undef X +}}; + +constexpr std::string_view to_string(PCodeMnemonic mnemonic) { + return PCodeMnemonicMapper.to_string(mnemonic); +} + +constexpr PCodeMnemonic from_string(llvm::StringRef mnemonic_str) { + return PCodeMnemonicMapper.from_string(mnemonic_str); +} + +constexpr std::string_view varnode_to_string(PCodeVarnodeType ty) { + return PCodeVarNodeMapper.to_string(ty); +} + +constexpr PCodeVarnodeType varnode_from_string(llvm::StringRef ty_str) { + return PCodeVarNodeMapper.from_string(ty_str); +} + +} \ No newline at end of file diff --git a/include/patchestry/Dialect/Pcode/Pcode.td b/include/patchestry/Dialect/Pcode/Pcode.td index 09095db..8465d30 100644 --- a/include/patchestry/Dialect/Pcode/Pcode.td +++ b/include/patchestry/Dialect/Pcode/Pcode.td @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/include/patchestry/Dialect/Pcode/PcodeDef.h b/include/patchestry/Dialect/Pcode/PcodeDef.h new file mode 100644 index 0000000..bffb863 --- /dev/null +++ b/include/patchestry/Dialect/Pcode/PcodeDef.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + + // Definition of ghidra pcode mnemonics + + #pragma once + + #define PCODE_MNEMONIC_LIST \ + X(COPY) \ + X(LOAD) \ + X(STORE) \ + X(BRANCH) \ + X(CBRANCH) \ + X(BRANCHIND) \ + X(CALL) \ + X(CALLIND) \ + X(USERDEFINED) \ + X(RETURN) \ + X(PIECE) \ + X(SUBPIECE) \ + X(INT_EQUAL) \ + X(INT_NOTEQUAL) \ + X(INT_LESS) \ + X(INT_SLESS) \ + X(INT_LESSEQUAL) \ + X(INT_SLESSEQUAL) \ + X(INT_ZEXT) \ + X(INT_SEXT) \ + X(INT_ADD) \ + X(INT_SUB) \ + X(INT_CARRY) \ + X(INT_SCARRY) \ + X(INT_SBORROW) \ + X(INT_2COMP) \ + X(INT_NEGATE) \ + X(INT_XOR) \ + X(INT_AND) \ + X(INT_OR) \ + X(INT_LEFT) \ + X(INT_RIGHT) \ + X(INT_SRIGHT) \ + X(INT_MULT) \ + X(INT_DIV) \ + X(INT_REM) \ + X(INT_SDIV) \ + X(INT_SREM) \ + X(BOOL_NEGATE) \ + X(BOOL_OR) \ + X(FLOAT_EQUAL) \ + X(FLOAT_NOTEQUAL) \ + X(FLOAT_LESS) \ + X(FLOAT_LESSEQUAL) \ + X(FLOAT_ADD) \ + X(FLOAT_SUB) \ + X(FLOAT_MULT) \ + X(FLOAT_DIV) \ + X(FLOAT_NEG) \ + X(FLOAT_ABS) \ + X(FLOAT_SQRT) \ + X(FLOAT_CEIL) \ + X(FLOAT_FLOOR) \ + X(FLOAT_ROUND) \ + X(FLOAT_NAN) \ + X(INT2FLOAT) \ + X(FLOAT2FLOAT) \ + X(TRUNC) + +#define PCODE_VARNODE_TYPE \ + X(unique) \ + X(const) \ + X(register) \ + X(ram) \ No newline at end of file diff --git a/include/patchestry/Dialect/Pcode/PcodeDialect.hpp b/include/patchestry/Dialect/Pcode/PcodeDialect.hpp index 5a84924..c130239 100644 --- a/include/patchestry/Dialect/Pcode/PcodeDialect.hpp +++ b/include/patchestry/Dialect/Pcode/PcodeDialect.hpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/include/patchestry/Dialect/Pcode/PcodeOps.hpp b/include/patchestry/Dialect/Pcode/PcodeOps.hpp index a3e9c36..34061d5 100644 --- a/include/patchestry/Dialect/Pcode/PcodeOps.hpp +++ b/include/patchestry/Dialect/Pcode/PcodeOps.hpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/include/patchestry/Dialect/Pcode/PcodeOps.td b/include/patchestry/Dialect/Pcode/PcodeOps.td index 0056be8..5779ac2 100644 --- a/include/patchestry/Dialect/Pcode/PcodeOps.td +++ b/include/patchestry/Dialect/Pcode/PcodeOps.td @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. @@ -35,7 +34,7 @@ def Pcode_BlockOp } def Pcode_InstOp - : Pcode_Op< "instruction" > + : Pcode_Op< "instruction", [NoTerminator] > , Arguments<( ins StrAttr:$inst_mnemonic )> { let regions = (region SizedRegion<1>:$semantics); diff --git a/include/patchestry/Dialect/Pcode/PcodeTypes.hpp b/include/patchestry/Dialect/Pcode/PcodeTypes.hpp index ddfebcc..197c41d 100644 --- a/include/patchestry/Dialect/Pcode/PcodeTypes.hpp +++ b/include/patchestry/Dialect/Pcode/PcodeTypes.hpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/include/patchestry/Dialect/Pcode/PcodeTypes.td b/include/patchestry/Dialect/Pcode/PcodeTypes.td index e387fde..2519d65 100644 --- a/include/patchestry/Dialect/Pcode/PcodeTypes.td +++ b/include/patchestry/Dialect/Pcode/PcodeTypes.td @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. @@ -17,6 +16,7 @@ class Pcode_Type< string type_name, string _mnemonic, list< Trait > traits = [] let mnemonic = _mnemonic; } +def PCode_ConstType : Pcode_Type< "Const", "const" >; def Pcode_RegType : Pcode_Type< "Reg", "reg" >; def Pcode_MemType : Pcode_Type< "Mem", "mem" >; def Pcode_VarType : Pcode_Type< "Var", "var" >; diff --git a/include/patchestry/Ghidra/JsonDeserialize.hpp b/include/patchestry/Ghidra/JsonDeserialize.hpp new file mode 100644 index 0000000..ac0dfa4 --- /dev/null +++ b/include/patchestry/Ghidra/JsonDeserialize.hpp @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace patchestry::ghidra { + + class JsonParser + { + public: + std::optional< Program > deserialize_program(const JsonObject &root); + + private: + // Create varnode type for each type object + std::shared_ptr< VarnodeType > create_vnode_type(const JsonObject &type_obj); + + // Process types from Json object + void deserialize_types(const JsonObject &type_obj, TypeMap &serialized_types); + + void deserialize_buildin( + BuiltinType &varnode, const JsonObject &builtin_obj, const TypeMap &serialized_types + ); + + void deserialize_array( + ArrayType &varnode, const JsonObject *array_obj, const TypeMap &serialized_types + ); + + void deserialize_pointer( + PointerType &varnode, const JsonObject &pointer_obj, const TypeMap &serialized_types + ); + + void deserialize_typedef( + TypedefType &varnode, const JsonObject &typedef_obj, const TypeMap &serialized_types + ); + + void deserialize_composite( + CompositeType &varnode, const JsonObject &composite_obj, + const TypeMap &serialized_types + ); + + void deserialize_enum( + EnumType &varnode, const JsonObject &enum_obj, const TypeMap &serialized_types + ); + + void deserialize_function_type( + FunctionType &varnode, const JsonObject &func_obj, const TypeMap &serialized_types + ); + + void deserialize_undefined_type( + UndefinedType &varnode, const JsonObject &undef_obj, const TypeMap &serialized_types + ); + + void deserialize_call_operation(const JsonObject &call_obj, Operation &op); + + void deserialize_branch_operation(const JsonObject &branch_obj, Operation &op); + + std::optional< Varnode > create_varnode(const JsonObject &var_obj); + + std::optional< Function > create_function(const JsonObject &func_obj); + + std::optional< FunctionPrototype > create_function_prototype(const JsonObject &proto_obj + ); + + // Function to parse Basic Blocks + std::optional< BasicBlock > + create_basic_block(const std::string &block_key, const JsonObject &block_obj); + + // Function to parse Pcode + std::optional< Operation > create_operation(const JsonObject &pcode_obj); + + // Deserialize functions + void deserialize_functions( + const JsonObject &function_array, FunctionMap &serialized_functions + ); + + void deserialize_blocks( + const JsonObject &blocks_array, BasicBlockMap &serialized_blocks, + std::string &entry_block + ); + + // Deserialize globals + void + deserialize_globals(const JsonObject &global_array, VariableMap &serialized_globals); + }; + +} // namespace patchestry::ghidra diff --git a/include/patchestry/Ghidra/Pcode.def b/include/patchestry/Ghidra/Pcode.def new file mode 100644 index 0000000..86fc974 --- /dev/null +++ b/include/patchestry/Ghidra/Pcode.def @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + + // Definition of ghidra pcode mnemonics + +#pragma once + +#define PCODE_MNEMONICS \ + X(COPY) \ + X(LOAD) \ + X(STORE) \ + X(BRANCH) \ + X(CBRANCH) \ + X(BRANCHIND) \ + X(CALL) \ + X(CALLIND) \ + X(USERDEFINED) \ + X(RETURN) \ + X(PIECE) \ + X(SUBPIECE) \ + X(INT_EQUAL) \ + X(INT_NOTEQUAL) \ + X(INT_LESS) \ + X(INT_SLESS) \ + X(INT_LESSEQUAL) \ + X(INT_SLESSEQUAL) \ + X(INT_ZEXT) \ + X(INT_SEXT) \ + X(INT_ADD) \ + X(INT_SUB) \ + X(INT_CARRY) \ + X(INT_SCARRY) \ + X(INT_SBORROW) \ + X(INT_2COMP) \ + X(INT_NEGATE) \ + X(INT_XOR) \ + X(INT_AND) \ + X(INT_OR) \ + X(INT_LEFT) \ + X(INT_RIGHT) \ + X(INT_SRIGHT) \ + X(INT_MULT) \ + X(INT_DIV) \ + X(INT_REM) \ + X(INT_SDIV) \ + X(INT_SREM) \ + X(BOOL_NEGATE) \ + X(BOOL_OR) \ + X(BOOL_AND) \ + X(FLOAT_EQUAL) \ + X(FLOAT_NOTEQUAL) \ + X(FLOAT_LESS) \ + X(FLOAT_LESSEQUAL) \ + X(FLOAT_ADD) \ + X(FLOAT_SUB) \ + X(FLOAT_MULT) \ + X(FLOAT_DIV) \ + X(FLOAT_NEG) \ + X(FLOAT_ABS) \ + X(FLOAT_SQRT) \ + X(FLOAT_CEIL) \ + X(FLOAT_FLOOR) \ + X(FLOAT_ROUND) \ + X(FLOAT_NAN) \ + X(INT2FLOAT) \ + X(FLOAT2FLOAT) \ + X(TRUNC) \ + X(PTRSUB) \ + X(PTRADD) \ + X(CAST) \ + X(DECLARE_PARAMETER) \ + X(DECLARE_LOCAL) \ + X(DECLARE_TEMPORARY) \ + X(ADDRESS_OF) diff --git a/include/patchestry/Ghidra/Pcode.hpp b/include/patchestry/Ghidra/Pcode.hpp new file mode 100644 index 0000000..e42d7b8 --- /dev/null +++ b/include/patchestry/Ghidra/Pcode.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include "Pcode.def" + +namespace patchestry::ghidra { + + enum class Mnemonic : int { +#define X(name) OP_##name, // NOLINT(cppcoreguidelines-macro-usage) + PCODE_MNEMONICS +#undef X + OP_UNKNOWN + }; + + template< typename EnumType, size_t N > + struct PCodeStringMapper + { + std::array< std::pair< EnumType, std::string_view >, N > mappings; + + // Convert enum to string + constexpr std::string_view to_string(EnumType val) const { + for (const auto &[pcode, str] : mappings) { + if (pcode == val) { + return str; + } + } + return "UNKNOWN"; + } + + constexpr EnumType from_string(std::string_view s) const { + for (const auto &[val, str] : mappings) { + if (str == s) { + return val; + } + } + return EnumType::OP_UNKNOWN; + } + }; + + // Calculate the number of mnemonics + constexpr size_t num_mnemonics = []() constexpr { + size_t count = 0; +#define X(name) ++count; // NOLINT(cppcoreguidelines-macro-usage) + PCODE_MNEMONICS +#undef X + return count; + }(); + + // Instantiate the EnumStringMapper for PCodeMnemonic + constexpr PCodeStringMapper< Mnemonic, num_mnemonics > mnemonic_mapper{ { +#define X(name) std::pair{ Mnemonic::OP_##name, #name }, + PCODE_MNEMONICS +#undef X + } }; + + constexpr std::string_view to_string(Mnemonic mnemonic) { + return mnemonic_mapper.to_string(mnemonic); + } + + constexpr Mnemonic from_string(const std::string_view &mnemonic_str) { + return mnemonic_mapper.from_string(mnemonic_str); + } + +} // namespace patchestry::ghidra diff --git a/include/patchestry/Ghidra/PcodeOperations.hpp b/include/patchestry/Ghidra/PcodeOperations.hpp new file mode 100644 index 0000000..58e020e --- /dev/null +++ b/include/patchestry/Ghidra/PcodeOperations.hpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +namespace patchestry::ghidra { + struct Varnode; + struct Variable; + struct Operation; + struct BasicBlock; + struct FunctionPrototype; + struct Function; + struct Program; + +} // namespace patchestry::ghidra + +namespace patchestry::ghidra { + using TypeMap = std::unordered_map< std::string, std::shared_ptr< VarnodeType > >; + + using FunctionMap = std::unordered_map< std::string, Function >; + + using BasicBlockMap = std::unordered_map< std::string, BasicBlock >; + + using VariableMap = std::unordered_map< std::string, Variable >; + + struct Varnode + { + enum Kind { + VARNODE_UNKNOWN = 0, + VARNODE_GLOBAL, + VARNODE_LOCAL, + VARNODE_PARAM, + VARNODE_FUNCTION, + VARNODE_TEMPORARY, + VARNODE_CONSTANT + }; + + static Varnode::Kind convertToKind(const std::string &kdd) { + static const std::unordered_map< std::string, Varnode::Kind > kind_map = { + { "unknown", VARNODE_UNKNOWN}, + { "global", VARNODE_GLOBAL}, + { "local", VARNODE_LOCAL}, + {"parameter", VARNODE_PARAM}, + { "function", VARNODE_FUNCTION}, + {"temporary", VARNODE_TEMPORARY}, + { "constant", VARNODE_CONSTANT} + }; + + // if kind is not present in the map, return varnode_unknown + auto iter = kind_map.find(kdd); + return iter != kind_map.end() ? iter->second : VARNODE_UNKNOWN; + } + + Kind kind; + uint32_t size; + std::string type_key; + + std::optional< std::string > operation; + std::optional< std::string > function; + std::optional< uint32_t > value; + std::optional< std::string > global; + }; + + struct Variable + { + std::string name; + std::string type; + uint32_t size; + std::string key; + }; + + struct OperationTarget + + { + Varnode::Kind kind; + std::optional< std::string > function; + std::optional< std::string > operation; + bool is_noreturn; + }; + + struct Operation + { + Mnemonic mnemonic; + std::optional< Varnode > output; + std::vector< Varnode > inputs; + + std::string key; + std::string parent_block_key; + + // Parameter/variable declaration + std::optional< std::string > name; + std::optional< std::string > type; + std::optional< uint32_t > index; + + // Call Operation + std::optional< OperationTarget > target; + + // Branch Operation + std::optional< std::string > target_block; + + // Cond Branch + std::optional< std::string > taken_block; + std::optional< std::string > not_taken_block; + std::optional< Varnode > condition; + std::optional< std::string > address; + }; + + struct BasicBlock + { + std::shared_ptr< BasicBlock > parent; + std::string key; + std::unordered_map< std::string, Operation > operations; + std::vector< std::string > ordered_operations; + bool is_entry_block; + }; + + struct FunctionPrototype + { + std::vector< std::string > parameters; + std::string rttype_key; + bool is_variadic; + bool is_noreturn; + }; + + struct Function + { + std::string name; + FunctionPrototype prototype; + std::string key; + std::string entry_block; + std::unordered_map< std::string, BasicBlock > basic_blocks; + }; + + struct Program + { + std::string arch; + std::string format; + std::unordered_map< std::string, Function > serialized_functions; + std::unordered_map< std::string, std::shared_ptr< VarnodeType > > serialized_types; + std::unordered_map< std::string, Variable > serialized_globals; + }; +} // namespace patchestry::ghidra diff --git a/include/patchestry/Ghidra/PcodeTranslation.hpp b/include/patchestry/Ghidra/PcodeTranslation.hpp index 173a933..ae0dc75 100644 --- a/include/patchestry/Ghidra/PcodeTranslation.hpp +++ b/include/patchestry/Ghidra/PcodeTranslation.hpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/include/patchestry/Ghidra/PcodeTypes.hpp b/include/patchestry/Ghidra/PcodeTypes.hpp new file mode 100644 index 0000000..f03c946 --- /dev/null +++ b/include/patchestry/Ghidra/PcodeTypes.hpp @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include + +#include "llvm/Support/JSON.h" + +namespace patchestry::ghidra { + using JsonArray = llvm::json::Array; + using JsonObject = llvm::json::Object; + using JsonValue = llvm::json::Value; + + struct VarnodeType + { + enum Kind { + VT_INVALID = 0, + VT_BOOLEAN, + VT_INTEGER, + VT_FLOAT, + VT_CHAR, + VT_POINTER, + VT_FUNCTION, + VT_ARRAY, + VT_STRUCT, + VT_UNION, + VT_ENUM, + VT_TYPEDEF, + VT_UNDEFINED, + VT_VOID + }; + + static VarnodeType::Kind convertToKind(const std::string &kind) { + static const std::unordered_map< std::string, VarnodeType::Kind > kind_map = { + { "bool", VT_BOOLEAN }, + { "integer", VT_INTEGER }, + { "float", VT_FLOAT }, + { "pointer", VT_POINTER }, + { "function", VT_FUNCTION }, + { "array", VT_ARRAY }, + { "struct", VT_STRUCT }, + { "union", VT_UNION }, + { "enum", VT_ENUM }, + { "typedef", VT_TYPEDEF }, + { "undefined", VT_UNDEFINED }, + { "void", VT_VOID } + }; + + // if kind is not present in the map, return vt_invalid + auto iter = kind_map.find(kind); + return iter != kind_map.end() ? iter->second : VT_INVALID; + } + + VarnodeType() = default; + + VarnodeType(std::string &name, Kind kind, uint32_t size) + : kind(kind), size(size), key({}), name(name) {} + + VarnodeType(const VarnodeType &) = default; + VarnodeType &operator=(const VarnodeType &) = default; + VarnodeType(VarnodeType &&) noexcept = default; + VarnodeType &operator=(VarnodeType &&) noexcept = default; + virtual ~VarnodeType() = default; + + void set_key(std::string &key) { this->key = key; } + + Kind kind{}; + uint32_t size{}; + std::string key; + std::string name; + }; + + // BuiltinType + struct BuiltinType : public VarnodeType + { + BuiltinType(std::string &name, Kind kind, uint32_t size) + : VarnodeType(name, kind, size) {} + }; + + // ArrayType + struct ArrayType : public VarnodeType + { + ArrayType(std::string name, Kind kind, uint32_t size) + : VarnodeType(name, kind, size), num_elements(0), element_type(nullptr) {} + + uint32_t get_element_count(void) const { return num_elements; } + + std::shared_ptr< VarnodeType > get_element_type(void) const { return element_type; } + + void set_element_type(const std::shared_ptr< VarnodeType > &element) { + element_type = element; + } + + void set_element_count(uint32_t count) { num_elements = count; } + + private: + uint32_t num_elements; + std::shared_ptr< VarnodeType > element_type; + }; + + // PointerType + struct PointerType : public VarnodeType + { + PointerType( + std::string name, Kind kind, uint32_t size, + std::shared_ptr< VarnodeType > pointee = nullptr + ) + : VarnodeType(name, kind, size), pointee_type(std::move(pointee)) {} + + std::shared_ptr< VarnodeType > get_pointee_type() const { return pointee_type; } + + void set_pointee_type(const VarnodeType &pointee) { + pointee_type = std::make_shared< VarnodeType >(pointee); + } + + void set_pointee_type(const std::shared_ptr< VarnodeType > &pointee) { + pointee_type = pointee; + } + + private: + std::shared_ptr< VarnodeType > pointee_type; + }; + + // TypedefType + struct TypedefType : public VarnodeType + { + TypedefType( + std::string name, Kind kind, uint32_t size, + std::shared_ptr< VarnodeType > base = nullptr + ) + : VarnodeType(name, kind, size), base_type(std::move(base)) {} + + std::shared_ptr< VarnodeType > get_base_type() const { return base_type; } + + void set_base_type(const std::shared_ptr< VarnodeType > &base) { base_type = base; } + + private: + std::shared_ptr< VarnodeType > base_type; + }; + + // UndefinedType + struct UndefinedType : public VarnodeType + { + UndefinedType(std::string name, Kind kind, uint32_t size) + : VarnodeType(name, kind, size) {} + }; + + // FunctionType + struct FunctionType : public VarnodeType + { + FunctionType(std::string name, Kind kind, uint32_t size) + : VarnodeType(name, kind, size) {} + }; + + // EnumType + struct EnumType : public VarnodeType + { + EnumType(std::string name, Kind kind, uint32_t size) : VarnodeType(name, kind, size) {} + }; + + // CompositeType + struct CompositeType : public VarnodeType + { + struct Component + { + std::string name; + uint32_t offset; + std::shared_ptr< VarnodeType > type; + }; + + CompositeType(std::string name, Kind kind, uint32_t size) + : VarnodeType(name, kind, size) {} + + void add_components(std::string &name, const VarnodeType &type, uint32_t offset) { + components.emplace_back( + Component(name, offset, std::make_shared< VarnodeType >(type)) + ); + } + + std::vector< Component > get_components(void) const { return components; } + + private: + std::vector< Component > components; + }; + +} // namespace patchestry::ghidra diff --git a/include/patchestry/Util/Common.hpp b/include/patchestry/Util/Common.hpp index eda37d9..a9e82cd 100644 --- a/include/patchestry/Util/Common.hpp +++ b/include/patchestry/Util/Common.hpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 23f3625..999a428 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,5 +1,6 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_subdirectory(patchestry) diff --git a/lib/patchestry/AST/ASTConsumer.cpp b/lib/patchestry/AST/ASTConsumer.cpp new file mode 100644 index 0000000..c60be08 --- /dev/null +++ b/lib/patchestry/AST/ASTConsumer.cpp @@ -0,0 +1,404 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace patchestry::ast { + + namespace { + + std::vector< std::string > __attribute__((unused)) + get_keys(const std::unordered_map< std::string, BasicBlock > &map) { + std::vector< std::string > keys; + keys.reserve(map.size()); + + for (const auto &[key, _] : map) { + keys.push_back(key); + } + + std::sort(keys.begin(), keys.end()); + return keys; + } + + std::vector< std::shared_ptr< Operation > > __attribute__((unused)) + get_parameter_operations(const Function &function) { + auto entry_block_key = function.entry_block; + if (entry_block_key.empty() && function.basic_blocks.empty()) { + return {}; + } + + auto iter = function.basic_blocks.find(entry_block_key); + if (iter == function.basic_blocks.end()) { + llvm::errs() << "Function entry block " << entry_block_key + << " not present in basic block list\n"; + assert(false); + return {}; + } + + std::vector< std::shared_ptr< Operation > > ops_vec; + auto entry_block = iter->second; + for (const auto &operation_key : entry_block.ordered_operations) { + auto iter = entry_block.operations.find(operation_key); + if (iter != entry_block.operations.end()) { + auto operation = iter->second; + if (operation.mnemonic == Mnemonic::OP_DECLARE_PARAMETER) { + ops_vec.push_back(std::make_shared< Operation >(operation)); + } + } + } + return ops_vec; + } + } // namespace + + void PcodeASTConsumer::HandleTranslationUnit(clang::ASTContext &ctx) { + if (!get_program().serialized_types.empty()) { + type_builder->create_types(ctx, get_program().serialized_types); + } + + if (!get_program().serialized_globals.empty()) { + create_globals(ctx, get_program().serialized_globals); + } + + if (!get_program().serialized_functions.empty()) { + create_functions( + ctx, get_program().serialized_functions, get_program().serialized_types + ); + } + + std::error_code ec; + auto out = + std::make_unique< llvm::raw_fd_ostream >(outfile, ec, llvm::sys::fs::OF_Text); + + llvm::errs() << "Print AST dump\n"; + ctx.getTranslationUnitDecl()->dumpColor(); + + ctx.getTranslationUnitDecl()->print( + *llvm::dyn_cast< llvm::raw_ostream >(out), ctx.getPrintingPolicy(), 0 + ); + + llvm::errs() << "Generate mlir\n"; + llvm::raw_fd_ostream file_os(outfile + ".mlir", ec); + codegen->generate_source_ir(ctx, file_os); + } + + void PcodeASTConsumer::set_sema_context(clang::DeclContext *dc) { + get_sema().CurContext = dc; + } + + void PcodeASTConsumer::write_to_file(void) {} + + clang::QualType PcodeASTConsumer::create_function_prototype( + clang::ASTContext &ctx, const FunctionPrototype &proto + ) { + auto return_key = proto.rttype_key; + auto iter = type_builder->get_serialized_types().find(return_key); + if (iter == type_builder->get_serialized_types().end()) { + llvm::errs() << "Function return type is not found\n"; + assert(false); + return clang::QualType(); + } + auto rttype = iter->second; + + std::vector< clang::QualType > args_vec; + for (const auto ¶m : proto.parameters) { + auto param_iter = type_builder->get_serialized_types().find(param); + if (param_iter == type_builder->get_serialized_types().end()) { + assert(false); + } + args_vec.push_back(param_iter->second); + } + clang::FunctionProtoType::ExtProtoInfo proto_info; + proto_info.Variadic = proto.is_variadic; + if (proto.is_noreturn) { + proto_info.ExceptionSpec.Type = clang::EST_DependentNoexcept; + } + + return ctx.getFunctionType(rttype, args_vec, proto_info); + } + + std::vector< clang::ParmVarDecl * > PcodeASTConsumer::create_default_paramaters( + clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const FunctionPrototype &proto + ) { + if (proto.parameters.empty()) { + return {}; + } + + std::vector< clang::ParmVarDecl * > params; + int index = 0; + for (const auto ¶m_key : proto.parameters) { + auto param_type = type_builder->get_serialized_types().at(param_key); + std::stringstream ss; + ss << "param_" << index++; + auto param_name = ss.str(); + auto *param_decl = clang::ParmVarDecl::Create( + ctx, func_decl, clang::SourceLocation(), clang::SourceLocation(), + &ctx.Idents.get(param_name), param_type, + ctx.getTrivialTypeSourceInfo(param_type, clang::SourceLocation()), + clang::SC_None, nullptr + ); + params.emplace_back(param_decl); + } + + return params; + } + + void PcodeASTConsumer::create_functions( + clang::ASTContext &ctx, FunctionMap &serialized_functions, TypeMap &serialized_types + ) { + for (const auto &[key, function] : serialized_functions) { + auto *function_decl = create_function_declaration(ctx, function); + if (function_decl != nullptr) { + function_declarations.emplace(key, function_decl); + } + + // TODO: Create global variables + } + + // Create definition for declared functions + for (const auto &[key, decl] : function_declarations) { + auto iter = serialized_functions.find(key); + assert(iter != serialized_functions.end()); + const auto &parsed_function = iter->second; + auto *func_def = create_function_definition(ctx, parsed_function); + if (func_def != nullptr) { + func_def->setPreviousDecl(decl); + } + } + (void) serialized_types; + } + + clang::FunctionDecl *PcodeASTConsumer::create_function_declaration( + clang::ASTContext &ctx, const Function &function, bool is_definition + ) { + if (function.name.empty()) { + llvm::errs() << "Function name is empty. function key " << function.key << "\n"; + return nullptr; + } + + auto function_type = create_function_prototype(ctx, function.prototype); + auto *func_decl = clang::FunctionDecl::Create( + ctx, ctx.getTranslationUnitDecl(), source_location_from_key(ctx, function.key), + source_location_from_key(ctx, function.key), &ctx.Idents.get(function.name), + function_type, nullptr, clang::SC_None + ); + + // Add function declaration to tralsation unit + func_decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(func_decl); + + // Set asm label attribute to symbol name + if (!is_definition) { + auto *asm_attr = clang::AsmLabelAttr::Create( + ctx, function.name, true, func_decl->getSourceRange() + ); + if (asm_attr != nullptr) { + func_decl->addAttr(asm_attr); + } + } + + // Create parameters for function declarations; + auto num_params = function.prototype.parameters.size(); + auto parameter_operations = get_parameter_operations(function); + if (parameter_operations.size() == num_params) { + std::vector< clang::ParmVarDecl * > params; + for (const auto ¶m_op : parameter_operations) { + auto type_iter = type_builder->get_serialized_types().find(*param_op->type); + assert(type_iter != type_builder->get_serialized_types().end()); + + auto *param_decl = clang::ParmVarDecl::Create( + ctx, func_decl, source_location_from_key(ctx, param_op->key), + source_location_from_key(ctx, param_op->key), + &ctx.Idents.get(*param_op->name), type_iter->second, nullptr, + clang::SC_None, nullptr + ); + params.push_back(param_decl); + local_variable_declarations.emplace(param_op->key, param_decl); + } + + func_decl->setParams(params); + return func_decl; + } + + func_decl->setParams(create_default_paramaters(ctx, func_decl, function.prototype)); + return func_decl; + } + + clang::FunctionDecl *PcodeASTConsumer::create_function_definition( + clang::ASTContext &ctx, const Function &function + ) { + if (function.name.empty() || function.basic_blocks.empty()) { + return nullptr; + } + + function_operation_stmts.clear(); + local_variable_declarations.clear(); + basic_block_stmts.clear(); + + auto *func_def = create_function_declaration(ctx, function, true); + if (func_def != nullptr) { + set_sema_context(func_def); + auto body_vec = create_function_body(ctx, func_def, function); + func_def->setBody(clang::CompoundStmt::Create( + ctx, body_vec, clang::FPOptionsOverride(), clang::SourceLocation(), + clang::SourceLocation() + )); + } + + return func_def; + } + + std::vector< clang::Stmt * > PcodeASTConsumer::create_function_body( + clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function + ) { + if (function.basic_blocks.empty()) { + llvm::errs() << "Function " << function.name << " doesn't have body\n"; + return {}; + } + + // Create label decl for all basic blocks + create_label_for_basic_blocks(ctx, func_decl, function); + + std::vector< clang::Stmt * > stmts; + + // If function has entry block, create it first to ensure we have local variables and + // parameter variables declared + if (!function.entry_block.empty()) { + auto iter = function.basic_blocks.find(function.entry_block); + assert(iter != function.basic_blocks.end()); + auto entry_stmts = create_basic_block(ctx, function, iter->second); + stmts.insert(stmts.end(), entry_stmts.begin(), entry_stmts.end()); + } + + // get lexicographically sorted keys for basic blocks + auto block_keys = get_keys(function.basic_blocks); + for (const auto &block_key : block_keys) { + llvm::errs() << "Processing basic block with key " << block_key << "\n"; + const auto &bb = function.basic_blocks.at(block_key); + if (bb.is_entry_block) { + continue; + } + + auto block_stmts = create_basic_block(ctx, function, bb); + basic_block_stmts.emplace(block_key, block_stmts); + } + + for (auto &[key, block_stmts] : basic_block_stmts) { + if (!block_stmts.empty()) { + auto *label_stmt = new (ctx) clang::LabelStmt( + clang::SourceLocation(), basic_block_labels.at(key), block_stmts[0] + ); + // replace first stmt of block with label stmts + block_stmts[0] = label_stmt; + stmts.insert(stmts.end(), block_stmts.begin(), block_stmts.end()); + } + } + + return stmts; + } + + void PcodeASTConsumer::create_label_for_basic_blocks( + clang::ASTContext &ctx, clang::FunctionDecl *func_decl, const Function &function + ) { + if (function.basic_blocks.empty()) { + llvm::errs() << "Function " << function.name << " does not have any basic block\n"; + return; + } + + for (const auto &[key, block] : function.basic_blocks) { + // entry block is custom added to each function; we don't need to make labels for + // entry block; + if (block.is_entry_block) { + continue; + } + + auto *label_decl = clang::LabelDecl::Create( + ctx, func_decl, clang::SourceLocation(), + &ctx.Idents.get(label_name_from_key(key)) + ); + + label_decl->setDeclContext(func_decl); + if (clang::DeclContext *dc = label_decl->getLexicalDeclContext()) { + dc->addDecl(label_decl); + } + + basic_block_labels.emplace(key, label_decl); + } + } + + std::vector< clang::Stmt * > PcodeASTConsumer::create_basic_block( + clang::ASTContext &ctx, const Function &function, const BasicBlock &block + ) { + std::vector< clang::Stmt * > stmt_vec; + for (const auto &operation_key : block.ordered_operations) { + auto iter = block.operations.find(operation_key); + if (iter == block.operations.end()) { + assert(false); + continue; + } + auto operation = iter->second; + auto [stmt, should_merge_to_next] = create_operation(ctx, function, operation); + if (stmt != nullptr) { + function_operation_stmts.emplace(operation.key, stmt); + if (!should_merge_to_next) { + stmt_vec.push_back(stmt); + } + } + } + + return stmt_vec; + } + + void PcodeASTConsumer::create_globals( + clang::ASTContext &ctx, VariableMap &serialized_variables + ) { + for (auto &[key, variable] : serialized_variables) { + if (variable.name.empty() || variable.type.empty()) { + continue; + } + + auto var_type = type_builder->get_serialized_types().at(variable.type); + + auto *var_decl = clang::VarDecl::Create( + ctx, ctx.getTranslationUnitDecl(), clang::SourceLocation(), + clang::SourceLocation(), &ctx.Idents.get(variable.name), var_type, + ctx.getTrivialTypeSourceInfo(var_type), clang::SC_Static + ); + + var_decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(var_decl); + global_variable_declarations.emplace(variable.key, var_decl); + } + } + +} // namespace patchestry::ast diff --git a/lib/patchestry/AST/CMakeLists.txt b/lib/patchestry/AST/CMakeLists.txt new file mode 100644 index 0000000..60b79c8 --- /dev/null +++ b/lib/patchestry/AST/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2024, Trail of Bits, Inc. This source code is licensed +# in accordance with the terms specified in the LICENSE file found +# in the root directory of this source tree. + +add_library(patchestry_ast STATIC + ASTConsumer.cpp + Codegen.cpp + OperationBuilder.cpp + OperationStmt.cpp + TypeBuilder.cpp + Utils.cpp +) + +set(VAST_LIBS + VAST::VASTTargetLLVMIR + VAST::VASTToLLVMConversionPasses + VAST::VASTAliasTypeInterface + VAST::VASTElementTypeInterface + VAST::VASTCodeGen + VAST::VASTFrontend + VAST::VASTSymbolInterface + VAST::VASTSymbolTableInterface + VAST::VASTSymbolRefInterface + VAST::VASTTypeDefinitionInterface +) + +add_library(patchestry::ast ALIAS patchestry_ast) + +target_link_libraries(patchestry_ast + PUBLIC + clangFrontend + PRIVATE + LLVMSupport + patchestry_settings + ${VAST_LIBS} +) \ No newline at end of file diff --git a/lib/patchestry/AST/Codegen.cpp b/lib/patchestry/AST/Codegen.cpp new file mode 100644 index 0000000..5e6e5f3 --- /dev/null +++ b/lib/patchestry/AST/Codegen.cpp @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#define VAST_ENABLE_EXCEPTIONS +#include + +VAST_RELAX_WARNINGS +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +VAST_UNRELAX_WARNINGS + +#define GAP_ENABLE_COROUTINES + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace patchestry::ast { + + class MLIRRegistryInitializer + { + public: + explicit MLIRRegistryInitializer(mlir::DialectRegistry ®istry) { + ::vast::registerAllDialects(registry); + ::mlir::registerAllDialects(registry); + } + }; + + class MLIRInitializer + { + private: + MLIRInitializer(void) = delete; + + mlir::DialectRegistry registry; + MLIRRegistryInitializer registry_initializer; + mutable mlir::MLIRContext ctx; + + public: + explicit MLIRInitializer(int); + + inline mlir::MLIRContext &context(void) const noexcept { return ctx; } + + ~MLIRInitializer(void); + }; + + MLIRInitializer::MLIRInitializer(int) + : registry() + , registry_initializer(registry) + , ctx(registry, mlir::MLIRContext::Threading::ENABLED) { + ctx.disableMultithreading(); + ctx.loadAllAvailableDialects(); + ctx.enableMultithreading(); + } + + MLIRInitializer::~MLIRInitializer(void) { ctx.disableMultithreading(); } + + static const MLIRInitializer kMLIR(0); + + // Custom meta generator that disable all location setting to unknown + class MetaGenerator final : public vast::cg::meta_generator + { + // MLIR context for generating the MLIR location from source location + mlir::MLIRContext *mctx; + + mlir::Location unknown_location; + + public: + explicit MetaGenerator(mlir::MLIRContext &mctx_) + : mctx(&mctx_), unknown_location(mlir::UnknownLoc::get(mctx)) {} + + mlir::Location location(const clang::Decl *data) const override { + return location_impl(data); + } + + mlir::Location location(const clang::Expr *data) const override { + return location_impl(data); + } + + mlir::Location location(const clang::Stmt *data) const override { + return location_impl(data); + } + + private: + template< typename T > + mlir::Location location_impl(const T *data) const { + return unknown_location; + } + }; + + static std::optional< vast::owning_mlir_module_ref > + create_module(clang::ASTContext &ctx, vast::cc::action_options &opts) { + auto &mctx = kMLIR.context(); + auto bld = vast::cg::mk_codegen_builder(mctx); + auto mg = std::make_shared< vast::cg::default_meta_gen >(&ctx, &mctx); + auto sg = + std::make_shared< vast::cg::default_symbol_generator >(ctx.createMangleContext()); + auto cp = std::make_shared< vast::cg::default_policy >(opts); + using vast::cg::as_node; + using vast::cg::as_node_with_list_ref; + + auto visitors = std::make_shared< vast::cg::visitor_list >() + | as_node_with_list_ref< vast::cg::attr_visitor_proxy >() + | as_node< vast::cg::type_caching_proxy >() + | as_node_with_list_ref< vast::cg::default_visitor >(mctx, ctx, *bld, mg, sg, cp) + | as_node_with_list_ref< vast::cg::unsup_visitor >(mctx, *bld, mg) + | as_node< vast::cg::fallthrough_visitor >(); + + vast::cg::driver driver(ctx, mctx, std::move(bld), visitors); + driver.enable_verifier(true); + for (auto &decl : ctx.getTranslationUnitDecl()->noload_decls()) { + driver.emit(clang::dyn_cast< clang::Decl >(decl)); + } + + driver.finalize(); + return std::make_optional(driver.freeze()); + } + + void CodeGenerator::generate_source_ir(clang::ASTContext &ctx, llvm::raw_fd_ostream &os) { + auto mod = create_module(ctx, opts); + auto flags = mlir::OpPrintingFlags(); + flags.enableDebugInfo(true, false); + (*mod)->print(os, flags); + } +} // namespace patchestry::ast diff --git a/lib/patchestry/AST/OperationBuilder.cpp b/lib/patchestry/AST/OperationBuilder.cpp new file mode 100644 index 0000000..6f7138c --- /dev/null +++ b/lib/patchestry/AST/OperationBuilder.cpp @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace patchestry::ast { + + extern clang::QualType get_type_for_size( + clang::ASTContext &ctx, unsigned bit_size, bool is_signed, bool is_integer + ); + + std::optional< Operation > + operation_from_key(const Function &function, const std::string &lookup_key) { + if (function.basic_blocks.empty()) { + return std::nullopt; + } + + for (const auto &[_, block] : function.basic_blocks) { + for (const auto &[operation_key, operation] : block.operations) { + if (operation_key == lookup_key) { + return operation; + } + } + } + + assert(false); // assert if failed to find operation + return std::nullopt; + } + + clang::CallExpr *create_function_call(clang::ASTContext &ctx, clang::FunctionDecl *decl) { + auto *ref_expr = clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), decl, false, + clang::SourceLocation(), decl->getType(), clang::VK_LValue + ); + + return clang::CallExpr::Create( + ctx, ref_expr, {}, decl->getReturnType(), clang::VK_PRValue, + clang::SourceLocation(), clang::FPOptionsOverride() + ); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_operation( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic == Mnemonic::OP_UNKNOWN) { + llvm::errs() << "Operation with unknown mnemonic. operation key ( " << op.key + << " )\n"; + return std::make_pair(nullptr, true); + } + + switch (op.mnemonic) { + case Mnemonic::OP_COPY: + return create_copy(ctx, function, op); + case Mnemonic::OP_LOAD: + return create_load(ctx, function, op); + case Mnemonic::OP_STORE: + return create_store(ctx, function, op); + case Mnemonic::OP_BRANCH: + return create_branch(ctx, function, op); + case Mnemonic::OP_CBRANCH: + return create_cbranch(ctx, function, op); + case Mnemonic::OP_BRANCHIND: + return create_branchind(ctx, function, op); + case Mnemonic::OP_CALL: + return create_call(ctx, function, op); + case Mnemonic::OP_CALLIND: + return create_callind(ctx, function, op); + case Mnemonic::OP_USERDEFINED: + return create_userdefined(ctx, function, op); + case Mnemonic::OP_RETURN: + return create_return(ctx, function, op); + case Mnemonic::OP_PIECE: + return create_piece(ctx, function, op); + case Mnemonic::OP_SUBPIECE: + return create_subpiece(ctx, function, op); + case Mnemonic::OP_INT_EQUAL: + return create_binary_operation< clang::BO_EQ >(ctx, function, op); + case Mnemonic::OP_INT_NOTEQUAL: + return create_binary_operation< clang::BO_NE >(ctx, function, op); + case Mnemonic::OP_INT_LESS: + case Mnemonic::OP_INT_SLESS: + return create_binary_operation< clang::BO_LT >(ctx, function, op); + case Mnemonic::OP_INT_LESSEQUAL: + case Mnemonic::OP_INT_SLESSEQUAL: + return create_binary_operation< clang::BO_LE >(ctx, function, op); + case Mnemonic::OP_INT_ZEXT: + return create_int_zext(ctx, function, op); + case Mnemonic::OP_INT_SEXT: + return create_int_sext(ctx, function, op); + case Mnemonic::OP_INT_ADD: + return create_binary_operation< clang::BO_Add >(ctx, function, op); + case Mnemonic::OP_INT_SUB: + return create_int_sub(ctx, function, op); + case Mnemonic::OP_INT_CARRY: + return create_int_carry(ctx, function, op); + case Mnemonic::OP_INT_SCARRY: + return create_int_scarry(ctx, function, op); + case Mnemonic::OP_INT_SBORROW: + return create_int_sborrow(ctx, function, op); + case Mnemonic::OP_INT_2COMP: + return create_int_2comp(ctx, function, op); + case Mnemonic::OP_INT_NEGATE: + return create_unary_operation< clang::UO_LNot >(ctx, function, op); + case Mnemonic::OP_INT_XOR: + return create_binary_operation< clang::BO_Xor >(ctx, function, op); + case Mnemonic::OP_INT_AND: + return create_binary_operation< clang::BO_And >(ctx, function, op); + case Mnemonic::OP_INT_OR: + return create_binary_operation< clang::BO_Or >(ctx, function, op); + case Mnemonic::OP_INT_LEFT: + return create_binary_operation< clang::BO_Shl >(ctx, function, op); + case Mnemonic::OP_INT_RIGHT: + case Mnemonic::OP_INT_SRIGHT: + return create_binary_operation< clang::BO_Shr >(ctx, function, op); + case Mnemonic::OP_INT_MULT: + return create_binary_operation< clang::BO_Mul >(ctx, function, op); + case Mnemonic::OP_INT_DIV: + return create_binary_operation< clang::BO_Div >(ctx, function, op); + case Mnemonic::OP_INT_REM: + return create_binary_operation< clang::BO_Rem >(ctx, function, op); + case Mnemonic::OP_INT_SDIV: + return create_binary_operation< clang::BO_Div >(ctx, function, op); + case Mnemonic::OP_INT_SREM: + return create_binary_operation< clang::BO_Rem >(ctx, function, op); + case Mnemonic::OP_BOOL_NEGATE: + return create_unary_operation< clang::UO_LNot >(ctx, function, op); + case Mnemonic::OP_BOOL_OR: + return create_binary_operation< clang::BO_Or >(ctx, function, op); + case Mnemonic::OP_BOOL_AND: + case Mnemonic::OP_FLOAT_EQUAL: + return create_float_equal(ctx, function, op); + case Mnemonic::OP_FLOAT_NOTEQUAL: + return create_float_notequal(ctx, function, op); + case Mnemonic::OP_FLOAT_LESS: + return create_float_less(ctx, function, op); + case Mnemonic::OP_FLOAT_LESSEQUAL: + return create_float_lessequal(ctx, function, op); + case Mnemonic::OP_FLOAT_ADD: + return create_float_add(ctx, function, op); + case Mnemonic::OP_FLOAT_SUB: + return create_float_sub(ctx, function, op); + case Mnemonic::OP_FLOAT_MULT: + return create_float_mult(ctx, function, op); + case Mnemonic::OP_FLOAT_DIV: + return create_float_div(ctx, function, op); + case Mnemonic::OP_FLOAT_NEG: + return create_float_neg(ctx, function, op); + case Mnemonic::OP_FLOAT_ABS: + return create_float_abs(ctx, function, op); + case Mnemonic::OP_FLOAT_SQRT: + return create_float_sqrt(ctx, function, op); + case Mnemonic::OP_FLOAT_CEIL: + return create_float_ceil(ctx, function, op); + case Mnemonic::OP_FLOAT_FLOOR: + return create_float_floor(ctx, function, op); + case Mnemonic::OP_FLOAT_ROUND: + return create_float_round(ctx, function, op); + case Mnemonic::OP_FLOAT_NAN: + return create_float_nan(ctx, function, op); + case Mnemonic::OP_INT2FLOAT: + return create_int2float(ctx, function, op); + case Mnemonic::OP_FLOAT2FLOAT: + return create_float2float(ctx, function, op); + case Mnemonic::OP_TRUNC: + return create_trunc(ctx, function, op); + case Mnemonic::OP_PTRSUB: + return create_ptrsub(ctx, function, op); + case Mnemonic::OP_PTRADD: + return create_ptradd(ctx, function, op); + case Mnemonic::OP_CAST: + return create_cast(ctx, function, op); + case Mnemonic::OP_DECLARE_LOCAL: + return create_declare_local(ctx, function, op); + case Mnemonic::OP_DECLARE_PARAMETER: + return create_declare_parameter(ctx, function, op); + case Mnemonic::OP_DECLARE_TEMPORARY: + return create_declare_temporary(ctx, function, op); + case Mnemonic::OP_ADDRESS_OF: + return create_address_of(ctx, function, op); + case Mnemonic::OP_UNKNOWN: + assert(false); + break; + } + + // Fallback to returning the stmt; + return std::make_pair(nullptr, true); + } + + clang::Stmt * + PcodeASTConsumer::create_call_stmt(clang::ASTContext &ctx, const Operation &op) { + if (op.mnemonic != Mnemonic::OP_CALL) { + assert(false); + return nullptr; + } + + auto call_target = op.target; + if (!call_target.has_value()) { + return nullptr; + } + auto function_key = call_target->function; + auto iter = function_declarations.find(function_key.value()); + if (iter == function_declarations.end()) { + return nullptr; + } + auto *func_decl = iter->second; + return create_function_call(ctx, func_decl); + } + + clang::QualType + PcodeASTConsumer::get_varnode_type(clang::ASTContext &ctx, const Varnode &vnode) { + if (!vnode.type_key.empty()) { + auto iter = type_builder->get_serialized_types().find(vnode.type_key); + assert(iter != type_builder->get_serialized_types().end()); + return iter->second; + } + + if (vnode.size != 0U) { + return get_type_for_size(ctx, vnode.size, /*is_signed=*/false, /*is_integer=*/true); + } + + return clang::QualType(); + } + + clang::Stmt *PcodeASTConsumer::create_varnode( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, bool is_input + ) { + switch (vnode.kind) { + case Varnode::VARNODE_UNKNOWN: + break; + case Varnode::VARNODE_GLOBAL: + return create_global(ctx, function, vnode, is_input); + case Varnode::VARNODE_PARAM: + return create_parameter(ctx, function, vnode, is_input); + case Varnode::VARNODE_FUNCTION: + return create_function(ctx, function, vnode); + case Varnode::VARNODE_LOCAL: + return create_local(ctx, function, vnode, is_input); + case Varnode::VARNODE_TEMPORARY: + return create_temporary(ctx, function, vnode, is_input); + case Varnode::VARNODE_CONSTANT: + return create_constant(ctx, vnode); + } + + return nullptr; + } + + clang::Stmt *PcodeASTConsumer::create_parameter( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, bool is_input + ) { + if (!vnode.operation || vnode.kind != Varnode::VARNODE_PARAM) { + assert(false && "Invalid parameter varnode"); + return nullptr; + } + + auto iter = local_variable_declarations.find(vnode.operation.value()); + assert( + iter != local_variable_declarations.end() + && "Failed to find parameter variable declaration" + ); + auto *param_decl = clang::dyn_cast< clang::ParmVarDecl >(iter->second); + return clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), param_decl, false, + clang::SourceLocation(), param_decl->getType(), + is_input ? clang::VK_PRValue : clang::VK_LValue + ); + (void) function; + } + + clang::Stmt *PcodeASTConsumer::create_global( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, bool is_input + ) { + if (!vnode.global || vnode.kind != Varnode::VARNODE_GLOBAL) { + assert(false && "Invalid global varnode"); + return nullptr; + } + + auto iter = global_variable_declarations.find(vnode.global.value()); + assert( + iter != global_variable_declarations.end() + && "Failed to find global variable declaration" + ); + + auto *var_decl = clang::dyn_cast< clang::VarDecl >(iter->second); + return clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), var_decl, false, + clang::SourceLocation(), var_decl->getType(), + is_input ? clang::VK_PRValue : clang::VK_LValue + ); + (void) function; + } + + clang::Stmt *PcodeASTConsumer::create_temporary( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, bool is_input + ) { + if (vnode.kind != Varnode::VARNODE_TEMPORARY) { + assert(false && "Invalid temporary varnode"); + return nullptr; + } + + if (!vnode.operation) { + return nullptr; + } + + auto var_iter = local_variable_declarations.find(vnode.operation.value()); + if (var_iter != local_variable_declarations.end()) { + assert(var_iter->second != nullptr); + return clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), var_iter->second, + false, clang::SourceLocation(), var_iter->second->getType(), + is_input ? clang::VK_PRValue : clang::VK_LValue + ); + } + + auto stmt_iter = function_operation_stmts.find(vnode.operation.value()); + if (stmt_iter != function_operation_stmts.end()) { + assert(stmt_iter->second != nullptr); + return stmt_iter->second; + } + + if (auto maybe_operation = operation_from_key(function, vnode.operation.value())) { + auto [stmt, _] = create_operation(ctx, function, *maybe_operation); + return stmt; + } + + assert(false && "Failed to get operation for key"); + return nullptr; + } + + clang::Stmt *PcodeASTConsumer::create_function( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, bool is_input + ) { + if (!vnode.function || vnode.kind != Varnode::VARNODE_FUNCTION) { + assert(false && "Invalid function varnode"); + return nullptr; + } + + auto iter = function_declarations.find(vnode.function.value()); + assert(iter != function_declarations.end() && "Failed to find function declaration"); + auto *func_decl = clang::dyn_cast< clang::FunctionDecl >(iter->second); + return clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), func_decl, false, + clang::SourceLocation(), func_decl->getType(), clang::VK_PRValue + ); + (void) function, is_input; + } + + clang::Stmt *PcodeASTConsumer::create_local( + clang::ASTContext &ctx, const Function &function, const Varnode &vnode, bool is_input + ) { + if (!vnode.operation || vnode.kind != Varnode::VARNODE_LOCAL) { + assert(false && "Invalid local varnode"); + return nullptr; + } + + auto iter = local_variable_declarations.find(vnode.operation.value()); + if (iter != local_variable_declarations.end()) { + assert(iter != local_variable_declarations.end()); + auto *var_decl = clang::dyn_cast< clang::VarDecl >(iter->second); + return clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), var_decl, false, + clang::SourceLocation(), var_decl->getType(), + is_input ? clang::VK_PRValue : clang::VK_LValue + ); + } + + auto maybe_operation = operation_from_key(function, vnode.operation.value()); + if (maybe_operation) { + auto [stmt, _] = create_operation(ctx, function, *maybe_operation); + return stmt; + } + + return nullptr; + } + + clang::Stmt * + PcodeASTConsumer::create_constant(clang::ASTContext &ctx, const Varnode &vnode) { + if (vnode.kind != Varnode::VARNODE_CONSTANT) { + assert(false && "Invalid constant varnode"); + return nullptr; + } + + clang::QualType type = get_varnode_type(ctx, vnode); + + if (type->isIntegerType()) { + auto value = vnode.value; + return new (ctx) clang::IntegerLiteral( + ctx, llvm::APInt(static_cast< uint32_t >(ctx.getTypeSize(type)), *value), type, + clang::SourceLocation() + ); + } + + if (type->isVoidType()) { + auto value = vnode.value; + auto *literal = new (ctx) clang::IntegerLiteral( + ctx, llvm::APInt(32U, *value), ctx.IntTy, clang::SourceLocation() + ); + return clang::CStyleCastExpr::Create( + ctx, type, clang::VK_PRValue, clang::CK_ToVoid, literal, nullptr, + clang::FPOptionsOverride(), ctx.getTrivialTypeSourceInfo(type), + clang::SourceLocation(), clang::SourceLocation() + ); + } + + if (type->isPointerType()) { + auto value = vnode.value; + auto *literal = new (ctx) clang::IntegerLiteral( + ctx, llvm::APInt(32U, *value), ctx.IntTy, clang::SourceLocation() + ); + return clang::CStyleCastExpr::Create( + ctx, type, clang::VK_PRValue, clang::CK_IntegralToPointer, literal, nullptr, + clang::FPOptionsOverride(), ctx.getTrivialTypeSourceInfo(type), + clang::SourceLocation(), clang::SourceLocation() + ); + } + + if (type->isFloatingType()) { + auto value = static_cast< double >(*vnode.value); + return clang::FloatingLiteral::Create( + ctx, llvm::APFloat(value), true, type, clang::SourceLocation() + ); + } + + return nullptr; + } + +} // namespace patchestry::ast diff --git a/lib/patchestry/AST/OperationStmt.cpp b/lib/patchestry/AST/OperationStmt.cpp new file mode 100644 index 0000000..e38d62a --- /dev/null +++ b/lib/patchestry/AST/OperationStmt.cpp @@ -0,0 +1,1336 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace patchestry::ast { + + extern std::optional< Operation > + operation_from_key(const Function &function, const std::string &lookup_key); + + namespace { + clang::CallExpr *create_function_call( + clang::ASTContext &ctx, clang::FunctionDecl *decl, + std::vector< clang::Expr * > &args + ) { + auto *ref_expr = clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), decl, false, + clang::SourceLocation(), decl->getType(), clang::VK_LValue + ); + + return clang::CallExpr::Create( + ctx, ref_expr, args, decl->getReturnType(), clang::VK_PRValue, + clang::SourceLocation(), clang::FPOptionsOverride() + ); + } + + clang::VarDecl *create_variable_decl( + clang::ASTContext &ctx, clang::DeclContext *dc, const std::string &name, + clang::QualType type, clang::SourceLocation loc + ) { + return clang::VarDecl::Create( + ctx, dc, loc, loc, &ctx.Idents.get(name), type, + ctx.getTrivialTypeSourceInfo(type), clang::SC_None + ); + } + } // namespace + + clang::DeclStmt * + PcodeASTConsumer::create_decl_stmt(clang::ASTContext &ctx, clang::Decl *decl) { + auto decl_group = clang::DeclGroupRef(decl); + return new (ctx) + clang::DeclStmt(decl_group, clang::SourceLocation(), clang::SourceLocation()); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_declare_local( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_DECLARE_LOCAL) { + assert(false); + return std::make_pair(nullptr, false); + } + + // Get type of the declared variable + auto type_iter = type_builder->get_serialized_types().find(*op.type); + assert( + (type_iter != type_builder->get_serialized_types().end()) + && "Failed to find type for declared variable." + ); + + auto *var_decl = create_variable_decl( + ctx, get_sema().CurContext, *op.name, type_iter->second, + source_location_from_key(ctx, op.key) + ); + + // add variable declaration to list for future references + local_variable_declarations.emplace(op.key, var_decl); + (void) function; + return std::make_pair(create_decl_stmt(ctx, var_decl), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_declare_temporary( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_DECLARE_TEMPORARY) { + assert(false); + return std::make_pair(nullptr, false); + } + + // Get type of the declared variable + auto type_iter = type_builder->get_serialized_types().find(*op.type); + assert( + (type_iter != type_builder->get_serialized_types().end()) + && "Failed to find type for declared variable." + ); + + auto *var_decl = create_variable_decl( + ctx, get_sema().CurContext, *op.name, type_iter->second, + source_location_from_key(ctx, op.key) + ); + + // add variable declaration to list for future references + local_variable_declarations.emplace(op.key, var_decl); + return std::make_pair(create_decl_stmt(ctx, var_decl), false); + (void) function; + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_declare_parameter( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_copy( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_COPY) { + assert(false && "Invalid copy operation"); + return std::make_pair(nullptr, false); + } + + // Copy operation does not have output varnode. Create stmt that will be merged to next + // operation + auto *input_expr = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs.front())); + if (!op.output) { + assert((op.inputs.size() == 1) && "Invalid input for copy operation"); + return std::make_pair(input_expr, true); + } + + auto *output_expr = clang::dyn_cast< clang::Expr >( + create_varnode(ctx, function, *op.output, /*is_input=*/false) + ); + + if (clang::dyn_cast< clang::Expr >(input_expr)->getType() != output_expr->getType()) { + auto cast_result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(output_expr->getType()), + clang::SourceLocation(), clang::dyn_cast< clang::Expr >(input_expr) + ); + + assert(!cast_result.isInvalid() && "Invalid cstyle cast to output expr"); + input_expr = cast_result.getAs< clang::Expr >(); + } + + auto assign_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, output_expr, input_expr + ); + assert(!assign_result.isInvalid()); + + return std::make_pair(assign_result.getAs< clang::Stmt >(), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_load( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_LOAD) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto merge_to_next = !op.output.has_value(); + + auto *input0_expr = create_varnode(ctx, function, op.inputs[0]); + assert(input0_expr != nullptr); + + auto deref_result = get_sema().CreateBuiltinUnaryOp( + clang::SourceLocation(), clang::UO_Deref, + clang::dyn_cast< clang::Expr >(input0_expr) + ); + + if (merge_to_next) { + return std::make_pair(deref_result.getAs< clang::Expr >(), merge_to_next); + } + + auto *result_expr = deref_result.getAs< clang::Expr >(); + // auto is_lvalue = result_expr->isLValue(); + + auto *output_expr = clang::dyn_cast< clang::Expr >( + create_varnode(ctx, function, *op.output, /*is_input=*/false) + ); + + // auto *result_expr = deref_result.getAs< clang::Expr >(); + + if (result_expr->getType() != output_expr->getType()) { + auto cast_result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(output_expr->getType()), + clang::SourceLocation(), result_expr + ); + + assert(!cast_result.isInvalid() && "Invalid cstyle cast to output expr"); + result_expr = cast_result.getAs< clang::Expr >(); + } + + auto assign_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, output_expr, result_expr + ); + assert(!assign_result.isInvalid()); + + return std::make_pair(assign_result.getAs< clang::Stmt >(), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_store( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_STORE) { + assert(false); + return std::make_pair(nullptr, false); + } + + assert(op.inputs.size() >= 2); + + auto *input0_expr = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[0])); + + auto *input1_expr = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[1])); + + if (op.inputs.size() == 2) { + auto deref_result = get_sema().CreateBuiltinUnaryOp( + clang::SourceLocation(), clang::UO_Deref, + clang::dyn_cast< clang::Expr >(input0_expr) + ); + + auto *result_expr = deref_result.getAs< clang::Expr >(); + + if (result_expr->getType() != input1_expr->getType()) { + auto cast_result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), + ctx.getTrivialTypeSourceInfo(result_expr->getType()), + clang::SourceLocation(), input1_expr + ); + + assert(!cast_result.isInvalid() && "Invalid cstyle cast to output expr"); + input1_expr = cast_result.getAs< clang::Expr >(); + } + + auto store_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, result_expr, + input1_expr + ); + + return std::make_pair(store_result.getAs< clang::Expr >(), false); + } + + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_branch( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_BRANCH) { + assert(false && "Invalid branch operation."); + return std::make_pair(nullptr, false); + } + + assert(op.target_block); + auto iter = basic_block_labels.find(*op.target_block); + assert(iter != basic_block_labels.end()); + + (void) function; + // Create GotoStmt for branch operation + return std::make_pair( + new (ctx) clang::GotoStmt( + iter->second, source_location_from_key(ctx, op.key), + source_location_from_key(ctx, *op.target_block) + ), + false + ); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_cbranch( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_CBRANCH) { + assert(false && "Invalid cbranch operation"); + return std::make_pair(nullptr, false); + } + + // TODO(kumarak): Could there be case where conditional statement is missing?? In + // such case treat it as branch instruction. + auto *condition_expr = create_varnode(ctx, function, *op.condition); + clang::Stmt *taken_stmt = nullptr; + clang::Stmt *not_taken_stmt = nullptr; + + if (op.taken_block && !op.taken_block->empty()) { + auto taken_block_key = *op.taken_block; + auto label_iter = basic_block_labels.find(taken_block_key); + assert(label_iter != basic_block_labels.end()); + + taken_stmt = new (ctx) clang::GotoStmt( + label_iter->second, source_location_from_key(ctx, op.key), + source_location_from_key(ctx, *op.target_block) + ); + } else { + taken_stmt = new (ctx) clang::NullStmt(clang::SourceLocation(), false); + } + + if (op.not_taken_block && !op.not_taken_block->empty()) { + auto not_taken_block_key = *op.not_taken_block; + auto label_iter = basic_block_labels.find(not_taken_block_key); + assert(label_iter != basic_block_labels.end()); + + not_taken_stmt = new (ctx) clang::GotoStmt( + label_iter->second, source_location_from_key(ctx, op.key), + source_location_from_key(ctx, *op.target_block) + ); + } else { + not_taken_stmt = new (ctx) clang::NullStmt(clang::SourceLocation(), false); + } + + return std::make_pair( + clang::IfStmt::Create( + ctx, clang::SourceLocation(), clang::IfStatementKind::Ordinary, nullptr, + nullptr, clang::dyn_cast< clang::Expr >(condition_expr), + clang::SourceLocation(), clang::SourceLocation(), taken_stmt, + clang::SourceLocation(), not_taken_stmt + ), + false + ); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_branchind( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_call( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_CALL) { + assert(false && "Invalid call operation."); + return std::make_pair(nullptr, false); + } + + auto call_target = op.target; + if (!call_target.has_value()) { + return std::make_pair(nullptr, false); + } + + if (!call_target->function && !call_target->operation) { + assert(false); + return std::make_pair(nullptr, false); + } + + clang::Expr *call_expr = nullptr; + + std::vector< clang::Expr * > arguments; + for (const auto &input : op.inputs) { + auto *arg_expr = create_varnode(ctx, function, input); + assert(arg_expr != nullptr); + arguments.push_back(clang::dyn_cast< clang::Expr >(arg_expr)); + } + + if (call_target->function) { + auto iter = function_declarations.find(call_target->function.value()); + if (iter == function_declarations.end()) { + return std::make_pair(nullptr, false); + } + call_expr = create_function_call(ctx, iter->second, arguments); + if (!op.output || iter->second->getReturnType()->isVoidType()) { + return std::make_pair(clang::dyn_cast< clang::Expr >(call_expr), false); + } + + } else if (call_target->operation) { + auto op = operation_from_key(function, call_target->operation.value()); + auto [stmt, _] = create_operation(ctx, function, op.value()); + auto result = get_sema().ActOnCallExpr( + nullptr, clang::dyn_cast< clang::Expr >(stmt), clang::SourceLocation(), + arguments, clang::SourceLocation() + ); + call_expr = result.getAs< clang::Expr >(); + if (!op->output) { + return std::make_pair(clang::dyn_cast< clang::Expr >(call_expr), false); + } + } + + auto *out_expr = create_varnode(ctx, function, *op.output, false); + set_sema_context(ctx.getTranslationUnitDecl()); + + auto rty_type = type_builder->get_serialized_types().at(*op.type); + + auto cast_result = + get_sema().ImpCastExprToType(call_expr, rty_type, clang::CastKind::CK_BitCast); + + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(out_expr), cast_result.getAs< clang::Expr >() + ); + + return std::make_pair(out_result.getAs< clang::Expr >(), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_callind( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_CALLIND) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto call_target = op.target; + if (!call_target.has_value()) { + return std::make_pair(nullptr, false); + } + clang::Expr *call_expr = nullptr; + + std::vector< clang::Expr * > arguments; + for (const auto &input : op.inputs) { + auto *arg_expr = create_varnode(ctx, function, input); + assert(arg_expr != nullptr); + arguments.push_back(clang::dyn_cast< clang::Expr >(arg_expr)); + } + + if (call_target->function) { + auto iter = function_declarations.find(call_target->function.value()); + if (iter == function_declarations.end()) { + return std::make_pair(nullptr, false); + } + call_expr = create_function_call(ctx, iter->second, arguments); + if (!op.output || iter->second->getReturnType()->isVoidType()) { + return std::make_pair(clang::dyn_cast< clang::Expr >(call_expr), false); + } + } else if (call_target->operation) { + auto op = operation_from_key(function, call_target->operation.value()); + auto [stmt, _] = create_operation(ctx, function, op.value()); + auto result = get_sema().ActOnCallExpr( + nullptr, clang::dyn_cast< clang::Expr >(stmt), clang::SourceLocation(), + arguments, clang::SourceLocation() + ); + call_expr = result.getAs< clang::Expr >(); + if (!op->output) { + return std::make_pair(clang::dyn_cast< clang::Expr >(call_expr), false); + } + } + + auto *out_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + auto rty_type = type_builder->get_serialized_types().at(*op.type); + + auto cast_result = + get_sema().ImpCastExprToType(call_expr, rty_type, clang::CastKind::CK_BitCast); + + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(out_expr), cast_result.getAs< clang::Expr >() + ); + + return std::make_pair(out_result.getAs< clang::Expr >(), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_userdefined( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_return( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_RETURN) { + assert(false && "Invalid return operation"); + return std::make_pair(nullptr, false); + } + + // Assert if number + // assert(ret_op.inputs.size() < 2); + if (!op.inputs.empty()) { + auto varnode = op.inputs.size() == 1 ? op.inputs.front() : op.inputs.at(1); + auto *ret_expr = create_varnode(ctx, function, varnode); + return std::make_pair( + clang::ReturnStmt::Create( + ctx, clang::SourceLocation(), llvm::dyn_cast< clang::Expr >(ret_expr), + nullptr + ), + false + ); + } + return std::make_pair( + clang::ReturnStmt::Create(ctx, clang::SourceLocation(), nullptr, nullptr), false + ); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_piece( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_PIECE) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto merge_to_next = !op.output.has_value(); + + unsigned low_width = 8U; + + auto *shift_value = clang::IntegerLiteral::Create( + ctx, llvm::APInt(32, low_width), ctx.IntTy, clang::SourceLocation() + ); + + auto *input0_expr = create_varnode(ctx, function, op.inputs[0]); + assert(input0_expr != nullptr); + + auto *input1_expr = create_varnode(ctx, function, op.inputs[1]); + assert(input1_expr != nullptr); + + auto shifted_high_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Shl, + clang::dyn_cast< clang::Expr >(input0_expr), + clang::dyn_cast< clang::Expr >(shift_value) + ); + + if (shifted_high_result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto or_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Or, + shifted_high_result.getAs< clang::Expr >(), + clang::dyn_cast< clang::Expr >(input1_expr) + ); + + if (or_result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + if (merge_to_next) { + return std::make_pair(or_result.getAs< clang::Expr >(), merge_to_next); + } + + auto *output_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(output_expr), or_result.getAs< clang::Expr >() + ); + + if (out_result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + return std::make_pair(out_result.getAs< clang::Stmt >(), merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_subpiece( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_SUBPIECE) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto merge_to_next = !op.output.has_value(); + assert(op.inputs.size() == 2); + + auto *shift_value = create_varnode(ctx, function, op.inputs[1]); + assert(shift_value != nullptr); + + auto *expr = create_varnode(ctx, function, op.inputs[0]); + assert(expr != nullptr); + + auto *expr_with_paren = new (ctx) clang::ParenExpr( + clang::SourceLocation(), clang::SourceLocation(), + clang::dyn_cast< clang::Expr >(expr) + ); + + auto shifted_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Shr, + clang::dyn_cast< clang::Expr >(expr_with_paren), + clang::dyn_cast< clang::Expr >(shift_value) + ); + + if (shifted_result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto *shifted_expr = new (ctx) clang::ParenExpr( + clang::SourceLocation(), clang::SourceLocation(), + shifted_result.getAs< clang::Expr >() + ); + + auto mask_value = llvm::APInt::getAllOnes(32); + auto *mask = + clang::IntegerLiteral::Create(ctx, mask_value, ctx.IntTy, clang::SourceLocation()); + + auto result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_And, shifted_expr, + clang::dyn_cast< clang::Expr >(mask) + ); + + if (result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto *result_expr = new (ctx) clang::ParenExpr( + clang::SourceLocation(), clang::SourceLocation(), result.getAs< clang::Expr >() + ); + + if (merge_to_next) { + return std::make_pair(result_expr, merge_to_next); + } + + auto *out_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(out_expr), result_expr + ); + + if (out_result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + return std::make_pair(out_result.getAs< clang::Stmt >(), merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_equal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_EQUAL) { + assert(false); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BinaryOperatorKind::BO_EQ >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_notequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_NOTEQUAL) { + assert(false); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BinaryOperatorKind::BO_NE >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_less( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_LESS) { + assert(false && "Invalid int_less operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_LT >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_sless( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_SLESS) { + assert(false && "Invalid int_sless operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_LT >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_lessequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_LESSEQUAL) { + assert(false && "Invalid int_lessequal operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_LE >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_slessequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_SLESSEQUAL) { + assert(false && "Invalid int_slessequal operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_LE >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_zext( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_ZEXT) { + assert(false); + return std::make_pair(nullptr, true); + } + + auto merge_to_next = !op.output.has_value(); + auto *input_expr = create_varnode(ctx, function, op.inputs[0]); + assert(input_expr != nullptr); + + auto target_type = type_builder->get_serialized_types().at(*op.type); + + auto result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(target_type), + clang::SourceLocation(), clang::dyn_cast< clang::Expr >(input_expr) + ); + + if (result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + if (merge_to_next) { + return std::make_pair(result.getAs< clang::Stmt >(), merge_to_next); + } + + auto *out_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(out_expr), result.getAs< clang::Expr >() + ); + + return std::make_pair(out_result.getAs< clang::Stmt >(), merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_sext( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_SEXT) { + assert(false); + return std::make_pair(nullptr, true); + } + + auto merge_to_next = !op.output.has_value(); + auto *input_expr = create_varnode(ctx, function, op.inputs[0]); + assert(input_expr != nullptr); + + auto target_type = type_builder->get_serialized_types().at(*op.type); + + auto result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(target_type), + clang::SourceLocation(), clang::dyn_cast< clang::Expr >(input_expr) + ); + + if (result.isInvalid()) { + assert(false); + return std::make_pair(nullptr, false); + } + + if (merge_to_next) { + return std::make_pair(result.getAs< clang::Stmt >(), merge_to_next); + } + + auto *out_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(out_expr), result.getAs< clang::Expr >() + ); + + return std::make_pair(out_result.getAs< clang::Stmt >(), merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_add( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_ADD) { + assert(false && "Invalid int_add operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Add >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_sub( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_SUB) { + assert(false && "Invalid int_add operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Sub >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_carry( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_scarry( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_sborrow( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_2comp( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + template std::pair< clang::Stmt *, bool > + PcodeASTConsumer::create_unary_operation< clang::UO_LNot >( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template std::pair< clang::Stmt *, bool > + PcodeASTConsumer::create_binary_operation< clang::BO_Xor >( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template std::pair< clang::Stmt *, bool > + PcodeASTConsumer::create_binary_operation< clang::BO_And >( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template std::pair< clang::Stmt *, bool > + PcodeASTConsumer::create_binary_operation< clang::BO_Or >( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template std::pair< clang::Stmt *, bool > + PcodeASTConsumer::create_binary_operation< clang::BO_Shl >( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + template std::pair< clang::Stmt *, bool > + PcodeASTConsumer::create_binary_operation< clang::BO_Shr >( + clang::ASTContext &ctx, const Function &function, const Operation &op + ); + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_mult( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_MULT) { + assert(false && "Invalid int_add operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Mul >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_div( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_DIV) { + assert(false && "Invalid int_add operation"); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Div >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_rem( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_REM) { + assert(false); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Rem >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int_sdiv( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT_SDIV) { + assert(false); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Div >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_equal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_notequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_less( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_lessequal( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_add( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_sub( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_mult( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_div( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_FLOAT_DIV) { + assert(false); + return std::make_pair(nullptr, false); + } + + return create_binary_operation< clang::BO_Div >(ctx, function, op); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_neg( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_abs( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_sqrt( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_ceil( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_floor( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_round( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float_nan( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_int2float( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_INT2FLOAT) { + assert(false && "Invalid int2float operation"); + return std::make_pair(nullptr, false); + } + + auto type_iter = type_builder->get_serialized_types().find(*op.type); + assert(type_iter != type_builder->get_serialized_types().end()); + + auto *lhs = clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[0])); + auto result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(type_iter->second), + clang::SourceLocation(), lhs + ); + assert(!result.isInvalid() && "Invalid cast expr result"); + + if (!op.output) { + return std::make_pair(result.getAs< clang::Stmt >(), true); + } + + auto *output = clang::dyn_cast< clang::Expr >( + create_varnode(ctx, function, *op.output, /*is_input=*/false) + ); + + auto output_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, output, + result.getAs< clang::Expr >() + ); + assert(!output_result.isInvalid() && "Invalid assignment result"); + + return std::make_pair(output_result.getAs< clang::Expr >(), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_float2float( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + (void) ctx, (void) function, (void) op; + return std::make_pair(nullptr, true); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_trunc( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_TRUNC) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto merge_to_next = !op.output.has_value(); + assert(op.inputs.size() == 1u); + + auto type_iter = type_builder->get_serialized_types().find(*op.type); + assert(type_iter != type_builder->get_serialized_types().end()); + + auto *src_expr = create_varnode(ctx, function, op.inputs[0]); + + set_sema_context(ctx.getTranslationUnitDecl()); + auto result = get_sema().ImpCastExprToType( + clang::dyn_cast< clang::Expr >(src_expr), type_iter->second, clang::CK_IntegralCast, + clang::VK_PRValue, nullptr + ); + + if (result.isInvalid()) { + llvm::errs() << "Failed to create operation for trunc\n"; + return std::make_pair(nullptr, true); + } + + if (merge_to_next) { + return std::make_pair(result.getAs< clang::Stmt >(), true); + } + + // If output varnode is avaiable + auto *dest_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + auto out_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(dest_expr), result.getAs< clang::Expr >() + ); + + return std::make_pair(out_result.getAs< clang::Stmt >(), merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_ptrsub( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_PTRSUB) { + assert(false && "Invalid PTRSUB operation."); + return std::make_pair(nullptr, false); + } + + auto merge_to_next = !op.output.has_value(); + auto *input0_expr = create_varnode(ctx, function, op.inputs[0]); + auto type_iter = type_builder->get_serialized_types().find(*op.type); + assert(type_iter != type_builder->get_serialized_types().end()); + auto ptr_type = type_iter->second; + + auto *ptr_expr = clang::ImplicitCastExpr::Create( + ctx, ptr_type, clang::CK_BitCast, clang::dyn_cast< clang::Expr >(input0_expr), + nullptr, clang::VK_PRValue, clang::FPOptionsOverride() + ); + + auto *byte_offset = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[1])); + + auto *ptr_add_expr = clang::BinaryOperator::Create( + ctx, ptr_expr, byte_offset, clang::BO_Add, ptr_type, clang::VK_PRValue, + clang::OK_Ordinary, clang::SourceLocation(), clang::FPOptionsOverride() + ); + + auto *result_expr = clang::ImplicitCastExpr::Create( + ctx, ptr_type, clang::CK_BitCast, ptr_add_expr, nullptr, clang::VK_PRValue, + clang::FPOptionsOverride() + ); + + return std::make_pair(result_expr, merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_ptradd( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_PTRADD) { + assert(false); + return std::make_pair(nullptr, true); + } + + auto merge_to_next = !op.output.has_value(); + assert(op.inputs.size() == 3U); + + auto *base = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[0])); + auto *index = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[1])); + auto *scale = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[2])); + + auto mult_result = + get_sema().CreateBuiltinBinOp(clang::SourceLocation(), clang::BO_Mul, index, scale); + assert(!mult_result.isInvalid()); + + auto result = get_sema().CreateBuiltinBinOp( + clang::SourceLocation(), clang::BO_Add, base, mult_result.getAs< clang::Expr >() + ); + assert(!result.isInvalid()); + + if (merge_to_next) { + return std::make_pair(result.getAs< clang::Stmt >(), merge_to_next); + } + + auto *output_stmt = create_varnode(ctx, function, *op.output, /*is_input=*/false); + if (output_stmt->getStmtClass() == clang::Stmt::DeclStmtClass) { + auto *decl = clang::dyn_cast< clang::DeclStmt >(output_stmt)->getSingleDecl(); + auto *ref_expr = clang::DeclRefExpr::Create( + ctx, clang::NestedNameSpecifierLoc(), clang::SourceLocation(), + clang::dyn_cast< clang::VarDecl >(decl), false, clang::SourceLocation(), + clang::dyn_cast< clang::VarDecl >(decl)->getType(), clang::VK_LValue + ); + + auto ref_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(ref_expr), result.getAs< clang::Expr >() + ); + + assert(!ref_result.isInvalid()); + return std::make_pair(ref_result.getAs< clang::Stmt >(), false); + } + + auto output_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(output_stmt), result.getAs< clang::Expr >() + ); + + if (output_result.isInvalid()) { + assert(false && "Invalid result from assignment operation"); + return std::make_pair(nullptr, false); + } + + return std::make_pair(output_result.getAs< clang::Stmt >(), false); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_cast( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_CAST) { + assert(false); + return std::make_pair(nullptr, true); + } + + auto merge_to_next = !op.output.has_value(); + assert(op.inputs.size() == 1U); + + auto type_iter = type_builder->get_serialized_types().find(*op.type); + assert(type_iter != type_builder->get_serialized_types().end()); + + auto *input_expr = create_varnode(ctx, function, op.inputs[0]); + auto result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(type_iter->second), + clang::SourceLocation(), clang::dyn_cast< clang::Expr >(input_expr) + ); + assert(!result.isInvalid()); + + return std::make_pair(result.getAs< clang::Stmt >(), merge_to_next); + } + + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_address_of( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + if (op.mnemonic != Mnemonic::OP_ADDRESS_OF) { + assert(false); + return std::make_pair(nullptr, false); + } + + auto merge_to_next = !op.output.has_value(); + auto *input_expr = create_varnode(ctx, function, op.inputs[0]); + assert(input_expr != nullptr); + + auto result = get_sema().CreateBuiltinUnaryOp( + clang::SourceLocation(), clang::UO_AddrOf, + clang::dyn_cast< clang::Expr >(input_expr) + ); + assert(!result.isInvalid()); + + if (merge_to_next) { + return std::make_pair(result.getAs< clang::Expr >(), merge_to_next); + } + + auto *output_expr = create_varnode(ctx, function, *op.output, /*is_input=*/false); + + auto *result_expr = result.getAs< clang::Expr >(); + auto output_type = clang::dyn_cast< clang::Expr >(output_expr)->getType(); + + if (result_expr->getType() != output_type) { + auto cast_result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(output_type), + clang::SourceLocation(), result_expr + ); + + assert(!cast_result.isInvalid() && "Invalid cstyle cast to output expr"); + result_expr = cast_result.getAs< clang::Expr >(); + } + + auto output_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(output_expr), result_expr + ); + assert(!output_result.isInvalid()); + + return std::make_pair(output_result.getAs< clang::Expr >(), merge_to_next); + } + + template< clang::BinaryOperatorKind Kind > + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_binary_operation( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + auto merge_to_next = !op.output.has_value(); + assert(op.inputs.size() == 2 && "Insufficient input operators"); + + auto *lhs = clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[0])); + + auto *rhs = clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[1])); + + auto result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), Kind, clang::dyn_cast< clang::Expr >(lhs), + clang::dyn_cast< clang::Expr >(rhs) + ); + + assert(!result.isInvalid() && "Invalid result from binary operation"); + + if (merge_to_next) { + return std::make_pair(result.getAs< clang::Stmt >(), merge_to_next); + } + + auto *output_expr = clang::dyn_cast< clang::Expr >( + create_varnode(ctx, function, *op.output, /*is_input=*/false) + ); + + auto output_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, + clang::dyn_cast< clang::Expr >(output_expr), result.getAs< clang::Expr >() + ); + assert( + !output_result.isInvalid() && "Invalid assignment operation after binary operator" + ); + + return std::make_pair(output_result.getAs< clang::Expr >(), merge_to_next); + } + + template< clang::UnaryOperatorKind Kind > + std::pair< clang::Stmt *, bool > PcodeASTConsumer::create_unary_operation( + clang::ASTContext &ctx, const Function &function, const Operation &op + ) { + // If output varnode is emptry, the stmt will get merged to next operation. + auto merge_to_next = !op.output.has_value(); + auto *input = + clang::dyn_cast< clang::Expr >(create_varnode(ctx, function, op.inputs[0])); + + auto result = get_sema().CreateBuiltinUnaryOp(clang::SourceLocation(), Kind, input); + assert(!result.isInvalid() && "Invalid unary operation"); + + if (merge_to_next) { + // merge to next operation + return std::make_pair(result.getAs< clang::Stmt >(), merge_to_next); + } + + auto *output = clang::dyn_cast< clang::Expr >( + create_varnode(ctx, function, *op.output, /*is_input=*/false) + ); + + auto *result_expr = result.getAs< clang::Expr >(); + if (result_expr->getType() != output->getType()) { + auto cast_result = get_sema().BuildCStyleCastExpr( + clang::SourceLocation(), ctx.getTrivialTypeSourceInfo(output->getType()), + clang::SourceLocation(), result_expr + ); + + assert(!cast_result.isInvalid() && "Invalid cstyle cast to output expr"); + result_expr = cast_result.getAs< clang::Expr >(); + } + + auto output_result = get_sema().CreateBuiltinBinOp( + source_location_from_key(ctx, op.key), clang::BO_Assign, output, result_expr + ); + assert(!output_result.isInvalid()); + + return std::make_pair(output_result.getAs< clang::Expr >(), false); + } + +} // namespace patchestry::ast diff --git a/lib/patchestry/AST/TypeBuilder.cpp b/lib/patchestry/AST/TypeBuilder.cpp new file mode 100644 index 0000000..0fb0908 --- /dev/null +++ b/lib/patchestry/AST/TypeBuilder.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include +#include + +namespace patchestry::ast { + + void TypeBuilder::create_types(clang::ASTContext &ctx, TypeMap &lifted_types) { + for (auto &[key, vnode_type] : lifted_types) { + serialized_types.emplace(key, create_type(ctx, vnode_type)); + } + + // Traverse through missing_type_definition list and complete definitions + for (auto &[key, decl] : missing_type_definition) { + if (const auto *record_decl = llvm::dyn_cast< clang::RecordDecl >(decl)) { + auto iter = lifted_types.find(key); + if (iter == lifted_types.end()) { + llvm::errs() << "Key not found in type map\n"; + assert(false); + continue; + } + auto vnode_type = iter->second; + create_record_definition( + ctx, dynamic_cast< CompositeType & >(*vnode_type), decl, serialized_types + ); + } + } + } + + clang::QualType TypeBuilder::create_type( + clang::ASTContext &ctx, const std::shared_ptr< VarnodeType > &vnode_type + ) { + auto type_iter = serialized_types.find(vnode_type->key); + if (type_iter != serialized_types.end()) { + return type_iter->second; + } + + switch (vnode_type->kind) { + case VarnodeType::VT_INVALID: + return ctx.CharTy; + case VarnodeType::VT_BOOLEAN: + return ctx.BoolTy; + case VarnodeType::VT_INTEGER: + return ctx.IntTy; + case VarnodeType::VT_CHAR: + return ctx.CharTy; + case VarnodeType::VT_FLOAT: + return ctx.FloatTy; + case VarnodeType::VT_ARRAY: + return create_array_type(ctx, dynamic_cast< const ArrayType & >(*vnode_type)); + case VarnodeType::VT_POINTER: + return create_pointer_type( + ctx, dynamic_cast< const PointerType & >(*vnode_type) + ); + case VarnodeType::Kind::VT_FUNCTION: + return ctx.VoidPtrTy; + case VarnodeType::VT_STRUCT: + case VarnodeType::VT_UNION: + return create_composite_type(ctx, *vnode_type); + case VarnodeType::VT_ENUM: + return create_enum_type(ctx, dynamic_cast< const EnumType & >(*vnode_type)); + case VarnodeType::VT_TYPEDEF: + return create_typedef_type( + ctx, dynamic_cast< const TypedefType & >(*vnode_type) + ); + case VarnodeType::VT_UNDEFINED: + return create_undefined_type( + ctx, dynamic_cast< const UndefinedType & >(*vnode_type) + ); + case VarnodeType::VT_VOID: { + return ctx.VoidTy; + } + } + } + + clang::QualType + TypeBuilder::create_typedef_type(clang::ASTContext &ctx, const TypedefType &typedef_type) { + auto &identifier = ctx.Idents.get(typedef_type.name); + auto base_type = typedef_type.get_base_type(); + if (!base_type) { + llvm::errs() << "Base Type of a typedef shouldn't be empty. key: " + << typedef_type.key << "\n"; + assert(false); + return clang::QualType(); + } + + if (base_type->key == typedef_type.key) { + llvm::errs() << "Base Type of typedef is pointing to itself. key: " + << typedef_type.key << "\n"; + assert(false); + return clang::QualType(); + } + + auto underlying_type = create_type(ctx, base_type); + serialized_types.emplace(base_type->key, underlying_type); + auto *tinfo = ctx.getTrivialTypeSourceInfo(underlying_type); + auto *typedef_decl = clang::TypedefDecl::Create( + ctx, ctx.getTranslationUnitDecl(), clang::SourceLocation(), clang::SourceLocation(), + &identifier, tinfo + ); + + typedef_decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(typedef_decl); + + return ctx.getTypedefType(typedef_decl); + } + + clang::QualType + TypeBuilder::create_pointer_type(clang::ASTContext &ctx, const PointerType &pointer_type) { + auto pointee = pointer_type.get_pointee_type(); + if (!pointee) { + llvm::errs() << "No pointee type in pointer with key " << pointer_type.key << "\n"; + assert(false); + return ctx.VoidPtrTy; + } + if (pointee->key == pointer_type.key) { + llvm::errs() << "Pointer type shouldn't have itself as pointee. key: " + << pointer_type.key << "\n"; + assert(false); + return clang::QualType(); + } + + auto pointee_type = create_type(ctx, pointee); + serialized_types.emplace(pointee->key, pointee_type); + return ctx.getPointerType(pointee_type); + } + + clang::QualType + TypeBuilder::create_array_type(clang::ASTContext &ctx, const ArrayType &array_type) { + auto element = array_type.get_element_type(); + if (!element) { + llvm::errs() << "No element types for array\n"; + assert(false); + return clang::QualType(); + } + + // If element key is same as array_type key, it will lead to infinite recursive. If it + // happens something is wrong and need to check ghidra scripts + if (element->key != array_type.key) { + auto element_type = create_type(ctx, element); + serialized_types.emplace(element->key, element_type); + auto size = array_type.get_element_count(); + auto num_bits = 32u; + return ctx.getConstantArrayType( + element_type, llvm::APInt(num_bits, size), nullptr, + clang::ArraySizeModifier::Normal, 0 + ); + } + assert(false); + return clang::QualType(); + } + + void TypeBuilder::create_record_definition( + clang::ASTContext &ctx, const CompositeType &varnode, clang::Decl *prev_decl, + const ASTTypeMap &clang_types + ) { + auto &identifier = ctx.Idents.get(varnode.name); + auto *record_decl = clang::RecordDecl::Create( + ctx, clang::TagDecl::TagKind::Struct, ctx.getTranslationUnitDecl(), + source_location_from_key(ctx, varnode.key), + source_location_from_key(ctx, varnode.key), &identifier, + llvm::dyn_cast< clang::RecordDecl >(prev_decl) + ); + + record_decl->completeDefinition(); + auto components = varnode.get_components(); + for (auto &comp : components) { + auto type_key = comp.type->key; + auto iter = clang_types.find(type_key); + if (iter == clang_types.end()) { + assert(false); + continue; + } + + auto field_type = iter->second; + auto *field_decl = clang::FieldDecl::Create( + ctx, record_decl, clang::SourceLocation(), clang::SourceLocation(), + &ctx.Idents.get(comp.name), field_type, nullptr, nullptr, false, + clang::ICIS_NoInit + ); + record_decl->addDecl(field_decl); + } + + record_decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(record_decl); + } + + clang::QualType TypeBuilder::create_composite_type( + clang::ASTContext &ctx, const VarnodeType &composite_type + ) { + auto tag_kind = [&]() -> clang::TagDecl::TagKind { + switch (composite_type.kind) { + case VarnodeType::Kind::VT_STRUCT: + return clang::TagDecl::TagKind::Struct; + case VarnodeType::Kind::VT_UNION: + return clang::TagDecl::TagKind::Union; + default: + assert(false); + return clang::TagDecl::TagKind::Struct; + } + }(); + auto *decl = clang::RecordDecl::Create( + ctx, tag_kind, ctx.getTranslationUnitDecl(), + source_location_from_key(ctx, composite_type.key), + source_location_from_key(ctx, composite_type.key), + &ctx.Idents.get(composite_type.name) + ); + + decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(decl); + missing_type_definition.emplace(composite_type.key, decl); + return ctx.getRecordType(decl); + } + + clang::QualType + TypeBuilder::create_enum_type(clang::ASTContext &ctx, const EnumType &enum_type) { + auto &identifier = ctx.Idents.get(enum_type.name); + auto *enum_decl = clang::EnumDecl::Create( + ctx, ctx.getTranslationUnitDecl(), source_location_from_key(ctx, enum_type.key), + source_location_from_key(ctx, enum_type.key), &identifier, nullptr, true, false, + false + ); + + enum_decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(enum_decl); + return ctx.getEnumType(enum_decl); + } + + clang::QualType TypeBuilder::create_undefined_type( + clang::ASTContext &ctx, const UndefinedType &undefined_type + ) { + if (undefined_type.kind != VarnodeType::Kind::VT_UNDEFINED) { + assert(false); + return clang::QualType(); + } + + auto base_type = get_type_for_size( + ctx, undefined_type.size * 8, /*is_signed=*/false, /*is_integer=*/true + ); + + if (base_type.isNull()) { + base_type = ctx.IntTy; + } + + auto *typedef_decl = clang::TypedefDecl::Create( + ctx, ctx.getTranslationUnitDecl(), clang::SourceLocation(), clang::SourceLocation(), + &ctx.Idents.get(undefined_type.name), ctx.getTrivialTypeSourceInfo(base_type) + ); + typedef_decl->setDeclContext(ctx.getTranslationUnitDecl()); + ctx.getTranslationUnitDecl()->addDecl(typedef_decl); + return ctx.getTypedefType(typedef_decl); + } + +} // namespace patchestry::ast diff --git a/lib/patchestry/AST/Utils.cpp b/lib/patchestry/AST/Utils.cpp new file mode 100644 index 0000000..e9b7ae3 --- /dev/null +++ b/lib/patchestry/AST/Utils.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +namespace patchestry::ast { + + clang::SourceLocation source_location_from_key(clang::ASTContext &ctx, std::string key) { + (void) key; + (void)ctx; + return clang::SourceLocation(); + } + + clang::QualType get_type_for_size( + clang::ASTContext &ctx, unsigned bit_size, bool is_signed, bool is_integer + ) { + if (is_integer) { + return ctx.getIntTypeForBitwidth(bit_size, static_cast< unsigned int >(is_signed)); + } + + switch (bit_size) { + case 32: + return ctx.FloatTy; + case 64: + return ctx.DoubleTy; + case 80: + return ctx.LongDoubleTy; + default: + assert(false); + return clang::QualType(); + } + } + + std::string label_name_from_key(std::string key) { + std::replace(key.begin(), key.end(), ':', '_'); + return key; + } + +} // namespace patchestry::ast diff --git a/lib/patchestry/CMakeLists.txt b/lib/patchestry/CMakeLists.txt index 3a62ba1..d5a7e3a 100644 --- a/lib/patchestry/CMakeLists.txt +++ b/lib/patchestry/CMakeLists.txt @@ -1,6 +1,8 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. +add_subdirectory(AST) add_subdirectory(Dialect) add_subdirectory(Ghidra) diff --git a/lib/patchestry/Dialect/CMakeLists.txt b/lib/patchestry/Dialect/CMakeLists.txt index 52868d9..f89e135 100644 --- a/lib/patchestry/Dialect/CMakeLists.txt +++ b/lib/patchestry/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_subdirectory(Pcode) diff --git a/lib/patchestry/Dialect/Pcode/CMakeLists.txt b/lib/patchestry/Dialect/Pcode/CMakeLists.txt index 14f6768..470e39d 100644 --- a/lib/patchestry/Dialect/Pcode/CMakeLists.txt +++ b/lib/patchestry/Dialect/Pcode/CMakeLists.txt @@ -1,6 +1,7 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_mlir_dialect_library(MLIRPcode PcodeDialect.cpp diff --git a/lib/patchestry/Dialect/Pcode/Deserialize.cpp b/lib/patchestry/Dialect/Pcode/Deserialize.cpp index 84535fb..4a497bc 100644 --- a/lib/patchestry/Dialect/Pcode/Deserialize.cpp +++ b/lib/patchestry/Dialect/Pcode/Deserialize.cpp @@ -1,79 +1,286 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. */ +#include "patchestry/Util/Common.hpp" +#include +#include + #include +#include +#include #include +#include namespace patchestry::pc { + mlir_value + create_bitcast_op(mlir_builder &bld, mlir_value &input_val, mlir_type &output_type) { + return bld.create< mlir::arith::BitcastOp >( + bld.getUnknownLoc(), output_type, input_val + ); + } + + mlir_value + create_truc_op(mlir_builder &bld, mlir_value &input_val, mlir_type &output_type) { + return bld.create< mlir::arith::TruncIOp >(bld.getUnknownLoc(), output_type, input_val); + } + + std::optional< program > json_parser::parse_program(const llvm::json::Object &root) { + program program; + program.arch = root.getString("arch").value_or(""); + program.os = root.getString("os").value_or(""); + + if (const auto *function_array = root.getArray("functions")) { + for (const auto &function : *function_array) { + if (const auto *func_obj = function.getAsObject()) { + if (auto parsed_func = parse_function(*func_obj)) { + program.functions.push_back(*parsed_func); + } + } + } + } + + return program; + } + + std::optional< pcode > json_parser::parse_pcode(const llvm::json::Object &pcode_obj) { + pcode pcode; + pcode.mnemonic = pcode_obj.getString("mnemonic").value_or(""); + + if (const auto *output_obj = pcode_obj.getObject("output")) { + pcode.output.type = output_obj->getString("type").value_or(""); + pcode.output.offset = output_obj->getInteger("offset"); + pcode.output.size = output_obj->getInteger("size"); + } + + if (const auto *inputs_array = pcode_obj.getArray("inputs")) { + for (const auto &input : *inputs_array) { + if (const auto *input_obj = input.getAsObject()) { + pcode::input input; + input.type = input_obj->getString("type").value_or(""); + input.offset = input_obj->getInteger("offset"); + input.size = input_obj->getInteger("size"); + pcode.inputs.push_back(input); + } + } + } + + return pcode; + } + + std::optional< instruction > + json_parser::parse_instruction(const llvm::json::Object &inst_obj) { + instruction inst; + inst.mnemonic = inst_obj.getString("mnemonic").value_or(""); + inst.address = inst_obj.getString("address").value_or(""); + + if (const auto *pcode_array = inst_obj.getArray("pcode")) { + for (const auto &pcode : *pcode_array) { + if (const auto *pcode_obj = pcode.getAsObject()) { + if (auto parsed_pcode = parse_pcode(*pcode_obj)) { + inst.pcodes.push_back(*parsed_pcode); + } + } + } + } + + return inst; + } + + std::optional< basic_block > + json_parser::parse_basic_block(const llvm::json::Object &block_obj) { + basic_block block; + block.label = block_obj.getString("label").value_or(""); + + if (const auto *instructions_array = block_obj.getArray("instructions")) { + for (const auto &instruction : *instructions_array) { + if (const auto *inst_obj = instruction.getAsObject()) { + if (const auto parsed_inst = parse_instruction(*inst_obj)) { + block.instructions.push_back(*parsed_inst); + } + } + } + } + + return block; + } + + std::optional< function > json_parser::parse_function(const llvm::json::Object &func_obj) { + function func; + func.name = func_obj.getString("name").value_or(""); + + if (const auto *blocks_array = func_obj.getArray("basic_blocks")) { + for (const auto &block : *blocks_array) { + if (const auto *block_obj = block.getAsObject()) { + if (auto parsed_block = parse_basic_block(*block_obj)) { + func.basic_blocks.push_back(*parsed_block); + } + } + } + } + + return func; + } + mlir::OwningOpRef< mlir::ModuleOp > deserialize(const json_obj &json, mcontext_t *mctx) { // FIXME: use implicit module creation auto loc = mlir::UnknownLoc::get(mctx); auto mod = mlir::OwningOpRef< mlir::ModuleOp >(mlir::ModuleOp::create(loc)); deserializer des(mod.get()); - des.process(json); + auto program = json_parser().parse_program(json); + if (program.has_value()) { + des.process(program.value()); + } else { + mlir::emitError(loc, "Failed to parse JSON object."); + } return mod; } - void deserializer::process(const json_obj &json) { - // FIXME: implement multi-function support - process_function(json); + mlir_operation deserializer::create_int_const(uint32_t offset, uint32_t size) { + auto const_type = mlir::IntegerType::get(bld.getContext(), size * 8); + auto const_attr = mlir::IntegerAttr::get(const_type, offset); + return bld.create< ConstOp >(bld.getUnknownLoc(), const_attr); + } + + mlir_operation + deserializer::create_varnode(std::string type, uint32_t offset, uint32_t size) { + auto varnode_type = varnode_from_string(type); + switch (varnode_type) { + case PCodeVarnodeType::unique_: { + auto mlir_type = bld.getType< VarType >(); + return bld.create< VarOp >(bld.getUnknownLoc(), mlir_type, type, offset, size); + } + case PCodeVarnodeType::const_: { + return bld.create< ConstOp >( + bld.getUnknownLoc(), + mlir::IntegerAttr::get( + mlir::IntegerType::get(bld.getContext(), size * 8), offset + ) + ); + } + case PCodeVarnodeType::register_: { + auto mlir_type = bld.getType< RegType >(); + auto int_type = bld.getI32Type(); + return bld.create< RegOp >(bld.getUnknownLoc(), int_type, type, offset, size); + } + case PCodeVarnodeType::ram_: { + auto mlir_type = bld.getType< MemType >(); + return bld.create< RegOp >(bld.getUnknownLoc(), mlir_type, type, offset, size); + } + default: + break; + } + return {}; + } + + void deserializer::process(const program &prog) { + if (prog.functions.empty()) { + mlir::emitError(bld.getUnknownLoc(), "No function to process!"); + return; + } + + for (const auto &func : prog.functions) { + process_function(func); + } } - void deserializer::process_function(const json_obj &json) { - if (!json.getString("name")) { - mlir::emitError(bld.getUnknownLoc(), "Function JSON missing 'name' field."); + void deserializer::process_function(const function &func) { + if (func.name.empty()) { + mlir::emitError(bld.getUnknownLoc(), "Function name is missing."); return; } - auto _ = insertion_guard(bld); - auto fn = bld.create< pc::FuncOp >( - bld.getUnknownLoc(), - json.getString("name").value() - ); + auto _ = insertion_guard(bld); + auto fn = bld.create< pc::FuncOp >(bld.getUnknownLoc(), func.name); bld.setInsertionPointToStart(bld.createBlock(&fn.getBlocks())); - if (auto blocks = json.getArray("basic_blocks")) { - for (const auto &block : *blocks) { - process_block(*block.getAsObject()); - } + for (const auto &block : func.basic_blocks) { + process_block(block); } } - void deserializer::process_block(const json_obj &json) { - if (!json.getString("label")) { - mlir::emitError(bld.getUnknownLoc(), "Block JSON missing 'label' field."); + void deserializer::process_block(const basic_block &block) { + if (block.label.empty()) { + mlir::emitError(bld.getUnknownLoc(), "Basic block is missing label name."); return; } - auto _ = insertion_guard(bld); - auto block = bld.create< pc::BlockOp >( - bld.getUnknownLoc(), - json.getString("label").value() - ); + auto _ = insertion_guard(bld); + auto mlir_block = bld.create< pc::BlockOp >(bld.getUnknownLoc(), block.label); + + bld.createBlock(&mlir_block.getInstructions()); + if (block.instructions.empty()) { + mlir::emitError(bld.getUnknownLoc(), "Block does not have instruction."); + return; + } + + for (const auto &inst : block.instructions) { + process_instruction(inst); + } + } + + void deserializer::process_instruction(const instruction &inst) { + if (inst.mnemonic.empty()) { + mlir::emitError(bld.getUnknownLoc(), "Instruction mnemonic is missing."); + return; + } - bld.createBlock(&block.getInstructions()); + auto _ = insertion_guard(bld); + auto block = bld.create< pc::InstOp >(bld.getUnknownLoc(), inst.mnemonic); - const auto *insts = json.getArray("instructions"); - if (insts == nullptr) { - mlir::emitError(bld.getUnknownLoc(), "Block JSON missing 'instructions' field."); + bld.createBlock(&block.getSemantics()); + if (inst.pcodes.empty()) { + mlir::emitError(bld.getUnknownLoc(), "Instruction has no pcode"); return; } - for (const auto &inst : *insts) { - process_instruction(*inst.getAsObject()); + for (const auto &pcode : inst.pcodes) { + process_pcode(pcode); } } - void deserializer::process_instruction(const json_obj &json) { + void deserializer::process_pcode(const pcode &code) { + if (code.mnemonic.empty()) { + mlir::emitError(bld.getUnknownLoc(), "Pcode mnemonic is missing."); + return; + } + + switch (from_string(code.mnemonic)) { + case PCodeMnemonic::COPY: { + const auto &output = code.output; + const auto &input0 = code.inputs.front(); + + auto *output_op = + create_varnode(output.type, output.offset.value(), output.size.value()); + auto *input_op = + create_varnode(input0.type, input0.offset.value(), input0.size.value()); + mlir::Type var_type = bld.getI32Type(); + mlir::Value var_result = + bld.create< VarOp >(bld.getUnknownLoc(), var_type, "input", 8, 8) + .getResult(); + bld.create< CopyOp >(bld.getUnknownLoc(), bld.getI32Type(), var_result); + break; + } + case PCodeMnemonic::LOAD: { + break; + } + case PCodeMnemonic::RETURN: { + const auto &input0 = code.inputs.front(); + auto *input_op = + create_varnode(input0.type, input0.offset.value(), input0.size.value()); + bld.create< ReturnOp >(bld.getUnknownLoc(), input_op->getResult(0)); + break; + } + default: + break; + } } } // namespace patchestry::pc diff --git a/lib/patchestry/Dialect/Pcode/PcodeDialect.cpp b/lib/patchestry/Dialect/Pcode/PcodeDialect.cpp index d736812..1b74345 100644 --- a/lib/patchestry/Dialect/Pcode/PcodeDialect.cpp +++ b/lib/patchestry/Dialect/Pcode/PcodeDialect.cpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/lib/patchestry/Dialect/Pcode/PcodeOps.cpp b/lib/patchestry/Dialect/Pcode/PcodeOps.cpp index 03584e2..4426d05 100644 --- a/lib/patchestry/Dialect/Pcode/PcodeOps.cpp +++ b/lib/patchestry/Dialect/Pcode/PcodeOps.cpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/lib/patchestry/Dialect/Pcode/PcodeTypes.cpp b/lib/patchestry/Dialect/Pcode/PcodeTypes.cpp index 5f734e9..3267ea5 100644 --- a/lib/patchestry/Dialect/Pcode/PcodeTypes.cpp +++ b/lib/patchestry/Dialect/Pcode/PcodeTypes.cpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. diff --git a/lib/patchestry/Ghidra/CMakeLists.txt b/lib/patchestry/Ghidra/CMakeLists.txt index 07c4f5a..1793aa9 100644 --- a/lib/patchestry/Ghidra/CMakeLists.txt +++ b/lib/patchestry/Ghidra/CMakeLists.txt @@ -1,9 +1,11 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_library(patchestry_ghidra STATIC PcodeTranslation.cpp + JsonDeserialize.cpp ) add_library(patchestry::ghidra ALIAS patchestry_ghidra) diff --git a/lib/patchestry/Ghidra/JsonDeserialize.cpp b/lib/patchestry/Ghidra/JsonDeserialize.cpp new file mode 100644 index 0000000..5a2d7ad --- /dev/null +++ b/lib/patchestry/Ghidra/JsonDeserialize.cpp @@ -0,0 +1,603 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include "patchestry/Ghidra/Pcode.hpp" +#include "patchestry/Ghidra/PcodeOperations.hpp" +#include +#include +#include + +#include +#include + +#include + +namespace patchestry::ghidra { + + template< typename ObjectType > + constexpr std::optional< std::string > + get_string_if_valid(ObjectType &obj, const char *field) { + if (auto value = (obj.getString)(field)) { + if (!value->empty()) { + return value->str(); + } + } + return std::nullopt; + } + + std::optional< Program > JsonParser::deserialize_program(const JsonObject &root) { + Program program; + program.arch = root.getString("arch").value_or(""); + program.format = root.getString("format").value_or(""); + + // Check if root object has types array; if yes then deserialize types + if (const auto *types_array = root.getObject("types")) { + deserialize_types(*types_array, program.serialized_types); + } + + llvm::outs() << "No of types recovered: " << program.serialized_types.size() << "\n"; + + if (const auto *function_array = root.getObject("functions")) { + deserialize_functions(*function_array, program.serialized_functions); + } + + llvm::outs() << "No of functions recovered: " << program.serialized_functions.size() + << "\n"; + + if (const auto *globals = root.getObject("globals")) { + deserialize_globals(*globals, program.serialized_globals); + } + + llvm::outs() << "No of globals recovered: " << program.serialized_globals.size() + << "\n"; + + return program; + } + + // Create varnode type from the json object + std::shared_ptr< VarnodeType > JsonParser::create_vnode_type(const JsonObject &type_obj) { + auto name = type_obj.getString("name").value_or("").str(); + auto size = static_cast< uint32_t >(type_obj.getInteger("size").value_or(0)); + auto kind = VarnodeType::convertToKind(type_obj.getString("kind").value_or("").str()); + switch (kind) { + case VarnodeType::Kind::VT_INVALID: { + // assert(false); // assert if invalid type is found + return std::make_shared< VarnodeType >(name, kind, size); + } + case VarnodeType::Kind::VT_BOOLEAN: + case VarnodeType::Kind::VT_INTEGER: + case VarnodeType::Kind::VT_FLOAT: + case VarnodeType::Kind::VT_CHAR: + return std::make_shared< BuiltinType >(name, kind, size); + case VarnodeType::Kind::VT_ARRAY: + return std::make_shared< ArrayType >(name, kind, size); + case VarnodeType::Kind::VT_POINTER: + return std::make_shared< PointerType >(name, kind, size); + case VarnodeType::Kind::VT_FUNCTION: + return std::make_shared< FunctionType >(name, kind, size); + case VarnodeType::Kind::VT_STRUCT: + case VarnodeType::Kind::VT_UNION: + return std::make_shared< CompositeType >(name, kind, size); + case VarnodeType::Kind::VT_ENUM: + return std::make_shared< EnumType >(name, kind, size); + case VarnodeType::VT_TYPEDEF: + return std::make_shared< TypedefType >(name, kind, size); + case VarnodeType::VT_UNDEFINED: + return std::make_shared< UndefinedType >(name, kind, size); + case VarnodeType::Kind::VT_VOID: + return std::make_shared< BuiltinType >(name, kind, size); + } + } + + // Deserialize types + void JsonParser::deserialize_types(const JsonObject &type_obj, TypeMap &serialized_types) { + if (type_obj.size() == 0) { + llvm::errs() << "No type objects to deserialize\n"; + return; + } + + std::unordered_map< std::string, const JsonValue & > types_value_map; + + for (const auto &type : type_obj) { + auto key = type.getFirst().str(); + const auto &value = type.getSecond(); + auto vnode_type = create_vnode_type(*value.getAsObject()); + if (!vnode_type) { + llvm::errs() << "Failed to create varnode type\n"; + assert(false); + continue; + } + + // Set type key for map lookup at later point + vnode_type->set_key(key); + serialized_types.emplace(key, std::move(vnode_type)); + types_value_map.emplace(key, value); + } + + llvm::errs() << "Number of entry in serialized types: " << serialized_types.size() + << "\n"; + // Post process varnodes from the map and resolve recursive references of labels + for (const auto &[key, vnode_type] : serialized_types) { + auto iter = types_value_map.find(key); + if (iter == types_value_map.end()) { + assert(false); + continue; + } + const auto &json_value = iter->second; + switch (vnode_type->kind) { + case VarnodeType::Kind::VT_BOOLEAN: + case VarnodeType::Kind::VT_INTEGER: + case VarnodeType::Kind::VT_FLOAT: + case VarnodeType::Kind::VT_CHAR: + case VarnodeType::Kind::VT_VOID: + deserialize_buildin( + *dynamic_cast< BuiltinType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + case VarnodeType::Kind::VT_ARRAY: { + deserialize_array( + *dynamic_cast< ArrayType * >(vnode_type.get()), + json_value.getAsObject(), serialized_types + ); + break; + } + case VarnodeType::Kind::VT_POINTER: { + deserialize_pointer( + *dynamic_cast< PointerType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + } + case VarnodeType::Kind::VT_FUNCTION: { + deserialize_function_type( + *dynamic_cast< FunctionType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + } + case VarnodeType::Kind::VT_STRUCT: + case VarnodeType::Kind::VT_UNION: { + deserialize_composite( + *dynamic_cast< CompositeType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + } + case VarnodeType::Kind::VT_ENUM: { + deserialize_enum( + *dynamic_cast< EnumType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + } + case VarnodeType::Kind::VT_TYPEDEF: { + deserialize_typedef( + *dynamic_cast< TypedefType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + } + case VarnodeType::Kind::VT_UNDEFINED: + deserialize_undefined_type( + *dynamic_cast< UndefinedType * >(vnode_type.get()), + *json_value.getAsObject(), serialized_types + ); + break; + case VarnodeType::Kind::VT_INVALID: + break; + } + } + } + + void + JsonParser::deserialize_buildin(BuiltinType &varnode, const JsonObject &, const TypeMap &) { + assert( + varnode.kind == VarnodeType::Kind::VT_BOOLEAN + || varnode.kind == VarnodeType::Kind::VT_INTEGER + || varnode.kind == VarnodeType::Kind::VT_CHAR + || varnode.kind == VarnodeType::Kind::VT_FLOAT + || varnode.kind == VarnodeType::Kind::VT_VOID + ); + (void) varnode; + } + + void JsonParser::deserialize_array( + ArrayType &varnode, const JsonObject *array_obj, const TypeMap &serialized_types + ) { + auto element_label = array_obj->getString("element_type").value_or("").str(); + if (element_label.empty()) { + llvm::errs() << "Element type of an array is empty. key: " << varnode.key << "\n"; + assert(false); + return; + } + + auto iter = serialized_types.find(element_label); + if (iter == serialized_types.end()) { + llvm::errs() << "Element type key " << element_label + << " not found in serialized types." + << " deserializing array with key " << varnode.key << "\n"; + assert(false); + return; + } + + varnode.set_element_type(iter->second); + auto num_elem = + static_cast< uint32_t >(array_obj->getInteger("num_elements").value_or(0)); + varnode.set_element_count(num_elem); + } + + void JsonParser::deserialize_pointer( + PointerType &varnode, const JsonObject &pointer_obj, const TypeMap &serialized_types + ) { + auto pointee_key = pointer_obj.getString("element_type").value_or("").str(); + if (pointee_key.empty()) { + llvm::errs() << "Pointer type with empty pointee key. pointer key: " << varnode.key + << "\n"; + assert(false); + return; + } + + // Check for the pointee label in serialized types + auto iter = serialized_types.find(pointee_key); + if (iter == serialized_types.end()) { + llvm::errs() << "Pointee type is not availe in serialized types. Pointer key: " + << varnode.key << "\n"; + assert(false); + return; + } + + varnode.set_pointee_type(iter->second); + } + + void JsonParser::deserialize_typedef( + TypedefType &varnode, const JsonObject &typedef_obj, const TypeMap &serialized_types + ) { + auto base_key = typedef_obj.getString("base_type").value_or("").str(); + if (base_key.empty()) { + llvm::errs() << "Base type for the tyepdef is not set. key: " << varnode.key + << "\n"; + assert(false); + return; + } + + auto iter = serialized_types.find(base_key); + if (iter == serialized_types.end()) { + llvm::errs() << "Base type key is not found in serialized types " << base_key + << " for typedef key " << varnode.key << "\n"; + assert(false); + return; + } + + varnode.set_base_type(iter->second); + } + + void JsonParser::deserialize_composite( + CompositeType &varnode, const JsonObject &composite_obj, const TypeMap &serialized_types + ) { + const auto *field_array = composite_obj.getArray("fields"); + for (const auto &field : *field_array) { + const auto *field_obj = field.getAsObject(); + auto field_label = field_obj->getString("type").value_or("").str(); + if (field_label.empty()) { + continue; + } + + auto iter = serialized_types.find(field_label); + if (iter == serialized_types.end()) { + llvm::errs() << "Field component is not found on serialized types"; + continue; + } + + auto field_type = iter->second; + auto field_offset = field_obj->getInteger("offset").value_or(-1); + if (field_offset < 0) { + continue; + } + + auto field_name = field_obj->getString("name").value_or("").str(); + varnode.add_components( + field_name, *field_type, static_cast< uint32_t >(field_offset) + ); + } + } + + void JsonParser::deserialize_enum( + EnumType &varnode, const JsonObject &enum_obj, const TypeMap &serialized_types + ) { + assert(varnode.kind == VarnodeType::Kind::VT_ENUM); + (void) serialized_types; + (void) varnode; + (void) enum_obj; + } + + void JsonParser::deserialize_function_type( + FunctionType &varnode, const JsonObject &func_obj, const TypeMap &serialized_types + ) { + assert(varnode.kind == VarnodeType::Kind::VT_FUNCTION); + (void) serialized_types; + (void) varnode; + (void) func_obj; + } + + void JsonParser::deserialize_undefined_type( + UndefinedType &varnode, const JsonObject &undef_obj, const TypeMap &serialized_types + ) { + assert(varnode.kind == VarnodeType::Kind::VT_UNDEFINED); + (void) serialized_types; + (void) varnode; + (void) undef_obj; + } + + // Deserialize operations + std::optional< Varnode > JsonParser::create_varnode(const JsonObject &var_obj) { + auto type_key = var_obj.getString("type").value_or("").str(); + auto size = var_obj.getInteger("size").value_or(0); + auto kind = Varnode::convertToKind(var_obj.getString("kind").value_or("").str()); + + Varnode vnode(kind, static_cast< uint32_t >(size), type_key); + auto operation_key = var_obj.getString("operation").value_or("").str(); + if (!operation_key.empty()) { + vnode.operation = operation_key; + } + + auto function_key = var_obj.getString("function"); + if (function_key && !function_key->empty()) { + vnode.function = function_key->str(); + } + + auto value = var_obj.getInteger("value"); + if (value) { + vnode.value = static_cast< uint32_t >(*value); + } + + auto global_key = var_obj.getString("global"); + if (global_key && !global_key->empty()) { + vnode.global = global_key->str(); + } + + return vnode; + } + + std::optional< Function > JsonParser::create_function(const JsonObject &func_obj) { + Function func; + func.name = func_obj.getString("name").value_or(""); + if (const auto *proto_obj = func_obj.getObject("type")) { + if (auto maybe_prototype = create_function_prototype(*proto_obj)) { + func.prototype = *maybe_prototype; + } + } + + auto entry_block = func_obj.getString("entry_block"); + if (entry_block && !entry_block->empty()) { + func.entry_block = entry_block->str(); + } + + if (const auto *blocks_array = func_obj.getObject("basic_blocks")) { + deserialize_blocks(*blocks_array, func.basic_blocks, func.entry_block); + } + + return func; + } + + void JsonParser::deserialize_call_operation(const JsonObject &call_obj, Operation &op) { + if (const auto *maybe_target = call_obj.getObject("target")) { + OperationTarget target; + target.kind = + Varnode::convertToKind(maybe_target->getString("kind").value_or("").str()); + + auto function = maybe_target->getString("function"); + if (function.has_value() && !function->empty()) { + target.function = function->str(); + } + auto call_op = maybe_target->getString("operation"); + if (call_op.has_value() && !call_op->empty()) { + target.operation = call_op->str(); + } + + target.is_noreturn = maybe_target->getBoolean("is_noreturn").value_or(false); + op.target = target; + } + } + + void JsonParser::deserialize_branch_operation(const JsonObject &branch_obj, Operation &op) { + auto target_block = branch_obj.getString("target_block"); + if (target_block && !target_block->empty()) { + op.target_block = target_block->str(); + } + + auto taken_block = branch_obj.getString("taken_block"); + if (taken_block && !taken_block->empty()) { + op.taken_block = taken_block->str(); + } + + auto not_taken_block = branch_obj.getString("not_taken_block"); + if (not_taken_block && !not_taken_block->empty()) { + op.not_taken_block = not_taken_block->str(); + } + + if (const auto *maybe_output = branch_obj.getObject("condition")) { + if (auto maybe_varnode = create_varnode(*maybe_output)) { + op.condition = *maybe_varnode; + } + } + } + + std::optional< Operation > JsonParser::create_operation(const JsonObject &pcode_obj) { + auto mnemonic = + patchestry::ghidra::from_string(pcode_obj.getString("mnemonic").value_or("").str()); + if (mnemonic == Mnemonic::OP_UNKNOWN) { + llvm::errs() << "Pcode with unknown operation\n"; + assert(false); + return std::nullopt; + } + + Operation operation; + operation.mnemonic = mnemonic; + if (const auto *maybe_output = pcode_obj.getObject("output")) { + if (auto maybe_varnode = create_varnode(*maybe_output)) { + operation.output = *maybe_varnode; + } + } + + if (const auto *input_array = pcode_obj.getArray("inputs")) { + for (auto input : *input_array) { + if (auto maybe_varnode = create_varnode(*input.getAsObject())) { + operation.inputs.emplace_back(*maybe_varnode); + } + } + } + + switch (operation.mnemonic) { + case Mnemonic::OP_CALL: + case Mnemonic::OP_CALLIND: + deserialize_call_operation(pcode_obj, operation); + break; + case Mnemonic::OP_CBRANCH: + case Mnemonic::OP_BRANCH: + case Mnemonic::OP_BRANCHIND: + deserialize_branch_operation(pcode_obj, operation); + break; + default: + break; + } + + operation.name = get_string_if_valid(pcode_obj, "name"); + operation.type = get_string_if_valid(pcode_obj, "type"); + operation.address = get_string_if_valid(pcode_obj, "address"); + + auto index = pcode_obj.getInteger("index"); + if (index) { + operation.index = static_cast< uint32_t >(*index); + } + + return operation; + } + + std::optional< BasicBlock > + JsonParser::create_basic_block(const std::string &block_key, const JsonObject &block_obj) { + if (const auto *operations = block_obj.getObject("operations")) { + BasicBlock block; + + for (const auto &operation : *operations) { + auto operation_key = operation.getFirst().str(); + const auto *operation_object = operation.getSecond().getAsObject(); + if (auto maybe_operation = create_operation(*operation_object)) { + maybe_operation->key = operation_key; + maybe_operation->parent_block_key = block_key; + block.operations.emplace(operation_key, *maybe_operation); + } + } + + if (const auto *ordered_operations = block_obj.getArray("ordered_operations")) { + for (const auto &operation : *ordered_operations) { + auto operation_label = operation.getAsString(); + if (operation_label && !operation_label->empty()) { + block.ordered_operations.push_back(operation_label->str()); + } + } + } + + return block; + } + return std::nullopt; + } + + std::optional< FunctionPrototype > + JsonParser::create_function_prototype(const JsonObject &proto_obj) { + FunctionPrototype proto; + const auto return_type = proto_obj.getString("return_type").value_or("").str(); + if (return_type.empty()) { + llvm::errs() << "FunctionProtoType return type is empty\n"; + assert(false); + return std::nullopt; + } + + proto.rttype_key = return_type; + + proto.is_variadic = proto_obj.getBoolean("is_variadic").value_or(false); + proto.is_noreturn = proto_obj.getBoolean("is_noreturn").value_or(false); + + if (const auto *parameters = proto_obj.getArray("parameter_types")) { + for (const auto ¶meter : *parameters) { + auto parameter_key = parameter.getAsString(); + if (parameter_key && !parameter_key->empty()) { + proto.parameters.push_back(parameter_key->str()); + } + } + } + + return proto; + } + + void JsonParser::deserialize_functions( + const JsonObject &function_array, FunctionMap &serialized_functions + ) { + if (function_array.empty()) { + llvm::errs() << "No functions to deserialize"; + return; + } + + for (const auto &func_obj : function_array) { + auto function_key = func_obj.getFirst().str(); + auto function = create_function(*func_obj.getSecond().getAsObject()); + if (!function) { + llvm::errs() << "Failed to get function for the key " << function_key << "\n"; + continue; + } + function->key = function_key; + serialized_functions.emplace(function_key, *function); + } + } + + void JsonParser::deserialize_blocks( + const JsonObject &blocks_array, BasicBlockMap &serialized_blocks, + std::string &entry_block + ) { + if (blocks_array.empty()) { + llvm::errs() << "No blocks in function\n"; + return; + } + + for (const auto &block : blocks_array) { + auto block_key = block.getFirst().str(); + const auto *block_obj = block.getSecond().getAsObject(); + if (auto maybe_block = create_basic_block(block_key, *block_obj)) { + if (block_key == entry_block) { + maybe_block->is_entry_block = true; + } + maybe_block->key = block_key; + serialized_blocks.emplace(block_key, *maybe_block); + } + } + } + + void JsonParser::deserialize_globals( + const JsonObject &global_array, VariableMap &serialized_globals + ) { + if (global_array.empty()) { + llvm::errs() << "No global variable to serialize\n"; + return; + } + for (const auto &global : global_array) { + Variable variable; + variable.key = global.getFirst().str(); + const auto *global_obj = global.getSecond().getAsObject(); + if (auto maybe_name = global_obj->getString("name")) { + variable.name = *maybe_name; + } + if (auto maybe_type = global_obj->getString("type")) { + variable.type = *maybe_type; + } + if (auto maybe_size = global_obj->getInteger("size")) { + variable.size = static_cast< uint32_t >(*maybe_size); + } + serialized_globals.emplace(variable.key, variable); + } + } + +} // namespace patchestry::ghidra diff --git a/lib/patchestry/Ghidra/PcodeTranslation.cpp b/lib/patchestry/Ghidra/PcodeTranslation.cpp index a7f644e..ab42848 100644 --- a/lib/patchestry/Ghidra/PcodeTranslation.cpp +++ b/lib/patchestry/Ghidra/PcodeTranslation.cpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree. @@ -20,12 +19,15 @@ #include #include +#include + namespace patchestry::ghidra { static mlir::OwningOpRef< mlir_operation > deserialize( const llvm::MemoryBuffer *buffer, mcontext_t *mctx ) { mctx->loadAllAvailableDialects(); + std::cout << buffer->getBuffer().str() << std::endl; auto json = llvm::json::parse(buffer->getBuffer()); if (!json) { diff --git a/scripts/ghidra/PatchestryDecompileFunctions.java b/scripts/ghidra/PatchestryDecompileFunctions.java index 91d629f..79724a8 100644 --- a/scripts/ghidra/PatchestryDecompileFunctions.java +++ b/scripts/ghidra/PatchestryDecompileFunctions.java @@ -7,43 +7,87 @@ import ghidra.app.script.GhidraScript; +import ghidra.app.cmd.function.CallDepthChangeInfo; + +import ghidra.app.decompiler.component.DecompilerUtils; + import ghidra.app.decompiler.DecompInterface; import ghidra.app.decompiler.DecompileOptions; import ghidra.app.decompiler.DecompileResults; +import ghidra.app.plugin.processors.sleigh.SleighLanguage; + +import ghidra.program.database.symbol.CodeSymbol; + import ghidra.program.model.address.Address; +import ghidra.program.model.address.AddressFactory; +import ghidra.program.model.address.AddressIterator; +import ghidra.program.model.address.AddressSet; +import ghidra.program.model.address.AddressSetView; +import ghidra.program.model.address.AddressSpace; import ghidra.program.model.block.BasicBlockModel; import ghidra.program.model.block.CodeBlock; import ghidra.program.model.block.CodeBlockIterator; +import ghidra.program.model.data.AbstractStringDataType; +import ghidra.program.model.data.BitFieldDataType; +import ghidra.program.model.data.DataType; +import ghidra.program.model.data.DataTypeManager; +import ghidra.program.model.data.StringDataType; + +import ghidra.program.model.lang.CompilerSpec; +import ghidra.program.model.lang.Language; +import ghidra.program.model.lang.Register; +import ghidra.program.model.lang.RegisterManager; + import ghidra.program.model.listing.Function; import ghidra.program.model.listing.FunctionIterator; import ghidra.program.model.listing.FunctionManager; - +import ghidra.program.model.listing.FunctionSignature; import ghidra.program.model.listing.Instruction; import ghidra.program.model.listing.InstructionIterator; - +import ghidra.program.model.listing.Listing; +import ghidra.program.model.listing.Parameter; import ghidra.program.model.listing.Program; +import ghidra.program.model.listing.StackFrame; +import ghidra.program.model.listing.Variable; +import ghidra.program.model.listing.VariableStorage; + +import ghidra.program.model.mem.MemBuffer; +import ghidra.program.model.mem.Memory; import ghidra.program.model.pcode.FunctionPrototype; +import ghidra.program.model.pcode.GlobalSymbolMap; import ghidra.program.model.pcode.HighFunction; +import ghidra.program.model.pcode.HighCodeSymbol; +import ghidra.program.model.pcode.HighConstant; +import ghidra.program.model.pcode.HighGlobal; +import ghidra.program.model.pcode.HighLocal; +import ghidra.program.model.pcode.HighOther; import ghidra.program.model.pcode.HighParam; +import ghidra.program.model.pcode.HighSymbol; import ghidra.program.model.pcode.HighVariable; +import ghidra.program.model.pcode.LocalSymbolMap; +import ghidra.program.model.pcode.PartialUnion; import ghidra.program.model.pcode.PcodeBlock; import ghidra.program.model.pcode.PcodeBlockBasic; import ghidra.program.model.pcode.PcodeOp; import ghidra.program.model.pcode.SequenceNumber; +import ghidra.program.model.pcode.SymbolEntry; import ghidra.program.model.pcode.Varnode; import ghidra.program.model.data.AbstractFloatDataType; import ghidra.program.model.data.AbstractIntegerDataType; import ghidra.program.model.data.Array; +import ghidra.program.model.data.ArrayStringable; import ghidra.program.model.data.BooleanDataType; +import ghidra.program.model.data.BuiltIn; import ghidra.program.model.data.Composite; import ghidra.program.model.data.CategoryPath; import ghidra.program.model.data.DataType; import ghidra.program.model.data.DataTypeComponent; +import ghidra.program.model.data.DefaultDataType; import ghidra.program.model.data.Enum; import ghidra.program.model.data.FunctionDefinition; import ghidra.program.model.data.ParameterDefinition; @@ -53,8 +97,14 @@ import ghidra.program.model.data.Undefined; import ghidra.program.model.data.Union; import ghidra.program.model.data.VoidDataType; +import ghidra.program.model.data.WideCharDataType; import ghidra.program.model.symbol.ExternalManager; +import ghidra.program.model.symbol.Namespace; +import ghidra.program.model.symbol.Reference; +import ghidra.program.model.symbol.ReferenceManager; +import ghidra.program.model.symbol.Symbol; +import ghidra.program.model.symbol.SymbolType; import ghidra.util.UniversalID; @@ -75,655 +125,2428 @@ import java.util.List; import java.util.Collections; import java.util.Iterator; +import java.util.Map; import java.util.Set; import java.util.TreeSet; +import java.util.TreeMap; public class PatchestryDecompileFunctions extends GhidraScript { - static final int decompilation_timeout = 30; + protected static final int DECOMPILATION_TIMEOUT = 30; + + protected static final int MIN_CALLOTHER = 0x100000; + protected static final int DECLARE_PARAM_VAR = MIN_CALLOTHER + 0; + protected static final int DECLARE_LOCAL_VAR = MIN_CALLOTHER + 1; + protected static final int DECLARE_TEMP_VAR = MIN_CALLOTHER + 2; + protected static final int ADDRESS_OF = MIN_CALLOTHER + 3; + + // A custom `Varnode` used to represent the output of a `CALLOTHER` that + // we have invented. + protected class DefinitionVarnode extends Varnode { + private PcodeOp def; + private HighVariable high; + + public DefinitionVarnode(Address address, int size) { + super(address, size); + } + + public void setDef(HighVariable high, PcodeOp def) { + this.def = def; + this.high = high; + } + + @Override + public PcodeOp getDef() { + return def; + } + + @Override + public HighVariable getHigh() { + return high; + } + + @Override + public boolean isInput() { + return false; + } + }; + + // A custome `Varnode` used to represent a rewritten input of a `PTRSUB` + // or other operation referencing a local variable that was not exactly + // correctly understood in the high p-code. + protected class UseVarnode extends Varnode { + private HighVariable high; + + public UseVarnode(Address address, int size) { + super(address, size); + } + + public void setHigh(HighVariable high) { + this.high = high; + } + + @Override + public HighVariable getHigh() { + return high; + } + + @Override + public boolean isInput() { + return true; + } + }; + + // A manually-created temporary variable with a single use. + protected class HighTemporary extends HighOther { + public HighTemporary(DataType type, Varnode vn, Varnode[] inst, Address pc, HighFunction func) { + super(type, vn, inst, pc, func); + } + } + + private class PcodeSerializer extends JsonWriter { + private Program program; + private String arch; + private AddressSpace extern_space; + private AddressSpace ram_space; + private AddressSpace stack_space; + private AddressSpace constant_space; + private AddressSpace unique_space; + private FunctionManager fm; + private ExternalManager em; + private DecompInterface ifc; + private BasicBlockModel bbm; + + // Tracks which functions to recover. The size of `functions` is + // monotonically non-decreasing, with newly discovered functions + // added to the end. The first `original_functions_size` functions in + // `functions` are meant to have their definitions (i.e. high p-code) + // serialized to JSON. + private List functions; + private int original_functions_size; + private Set
seen_functions; + + // The seen globals. + private Map seen_globals; + private Map address_of_global; + + // The seen types. The size of `types_to_serialize` is monotonically + // non-decreasing, so that as we add new things to `seen_types`, we add + // to the end of `types_to_serialize`. This lets us properly handle + // tracking what recursive types need to be serialized. + private Set seen_types; + private List types_to_serialize; + + // Current function being serialized, and current block within that + // function being serialized. + private HighFunction current_function; + private PcodeBlockBasic current_block; + + // We invent an entry block for each `HighFunction` to be serialized. + // The operations within this entry block are custom `CALLOTHER`s, that + // "declare" variables of various forms. The way to think about this is + // with a visual analogy: when looking at a decompilation in Ghidra, the + // first thing we see in the body of a function are the local variable + // declarations. In our JSON output, we try to mimic this, and then + // canonicalize accesses of things to target those variables, doing a + // kind of de-SSAing. + private List entry_block; + + // When creating the `CALLOTHER`s for the `entry_block`, we need to + // synthesize addresses in the unique address space, and so we need to + // keep track of what unique addresses we've already used/generated. + private long next_unique; + private int next_seqnum; + + private SleighLanguage language; + + // Stack pointer for this program's architecture. High p-code can have + // two forms of stack references: `Varnode`s of whose `Address` is part + // of the stack address space, and `Varnode`s representing registers, + // where some of those are the stack pointer. In this latter case, we + // need to be able to identify those and convert them into the former + // case. + private Register stack_pointer; + + // Maps names of missing locals to invented `HighLocal`s used to + // represent them. `Function`s often have many `Variable`s, not all of + // which become `HighLocal`s or `HighParam`s. Sometimes when something + // can't be precisely recognized, it is represented as a `HighOther` + // connected to a `HighSymbol`. Confusingly, the `DataType` associated + // with the `HighSymbol` is more representative of what the decompiler + // actually shows, and the `HighOther` more representative of the + // data type in the low `Variable` sourced from the `StackFrame`. + private Map missing_locals; + private Map old_locals; + + // Maps `HighVariables` (really, `HighOther`s) that are attached to + // register `Varnode`s to the `PcodeOp` containing those nodes. We + // The same-named temporary/register may be associated with many such + // independent `HighVariable`s, so to distinguish them to downstream + // readers of the JSON, we want to 'version' the register variables by + // their initial user. + private Map temporary_address; + + // Replacement operations. Sometimes we have something that we actually + // need to replace, and so this mapping allows us to do that without + // having to aggressively rewrite things, especially output operands. + private Map replacement_operations; + + // Sometimes we need to arrange for some operations to exist prior to + // another one, e.g. if there is a `CALL foo, SP` that decompiles to + // `foo(&local_x)`, then we really want to be able to represent `SP`, + // the stack pointer, as a reference to the address of `local_x`, rather + // than whatever it is. + private Map> prefix_operations; + + // A mapping of `CALLOTHER` locations operating with named intrinsics to + // the `PcodeOp`s representing those `CALLOTHER`s. + private List callother_uses; + + public PcodeSerializer(java.io.BufferedWriter writer, + String arch_, FunctionManager fm_, + ExternalManager em_, DecompInterface ifc_, + BasicBlockModel bbm_, + List functions_) { + super(writer); + + this.program = fm_.getProgram(); + + this.language = (SleighLanguage) program.getLanguage(); + AddressFactory address_factory = program.getAddressFactory(); + + this.arch = arch_; + this.extern_space = address_factory.getAddressSpace("extern"); + this.ram_space = address_factory.getAddressSpace("ram"); + this.stack_space = address_factory.getStackSpace(); + this.constant_space = address_factory.getConstantSpace(); + this.unique_space = address_factory.getUniqueSpace(); + this.fm = fm_; + this.em = em_; + this.ifc = ifc_; + this.bbm = bbm_; + this.functions = functions_; + this.original_functions_size = functions.size(); + this.seen_functions = new TreeSet<>(); + this.seen_types = new HashSet<>(); + this.seen_globals = new HashMap<>(); + this.types_to_serialize = new ArrayList<>(); + this.current_function = null; + this.current_block = null; + this.next_unique = language.getUniqueBase(); + this.next_seqnum = 0; + this.entry_block = new ArrayList<>(); + this.stack_pointer = program.getCompilerSpec().getStackPointer(); + this.missing_locals = new HashMap<>(); + this.old_locals = new HashMap<>(); + this.temporary_address = new HashMap<>(); + this.replacement_operations = new HashMap<>(); + this.prefix_operations = new HashMap<>(); + this.address_of_global = new HashMap<>(); + this.callother_uses = new ArrayList<>(); + } + + private static String label(HighFunction function) throws Exception { + return label(function.getFunction()); + } + + private static String label(Function function) throws Exception { + return label(function.getEntryPoint()); + } + + private static String label(Address address) throws Exception { + return address.toString(true /* show address space prefix */); + } + + private static String label(SequenceNumber sn) throws Exception { + return label(sn.getTarget()) + Address.SEPARATOR + + Integer.toString(sn.getTime()) + Address.SEPARATOR + + Integer.toString(sn.getOrder()); + } + + private static String label(PcodeBlock block) throws Exception { + return label(block.getStart()) + Address.SEPARATOR + + Integer.toString(block.getIndex()) + Address.SEPARATOR + + PcodeBlock.typeToName(block.getType()); + } + + private static String label(PcodeOp op) throws Exception { + return label(op.getSeqnum()); + } + + private String label(DataType type) throws Exception { + // In type is null, assign VoidDataType in all cases. + // We assume it as void type. + if (type == null) { + type = VoidDataType.dataType; + } + + String name = type.getName(); + CategoryPath category = type.getCategoryPath(); + String concat_type = category.toString() + name + Integer.toString(type.getLength()); + String type_id = Integer.toHexString(concat_type.hashCode()); + + UniversalID uid = type.getUniversalID(); + if (uid != null) { + type_id += Address.SEPARATOR + uid.toString(); + } + + if (seen_types.add(type_id)) { + types_to_serialize.add(type); + } + return type_id; + } + + // Figure out the return type of an intrinsic op. + private DataType intrinsicReturnType(PcodeOp op) { + DataType ret_type = null; + Varnode ret_val = op.getOutput(); + if (ret_val == null) { + return VoidDataType.dataType; + } + + HighVariable var = ret_val.getHigh(); + if (var != null) { + return var.getDataType(); + } + + return Undefined.getUndefinedDataType(ret_val.getSize()); + } + + // Return the label of an intrinsic with `CALLOTHER`. This is based + // off of the return value. + private String intrinsicLabel(PcodeOp op) throws Exception { + int index = (int) op.getInput(0).getOffset(); + String name = language.getUserDefinedOpName(index); + return intrinsicLabel(name, intrinsicReturnType(op)); + } + + private String intrinsicLabel( + String name, DataType ret_type) throws Exception { + return name + Address.SEPARATOR + label(ret_type); + } + + private void serializePointerType(Pointer ptr) throws Exception { + name("kind").value("pointer"); + name("size").value(ptr.getLength()); + name("element_type").value(label(ptr.getDataType())); + } + + private void serializeTypedefType(TypeDef typedef) throws Exception { + name("name").value(typedef.getDisplayName()); + name("kind").value("typedef"); + name("size").value(typedef.getLength()); + name("base_type").value(label(typedef.getBaseDataType())); + } + + private void serializeArrayType(Array arr) throws Exception { + name("kind").value("array"); + name("size").value(arr.getLength()); + name("num_elements").value(arr.getNumElements()); + name("element_type").value(label(arr.getDataType())); + } + + private void serializeBuiltinType( + DataType data_type, String kind) throws Exception { + + String display_name = null; + if (data_type instanceof AbstractIntegerDataType) { + AbstractIntegerDataType adt = (AbstractIntegerDataType) data_type; + display_name = adt.getCDeclaration(); + } + + if (display_name == null) { + display_name = data_type.getDisplayName(); + } + + name("name").value(display_name); + name("size").value(data_type.getLength()); + name("kind").value(kind); + } + + private void serializeCompositeType( + Composite data_type, String kind) throws Exception { + name("name").value(data_type.getDisplayName()); + name("kind").value(kind); + name("size").value(data_type.getLength()); + name("fields").beginArray(); + + for (int i = 0; i < data_type.getNumComponents(); i++) { + DataTypeComponent dtc = data_type.getComponent(i); + beginObject(); + name("type").value(label(dtc.getDataType())); + name("offset").value(dtc.getOffset()); + + if (dtc.getFieldName() != null) { + name("name").value(dtc.getFieldName()); + } + endObject(); + } + endArray(); + } + + private void serialize(DataType data_type) throws Exception { + if (data_type == null) { + nullValue(); + return; + } + + if (data_type instanceof Pointer) { + serializePointerType((Pointer) data_type); + + } else if (data_type instanceof TypeDef) { + serializeTypedefType((TypeDef) data_type); + + } else if (data_type instanceof Array) { + serializeArrayType((Array) data_type); + + } else if (data_type instanceof Structure) { + serializeCompositeType((Composite) data_type, "struct"); + + } else if (data_type instanceof Union) { + serializeCompositeType((Composite) data_type, "union"); + + } else if (data_type instanceof AbstractIntegerDataType){ + serializeBuiltinType(data_type, "integer"); + + } else if (data_type instanceof AbstractFloatDataType){ + serializeBuiltinType(data_type, "float"); + + } else if (data_type instanceof BooleanDataType){ + serializeBuiltinType(data_type, "boolean"); + + } else if (data_type instanceof Enum) { + serializeBuiltinType(data_type, "enum"); + + } else if (data_type instanceof VoidDataType) { + serializeBuiltinType(data_type, "void"); + + } else if (data_type instanceof Undefined || data_type instanceof DefaultDataType) { + serializeBuiltinType(data_type, "undefined"); + + } else if (data_type instanceof FunctionDefinition) { + name("kind").value("function"); + serializePrototype((FunctionSignature) data_type); + + } else if (data_type instanceof PartialUnion) { + name("kind").value("todo"); // TODO(pag): Implement this + name("size").value(data_type.getLength()); + + } else if (data_type instanceof BitFieldDataType) { + name("kind").value("todo"); // TODO(pag): Implement this + name("size").value(data_type.getLength()); + + } else if (data_type instanceof WideCharDataType) { + name("kind").value("todo"); // TODO(pag): Implement this + name("size").value(data_type.getLength()); + + } else if (data_type instanceof StringDataType) { + name("kind").value("todo"); // TODO(pag): Implement this + name("size").value(data_type.getLength()); + + } else { + throw new Exception("Unhandled type: " + data_type.getClass().getName()); + } + } + + private void serializeTypes() throws Exception { + for (int i = 0; i < types_to_serialize.size(); i++) { + DataType type = types_to_serialize.get(i); + name(label(type)).beginObject(); + serialize(type); + endObject(); + } + + println("Total serialized types: " + types_to_serialize.size()); + } + + private int serializePrototype() throws Exception { + name("return_type").value(label((DataType) null)); + name("is_variadic").value(false); + name("is_noreturn").value(false); + name("parameter_types").beginArray().endArray(); + return 0; + } + + private int serializePrototype(FunctionPrototype proto) throws Exception { + if (proto == null) { + return serializePrototype(); + } + + name("return_type").value(label(proto.getReturnType())); + name("is_variadic").value(proto.isVarArg()); + name("is_noreturn").value(proto.hasNoReturn()); + + name("parameter_types").beginArray(); + int num_params = proto.getNumParams(); + for (int i = 0; i < num_params; i++) { + value(label(proto.getParam(i).getDataType())); + } + endArray(); // End of `parameter_types`. + return num_params; + } + + private int serializePrototype(FunctionSignature proto) throws Exception { + if (proto == null) { + return serializePrototype(); + } + + name("return_type").value(label(proto.getReturnType())); + name("is_variadic").value(proto.hasVarArgs()); + name("is_noreturn").value(proto.hasNoReturn()); + name("calling_convention").value(proto.getCallingConventionName()); + + ParameterDefinition[] arguments = proto.getArguments(); + name("parameter_types").beginArray(); + int num_params = (int) arguments.length; + for (int i = 0; i < num_params; i++) { + value(label(arguments[i].getDataType())); + } + endArray(); // End of `parameter_types`. + return num_params; + } + + private void serialize(HighVariable high_var) throws Exception { + if (high_var == null) { + nullValue(); + return; + } + + beginObject(); + name("name").value(high_var.getName()); + name("type").value(label(high_var.getDataType())); + endObject(); + } + + // Return the r-value of a varnode. + private Varnode rValueOf(Varnode node) throws Exception { + return node; +// HighVariable high = node.getHigh(); +// if (high == null) { +// return node; +// } +// +// Varnode rep = high.getRepresentative(); +// return rep == null ? node : rep; + } + + private enum VariableClassification { + UNKNOWN, + PARAMETER, + LOCAL, + NAMED_TEMPORARY, + TEMPORARY, + GLOBAL, + FUNCTION, + CONSTANT + }; + + // Returns `true` if a given representative is an original + // representative. + private boolean isOriginalRepresentative(Varnode node) { + if (node.isInput()) { + return true; + } + + // NOTE(pag): Don't use `resolveOp` here because that screws up the + // variable creation logic. + PcodeOp op = node.getDef(); + if (op == null) { + return true; + } + + if (op.getOpcode() != PcodeOp.CALLOTHER) { + return true; + } + + if (op.getInput(0).getOffset() < MIN_CALLOTHER) { + return true; + } + + return false; + } + + // Resolve an operation to a replacement operation, if any. + private PcodeOp resolveOp(PcodeOp op) { + if (op == null) { + return null; + } + + PcodeOp replacement_op = replacement_operations.get(op); + if (replacement_op != null) { + return replacement_op; + } + return op; + } + + // Get the representative of a `HighVariable`, or if we've re-written + // the representative with a `CALLOTHER`, then get the original + // representative. + private Varnode originalRepresentativeOf(HighVariable var) { + if (var == null) { + return null; + } + + Varnode rep = var.getRepresentative(); + if (isOriginalRepresentative(rep)) { + return rep; + } + + Varnode[] instances = var.getInstances(); + if (instances.length <= 1) { + return null; + } + + return instances[1]; + } + + // Return the address of a high global variable. + private Address addressOfGlobal(HighVariable var) throws Exception { + HighSymbol sym = var.getSymbol(); + if (sym != null && sym.isGlobal()) { + SymbolEntry entry = sym.getFirstWholeMap(); + VariableStorage storage = entry.getStorage(); + if (storage != VariableStorage.BAD_STORAGE && + storage != VariableStorage.UNASSIGNED_STORAGE && + storage != VariableStorage.VOID_STORAGE) { + return storage.getMinAddress(); + } + } + + Varnode rep = var.getRepresentative(); + int type = AddressSpace.ID_TYPE_MASK & rep.getSpace(); + if (type == AddressSpace.TYPE_RAM) { + return ram_space.getAddress(rep.getOffset()); + + } else if (type == AddressSpace.TYPE_EXTERNAL) { + return extern_space.getAddress(rep.getOffset()); + } + + Address fixed_address = address_of_global.get(var); + if (fixed_address != null) { + return fixed_address; + } + + println("Could not get address of variable " + var.toString()); + return null; + } + + // Try to distinguish "local" variables from global ones. Roughly, we + // want to make sure that the backing storage for a given variable + // *isn't* RAM. Thus, UNIQUE, STACK, CONST, etc. are all in-scope for + // locals. + private VariableClassification classifyVariable(HighVariable var) throws Exception { + if (var == null) { + return VariableClassification.UNKNOWN; + } + + if (var instanceof HighParam) { + return VariableClassification.PARAMETER; + + } else if (var instanceof HighLocal) { + return VariableClassification.LOCAL; + + } else if (var instanceof HighConstant) { + return VariableClassification.CONSTANT; + + } else if (var instanceof HighGlobal) { + seen_globals.put(addressOfGlobal(var), var); + return VariableClassification.GLOBAL; + + } else if (var instanceof HighTemporary) { + return VariableClassification.TEMPORARY; + } + + HighSymbol symbol = var.getSymbol(); + if (symbol != null) { + if (symbol.isGlobal()) { + seen_globals.put(addressOfGlobal(var), var); + return VariableClassification.GLOBAL; + + } else if (symbol.isParameter() || symbol.isThisPointer()) { + return VariableClassification.PARAMETER; + } + } + + Varnode rep = originalRepresentativeOf(var); + if (rep != null) { + + // TODO(pag): Consider checking if all uses of the unique + // belong to the same block. We don't want to + // introduce a kind of code motion risk into the + // lifted representation. + if (rep.isRegister() || rep.isUnique() || var instanceof HighOther) { + if (rep.getLoneDescend() != null) { + return VariableClassification.TEMPORARY; + } else { + return VariableClassification.NAMED_TEMPORARY; + } + } + } + + return VariableClassification.UNKNOWN; + } + + // Serialize an input or output varnode. + private void serializeInput(PcodeOp op, Varnode node) throws Exception { + assert !node.isFree(); + assert node.isInput(); + + PcodeOp def = resolveOp(node.getDef()); + HighVariable var = variableOf(node.getHigh()); + + beginObject(); + + if (var != null) { + if (def == null) { + def = var.getRepresentative().getDef(); + } + + name("type").value(label(var.getDataType())); + } else { + name("size").value(node.getSize()); + } + + switch (classifyVariable(var)) { + case UNKNOWN: + if (def != null && !node.isInput() && def == op) { + if (node.isUnique()) { + name("kind").value("temporary"); + + // TODO(pag): Figure this out. + } else { + assert false; + name("kind").value("unknown"); + } + + // NOTE(pag): Should be a `TEMPORARY` classification. + } else if (node.isUnique()) { + assert false; + assert def != null; + name("kind").value("temporary"); + name("operation").value(label(def)); + + // NOTE(pag): Should be a `REGISTER` classification. + } else if (node.isConstant()) { + assert false; + name("kind").value("constant"); + name("value").value(node.getOffset()); + + } else { + assert false; + name("kind").value("unknown"); + } + break; + case PARAMETER: + name("kind").value("parameter"); + name("operation").value(label(getOrCreateLocalVariable(var, op))); + break; + case LOCAL: + name("kind").value("local"); + name("operation").value(label(getOrCreateLocalVariable(var, op))); + break; + case NAMED_TEMPORARY: + name("kind").value("temporary"); + name("operation").value(label(getOrCreateLocalVariable(var, op))); + break; + case TEMPORARY: + assert def != null; + name("kind").value("temporary"); + name("operation").value(label(def)); + break; + case GLOBAL: + name("kind").value("global"); + name("global").value(label(addressOfGlobal(var))); + break; + case FUNCTION: + name("kind").value("function"); + name("function").value(label(var.getHighFunction())); + break; + case CONSTANT: + if (node.isConstant()) { + name("kind").value("constant"); + name("value").value(node.getOffset()); + } else { + assert false; + name("kind").value("unknown"); + } + break; + } + + endObject(); + } + + // Returns the index of the first input `Varnode` referncing the stack + // pointer, or `-1` if no direct references are found. + private int referencesStackPointer(PcodeOp op) throws Exception { + int input_index = 0; + for (Varnode node : op.getInputs()) { + if (node.isRegister()) { + Register reg = language.getRegister(node.getAddress(), 0); + if (reg == null) { + continue; + } + + // TODO(pag): This doesn't seem to work? All `typeFlags` for + // all registers seem to be zero, at least for + // x86. + // + // NOTE(pag): NCC group blog post on "earlyremoval" also + // notes this curiosity. + if ((reg.getTypeFlags() & Register.TYPE_SP) != 0) { + return input_index; + } + + if (reg == stack_pointer) { + return input_index; + } + + // TODO(pag): Should we consider references to the frame + // pointer, e.g. using the `CompilerSpec` or + // `reg.isDefaultFramePointer()`? + } + + ++input_index; + } + + return -1; + } + + // Given a `PTRSUB SP, offset` that resolves to the base of a local + // variable, or a `PTRSUB 0, addr` that resolves to the address of a + // global variable, generate and `ADDRESS_OF var`. + private PcodeOp createAddressOf( + Varnode def, SequenceNumber loc, Varnode input_var) { + Varnode inputs[] = new Varnode[2]; + inputs[0] = new Varnode(constant_space.getAddress(ADDRESS_OF), 4); + inputs[1] = input_var; + return new PcodeOp(loc, PcodeOp.CALLOTHER, inputs, def); + } + + // Given an offset `var_offset` from the stack pointer in `op`, return + // two `Varnode`s, the first referencing the relevant `HighVariable` + // that contains the byte at that stack offset, and the second being + // a constant byte displacement from the base of the stack variable. + private Varnode[] createStackPointerVarnodes( + HighFunction high_function, PcodeOp op, + int var_offset) throws Exception { + + Function function = high_function.getFunction(); + StackFrame frame = function.getStackFrame(); + + int frame_size = frame.getFrameSize(); + int adjust_offset = 0; + + // Given the local symbol mapping for the high function, go find + // a `HighSymbol` corresponding to `local_118`. This high symbol + // will generally have a much better `DataType`, but initially + // and confusingly won't have a corresponding `HighVariable`. + LocalSymbolMap symbols = high_function.getLocalSymbolMap(); + Address pc = op.getSeqnum().getTarget(); + + // Given a stack pointer offset, e.g. `-0x118`, go find the low + // `Variable` representing `local_118`. + Variable var = frame.getVariableContaining(var_offset); + Address stack_address = stack_space.getAddress(var_offset); + HighSymbol sym = null; + if (var != null) { + sym = symbols.findLocal(var.getVariableStorage(), pc); + stack_address = stack_space.getAddress(var.getStackOffset()); + + } else { + sym = symbols.findLocal(stack_address, pc); + } + + // Try to recover by locating the parameter containing the stack + // address. + if (sym == null) { + for (Variable param : frame.getParameters()) { + VariableStorage storage = param.getVariableStorage(); + if (!storage.contains(stack_address)) { + continue; + } + + int index = ((Parameter) param).getOrdinal(); + if (index >= symbols.getNumParams()) { + break; + } + + sym = symbols.getParamSymbol(index); + break; + } + } + + // This is usually for one of a few reasons: + // - Trying to lift `_start` + // - Trying to lift a variadic function using `va_list`. + if (sym == null) { + return null; + } + + Varnode var_node = op.getInput(0); + UseVarnode new_var_node = new UseVarnode( + stack_address, sym.getDataType().getLength()); + + // We've already got a high variable for this missing local. + HighVariable new_var = sym.getHighVariable(); + String sym_name = sym.getName(); + if (new_var != null && !new_var.getName().equals("UNNAMED")) { + // println("Using existing high sym " + sym_name + " with var named " + new_var.getName() + " and type " + new_var.getDataType().toString()); + + // We need to invent a new `HighVariable` for this `HighSymbol`. + // Unfortunately we can't use `HighSymbol.setHighVariable` for the + // caching, so we need `missing_locals`. + } else { + HighLocal local_var = old_locals.get(new_var); + if (local_var == null) { + local_var = missing_locals.get(sym_name); + } + + if (local_var == null) { + local_var = new HighLocal( + sym.getDataType(), new_var_node, null, pc, sym); + missing_locals.put(sym_name, local_var); + + // println("Created " + local_var.getName() + " with type " + local_var.getDataType().toString()); + + // Remap old-to-new. + if (new_var != null) { + old_locals.put(new_var, local_var); + } + + } + new_var = local_var; + } + + new_var_node.setHigh(new_var); + + if (var != null) { + adjust_offset = (var_offset - var.getStackOffset()); + } + + Varnode[] nodes = new Varnode[2]; + nodes[0] = new_var_node; + nodes[1] = new Varnode(constant_space.getAddress(adjust_offset), + ram_space.getSize() / 8); + return nodes; + } + + // Update a `PTRSUB 0, addr` or a `PTRSUB SP, offset` to be prefixed + // by an `ADDRESS_OF`, then operate on the `ADDRESS_OF` in the first + // input, and use a modified offset in the second input. + private boolean prefixPtrSubcomponentWithAddressOf( + HighFunction high_function, PcodeOp op, + Varnode[] nodes) throws Exception { + + Address op_loc = op.getSeqnum().getTarget(); + List ops = getOrCreatePrefixOperations(op); + + // Figure out the tye of the pointer to the local variable being + // referenced. + DataTypeManager dtm = program.getDataTypeManager(); + DataType var_type = nodes[0].getHigh().getDataType(); + DataType node_type = dtm.getPointer(var_type); + + // Create a unique address for this `Varnode`. + Address address = nextUniqueAddress(); + SequenceNumber loc = new SequenceNumber(address, next_seqnum++); + + // Make the `Varnode` instances. + DefinitionVarnode def = new DefinitionVarnode( + address, node_type.getLength()); + UseVarnode use = new UseVarnode(address, def.getSize()); + + // Create a prefix `ADDRESS_OF` for the local variable. + PcodeOp address_of = this.createAddressOf(def, loc, nodes[0]); + ops.add(address_of); + + // Track the logical value using a `HighOther`. + Varnode[] instances = new Varnode[2]; + instances[0] = address_of.getOutput(); + instances[1] = use; + HighVariable tracker = new HighTemporary( + node_type, instances[0], instances, op_loc, high_function); + + def.setDef(tracker, address_of); + use.setHigh(tracker); + + println(label(op)); + println(" Rewriting " + op.getSeqnum().toString() + ": " + op.toString()); + + // Rewrite the stack reference to point to the `HighVariable`. + op.setInput(use, 0); + + // Rewrite the offset. + op.setInput(nodes[1], 1); + + println(" to: " + op.toString()); + + return true; + } + + // Given a `PTRSUB SP, offset`, try to invent a local variable at + // `offset` in a similar way to how the decompiler would. + private boolean createLocalForPtrSubcomponent( + HighFunction high_function, PcodeOp op, + CallDepthChangeInfo cdci) throws Exception { + + Varnode offset = op.getInput(1); + if (!offset.isConstant()) { + return false; + } + + Varnode[] nodes = createStackPointerVarnodes( + high_function, op, (int) offset.getOffset()); + if (nodes == null) { + return false; + } + + // We can replace the `PTRSUB SP, offset` with an + // `ADDRESS_OF local`. + if (nodes[1].getOffset() == 0) { + PcodeOp new_op = createAddressOf( + op.getOutput(), op.getSeqnum(), nodes[0]); + replacement_operations.put(op, new_op); + return true; + } + + // We need to get the `ADDRESS_OF local`, then pass that to a + // fixed-up `PTRSUB`. + return prefixPtrSubcomponentWithAddressOf(high_function, op, nodes); + } + + // Return the next referenced address after `start`, or the maximum + // address in `start`'s address space. + private Address getNextReferencedAddressOrMax(Address start) { + Address end = start.getAddressSpace().getMaxAddress(); + AddressSet range = new AddressSet(start, end); + ReferenceManager references = program.getReferenceManager(); + AddressIterator it = references.getReferenceDestinationIterator(range, true); + if (!it.hasNext()) { + return end; + } + + Address referenced_address = it.next(); + if (!start.equals(referenced_address)) { + return referenced_address; + } + + if (it.hasNext()) { + return it.next(); + } + + return end; + } + + // Given a `PTRSUB const, const`, try to recognize it as a global variable + // reference, or a field reference within a global variable. + private boolean createGlobalForPtrSubcomponent( + HighFunction high_function, PcodeOp op) throws Exception { + + HighVariable zero = op.getInput(0).getHigh(); + if (!(zero instanceof HighOther)) { + return false; + } + + if (!zero.getName().equals("UNNAMED")) { + return false; + } + + if (zero.getOffset() != -1) { + return false; + } + + Varnode offset_node = op.getInput(1); + HighVariable offset_var = offset_node.getHigh(); + if (!(offset_var instanceof HighConstant)) { + return false; + } + + HighSymbol high_sym = offset_var.getSymbol(); + if (high_sym == null) { + return false; + } + + // println("Found variable use " + high_sym.getName()); + + SymbolEntry entry = high_sym.getFirstWholeMap(); + VariableStorage storage = entry.getStorage(); + Address address = null; + + if (storage == VariableStorage.BAD_STORAGE || + storage == VariableStorage.UNASSIGNED_STORAGE || + storage == VariableStorage.VOID_STORAGE) { + + address = ram_space.getAddress(offset_node.getOffset()); + } else { + address = storage.getMinAddress(); + } + + DataType type = high_sym.getDataType(); + + // Get the size in bytes. This might require calculating the length + // of a string, for which we use the heuristic that the string + // probably ends at the next referenced address. + // + // TODO(pag): This isn't a great heuristic because it's fairly + // common for compilers to do suffix compression of + // strings, i.e. given `"c"` can be a suffix of `"bc"` + // which can be a suffix of `"abc"`, and so every string + // in this case would show as having a maximum length of + // `1` by this heuristic. + int size_in_bytes = type.getLength(); + if (size_in_bytes == -1 && type instanceof AbstractStringDataType) { + Listing listing = program.getListing(); + MemBuffer memory = listing.getCodeUnitAt(address); + Address next_address = getNextReferencedAddressOrMax(address); + size_in_bytes = ((AbstractStringDataType) type).getLength( + memory, (int) next_address.subtract(address)); + } + + UseVarnode new_var_node = new UseVarnode(address, size_in_bytes); + HighVariable global_var = seen_globals.get(address); + if (global_var == null) { + global_var = high_sym.getHighVariable(); + if (global_var == null) { + global_var = new HighGlobal(high_sym, new_var_node, null); + } + + seen_globals.put(address, global_var); + } + + address_of_global.put(global_var, address); + + // Rewrite the offset. + Address offset_as_address = address.getAddressSpace().getAddress(offset_node.getOffset()); + int sub_offset = (int) offset_as_address.subtract(address); + + new_var_node.setHigh(global_var); + if (sub_offset == 0) { + PcodeOp new_op = createAddressOf( + op.getOutput(), op.getSeqnum(), new_var_node); + replacement_operations.put(op, new_op); + return true; + } + + Varnode[] nodes = new Varnode[2]; + nodes[0] = new_var_node; + nodes[1] = new Varnode(constant_space.getAddress(sub_offset), + offset_node.getSize()); + + // We need to get the `ADDRESS_OF global`, then pass that to a + // fixed-up `PTRSUB`. + return prefixPtrSubcomponentWithAddressOf(high_function, op, nodes); + } + + // Try to rewrite/mutate a `PTRSUB`. + private boolean rewritePtrSubcomponent( + HighFunction high_function, PcodeOp op, + CallDepthChangeInfo cdci) throws Exception { + + // Look for `PTRSUB SP, offset` and convert into `PTRSUB local_N, M`. + if (referencesStackPointer(op) == 0) { + return createLocalForPtrSubcomponent(high_function, op, cdci); + } + + Varnode base_node = op.getInput(0); + Varnode offset_node = op.getInput(1); + if (base_node.isConstant() && base_node.getOffset() == 0 && + offset_node.isConstant()) { + return createGlobalForPtrSubcomponent(high_function, op); + } + + return true; + } + + // Get or create a prefix operations list. These are operations that + // will precede `op` in our serialization, regardless of whether or + // not `op` is elided. + private List getOrCreatePrefixOperations(PcodeOp op) { + List ops = prefix_operations.get(op); + if (ops == null) { + ops = new ArrayList(); + prefix_operations.put(op, ops); + } + return ops; + } + + // Try to fixup direct stack pointer references in `op`. + private boolean tryFixupStackVarnode( + HighFunction high_function, PcodeOp op, + CallDepthChangeInfo cdci) throws Exception { + + int offset = referencesStackPointer(op); + if (offset == -1) { + return true; + } + + // Figure out what stack offset is pointed to by `SP`. + Address op_loc = op.getSeqnum().getTarget(); + int stack_offset = cdci.getDepth(op_loc); + if (stack_offset == Function.UNKNOWN_STACK_DEPTH_CHANGE) { + + Function function = high_function.getFunction(); + StackFrame frame = function.getStackFrame(); + Variable[] stack_vars = frame.getStackVariables(); + if (stack_vars == null || stack_vars.length == 0) { + return false; + } + + if (frame.growsNegative()) { + stack_offset = stack_vars[0].getStackOffset(); + } else { + stack_offset = stack_vars[stack_vars.length - 1].getStackOffset(); + } + } + + Varnode sp_ref = op.getInput(offset); + Varnode[] nodes = createStackPointerVarnodes( + high_function, op, stack_offset); + if (nodes == null) { + return false; + } + + if (nodes[1].getOffset() != 0) { + printf("??? " + Long.toHexString(nodes[1].getOffset())); + return false; + } + + List ops = getOrCreatePrefixOperations(op); + + // Figure out the tye of the pointer to the local variable being + // referenced. + DataTypeManager dtm = program.getDataTypeManager(); + DataType var_type = nodes[0].getHigh().getDataType(); + DataType node_type = dtm.getPointer(var_type); + + // Create a unique address for this `Varnode`. + Address address = nextUniqueAddress(); + SequenceNumber loc = new SequenceNumber(address, next_seqnum++); + + // Make the `Varnode` instances. + DefinitionVarnode def = new DefinitionVarnode( + address, node_type.getLength()); + UseVarnode use = new UseVarnode(address, def.getSize()); + + // Create a prefix `ADDRESS_OF` for the local variable. + PcodeOp address_of = this.createAddressOf(def, loc, nodes[0]); + ops.add(address_of); + + // Track the logical value using a `HighOther`. + Varnode[] instances = new Varnode[2]; + instances[0] = address_of.getOutput(); + instances[1] = use; + HighVariable tracker = new HighTemporary( + node_type, instances[0], instances, op_loc, high_function); + + def.setDef(tracker, address_of); + use.setHigh(tracker); + + // println("Rewriting " + label(op)); + // println(" From: " + op.toString()); + + op.setInput(instances[1], offset); + + // println(" To: " + op.toString()); + + return tryFixupStackVarnode(high_function, op, cdci); + } + + // The data model of high P-CODE is fundamentally value based. Lets + // focus on the following example: + // + // extern int do_with_int(int *); + // + // int main() { + // int x; + // return do_with_int(&x); + // } + // + // Ignoring `INDIRECT`s, we might expect to see the following high + // P-CODE for the above C code: + // + // (unique res) CALL do_with_in (register RSP) + // --- RETURN 0 (unique res) + // + // Or: + // + // (unique addr_of_x) PTRSUB (register RSP) (const NNN) + // (unique res) CALL do_with_in (unique addr_of_x) + // --- RETURN 0 (unique res) + // + // At first this is confusing: why not reference `x`? Why instead go + // through the stack pointer register, `RSP`? The reason is that there + // are no uses of the *value of x* in this code. Operations such as + // `PTRSUB` operate on the address of things, and P-CODE doesn't + // natively have an `ADDRESS_OF` operation (though we add one). + // + // The purpose of this method is to go and find the "real" `HighVariable` + // if it exists by way of mining them from `MULTIEQUAL`, `COPY`, and + // `INDIRECT`` operations, which exist to encode SSA form, as well as to + // represent control-flow barriers in terms of data flow dependencies. + private void mineForVarNodes(PcodeOp op) { + for (Varnode node : op.getInputs()) { + HighVariable var = node.getHigh(); + if (var == null || !(var instanceof HighLocal)) { + continue; + } + + HighSymbol symbol = var.getSymbol(); + String var_name = var.getName(); + if (var_name == null || var_name.equals("UNNAMED")) { + if (symbol != null) { + var_name = symbol.getName(); + } + } + + if (var_name == null || var_name.equals("UNNAMED")) { + continue; + } + + missing_locals.put(var_name, (HighLocal) var); + } + } + + // Create missing local variables. High p-code still includes things + // like `PTRSUB SP, -offset` instead of treating the unrecognized data + // as `local_`. The decompiler, however, does these + // automatic variable inventions. + // + // Returns `false` on failure. + // + // NOTE(pag): This function is very much inspired by the `MakeStackRefs` + // script embedded in the Ghidra source. + private boolean fixupOperations( + HighFunction high_function, int num_params) throws Exception { + Function function = high_function.getFunction(); + FunctionSignature signature = function.getSignature(); + FunctionPrototype proto = high_function.getFunctionPrototype(); + LocalSymbolMap symbols = high_function.getLocalSymbolMap(); + CallDepthChangeInfo cdci = new CallDepthChangeInfo(function); + + // Fill in the parameters first so that they are the first + // things added to `entry_block`. + for (int i = 0; i < num_params; ++i) { + HighParam param = symbols.getParam(i); + if (param == null) { + HighSymbol param_sym = proto.getParam(i); + // println("Inventing HighParam for " + param_sym.getName() + " in " + function.getName()); + param = new HighParam(param_sym.getDataType(), null, null, i, param_sym); + missing_locals.put(param.getName(), param); + } + + createParamVarDecl(param); + } + + // Now go look for operations directly referencing the stack pointer. + for (PcodeBlockBasic block : high_function.getBasicBlocks()) { + Iterator op_iterator = block.getIterator(); + while (op_iterator.hasNext()) { + PcodeOp op = op_iterator.next(); + + switch (op.getOpcode()) { + case PcodeOp.CALLOTHER: + if (op.getInput(0).getOffset() < MIN_CALLOTHER) { + int index = (int) op.getInput(0).getOffset(); + String name = language.getUserDefinedOpName(index); + if (name != null) { + callother_uses.add(op); + } else { + println("Unsupported CALLOTHER at " + label(op) + ": " + op.toString()); + return false; + } + } + break; + case PcodeOp.PTRSUB: + if (!rewritePtrSubcomponent(high_function, op, cdci)) { + println("Unsupported PTRSUB at " + label(op) + ": " + op.toString()); + return false; + } + break; +// case PcodeOp.PTRADD: +// if (!markPtrAddForElision(high_function, op)) { +// println("Unsupported PTRADD at " + label(op) + ": " + op.toString()); +// return false; +// } +// break; + case PcodeOp.MULTIEQUAL: + if (!canElideMultiEqual(op)) { + println("Unsupported MULTIEQUAL at " + label(op) + ": " + op.toString()); + return false; + } + // Fall-through. + case PcodeOp.INDIRECT: + mineForVarNodes(op); + break; + + case PcodeOp.COPY: + if (canElideCopy(op)) { + mineForVarNodes(op); + break; + } + // Fall-through. + default: + if (!tryFixupStackVarnode(high_function, op, cdci)) { + println("Unsupported stack pointer reference at " + label(op) + ": " + op.toString()); + return false; + } + break; + } + } + } + + return true; + } + + private HighVariable variableOf(HighVariable var) { + if (var == null) { + return null; + } + + HighLocal fixed_var = old_locals.get(var); + return fixed_var != null ? fixed_var : var; + } + + // Return the variable of a given `Varnode`. This applies local fixups. + private HighVariable variableOf(Varnode node) { + return node == null ? null : variableOf(node.getHigh()); + } + + private HighVariable variableOf(PcodeOp op) { + return variableOf(op.getOutput()); + } + + // Handles serializing the output, if any, of `op`. We only actually + // serialize the named outputs. + private void serializeOutput(PcodeOp op) throws Exception { + Varnode output = op.getOutput(); + if (output == null) { + return; + } + + HighVariable var = variableOf(output); + if (var != null) { + name("type").value(label(var.getDataType())); + } else { + name("size").value(output.getSize()); + } + + // Only record an output node when the target is something named. + // Otherwise, this p-code operation will be used as part of an + // operand to something else. + // + // TODO(pag): Probably need some kind of verifier downstream to + // ensure no code motion happens. + VariableClassification klass = classifyVariable(var); + switch (klass) { + case PARAMETER: + case LOCAL: + case NAMED_TEMPORARY: + case GLOBAL: + break; + default: + return; + } + + name("output").beginObject(); + if (klass == VariableClassification.PARAMETER) { + name("kind").value("parameter"); + name("operation").value(label(getOrCreateLocalVariable(var, op))); + } else if (klass == VariableClassification.LOCAL) { + name("kind").value("local"); + name("operation").value(label(getOrCreateLocalVariable(var, op))); + } else if (klass == VariableClassification.NAMED_TEMPORARY) { + name("kind").value("temporary"); + name("operation").value(label(getOrCreateLocalVariable(var, op))); + } else if (klass == VariableClassification.GLOBAL) { + name("kind").value("global"); + name("global").value(label(addressOfGlobal(var))); + } else { + assert false; + } + endObject(); + } + + // The address of a `LOAD` or `STORE` is spread across two operands: + // the first being a constant representing the address space, and the + // second being the actual address. + private void serializeLoadStoreAddress(PcodeOp op) throws Exception { + Varnode address = rValueOf(op.getInput(1)); + if (!address.isConstant()) { + serializeInput(op, address); + return; + } + + Varnode aspace = op.getInput(0); + assert aspace.isConstant(); + + beginObject(); + name("size").value(op.getInput(1).getSize()); + println("!!! " + label(op) + ": " + op.toString()); + endObject(); + } + + // Serialize a `LOAD space, address` op, eliding the address space. + private void serializeLoadOp(PcodeOp op) throws Exception { + serializeOutput(op); + name("inputs").beginArray(); + serializeLoadStoreAddress(op); + endArray(); + } + + // Serialize a `STORE space, address, value` op, eliditing the address + // space. + private void serializeStoreOp(PcodeOp op) throws Exception { + serializeOutput(op); + name("inputs").beginArray(); + serializeLoadStoreAddress(op); + serializeInput(op, rValueOf(op.getInput(2))); + endArray(); + } + + // Product a new address in the `UNIQUE` address space. + private Address nextUniqueAddress() throws Exception { + Address address = unique_space.getAddress(next_unique); + next_unique += unique_space.getAddressableUnitSize(); + return address; + } + + // Creates a pseudo p-code op using a `CALLOTHER` that logically + // represents the definition of a parameter variable. + // + // NOTE(pag): The `HighParam` may have been invented and not have a + // representative. + // + // TODO(pag): Do we want the `.getLength()` or `.getAlignedLength()` + // for the parameter size in the absence of a representative? + private PcodeOp createParamVarDecl(HighVariable var) throws Exception { + HighParam param = (HighParam) var; + Address address = nextUniqueAddress(); + DefinitionVarnode def = new DefinitionVarnode(address, var.getDataType().getAlignedLength()); + Varnode[] ins = new Varnode[2]; + SequenceNumber loc = new SequenceNumber(address, next_seqnum++); + PcodeOp op = new PcodeOp(loc, PcodeOp.CALLOTHER, 2, def); + op.insertInput(new Varnode(constant_space.getAddress(DECLARE_PARAM_VAR), 4), 0); + op.insertInput(new Varnode(constant_space.getAddress(param.getSlot()), 4), 1); + def.setDef(var, op); + + Varnode[] instances = var.getInstances(); + Varnode[] new_instances = new Varnode[instances.length + 1]; + System.arraycopy(instances, 0, new_instances, 1, instances.length); + new_instances[0] = def; + + var.attachInstances(new_instances, def); + + entry_block.add(op); + + return op; + } + + // Creates a pseudo p-code op using a `CALLOTHER` that logically + // represents the definition of a local variable. + private PcodeOp createLocalVarDecl(HighVariable var) throws Exception { + Address address = nextUniqueAddress(); + DefinitionVarnode def = new DefinitionVarnode(address, var.getSize()); + Varnode[] ins = new Varnode[1]; + SequenceNumber loc = new SequenceNumber(address, next_seqnum++); + PcodeOp op = new PcodeOp(loc, PcodeOp.CALLOTHER, 1, def); + op.insertInput(new Varnode(constant_space.getAddress(DECLARE_LOCAL_VAR), 4), 0); + def.setDef(var, op); + + Varnode[] instances = var.getInstances(); + Varnode[] new_instances = new Varnode[instances.length + 1]; + System.arraycopy(instances, 0, new_instances, 1, instances.length); + new_instances[0] = def; + + var.attachInstances(new_instances, def); + + HighSymbol sym = var.getSymbol(); + entry_block.add(op); + + return op; + } + + // Creates a pseudo p-code op using a `CALLOTHER` that logically + // represents the definition of a variable that stands in for a register. + private PcodeOp createNamedTemporaryDecl( + HighVariable var, PcodeOp user_op) throws Exception { + Varnode rep = originalRepresentativeOf(var); + assert rep.isRegister(); + + Address address = nextUniqueAddress(); + DefinitionVarnode def = new DefinitionVarnode(address, var.getSize()); + Varnode[] ins = new Varnode[1]; + SequenceNumber loc = new SequenceNumber(address, next_seqnum++); + PcodeOp op = new PcodeOp(loc, PcodeOp.CALLOTHER, 1, def); + op.insertInput(new Varnode(constant_space.getAddress(DECLARE_TEMP_VAR), 4), 0); + def.setDef(var, op); + + Varnode[] instances = var.getInstances(); + Varnode[] new_instances = new Varnode[instances.length + 1]; + System.arraycopy(instances, 0, new_instances, 1, instances.length); + new_instances[0] = def; + + var.attachInstances(new_instances, def); + + entry_block.add(op); + temporary_address.put(var, user_op); + + return op; + } + + // Get or create a local variable pseudo definition op for the high + // variable `var`. + private PcodeOp getOrCreateLocalVariable( + HighVariable var, PcodeOp user_op) throws Exception { + + Varnode representative = var.getRepresentative(); + PcodeOp def = null; + if (representative != null) { + def = resolveOp(representative.getDef()); + if (!isOriginalRepresentative(representative)) { + return def; + } + } + + switch (classifyVariable(var)) { + case PARAMETER: + println("Creating late parameter for " + label(user_op) + ": " + user_op.toString()); + return createParamVarDecl(var); + case LOCAL: + return createLocalVarDecl(var); + case NAMED_TEMPORARY: + return createNamedTemporaryDecl(var, user_op); + default: + break; + } + + return def; + } + + // Serialize a direct call. This enqueues the targeted for type lifting + // `Function` if it can be resolved. + private void serializeCallOp(PcodeOp op) throws Exception { + Address caller_address = current_function.getFunction().getEntryPoint(); + Varnode target_node = op.getInput(0); + Function callee = null; + + name("has_return_value").value(op.getOutput() != null); + + if (target_node.isAddress()) { + Address target_address = caller_address.getNewAddress(target_node.getOffset()); + String target_label = label(target_address); + callee = fm.getFunctionAt(target_address); - private class PcodeSerializer extends JsonWriter { - private String arch; - private FunctionManager fm; - private ExternalManager em; - private DecompInterface ifc; - private BasicBlockModel bbm; - private List functions; - private int original_functions_size; - private Set
seen_functions; - private Set seen_types; - private List types_to_serialize; - - public PcodeSerializer(java.io.BufferedWriter writer, - String arch_, FunctionManager fm_, - ExternalManager em_, DecompInterface ifc_, - BasicBlockModel bbm_, - List functions_) { - super(writer); - this.arch = arch_; - this.fm = fm_; - this.em = em_; - this.ifc = ifc_; - this.bbm = bbm_; - this.functions = functions_; - this.original_functions_size = functions.size(); - this.seen_functions = new TreeSet<>(); - this.seen_types = new HashSet<>(); - this.types_to_serialize = new ArrayList<>(); - } - - private static String label(Address address) throws Exception { - return address.toString(true /* show address space prefix */); - } - - private static String label(SequenceNumber sn) throws Exception { - return label(sn.getTarget()) + Address.SEPARATOR + - Integer.toString(sn.getTime()) + Address.SEPARATOR + - Integer.toString(sn.getOrder()); - } - - private static String label(PcodeBlock block) throws Exception { - return label(block.getStart()) + Address.SEPARATOR + - Integer.toString(block.getIndex()) + Address.SEPARATOR + - PcodeBlock.typeToName(block.getType()); - } - - private static String label(PcodeOp op) throws Exception { - return label(op.getSeqnum()); - } - - private String label(DataType type) throws Exception { - // In type is null, assign VoidDataType in all cases. - // We assume it as void type. - if (type == null) { - type = VoidDataType.dataType; - } - String name = type.getName(); - CategoryPath category = type.getCategoryPath(); - String concat_type = category.toString() + name + Integer.toString(type.getLength()); - String type_id = Integer.toHexString(concat_type.hashCode()); - - UniversalID uid = type.getUniversalID(); - if (uid != null) { - type_id += Address.SEPARATOR + uid.toString(); - } - - if (seen_types.add(type_id)) { - types_to_serialize.add(type); - } - return type_id; - } - - // Return the r-value of a varnode. - private Varnode rValueOf(Varnode node) throws Exception { - while (true) { - PcodeOp def = node.getDef(); - if (def == null) { - break; - } - - if (def.getOpcode() != PcodeOp.INDIRECT) { - break; - } - - Varnode i0 = def.getInput(0); - if (!i0.isConstant()) { - break; - } - - assert i0.getOffset() == 0; - node = def.getInput(1); - } - - return node; - } - - // Return the l-value of a varnode. - private HighVariable lValueOf(Varnode node) throws Exception { - HighVariable var = node.getHigh(); - while (var == null) { - PcodeOp def = node.getDef(); - if (def == null) { - break; - } - - if (def.getOpcode() != PcodeOp.INDIRECT) { - break; - } - - Varnode i0 = def.getInput(0); - - // A constant varnode for input 0 in an INDIRECT op means that - // the referenced operation producing input 1 is the producer - // of the value. - if (i0.isConstant()) { - assert i0.getOffset() == 0; - break; - } - - node = i0; - var = node.getHigh(); - } - - return var; - } - - private void serializePointerType(Pointer ptr) throws Exception { - name("name").value(ptr.getDisplayName()); - name("kind").value("pointer"); - name("size").value(ptr.getLength()); - name("element_type").value(label(ptr.getDataType())); - } - - private void serializeTypedefType(TypeDef typedef) throws Exception { - name("name").value(typedef.getDisplayName()); - name("kind").value("typedef"); - name("size").value(typedef.getLength()); - name("base_type").value(label(typedef.getBaseDataType())); - } - - private void serializeArrayType(Array arr) throws Exception { - name("name").value(arr.getDisplayName()); - name("kind").value("array"); - name("size").value(arr.getLength()); - name("num_elements").value(arr.getNumElements()); - name("element_type").value(label(arr.getDataType())); - } - - private void serializeBuiltinType(DataType data_type, String kind) throws Exception { - name("name").value(data_type.getDisplayName()); - name("size").value(data_type.getLength()); - name("kind").value(kind); - } - - private void serializeCompositeType(Composite data_type, String kind) throws Exception { - name("name").value(data_type.getDisplayName()); - name("kind").value(kind); - name("size").value(data_type.getLength()); - name("fields").beginArray(); - - for (int i = 0; i < data_type.getNumComponents(); i++) { - DataTypeComponent dtc = data_type.getComponent(i); - beginObject(); - - name("type").value(label(dtc.getDataType())); - name("offset").value(dtc.getOffset()); - - if (dtc.getFieldName() != null) { - name("name").value(dtc.getFieldName()); - } - endObject(); - } - endArray(); - } - - private void serializeFunctionDefinition(FunctionDefinition fd) throws Exception { - name("name").value(fd.getDisplayName()); - name("kind").value("function"); - name("return_type").value(label(fd.getReturnType())); - name("is_variadic").value(fd.hasVarArgs()); - ParameterDefinition[] arguments = fd.getArguments(); - name("parameters").beginArray(); - for (int i = 0; i < arguments.length; i++) { - beginObject(); - String name = arguments[i].getName(); - if (name != null && !name.isEmpty()) { - name("name").value(arguments[i].getName()); - } - name("size").value(arguments[i].getLength()); - name("type").value(label(arguments[i].getDataType())); - endObject(); - } - endArray(); - } - - private void serialize(DataType data_type) throws Exception { - if (data_type == null) { - nullValue(); - return; - } - - if (data_type instanceof Pointer) { - serializePointerType((Pointer) data_type); - } else if (data_type instanceof TypeDef) { - serializeTypedefType((TypeDef) data_type); - } else if (data_type instanceof Array) { - serializeArrayType((Array) data_type); - } else if (data_type instanceof Structure) { - serializeCompositeType((Composite) data_type, "struct"); - } else if (data_type instanceof Union) { - serializeCompositeType((Composite) data_type, "union"); - } else if (data_type instanceof AbstractIntegerDataType){ - serializeBuiltinType(data_type, "integer"); - } else if (data_type instanceof AbstractFloatDataType){ - serializeBuiltinType(data_type, "float"); - } else if (data_type instanceof BooleanDataType){ - serializeBuiltinType(data_type, "boolean"); - } else if (data_type instanceof Enum) { - serializeBuiltinType(data_type, "enum"); - } else if (data_type instanceof VoidDataType) { - serializeBuiltinType(data_type, "void"); - } else if (data_type instanceof Undefined || data_type.toString().contains("undefined")) { - serializeBuiltinType(data_type, "undefined"); - } else if (data_type instanceof FunctionDefinition) { - serializeFunctionDefinition((FunctionDefinition) data_type); - } else { - throw new Exception("Unhandled type: " + data_type.toString()); - } - } - - private void serializeTypes() throws Exception { - for (int i = 0; i < types_to_serialize.size(); i++) { - DataType type = types_to_serialize.get(i); - name(label(type)).beginObject(); - serialize(type); - endObject(); - } - - println("Total serialized types: " + types_to_serialize.size()); - } - - private void serialize(FunctionPrototype proto) throws Exception { - if (proto == null) { - nullValue(); - return; - } - - name("parameters").beginArray(); - for (int i = 0; i < proto.getNumParams(); i++) { - HighVariable hv = proto.getParam(i).getHighVariable(); - // Assert if hv is not an instance of HighParam - assert hv instanceof HighParam; - if (hv != null) { - beginObject(); - String hv_name = hv.getName(); - if (hv_name != null && !hv_name.isEmpty()) { - name("name").value(hv_name); - } - name("type").value(label(hv.getDataType())); - endObject(); - } - } - endArray(); - } - - private void serialize(HighVariable high_var) throws Exception { - if (high_var == null) { - nullValue(); - return; - } - - beginObject(); - name("name").value(high_var.getName()); - name("type").value(label(high_var.getDataType())); - endObject(); - } - - private void serialize(Varnode node) throws Exception { - if (node == null) { - assert false; - nullValue(); - return; - } - - // Make sure INDIRECTs don't leak back into our output. We won't - // have the ability to reference them. - PcodeOp def = node.getDef(); - if (def != null) { - assert def.getOpcode() != PcodeOp.INDIRECT; - } - - beginObject(); - - Address address = node.getAddress(); - name("space").value(address.getAddressSpace().getName()); - name("offset").value(node.getOffset()); - name("size").value(node.getSize()); - name("address").value(label(node.getAddress())); - HighVariable high_var = node.getHigh(); - if (high_var != null) { - name("variable").beginObject(); - name("name").value(high_var.getName()); - name("type").value(label(high_var.getDataType())); - endObject(); - } - endObject(); - } - - // Serialize a direct call. This enqueues the targeted for type lifting - // `Function` if it can be resolved. - private void serializeDirectCallOp(Address caller_address, PcodeOp op) throws Exception { - Varnode target_address_node = op.getInput(0); - if (!target_address_node.isAddress()) { - throw new Exception("Unexpected non-address input to CALL"); - } - - Address target_address = caller_address.getNewAddress(target_address_node.getOffset()); - String target_label = label(target_address); - Function callee = fm.getFunctionAt(target_address); - - // `target_address` may be a pointer to an external. Figure out - // what we're calling. - if (callee == null) { - callee = fm.getReferencedFunction(target_address); - if (callee != null) { - target_address = callee.getEntryPoint(); - println("Call through " + target_label + - " targets " + callee.getName() + - " at " + label(target_address)); - target_label = label(target_address); - } - } - - name("target_address").value(target_label); - - if (callee != null) { - functions.add(callee); - } else { - println("Could not find function at address " + target_label + - " called by " + caller_address.toString()); - } - } - - // Serialize an unconditional branch. This records the targeted block. - private void serializeBranchOp(PcodeBlockBasic block, PcodeOp op) throws Exception { - assert block.getOutSize() == 1; - name("target_block").value(label(block.getOut(0))); - } - - // Serialize a conditional branch. This records the targeted blocks. - // - // TODO(pag): How does p-code handle delay slots? Are they separate - // blocks? - // - // XREF(pag): https://github.com/NationalSecurityAgency/ghidra/issues/2736 - // describes how the `op` meaning of the branch, i.e. whether - // it branches on zero or not zero, can change over the course - // of simplification, and so the inputs representing the - // branch targets may not actually represent the `true` or - // `false` outputs in the traditional sense. - private void serializeCondBranchOp(PcodeBlockBasic block, PcodeOp op) throws Exception { - name("taken_block").value(label(block.getTrueOut())); - name("not_taken_block").value(label(block.getFalseOut())); - name("condition"); - serialize(rValueOf(op.getInput(1))); - } - - // Serialize a generic multi-input, single-output p-code operation. - private void serializeGenericOp(PcodeOp op) throws Exception { - name("output"); - serialize(op.getOutput()); - name("inputs").beginArray(); - for (var input : op.getInputs()) { - serialize(rValueOf(input)); - } - endArray(); - } - - private void serialize(HighFunction function, PcodeBlockBasic block, PcodeOp op) throws Exception { - Address function_address = function.getFunction().getEntryPoint(); - beginObject(); - name("mnemonic").value(op.getMnemonic()); - name("name").value(label(op.getSeqnum())); - switch (op.getOpcode()) { - case PcodeOp.CALL: - serializeDirectCallOp(function_address, op); - break; - case PcodeOp.CBRANCH: - serializeCondBranchOp(block, op); - break; - case PcodeOp.BRANCH: - serializeBranchOp(block, op); - break; - default: - serializeGenericOp(op); - break; - } - endObject(); - } - - // Serialize a high p-code basic block. This iterates over the p-code - // operations within the block and serializes them individually. - private void serialize(HighFunction function, PcodeBlockBasic block) throws Exception { - PcodeBlock parent_block = block.getParent(); - if (parent_block != null) { - name("parent_block").value(label(parent_block)); - } - - PcodeOp op = null; - Iterator op_iterator = block.getIterator(); - name("operations").beginObject(); - while (op_iterator.hasNext()) { - op = op_iterator.next(); - - // NOTE(pag): INDIRECTs seem like a good way of modelling may- - // alias relations, as well as embedding control - // dependencies into the dataflow graph, e.g. to - // ensure code motion cannot happen from after a CALL - // to before a CALL, especially for stuff operating - // on stack slots. The idea at the time of this - // comment is that we will assume that eventual - // codegen also should not do any reordering, though - // enforcing that is also tricky. - if (op.getOpcode() != PcodeOp.INDIRECT) { - name(label(op)); - serialize(function, block, op); - } - } - endObject(); - - // List out the operations in their order. - op_iterator = block.getIterator(); - name("ordered_operations").beginArray(); - while (op_iterator.hasNext()) { - value(label(op_iterator.next())); - } - endArray(); - } - - // Serialize `function`. If we have `high_function` (the decompilation - // of function) then we will serialize its type information. Otherwise, - // we will serialize the type information of `function`. If - // `visit_pcode` is true, then this is a function for which we want to - // fully lift, i.e. visit all the high p-code. - private void serialize(HighFunction high_function, Function function, boolean visit_pcode) throws Exception { - - name("name").value(function.getName()); - FunctionPrototype proto = high_function.getFunctionPrototype(); - name("prototype").beginObject(); - if (proto != null) { - serialize(proto); - } - endObject(); - - // If we have a high P-Code function, then serialize the blocks. - if (high_function != null) { - if (visit_pcode) { - name("basic_blocks").beginObject(); - for (PcodeBlockBasic block : high_function.getBasicBlocks()) { - name(label(block)).beginObject(); - serialize(high_function, block); - endObject(); - } - endObject(); - } - } - } - - // Serialize the input function list to JSON. This function will also - // serialize type information related to referenced functions and - // variables. - public void serialize() throws Exception { - - beginObject(); - name("arch").value(getArch()); - name("format").value(currentProgram.getExecutableFormat()); - name("functions").beginObject(); - - for (int i = 0; i < functions.size(); ++i) { - Function function = functions.get(i); - Address function_address = function.getEntryPoint(); - if (!seen_functions.add(function_address)) { - continue; - } - - DecompileResults res = ifc.decompileFunction(function, decompilation_timeout, null); - HighFunction high_function = res.getHighFunction(); - if (high_function == null) { - continue; - } - - name(label(function_address)).beginObject(); - serialize(high_function, function, i < original_functions_size); - endObject(); - } - - // Serialize Types - name("types").beginObject(); - serializeTypes(); - endObject(); - - endObject().endObject(); - } - } - - private String getArch() throws Exception { - if (currentProgram.getLanguage() == null || - currentProgram.getLanguage().getProcessor() == null) { - return "unknown"; - } - return currentProgram.getLanguage().getProcessor().toString(); - } - - private DecompInterface getDecompilerInterface() throws Exception { - if (currentProgram == null) { - throw new Exception("Unable to initialize decompiler: invalid current program."); - } - DecompInterface decompiler = new DecompInterface(); - decompiler.setOptions(new DecompileOptions()); - if (!decompiler.openProgram(currentProgram)) { - throw new Exception("Unable to initialize decompiler: " + decompiler.getLastMessage()); - } - return decompiler; - } - - private void serializeToFile(Path file, List functions) throws Exception { - if (file == null || functions == null || functions.isEmpty()) { - throw new IllegalArgumentException("Invalid file path or empty function list"); - } - - final var serializer = new PcodeSerializer( - Files.newBufferedWriter(file), getArch(), - currentProgram.getFunctionManager(), currentProgram.getExternalManager(), - getDecompilerInterface(), new BasicBlockModel(currentProgram), functions); - serializer.serialize(); - serializer.close(); - } - - private List getAllFunctions() { - if (currentProgram == null || currentProgram.getFunctionManager() == null) { - return Collections.emptyList(); - } - FunctionIterator functionIter = currentProgram.getFunctionManager().getFunctions(true); - List functions = new ArrayList<>(); - while (functionIter.hasNext() && !monitor.isCancelled()) { - functions.add(functionIter.next()); - } - return functions; - } - - private void decompileSingleFunction() throws Exception { - if (getScriptArgs().length < 3) { - throw new IllegalArgumentException("Insufficient arguments. Expected: as argument"); - } - serializeToFile(Path.of(getScriptArgs()[2]), getGlobalFunctions(getScriptArgs()[1])); - } - - private void decompileAllFunctions() throws Exception { - if (getScriptArgs().length < 2) { - throw new IllegalArgumentException("Insufficient arguments. Expected: as argument"); - } - serializeToFile(Path.of(getScriptArgs()[1]), getAllFunctions()); - } - - private void runHeadless() throws Exception { - if (getScriptArgs().length < 1) { - throw new IllegalArgumentException("mode is not specified for headless execution"); - } - - // Execution mode - String mode = getScriptArgs()[0]; - println("Running in mode: " + mode); - switch (mode.toLowerCase()) { - case "single": - decompileSingleFunction(); - break; - case "all": - decompileAllFunctions(); - break; - default: - throw new IllegalArgumentException("Invalid mode: " + mode); - } - } - - private void decompileSingleFunctionInGUI() throws Exception { - List functions = null; - if (currentProgram != null) { - FunctionManager manager = currentProgram.getFunctionManager(); - if (manager != null) { - Function function = manager.getFunctionContaining(currentAddress); - if (function != null) { - functions = new ArrayList<>(); - functions.add(function); - } - } - } - - if (functions == null) { - String functionNameArg = askString("functionNameArg", "Function name to decompile: "); - functions = getGlobalFunctions(functionNameArg); - } - - File outputDirectory = askDirectory("outputFilePath", "Select output directory"); - File outputFilePath = new File(outputDirectory, "patchestry.json"); - serializeToFile(outputFilePath.toPath(), functions); - } - - private void decompileAllFunctionsInGUI() throws Exception { - File outputDirectory = askDirectory("outputFilePath", "Select output directory"); - File outputFilePath = new File(outputDirectory, "patchestry.json"); - serializeToFile(outputFilePath.toPath(), getAllFunctions()); - } - - // GUI mode execution - private void runGUI() throws Exception { - String mode = askString("mode", "Please enter mode:"); - println("Running in mode: " + mode); - switch (mode.toLowerCase()) { - case "single": - decompileSingleFunctionInGUI(); - break; - case "all": - decompileAllFunctionsInGUI(); - break; - default: - throw new IllegalArgumentException("Invalid mode: " + mode); - } - } - - // Script entry point - @Override - public void run() throws Exception { - try { - if (isRunningHeadless()) { - runHeadless(); - } else { - runGUI(); - } - } catch (Exception e) { - println("Error: " + e.getMessage()); - e.printStackTrace(new PrintWriter(new OutputStreamWriter(System.err))); - throw e; - } - } + // `target_address` may be a pointer to an external. Figure out + // what we're calling. + + if (callee == null) { + callee = fm.getReferencedFunction(target_address); + if (callee != null) { + target_address = callee.getEntryPoint(); + target_label = label(target_address); + } + } + } + + name("target"); + if (callee != null) { + functions.add(callee); + + beginObject(); + name("kind").value("function"); + name("function").value(label(callee)); + name("is_variadic").value(callee.hasVarArgs()); + name("is_noreturn").value(callee.hasNoReturn()); + endObject(); + + } else { + serializeInput(op, rValueOf(target_node)); + } + + name("inputs").beginArray(); + Varnode[] inputs = op.getInputs(); + for (int i = 1; i < inputs.length; ++i) { + serializeInput(op, rValueOf(inputs[i])); + } + endArray(); + } + + // Serialize an unconditional branch. This records the targeted block. + private void serializeBranchOp(PcodeOp op) throws Exception { + assert current_block.getOutSize() == 1; + name("target_block").value(label(current_block.getOut(0))); + } + + // Serialize a conditional branch. This records the targeted blocks. + // + // TODO(pag): How does p-code handle delay slots? Are they separate + // blocks? + // + // XREF(pag): https://github.com/NationalSecurityAgency/ghidra/issues/2736 + // describes how the `op` meaning of the branch, i.e. whether + // it branches on zero or not zero, can change over the course + // of simplification, and so the inputs representing the + // branch targets may not actually represent the `true` or + // `false` outputs in the traditional sense. + private void serializeCondBranchOp(PcodeOp op) throws Exception { + name("taken_block").value(label(current_block.getTrueOut())); + name("not_taken_block").value(label(current_block.getFalseOut())); + name("condition"); + serializeInput(op, rValueOf(op.getInput(1))); + } + + // Serialize a generic multi-input, single-output p-code operation. + private void serializeGenericOp(PcodeOp op) throws Exception { + name("inputs").beginArray(); + for (Varnode input : op.getInputs()) { + serializeInput(op, rValueOf(input)); + } + endArray(); + } + + // Serializes a pseudo-op `DECLARE_PARAM_VAR`, which is actually encoded + // as a `CALLOTHER`. + private void serializeDeclareParamVar(PcodeOp op) throws Exception { + HighVariable var = variableOf(op); + name("name").value(var.getName()); + name("type").value(label(var.getDataType())); + name("kind").value("parameter"); // So that it also looks like an input/output. + if (var instanceof HighParam) { + name("index").value(((HighParam) var).getSlot()); + } + } + + // Serializes a pseudo-op `DECLARE_LOCAL_VAR`, which is actually encoded + // as a `CALLOTHER`. + private void serializeDeclareLocalVar(PcodeOp op) throws Exception { + HighVariable var = variableOf(op); + HighSymbol sym = var.getSymbol(); + name("kind").value("local"); // So that it also looks like an input/output. + if (sym != null && var.getOffset() == -1 && var.getName().equals("UNNAMED")) { + name("name").value(sym.getName()); + name("type").value(label(sym.getDataType())); + } else { + name("name").value(var.getName()); + name("type").value(label(var.getDataType())); + } + } + + private void serializeDeclareNamedTemporary(PcodeOp op) throws Exception { + HighVariable var = variableOf(op); + Varnode rep = originalRepresentativeOf(var); + + // NOTE(pag): In practice, the `HighOther`s name associated with + // this register is probably `UNNAMED`, which happens in + // `HighOther.decode`; however, we'll be cautious and + // only canonicalize on the register name if the name + // it is the default. + if (var.getName().equals("UNNAMED")) { + if (rep.isRegister()) { + Register reg = language.getRegister(rep.getAddress(), 0); + if (reg != null) { + name("name").value(reg.getName()); + } else { + name("name").value("reg" + Address.SEPARATOR + + Long.toHexString(rep.getOffset())); + } + } else { + name("name").value("temp"); + } + } else { + name("name").value(var.getName()); + } + + name("kind").value("temporary"); // So that it also looks like an input/output. + name("type").value(label(var.getDataType())); + + // NOTE(pag): The same register might appear multiple times, though + // we can't guarantee that they will appear with the + // same names. Thus, we want to record the address of + // the operation using the original register as a kind of + // SSA-like version number downstream, e.g. in a Clang + // AST. + PcodeOp user_op = temporary_address.get(var); + if (user_op != null) { + name("address").value(label(user_op)); + } + } + + // Serialize an `ADDRESS_OF`, used to the get the address of a local or + // global variable. These are created from `PTRSUB` nodes. + private void serializeAddressOfOp(PcodeOp op) throws Exception { + serializeOutput(op); + name("inputs").beginArray(); + serializeInput(op, rValueOf(op.getInputs()[1])); + endArray(); + } + + // Serialize a `CALLOTHER` as a call to an intrinsic. + private void serializeIntrinsicCallOp(PcodeOp op) throws Exception { + serializeOutput(op); + + name("target").beginObject(); + name("kind").value("intrinsic"); + name("function").value(intrinsicLabel(op)); + name("is_variadic").value(true); + name("is_noreturn").value(false); + endObject(); // End of `target`. + + name("inputs").beginArray(); + Varnode[] inputs = op.getInputs(); + for (int i = 1; i < inputs.length; ++i) { + serializeInput(op, rValueOf(inputs[i])); + } + endArray(); + } + + // Serialize a `CALLOTHER`. The first input operand is a constant + // representing the user-defined opcode number. In our case, we have + // our own user-defined opcodes for making things better mirror the + // structure/needs of MLIR. + private void serializeCallOtherOp(PcodeOp op) throws Exception { + switch ((int) op.getInput(0).getOffset()) { + case DECLARE_PARAM_VAR: + serializeDeclareParamVar(op); + break; + case DECLARE_LOCAL_VAR: + serializeDeclareLocalVar(op); + break; + case DECLARE_TEMP_VAR: + serializeDeclareNamedTemporary(op); + break; + case ADDRESS_OF: + serializeAddressOfOp(op); + break; + default: + serializeIntrinsicCallOp(op); + break; + } + } + + // Serialize a `RETURN N[, val]` as logically being `RETURN val`. + private void serializeReturnOp(PcodeOp op) throws Exception { + Varnode inputs[] = op.getInputs(); + name("inputs").beginArray(); + if (inputs.length == 2) { + serializeInput(op, rValueOf(inputs[1])); + } + endArray(); + } + + // Get the mnemonic for a p-code operation. We have some custom + // operations encoded as `CALLOTHER`s, so we get their names manually + // here. + // + // TODO(pag): There is probably a way to register the name of a + // `CALLOTHER` via `Language.getSymbolTable()` using a + // `UseropSymbol`. It's not clear if there's really value in + // doing this, though. + private static String mnemonic(PcodeOp op) { + if (op.getOpcode() == PcodeOp.CALLOTHER) { + switch ((int) op.getInput(0).getOffset()) { + case DECLARE_PARAM_VAR: + return "DECLARE_PARAMETER"; + case DECLARE_LOCAL_VAR: + return "DECLARE_LOCAL"; + case DECLARE_TEMP_VAR: + return "DECLARE_TEMPORARY"; + case ADDRESS_OF: + return "ADDRESS_OF"; + default: + break; + } + } + return op.getMnemonic(); + } + + private void serialize(PcodeOp op) throws Exception { + + beginObject(); + name("mnemonic").value(mnemonic(op)); + + switch (op.getOpcode()) { + case PcodeOp.CALL: + case PcodeOp.CALLIND: + serializeOutput(op); + serializeCallOp(op); + break; + case PcodeOp.CALLOTHER: + serializeCallOtherOp(op); + break; + case PcodeOp.CBRANCH: + serializeCondBranchOp(op); + break; + case PcodeOp.BRANCH: + serializeBranchOp(op); + break; + case PcodeOp.RETURN: + serializeReturnOp(op); + break; + case PcodeOp.LOAD: + serializeLoadOp(op); + break; + case PcodeOp.STORE: + serializeStoreOp(op); + break; +// case PcodeOp.COPY: +// case PcodeOp.CAST: + default: + serializeOutput(op); + serializeGenericOp(op); + break; + } + + endObject(); + } + + // Returns `true` if we can elide a `MULTIEQUAL` operation. If all + // inputs are of the identical `HighVariable`, then we can elide. + private boolean canElideMultiEqual(PcodeOp op) throws Exception { + HighVariable high = variableOf(op); + if (high == null) { + return false; + } + + for (Varnode node : op.getInputs()) { + if (high != variableOf(node)) { + return false; + } + } + + return true; + } + + // Returns `true` if we can elide a copy operation. This only happens + // when we copy a variable into itself. + // + // NOTE(pag): I think this comes about as a sort of "pre-PHI" operation. + // + // TODO(pag): This is toally unsafe if there's an intervening write to + // relevant variable. Probably should investigate this case. + private boolean canElideCopy(PcodeOp op) throws Exception { + HighVariable high = variableOf(op); + return high != null && high == variableOf(op.getInput(0)); + } + + // Returns `true` if `op` is a branch operator. + private static boolean isBranch(PcodeOp op) throws Exception { + switch (op.getOpcode()) { + case PcodeOp.BRANCH: + case PcodeOp.CBRANCH: + case PcodeOp.BRANCHIND: + return true; + default: + return false; + } + } + + // Serialize a high p-code basic block. This iterates over the p-code + // operations within the block and serializes them individually. + private void serialize(PcodeBlockBasic block) throws Exception { + PcodeBlock parent_block = block.getParent(); + if (parent_block != null) { + name("parent_block").value(label(parent_block)); + } + + boolean last_is_branch = false; + Iterator op_iterator = block.getIterator(); + ArrayList ordered_operations = new ArrayList<>(); + + while (op_iterator.hasNext()) { + PcodeOp op = resolveOp(op_iterator.next()); + + // Inject the prefix operations into the ordered operations + // list. These are to handle things like stack pointer + // references flowing into `CALL` arguments. + List prefix_ops = prefix_operations.get(op); + if (prefix_ops != null) { + for (PcodeOp prefix_op : prefix_ops) { + ordered_operations.add(prefix_op); + } + } + + switch (op.getOpcode()) { + // NOTE(pag): INDIRECTs seem like a good way of modelling + // may- alias relations, as well as embedding + // control dependencies into the dataflow graph, + // e.g. to ensure code motion cannot happen from + // after a CALL to before a CALL, especially for + // stuff operating on stack slots. The idea at + // the time of this comment is that we will + // assume that eventual codegen also should not + // do any reordering, though enforcing that is + // also tricky. + case PcodeOp.INDIRECT: + continue; + + // MULTIEQUALs are Ghidra's form of SSA-form PHI nodes. + case PcodeOp.MULTIEQUAL: + if (canElideMultiEqual(op)) { + continue; + } + break; + + // Some copies end up imlpementing the kind of forward edge + // of a phi node (i.e. `MULTIEQUAL`) and can be elided. + case PcodeOp.COPY: + if (canElideCopy(op)) { + continue; + } + break; + + default: + break; + } + + ordered_operations.add(op); + } + + // Serialize the operations. + name("operations").beginObject(); + for (PcodeOp op : ordered_operations) { + name(label(op)); + current_block = block; + serialize(op); + last_is_branch = isBranch(op); + current_block = null; + } + + // Synthesize a fake `BRANCH` operation to the fall-through block. + // We'll have a fall-through if we don't already end in a branch, + // and if the last operation isn't a `RETURN` or a `CALL*` to a + // `noreturn`-attributed function. + String fall_through_label = ""; + if (!last_is_branch && block.getOutSize() == 1) { + fall_through_label = label(block) + ".exit"; + name(fall_through_label).beginObject(); + name("mnemonic").value("BRANCH"); + name("target_block").value(label(block.getOut(0))); + endObject(); // End of BRANCH to `first_block`. + } + + endObject(); // End of `operations`. + + // List out the operations in their order. + name("ordered_operations").beginArray(); + for (PcodeOp op : ordered_operations) { + value(label(op)); + } + if (!fall_through_label.equals("")) { + value(fall_through_label); + } + endArray(); // End of `ordered_operations`. + } + + // Emit a pseudo entry block to represent + private void serializeEntryBlock( + String label, PcodeBlockBasic first_block) throws Exception { + name(label).beginObject(); + name("operations").beginObject(); + for (PcodeOp pseudo_op : entry_block) { + name(label(pseudo_op)); + serialize(pseudo_op); + } + + // If there is a proper entry block, then invent a branch to it. + if (first_block != null) { + name("entry.exit").beginObject(); + name("mnemonic").value("BRANCH"); + name("target_block").value(label(first_block)); + endObject(); // End of BRANCH to `first_block`. + } + + endObject(); // End of operations. + + name("ordered_operations").beginArray(); + for (PcodeOp pseudo_op : entry_block) { + value(label(pseudo_op)); + } + value("entry.exit"); + endArray(); // End of `ordered_operations`. + endObject(); // End of `entry` block. + } + + // Serialize `function`. If we have `high_function` (the decompilation + // of function) then we will serialize its type information. Otherwise, + // we will serialize the type information of `function`. If + // `visit_pcode` is true, then this is a function for which we want to + // fully lift, i.e. visit all the high p-code. + private void serialize( + HighFunction high_function, Function function, + boolean visit_pcode) throws Exception { + + temporary_address.clear(); + old_locals.clear(); + missing_locals.clear(); + entry_block.clear(); + replacement_operations.clear(); + prefix_operations.clear(); + + FunctionPrototype proto = null; + name("name").value(function.getName()); + name("is_intrinsic").value(false); + + // If we have a high P-Code function, then serialize the blocks. + if (high_function != null) { + proto = high_function.getFunctionPrototype(); + + name("type").beginObject(); + + int num_params = 0; + if (proto != null) { + num_params = serializePrototype(proto); + } else { + num_params = serializePrototype(function.getSignature()); + } + endObject(); // End `type`. + + if (visit_pcode && fixupOperations(high_function, num_params)) { + + String entry_label = null; + PcodeBlockBasic first_block = null; + current_function = high_function; + + name("basic_blocks").beginObject(); + for (PcodeBlockBasic block : high_function.getBasicBlocks()) { + if (first_block == null) { + first_block = block; + } + + name(label(block)).beginObject(); + serialize(block); + endObject(); + } + + // If we created a fake entry block to represent variable + // declarations then emit that here. + if (!entry_block.isEmpty()) { + entry_label = entryBlockLabel(); + serializeEntryBlock(entry_label, first_block); + } + + endObject(); // End of `basic_blocks`. + current_function = null; + + if (entry_label != null) { + name("entry_block").value(entry_label); + + } else if (first_block != null) { + name("entry_block").value(label(first_block)); + } + } + } else { + name("type").beginObject(); + serializePrototype(function.getSignature()); + endObject(); // End `type`. + } + } + + private String entryBlockLabel() throws Exception { + return label(current_function) + Address.SEPARATOR + "entry"; + } + + // Serialize the global variable declarations. + private void serializeGlobals() throws Exception { + for (Map.Entry entry : seen_globals.entrySet()) { + Address address = entry.getKey(); + HighVariable global = entry.getValue(); + + // Try to get the global's name. + String name = global.getName(); + if (name == null || (name.equals("UNNAMED") && global.getOffset() == -1)) { + HighSymbol sym = global.getSymbol(); + if (sym != null) { + name = sym.getName(); + } + } + + name(label(address)).beginObject(); + name("name").value(name); + name("size").value(Integer.toString(global.getSize())); + name("type").value(label(global.getDataType())); + endObject(); + } + + println("Total serialized globals: " + Integer.toString(seen_globals.size())); + } + + // Don't try to decompile some functions. + private static final Set IGNORED_NAMES = Set.of( + "_start", "__libc_csu_fini", "__libc_csu_init", "__libc_start_main", + "__data_start", "__dso_handle", "_IO_stdin_used", + "_dl_relocate_static_pie", "__DTOR_END__", "__ashlsi3", + "__ashldi3", "__ashlti3", "__ashrsi3", "__ashrdi3", "__ashrti3", + "__divsi3", "__divdi3", "__divti3", "__lshrsi3", "__lshrdi3", + "__lshrti3", "__modsi3", "__moddi3", "__modti3", "__mulsi3", + "__muldi3", "__multi3", "__negdi2", "__negti2", "__udivsi3", + "__udivdi3", "__udivti3", "__udivmoddi4", "__udivmodti4", + "__umodsi3", "__umoddi3", "__umodti3", "__cmpdi2", "__cmpti2", + "__ucmpdi2", "__ucmpti2", "__absvsi2", "__absvdi2", "__addvsi3", + "__addvdi3", "__mulvsi3", "__mulvdi3", "__negvsi2", "__negvdi2", + "__subvsi3", "__subvdi3", "__clzsi2", "__clzdi2", "__clzti2", + "__ctzsi2", "__ctzdi2", "__ctzti2", "__ffsdi2", "__ffsti2", + "__paritysi2", "__paritydi2", "__parityti2", "__popcountsi2", + "__popcountdi2", "__popcountti2", "__bswapsi2", "__bswapdi2", + "frame_dummy", "call_frame_dummy", "__do_global_dtors", + "__do_global_dtors_aux", "call___do_global_dtors_aux", + "__do_global_ctors", "__do_global_ctors_1", "__do_global_ctors_aux", + "call___do_global_ctors_aux", "__gmon_start__", "_init_proc", + ".init_proc", "_term_proc", ".term_proc", "__uClibc_main", + "abort", "exit", "_Exit", "panic", "terminate", + "_Jv_RegisterClasses", + "__deregister_frame_info_bases", "__deregister_frame_info", + "__register_frame_info", "__cxa_throw", "__cxa_finalize", + "__cxa_allocate_exception", "__cxa_free_exception", + "__cxa_begin_catch", "__cxa_end_catch", "__cxa_new_handler", + "__cxa_get_globals", "__cxa_get_globals_fast", + "__cxa_current_exception_type", "__cxa_rethrow", "__cxa_bad_cast", + "__cxa_bad_typeid", "__allocate_exception", "__throw", + "__free_exception", + "__Unwind_RaiseException", "_Unwind_RaiseException", "_Unwind_Resume", + "_Unwind_DeleteException", "_Unwind_GetGR", "_Unwind_SetGR", + "_Unwind_GetIP", "_Unwind_SetIP", "_Unwind_GetRegionStart", + "_Unwind_GetLanguageSpecificData", "_Unwind_ForcedUnwind", + "__unw_getcontext", + "longjmp", "siglongjmp", "setjmp", "sigsetjmp", + "__register_frame_info_bases", "__assert_fail", + "_init", "_fini", "_ITM_registerTMCloneTable", + "_ITM_deregisterTMCloneTable", "register_tm_clones", + "deregister_tm_clones" + ); + + // Serialize all `CALLOTHER` intrinsics. + private void serializeIntrinsics() throws Exception { + Set seen_intrinsics = new HashSet<>(); + int num_intrinsics = 0; + + for (PcodeOp op : callother_uses) { + int index = (int) op.getInput(0).getOffset(); + String name = language.getUserDefinedOpName(index); + DataType ret_type = intrinsicReturnType(op); + String label = intrinsicLabel(name, ret_type); + if (!seen_intrinsics.add(label)) { + continue; + } + + name(label).beginObject(); + name("name").value(name); + name("is_intrinsic").value(true); + name("type").beginObject(); + name("return_type").value(label(ret_type)); + name("is_variadic").value(true); + name("is_noreturn").value(false); + name("parameter_types").beginArray().endArray(); + endObject(); // End of `type`. + endObject(); + + ++num_intrinsics; + } + + println("Total serialized intrinsics: " + Integer.toString(num_intrinsics)); + } + + // Serialize all functions. + // + // NOTE(pag): As we serialize functions, we might discover references + // to other functions, causing `functions` will grow over + // time. + private void serializeFunctions() throws Exception { + for (int i = 0; i < functions.size(); ++i) { + Function function = functions.get(i); + Address function_address = function.getEntryPoint(); + if (!seen_functions.add(function_address)) { + continue; + } + + boolean visit_pcode = i < original_functions_size && + !IGNORED_NAMES.contains(function.getName()); + + DecompileResults res = ifc.decompileFunction(function, DECOMPILATION_TIMEOUT, null); + HighFunction high_function = res.getHighFunction(); + + name(label(function)).beginObject(); + serialize(high_function, function, visit_pcode); + endObject(); + } + + println("Total serialized functions: " + Integer.toString(functions.size())); + + if (!callother_uses.isEmpty()) { + serializeIntrinsics(); + } + } + + // Serialize the input function list to JSON. This function will also + // serialize type information related to referenced functions and + // variables. + public void serialize() throws Exception { + + beginObject(); + name("arch").value(getArch()); + name("format").value(currentProgram.getExecutableFormat()); + + name("functions").beginObject(); + serializeFunctions(); + endObject(); // End of functions. + + name("globals").beginObject(); + serializeGlobals(); + endObject(); // End of globals. + + name("types").beginObject(); + serializeTypes(); + endObject(); // End of types. + + endObject(); + } + } + + private String getArch() throws Exception { + if (currentProgram.getLanguage() == null || + currentProgram.getLanguage().getProcessor() == null) { + return "unknown"; + } + return currentProgram.getLanguage().getProcessor().toString(); + } + + private DecompInterface getDecompilerInterface() throws Exception { + if (currentProgram == null) { + throw new Exception("Unable to initialize decompiler: invalid current program."); + } + + DecompileOptions options = DecompilerUtils.getDecompileOptions(state.getTool(), currentProgram); + DecompInterface decompiler = new DecompInterface(); + + decompiler.setOptions(options); + decompiler.toggleCCode(false); + decompiler.toggleSyntaxTree(true); + decompiler.toggleJumpLoads(true); + decompiler.toggleParamMeasures(false); + decompiler.setSimplificationStyle("decompile"); + + if (!decompiler.openProgram(currentProgram)) { + throw new Exception("Unable to initialize decompiler: " + decompiler.getLastMessage()); + } + return decompiler; + } + + private void serializeToFile(Path file, List functions) throws Exception { + if (file == null || functions == null || functions.isEmpty()) { + throw new IllegalArgumentException("Invalid file path or empty function list"); + } + + final var serializer = new PcodeSerializer( + Files.newBufferedWriter(file), getArch(), + currentProgram.getFunctionManager(), currentProgram.getExternalManager(), + getDecompilerInterface(), new BasicBlockModel(currentProgram), functions); + serializer.serialize(); + serializer.close(); + } + + private List getAllFunctions() { + if (currentProgram == null || currentProgram.getFunctionManager() == null) { + return Collections.emptyList(); + } + FunctionIterator functionIter = currentProgram.getFunctionManager().getFunctions(true); + List functions = new ArrayList<>(); + while (functionIter.hasNext() && !monitor.isCancelled()) { + functions.add(functionIter.next()); + } + return functions; + } + + private void decompileSingleFunction() throws Exception { + if (getScriptArgs().length < 3) { + throw new IllegalArgumentException("Insufficient arguments. Expected: as argument"); + } + serializeToFile(Path.of(getScriptArgs()[2]), getGlobalFunctions(getScriptArgs()[1])); + } + + private void decompileAllFunctions() throws Exception { + if (getScriptArgs().length < 2) { + throw new IllegalArgumentException("Insufficient arguments. Expected: as argument"); + } + serializeToFile(Path.of(getScriptArgs()[1]), getAllFunctions()); + } + + private void runHeadless() throws Exception { + if (getScriptArgs().length < 1) { + throw new IllegalArgumentException("mode is not specified for headless execution"); + } + + // Execution mode + String mode = getScriptArgs()[0]; + println("Running in mode: " + mode); + switch (mode.toLowerCase()) { + case "single": + decompileSingleFunction(); + break; + case "all": + decompileAllFunctions(); + break; + default: + throw new IllegalArgumentException("Invalid mode: " + mode); + } + } + + private void decompileSingleFunctionInGUI() throws Exception { + List functions = null; + if (currentProgram != null) { + FunctionManager manager = currentProgram.getFunctionManager(); + if (manager != null) { + Function function = manager.getFunctionContaining(currentAddress); + if (function != null) { + functions = new ArrayList<>(); + functions.add(function); + } + } + } + + if (functions == null) { + String functionNameArg = askString("functionNameArg", "Function name to decompile: "); + functions = getGlobalFunctions(functionNameArg); + } + + File outputDirectory = askDirectory("outputFilePath", "Select output directory"); + File outputFilePath = new File(outputDirectory, "patchestry.json"); + serializeToFile(outputFilePath.toPath(), functions); + } + + private void decompileAllFunctionsInGUI() throws Exception { + File outputDirectory = askDirectory("outputFilePath", "Select output directory"); + File outputFilePath = new File(outputDirectory, "patchestry.json"); + serializeToFile(outputFilePath.toPath(), getAllFunctions()); + } + + // GUI mode execution + private void runGUI() throws Exception { + String mode = askString("mode", "Please enter mode:"); + println("Running in mode: " + mode); + switch (mode.toLowerCase()) { + case "single": + decompileSingleFunctionInGUI(); + break; + case "all": + decompileAllFunctionsInGUI(); + break; + default: + throw new IllegalArgumentException("Invalid mode: " + mode); + } + } + + // Script entry point + @Override + public void run() throws Exception { + try { + if (isRunningHeadless()) { + runHeadless(); + } else { + runGUI(); + } + } catch (Exception e) { + println("Error: " + e.getMessage()); + e.printStackTrace(new PrintWriter(new OutputStreamWriter(System.err))); + throw e; + } + } } diff --git a/scripts/ghidra/decompile-entrypoint.sh b/scripts/ghidra/decompile-entrypoint.sh index 039c57b..3ff30a0 100644 --- a/scripts/ghidra/decompile-entrypoint.sh +++ b/scripts/ghidra/decompile-entrypoint.sh @@ -1,7 +1,6 @@ #!/bin/bash -x # # Copyright (c) 2024, Trail of Bits, Inc. -# All rights reserved. # # This source code is licensed in accordance with the terms specified in # the LICENSE file found in the root directory of this source tree. diff --git a/scripts/ghidra/decompile-headless.dockerfile b/scripts/ghidra/decompile-headless.dockerfile index 8e06b6c..5d37cec 100644 --- a/scripts/ghidra/decompile-headless.dockerfile +++ b/scripts/ghidra/decompile-headless.dockerfile @@ -3,6 +3,8 @@ FROM eclipse-temurin:17 AS base FROM base AS build ENV GHIDRA_VERSION=11.1.2 +ENV GRADLE_VERSION=8.2 +ENV GRADLE_HOME=/opt/gradle ENV GHIDRA_RELEASE_TAG=20240709 ENV GHIDRA_PACKAGE=ghidra_${GHIDRA_VERSION}_PUBLIC_${GHIDRA_RELEASE_TAG} ENV GHIDRA_SHA256=219ec130b901645779948feeb7cc86f131dd2da6c36284cf538c3a7f3d44b588 @@ -12,6 +14,8 @@ RUN apt-get update && apt-get install -y \ wget \ ca-certificates \ unzip \ + gcc \ + g++ \ --no-install-recommends && \ apt-get clean && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives @@ -24,6 +28,18 @@ RUN unzip /tmp/ghidra.zip -d /tmp && \ chmod +x /ghidra/ghidraRun && \ rm -rf /var/tmp/* /tmp/* /ghidra/docs /ghidra/Extensions/Eclipse /ghidra/licenses +# Download and install Gradle +RUN wget https://services.gradle.org/distributions/gradle-${GRADLE_VERSION}-bin.zip -P /tmp \ + && unzip /tmp/gradle-${GRADLE_VERSION}-bin.zip -d /opt/ \ + && ln -s /opt/gradle-${GRADLE_VERSION} ${GRADLE_HOME} \ + && rm /tmp/gradle-${GRADLE_VERSION}-bin.zip + +# Set the PATH for Gradle +ENV PATH="${GRADLE_HOME}/bin:${PATH}" + +RUN chmod +x /ghidra/support/buildNatives && \ + /ghidra/support/buildNatives + RUN apt-get purge -y --auto-remove wget ca-certificates unzip && \ apt-get clean diff --git a/scripts/ghidra/decompile-headless.sh b/scripts/ghidra/decompile-headless.sh index 44ac13e..72e0084 100755 --- a/scripts/ghidra/decompile-headless.sh +++ b/scripts/ghidra/decompile-headless.sh @@ -1,7 +1,6 @@ #!/bin/bash # # Copyright (c) 2024, Trail of Bits, Inc. -# All rights reserved. # # This source code is licensed in accordance with the terms specified in # the LICENSE file found in the root directory of this source tree. diff --git a/scripts/render_json.py b/scripts/render_json.py new file mode 100644 index 0000000..5c7c561 --- /dev/null +++ b/scripts/render_json.py @@ -0,0 +1,235 @@ +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. +import collections +import json +import sys +from typing import Dict, List, Optional + +NEXT_ID = 0 + + +def next_id() -> int: + global NEXT_ID + next_val = NEXT_ID + NEXT_ID += 1 + return next_val + + +ID = collections.defaultdict(next_id) +EMPTY = {} + + +def should_render_line(op: Dict) -> bool: + if "output" in op: + return True + + +def should_render(data: Dict) -> bool: + if "output" in data: + return True + mnemonic = data['mnemonic'] + if mnemonic in ("DECLARE_PARAMETER", "ADDRESS_OF"): + return False + if mnemonic.startswith("DECLARE_"): + return True + if "target" in data: + if "is_noreturn" in data["target"]: + if data["target"]["is_noreturn"]: + return True + if "has_return_value" in data: + return not data["has_return_value"] + return mnemonic in ("BRANCH", "CBRANCH", "BRANCHIND", "RETURN", "STORE") + + +def render_typed_var(var: str, type_key: str, types: Dict[str, Dict]) -> str: + data = types[type_key] + match data["kind"]: + case "pointer": + return render_typed_var("", data["element_type"], types) + " *" + var + case "typedef": + return data["name"] + var + case "void": + return "void" + var + case "integer": + return data["name"] + var + case "float": + return data["name"] + var + case "boolean": + return data["name"] + var + case "enum": + return "enum " + data["name"] + var + case "struct": + return "struct " + data["name"] + var + case "union": + return "union " + data["name"] + var + case "array": + return render_typed_var(var, data["element_type"], types) + "[" + str(data["num_elements"]) + "]" + case _: + return var + + +def render_output(data: Dict[str, str], operations: Dict[str, Dict], vars: Dict[str, Dict]): + assert "kind" in data + if data["kind"] == "global": + print(var_name(vars, data), end=" = ") + else: + assert "operation" in data + data = operations[data["operation"]] + assert data["mnemonic"].startswith("DECLARE_") + print(var_name(vars, data), end=" = ") + + +def render_input(data: Dict, operations: Dict[str, Dict], functions: Dict[str, Dict], vars: Dict[str, Dict], types: Dict[str, Dict]): + match data["kind"]: + case "constant": + print(data["value"], end='') + case "temporary": + source = operations[data["operation"]] + if "name" in source: + print(var_name(vars, source), end='') + else: + print('(', end='') + render_op(source, operations, functions, vars, types, True) + print(')', end='') + case "local": + print(var_name(vars, operations[data["operation"]]), end='') + case "parameter": + print(var_name(vars, operations[data["operation"]]), end='') + case "global": + print(var_name(vars, data), end='') + case "function": + print(functions[data["function"]]["name"], end='') + case _: + print("?INPUT?", end='') + + +def var_name(vars: Dict[str, Dict], data: Dict) -> str: + if data["kind"] == "global": + return vars[data["global"]]["name"] + + name = data["name"] + if data["kind"] == "temporary": + name += "_" + str(ID[data["address"]]) + return name + + +def render_op(data: Dict, operations: Dict[str, Dict], functions: Dict[str, Dict], vars: Dict[str, Dict], types: Dict[str, Dict], inline=False): + if not should_render(data): + if not inline: + return + + mnemonic = data['mnemonic'] + if mnemonic.startswith("DECLARE_"): + if inline: + print(data["name"], end='') + else: + decl = render_typed_var(" " + var_name(vars, data), data["type"], types) + print(f"{decl}
", end='') + return + + is_call = mnemonic in ("CALL", "CALLIND") + if "output" in data: + render_output(data["output"], operations, vars) + + if is_call: + render_input(data["target"], operations, functions, vars, types) + print("(") + else: + print(mnemonic, end=" ") + + sep = "" + + if "inputs" in data: + for input in data["inputs"]: + print("", end=sep) + render_input(input, operations, functions, vars, types) + sep = ", " + + if "condition" in data: # For a CBRANCH + render_input(data["condition"], operations, functions, vars, types) + + if is_call: + print(")") + + if not inline: + print("
", end='') + + +def render_block(key: int, data: Dict, operations: Dict[str, Dict], functions: Dict[str, Dict], vars: Dict[str, Dict], types: Dict[str, Dict]): + print(f"b{key} [label=<
", end='') + last_op: Optional[Dict] = None + for op_key in data["ordered_operations"]: + last_op = operations[op_key] + render_op(last_op, operations, functions, vars, types) + + print(f"
>];") + + if last_op is None: + return + + if "target_block" in last_op: + target_key = ID[last_op["target_block"]] + print(f"b{key} -> b{target_key};") + + if "not_taken_block" in last_op: + target_key = ID[last_op["not_taken_block"]] + print(f"b{key} -> b{target_key} [color=\"red\"];") + + if "taken_block" in last_op: + target_key = ID[last_op["taken_block"]] + print(f"b{key} -> b{target_key} [color=\"green\"];") + + +def render_function(func_key: str, functions: Dict[str, Dict], vars: Dict[str, Dict], types: Dict[str, Dict]): + data = functions[func_key] + key = ID[func_key] + basic_blocks: Dict = data.get("basic_blocks", EMPTY) + + # Local the entry block. + entry_block_name: str = data.get("entry_block", "") + entry_block: Dict = basic_blocks.get(entry_block_name, EMPTY) + + # Merge all block operations into one dictionary for convenient lookup. + operations: Dict[str, Dict] = {} + for block in basic_blocks.values(): + for op_key, op_data in block["operations"].items(): + operations[op_key] = op_data + + # Extract parameter names from the entry block. + param_types: List[str] = data["type"]["parameter_types"] + param_names: List[str] = [""] * len(param_types) + + for op in entry_block.get("operations", EMPTY).values(): + if op["mnemonic"] == "DECLARE_PARAMETER": + param_names[op["index"]] = op["name"] + + func_name: str = data["name"] + + param_str: str = ", ".join(render_typed_var(n, t, types) for n, t in zip(param_names, param_types)) + print(f"f{key} [label=<
{func_name}({param_str})
{func_key}
>];") + + if not entry_block_name: + return + + entry_block_key = ID[entry_block_name] + print(f"f{key} -> b{entry_block_key};") + + for block_key, block in basic_blocks.items(): + render_block(ID[block_key], block, operations, functions, vars, types) + + +def render_functions(data: Dict): + print("digraph {") + print("node [shape=none fontname=Courier];") + functions: Dict[str, Dict] = data["functions"] + vars: Dict[str, Dict] = data["globals"] + for func_key in functions.keys(): + render_function(func_key, functions, vars, data["types"]) + print("}") + + +if __name__ == "__main__": + with open(sys.argv[1], "r") as json_file: + render_functions(json.loads(json_file.read())) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 7ffa41c..4d41b17 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,7 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. cmake_minimum_required(VERSION 3.25) diff --git a/test/pcode-translate/function.json b/test/pcode-translate/function.json index a763a09..bb7369c 100644 --- a/test/pcode-translate/function.json +++ b/test/pcode-translate/function.json @@ -1,8 +1,10 @@ // RUN: bash %strip-json-comments %s | %pcode-translate --deserialize-pcode | %file-check %s { - // CHECK: pc.func @function - "name": "function", - "basic_blocks": [ + "functions": [ + { + // CHECK: pc.func @function + "name": "function", + "basic_blocks": [ { // CHECK: pc.block @fisrt_block "label": "fisrt_block", @@ -13,5 +15,7 @@ "label": "second_block", "instructions": [] } + ] + } ] } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 2b37ef1..02b8237 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -1,5 +1,7 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. add_subdirectory(pcode-translate) +add_subdirectory(pcode-lifter) diff --git a/tools/pcode-lifter/CMakeLists.txt b/tools/pcode-lifter/CMakeLists.txt new file mode 100644 index 0000000..3e9ec84 --- /dev/null +++ b/tools/pcode-lifter/CMakeLists.txt @@ -0,0 +1,24 @@ +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in +# the LICENSE file found in the root directory of this source tree. + +set(LLVM_LINK_COMPONENTS + Support + clangFrontend +) + + +add_executable(pcode-lifter + main.cpp +) + +llvm_update_compile_flags(pcode-lifter) +target_link_libraries(pcode-lifter + PRIVATE + patchestry::ghidra + patchestry::ast + clangFrontend +) + +mlir_check_link_libraries(pcode-lifter) \ No newline at end of file diff --git a/tools/pcode-lifter/main.cpp b/tools/pcode-lifter/main.cpp new file mode 100644 index 0000000..d09fb3f --- /dev/null +++ b/tools/pcode-lifter/main.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2024, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include +#include + +#include "clang/Basic/DiagnosticOptions.h" +#include "clang/Basic/FileManager.h" +#include "clang/Basic/TargetInfo.h" +#include "clang/Basic/TargetOptions.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Frontend/CompilerInvocation.h" +#include "clang/Frontend/FrontendOptions.h" +#include "clang/Lex/PreprocessorOptions.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/VirtualFileSystem.h" +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +const llvm::cl::opt< std::string > input_filename( + llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::Required +); + +const llvm::cl::opt< bool > + verbose("v", llvm::cl::desc("Enable debug logs"), llvm::cl::init(false)); + +const llvm::cl::opt< bool > pprint( + "pretty-print", llvm::cl::desc("Pretty print translation unit"), llvm::cl::init(false) +); + +const llvm::cl::opt< std::string > output_filename( + "output", llvm::cl::desc("Specify output filename"), llvm::cl::value_desc("filename"), + llvm::cl::init("/tmp/output.c") +); + +int main(int argc, char **argv) { + llvm::cl::ParseCommandLineOptions( + argc, argv, "pcode-lifter to lift high pcode into clang ast\n" + ); + + llvm::ErrorOr< std::unique_ptr< llvm::MemoryBuffer > > file_or_err = + llvm::MemoryBuffer::getFile(input_filename); + + if (std::error_code error_code = file_or_err.getError()) { + llvm::errs() << "Error reading json file : " << error_code.message() << "\n"; + return EXIT_FAILURE; + } + + std::unique_ptr< llvm::MemoryBuffer > buffer = std::move(file_or_err.get()); + auto json = llvm::json::parse(buffer->getBuffer()); + if (!json) { + llvm::errs() << "Failed to parse pcode JSON: " << json.takeError(); + return EXIT_FAILURE; + } + + auto program = patchestry::ghidra::JsonParser().deserialize_program(*json->getAsObject()); + if (!program.has_value()) { + llvm::errs() << "Failed to process json object" << json.takeError(); + return EXIT_FAILURE; + } + + clang::CompilerInstance ci; + ci.createDiagnostics(); + if (!ci.hasDiagnostics()) { + llvm::errs() << "Failed to initialize diagnostics.\n"; + return EXIT_FAILURE; + } + + clang::CompilerInvocation &invocation = ci.getInvocation(); + clang::TargetOptions &inv_target_opts = invocation.getTargetOpts(); + inv_target_opts.Triple = llvm::sys::getDefaultTargetTriple(); + + std::shared_ptr< clang::TargetOptions > target_options = + std::make_shared< clang::TargetOptions >(); + target_options->Triple = llvm::sys::getDefaultTargetTriple(); + ci.setTarget(clang::TargetInfo::CreateTargetInfo(ci.getDiagnostics(), target_options)); + + ci.getFrontendOpts().ProgramAction = clang::frontend::ParseSyntaxOnly; + ci.getLangOpts().C99 = true; + // Setup file manager and source manager + ci.createFileManager(); + ci.createSourceManager(ci.getFileManager()); + + auto &sm = ci.getSourceManager(); + std::string file_data = "/patchestry"; + llvm::ErrorOr< clang::FileEntryRef > file_entry_ref_or_err = + ci.getFileManager().getVirtualFileRef("/tmp/patchestry", file_data.size(), 0); + clang::FileID file_id = sm.createFileID( + *file_entry_ref_or_err, clang::SourceLocation(), clang::SrcMgr::C_User, 0 + ); + + sm.setMainFileID(file_id); + + // Create the preprocessor and AST context + ci.createPreprocessor(clang::TU_Complete); + ci.createASTContext(); + + auto &ast_context = ci.getASTContext(); + + std::string outfile = output_filename.getValue(); + std::unique_ptr< patchestry::ast::PcodeASTConsumer > consumer = + std::make_unique< patchestry::ast::PcodeASTConsumer >(ci, program.value(), outfile); + ci.setASTConsumer(std::move(consumer)); + ci.createSema(clang::TU_Complete, nullptr); + + auto &ast_consumer = ci.getASTConsumer(); + ast_consumer.HandleTranslationUnit(ast_context); + + return EXIT_SUCCESS; +} diff --git a/tools/pcode-translate/CMakeLists.txt b/tools/pcode-translate/CMakeLists.txt index 833498c..3d9b778 100644 --- a/tools/pcode-translate/CMakeLists.txt +++ b/tools/pcode-translate/CMakeLists.txt @@ -1,6 +1,7 @@ -# Copyright (c) 2024, Trail of Bits, Inc. All rights reserved. This source code -# is licensed in accordance with the terms specified in the LICENSE file found -# in the root directory of this source tree. +# Copyright (c) 2024, Trail of Bits, Inc. +# +# This source code is licensed in accordance with the terms specified in the +# LICENSE file found in the root directory of this source tree. set(LLVM_LINK_COMPONENTS Support diff --git a/tools/pcode-translate/main.cpp b/tools/pcode-translate/main.cpp index ce613e1..bc711f6 100644 --- a/tools/pcode-translate/main.cpp +++ b/tools/pcode-translate/main.cpp @@ -1,6 +1,5 @@ /* * Copyright (c) 2024, Trail of Bits, Inc. - * All rights reserved. * * This source code is licensed in accordance with the terms specified in * the LICENSE file found in the root directory of this source tree.