From b046e123dc69821f2c375161e0adef3c6d9c9db4 Mon Sep 17 00:00:00 2001 From: cursey Date: Tue, 14 May 2024 22:57:47 -0700 Subject: [PATCH] Test/Thread trapping (#63) Co-authored-by: bottiger1 <55270538+bottiger1@users.noreply.github.com> --- CMakeLists.txt | 2 +- amalgamate.py | 39 +++++- cmake.toml | 2 +- include/safetyhook/easy.hpp | 17 ++- include/safetyhook/inline_hook.hpp | 41 +++++- include/safetyhook/mid_hook.hpp | 35 ++++-- include/safetyhook/os.hpp | 12 +- src/allocator.cpp | 2 + src/easy.cpp | 8 +- src/inline_hook.cpp | 119 +++++++++--------- src/mid_hook.cpp | 30 ++++- src/os.linux.cpp | 8 +- src/os.windows.cpp | 194 +++++++++++++++++------------ src/vmt_hook.cpp | 40 +++--- test/inline_hook.cpp | 47 +++++++ test/mid_hook.cpp | 61 +++++++++ 16 files changed, 454 insertions(+), 203 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4807db5..df17dce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -190,7 +190,7 @@ if(SAFETYHOOK_AMALGAMATE) # amalgamate add_custom_command( OUTPUT ${AMALGAMATED_FILE} ${AMALGAMATED_HEADER} DEPENDS ${HEADER_FILES} ${SOURCE_FILES} ${AMALGAMATE_SCRIPT} - COMMAND ${Python3_EXECUTABLE} ${AMALGAMATE_SCRIPT} ${AMALGAMATED_FILE} ${AMALGAMATED_HEADER} + COMMAND ${Python3_EXECUTABLE} ${AMALGAMATE_SCRIPT} MAIN_DEPENDENCY ${AMALGAMATE_SCRIPT} COMMENT "Amalgamating" ) diff --git a/amalgamate.py b/amalgamate.py index 2f2b51c..2a705cf 100644 --- a/amalgamate.py +++ b/amalgamate.py @@ -4,9 +4,12 @@ from typing import List, Set from glob import glob from shutil import rmtree +from textwrap import dedent import os import re +import sys +import argparse SAFETYHOOK_ROOT = Path(__file__).resolve().parent PUBLIC_INCLUDE_PATHS = [ @@ -18,6 +21,10 @@ OUTPUT_DIR = SAFETYHOOK_ROOT / 'amalgamated-dist' FILE_HEADER = ['// DO NOT EDIT. This file is auto-generated by `amalgamate.py`.', ''] +parser = argparse.ArgumentParser(description='bundles cpp and hpp files together') +parser.add_argument('--polyfill', action='store_true', + help='replace std::except with a polyfill so it can be compiled on C++20 or older. https://raw.githubusercontent.com/TartanLlama/expected/master/include/tl/expected.hpp') + # Python versions before 3.10 don't have the root_dir argument for glob, so we # crudely emulate it here. @@ -155,7 +162,25 @@ def merge_sources(*, source_dir: Path, covered_headers: Set[Path]): return output +def do_polyfill(content): + return content.replace('#include ', + dedent(''' + #if __has_include("tl/expected.hpp") + #include "tl/expected.hpp" + #elif __has_include("expected.hpp") + #include "expected.hpp" + #else + #error "No polyfill found" + #endif + ''')) \ + .replace('std::expected', 'tl::expected') \ + .replace('std::unexpected', 'tl::unexpected') + + def main(): + args = parser.parse_args() + polyfill = args.polyfill is True + if OUTPUT_DIR.exists(): print('Output directory exists. Deleting.') rmtree(OUTPUT_DIR) @@ -164,20 +189,26 @@ def main(): covered_headers = set() with open(OUTPUT_DIR / 'safetyhook.hpp', 'w') as f: - f.write('\n'.join(FILE_HEADER + merge_headers( + content = '\n'.join(FILE_HEADER + merge_headers( header='safetyhook.hpp', search_paths=PUBLIC_INCLUDE_PATHS, covered_headers=covered_headers, stack=[], - ))) + )) + if polyfill: + content = do_polyfill(content) + f.write(content) print(covered_headers) with open(OUTPUT_DIR / 'safetyhook.cpp', 'w') as f: - f.write('\n'.join(FILE_HEADER + merge_sources( + content = '\n'.join(FILE_HEADER + merge_sources( source_dir=SAFETYHOOK_ROOT / 'src', covered_headers=covered_headers, - ))) + )) + if polyfill: + content = do_polyfill(content) + f.write(content) if __name__ == '__main__': diff --git a/cmake.toml b/cmake.toml index 70abb24..df17b9f 100644 --- a/cmake.toml +++ b/cmake.toml @@ -97,7 +97,7 @@ set(AMALGAMATE_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/amalgamate.py) add_custom_command( OUTPUT ${AMALGAMATED_FILE} ${AMALGAMATED_HEADER} DEPENDS ${HEADER_FILES} ${SOURCE_FILES} ${AMALGAMATE_SCRIPT} - COMMAND ${Python3_EXECUTABLE} ${AMALGAMATE_SCRIPT} ${AMALGAMATED_FILE} ${AMALGAMATED_HEADER} + COMMAND ${Python3_EXECUTABLE} ${AMALGAMATE_SCRIPT} MAIN_DEPENDENCY ${AMALGAMATE_SCRIPT} COMMENT "Amalgamating" ) diff --git a/include/safetyhook/easy.hpp b/include/safetyhook/easy.hpp index bb4bb28..080a2a8 100644 --- a/include/safetyhook/easy.hpp +++ b/include/safetyhook/easy.hpp @@ -12,29 +12,34 @@ namespace safetyhook { /// @brief Easy to use API for creating an InlineHook. /// @param target The address of the function to hook. /// @param destination The address of the destination function. +/// @param flags The flags to use. /// @return The InlineHook object. -[[nodiscard]] InlineHook create_inline(void* target, void* destination); +[[nodiscard]] InlineHook create_inline(void* target, void* destination, InlineHook::Flags flags = InlineHook::Default); /// @brief Easy to use API for creating an InlineHook. /// @param target The address of the function to hook. /// @param destination The address of the destination function. +/// @param flags The flags to use. /// @return The InlineHook object. -[[nodiscard]] InlineHook create_inline(FnPtr auto target, FnPtr auto destination) { - return create_inline(reinterpret_cast(target), reinterpret_cast(destination)); +[[nodiscard]] InlineHook create_inline( + FnPtr auto target, FnPtr auto destination, InlineHook::Flags flags = InlineHook::Default) { + return create_inline(reinterpret_cast(target), reinterpret_cast(destination), flags); } /// @brief Easy to use API for creating a MidHook. /// @param target the address of the function to hook. /// @param destination The destination function. +/// @param flags The flags to use. /// @return The MidHook object. -[[nodiscard]] MidHook create_mid(void* target, MidHookFn destination); +[[nodiscard]] MidHook create_mid(void* target, MidHookFn destination, MidHook::Flags = MidHook::Default); /// @brief Easy to use API for creating a MidHook. /// @param target the address of the function to hook. /// @param destination The destination function. +/// @param flags The flags to use. /// @return The MidHook object. -[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination) { - return create_mid(reinterpret_cast(target), destination); +[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination, MidHook::Flags flags = MidHook::Default) { + return create_mid(reinterpret_cast(target), destination, flags); } /// @brief Easy to use API for creating a VmtHook. diff --git a/include/safetyhook/inline_hook.hpp b/include/safetyhook/inline_hook.hpp index 0ddac53..cb7e2be 100644 --- a/include/safetyhook/inline_hook.hpp +++ b/include/safetyhook/inline_hook.hpp @@ -87,42 +87,54 @@ class InlineHook final { [[nodiscard]] static Error not_enough_space(uint8_t* ip) { return {.type = NOT_ENOUGH_SPACE, .ip = ip}; } }; + /// @brief Flags for InlineHook. + enum Flags : int { + Default = 0, ///< Default flags. + StartDisabled = 1 << 0, ///< Start the hook disabled. + }; + /// @brief Create an inline hook. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). - [[nodiscard]] static std::expected create(void* target, void* destination); + [[nodiscard]] static std::expected create( + void* target, void* destination, Flags flags = Default); /// @brief Create an inline hook. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). - [[nodiscard]] static std::expected create(FnPtr auto target, FnPtr auto destination) { - return create(reinterpret_cast(target), reinterpret_cast(destination)); + [[nodiscard]] static std::expected create( + FnPtr auto target, FnPtr auto destination, Flags flags = Default) { + return create(reinterpret_cast(target), reinterpret_cast(destination), flags); } /// @brief Create an inline hook with a given Allocator. /// @param allocator The allocator to use. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, void* target, void* destination); + const std::shared_ptr& allocator, void* target, void* destination, Flags flags = Default); /// @brief Create an inline hook with a given Allocator. /// @param allocator The allocator to use. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, FnPtr auto target, FnPtr auto destination) { - return create(allocator, reinterpret_cast(target), reinterpret_cast(destination)); + const std::shared_ptr& allocator, FnPtr auto target, FnPtr auto destination, Flags flags = Default) { + return create(allocator, reinterpret_cast(target), reinterpret_cast(destination), flags); } InlineHook() = default; @@ -285,15 +297,32 @@ class InlineHook final { return original()(args...); } + /// @brief Enable the hook. + [[nodiscard]] std::expected enable(); + + /// @brief Disable the hook. + [[nodiscard]] std::expected disable(); + + /// @brief Check if the hook is enabled. + [[nodiscard]] bool enabled() const { return m_enabled; } + private: friend class MidHook; + enum class Type { + Unset, + E9, + FF, + }; + uint8_t* m_target{}; uint8_t* m_destination{}; Allocation m_trampoline{}; std::vector m_original_bytes{}; uintptr_t m_trampoline_size{}; std::recursive_mutex m_mutex{}; + bool m_enabled{}; + Type m_type{Type::Unset}; std::expected setup( const std::shared_ptr& allocator, uint8_t* target, uint8_t* destination); diff --git a/include/safetyhook/mid_hook.hpp b/include/safetyhook/mid_hook.hpp index a5ad4f8..7b2fb79 100644 --- a/include/safetyhook/mid_hook.hpp +++ b/include/safetyhook/mid_hook.hpp @@ -52,43 +52,55 @@ class MidHook final { } }; + /// @brief Flags for MidHook. + enum Flags : int { + Default = 0, ///< Default flags. + StartDisabled = 1, ///< Start the hook disabled. + }; + /// @brief Creates a new MidHook object. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - [[nodiscard]] static std::expected create(void* target, MidHookFn destination_fn); + [[nodiscard]] static std::expected create( + void* target, MidHookFn destination_fn, Flags flags = Default); /// @brief Creates a new MidHook object. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - [[nodiscard]] static std::expected create(FnPtr auto target, MidHookFn destination_fn) { - return create(reinterpret_cast(target), destination_fn); + [[nodiscard]] static std::expected create( + FnPtr auto target, MidHookFn destination_fn, Flags flags = Default) { + return create(reinterpret_cast(target), destination_fn, flags); } /// @brief Creates a new MidHook object with a given Allocator. /// @param allocator The Allocator to use. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, void* target, MidHookFn destination_fn); + const std::shared_ptr& allocator, void* target, MidHookFn destination_fn, Flags flags = Default); /// @brief Creates a new MidHook object with a given Allocator. /// @tparam T The type of the function to hook. /// @param allocator The Allocator to use. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, FnPtr auto target, MidHookFn destination_fn) { - return create(allocator, reinterpret_cast(target), destination_fn); + [[nodiscard]] static std::expected create(const std::shared_ptr& allocator, + FnPtr auto target, MidHookFn destination_fn, Flags flags = Default) { + return create(allocator, reinterpret_cast(target), destination_fn, flags); } MidHook() = default; @@ -123,6 +135,15 @@ class MidHook final { /// @return true if the hook is valid, false otherwise. explicit operator bool() const { return static_cast(m_stub); } + /// @brief Enable the hook. + [[nodiscard]] std::expected enable(); + + /// @brief Disable the hook. + [[nodiscard]] std::expected disable(); + + /// @brief Check if the hook is enabled. + [[nodiscard]] bool enabled() const { return m_hook.enabled(); } + private: InlineHook m_hook{}; uint8_t* m_target{}; diff --git a/include/safetyhook/os.hpp b/include/safetyhook/os.hpp index 5906df9..abbfc5c 100644 --- a/include/safetyhook/os.hpp +++ b/include/safetyhook/os.hpp @@ -63,19 +63,9 @@ struct SystemInfo { SystemInfo system_info(); -using ThreadId = uint32_t; -using ThreadHandle = void*; using ThreadContext = void*; -/// @brief Executes a function while all other threads are frozen. Also allows for visiting each frozen thread and -/// modifying it's context. -/// @param run_fn The function to run while all other threads are frozen. -/// @param visit_fn The function that will be called for each frozen thread. -/// @note The visit function will be called in the order that the threads were frozen. -/// @note The visit function will be called before the run function. -/// @note Keep the logic inside run_fn and visit_fn as simple as possible to avoid deadlocks. -void execute_while_frozen(const std::function& run_fn, - const std::function& visit_fn = {}); +void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn); /// @brief Will modify the context of a thread's IP to point to a new address if its IP is at the old address. /// @param ctx The thread context to modify. diff --git a/src/allocator.cpp b/src/allocator.cpp index eaa3d84..4a46f1c 100644 --- a/src/allocator.cpp +++ b/src/allocator.cpp @@ -5,6 +5,8 @@ #include "safetyhook/os.hpp" #include "safetyhook/utility.hpp" +#include "safetyhook/utility.hpp" + #include "safetyhook/allocator.hpp" namespace safetyhook { diff --git a/src/easy.cpp b/src/easy.cpp index 65efb06..1f92b7c 100644 --- a/src/easy.cpp +++ b/src/easy.cpp @@ -1,16 +1,16 @@ #include "safetyhook/easy.hpp" namespace safetyhook { -InlineHook create_inline(void* target, void* destination) { - if (auto hook = InlineHook::create(target, destination)) { +InlineHook create_inline(void* target, void* destination, InlineHook::Flags flags) { + if (auto hook = InlineHook::create(target, destination, flags)) { return std::move(*hook); } else { return {}; } } -MidHook create_mid(void* target, MidHookFn destination) { - if (auto hook = MidHook::create(target, destination)) { +MidHook create_mid(void* target, MidHookFn destination, MidHook::Flags flags) { + if (auto hook = MidHook::create(target, destination, flags)) { return std::move(*hook); } else { return {}; diff --git a/src/inline_hook.cpp b/src/inline_hook.cpp index 76657e8..dea6a81 100644 --- a/src/inline_hook.cpp +++ b/src/inline_hook.cpp @@ -64,12 +64,6 @@ static auto make_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data) { return std::unexpected{InlineHook::Error::not_enough_space(dst)}; } - auto um = unprotect(src, size); - - if (!um) { - return std::unexpected{InlineHook::Error::failed_to_unprotect(src)}; - } - if (size > sizeof(JmpFF)) { std::fill_n(src, size, static_cast(0x90)); } @@ -94,12 +88,6 @@ constexpr auto make_jmp_e9(uint8_t* src, uint8_t* dst) { return std::unexpected{InlineHook::Error::not_enough_space(dst)}; } - auto um = unprotect(src, size); - - if (!um) { - return std::unexpected{InlineHook::Error::failed_to_unprotect(src)}; - } - if (size > sizeof(JmpE9)) { std::fill_n(src, size, static_cast(0x90)); } @@ -126,12 +114,12 @@ static bool decode(ZydisDecodedInstruction* ix, uint8_t* ip) { return ZYAN_SUCCESS(ZydisDecoderDecodeInstruction(&decoder, nullptr, ip, 15, ix)); } -std::expected InlineHook::create(void* target, void* destination) { - return create(Allocator::global(), target, destination); +std::expected InlineHook::create(void* target, void* destination, Flags flags) { + return create(Allocator::global(), target, destination, flags); } std::expected InlineHook::create( - const std::shared_ptr& allocator, void* target, void* destination) { + const std::shared_ptr& allocator, void* target, void* destination, Flags flags) { InlineHook hook{}; if (const auto setup_result = @@ -140,6 +128,12 @@ std::expected InlineHook::create( return std::unexpected{setup_result.error()}; } + if (!(flags & StartDisabled)) { + if (auto enable_result = hook.enable(); !enable_result) { + return std::unexpected{enable_result.error()}; + } + } + return hook; } @@ -158,10 +152,14 @@ InlineHook& InlineHook::operator=(InlineHook&& other) noexcept { m_trampoline = std::move(other.m_trampoline); m_trampoline_size = other.m_trampoline_size; m_original_bytes = std::move(other.m_original_bytes); + m_enabled = other.m_enabled; + m_type = other.m_type; other.m_target = nullptr; other.m_destination = nullptr; other.m_trampoline_size = 0; + other.m_enabled = false; + other.m_type = Type::Unset; } return *this; @@ -317,26 +315,7 @@ std::expected InlineHook::e9_hook(const std::shared_ptr } #endif - std::optional error; - - // jmp from original to trampoline. - execute_while_frozen( - [this, &trampoline_epilogue, &error] { - if (auto result = emit_jmp_e9(m_target, - reinterpret_cast(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size()); - !result) { - error = result.error(); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_target + i, m_trampoline.data() + i); - } - }); - - if (error) { - return std::unexpected{*error}; - } + m_type = Type::E9; return {}; } @@ -385,49 +364,77 @@ std::expected InlineHook::ff_hook(const std::shared_ptr return std::unexpected{result.error()}; } + m_type = Type::FF; + + return {}; +} +#endif + +std::expected InlineHook::enable() { + std::scoped_lock lock{m_mutex}; + + if (m_enabled) { + return {}; + } + std::optional error; // jmp from original to trampoline. - execute_while_frozen( - [this, &error] { - if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); + trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &error] { + if (m_type == Type::E9) { + auto trampoline_epilogue = reinterpret_cast( + m_trampoline.address() + m_trampoline_size - sizeof(TrampolineEpilogueE9)); + + if (auto result = emit_jmp_e9(m_target, + reinterpret_cast(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size()); !result) { error = result.error(); } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_target + i, m_trampoline.data() + i); + } + +#if SAFETYHOOK_ARCH_X86_64 + if (m_type == Type::FF) { + if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); + !result) { + error = result.error(); } - }); + } +#endif + }); if (error) { return std::unexpected{*error}; } + m_enabled = true; + + return {}; +} + +std::expected InlineHook::disable() { + std::scoped_lock lock{m_mutex}; + + if (!m_enabled) { + return {}; + } + + trap_threads(m_trampoline.data(), m_target, m_original_bytes.size(), + [this] { std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); }); + + m_enabled = false; + return {}; } -#endif void InlineHook::destroy() { + [[maybe_unused]] auto disable_result = disable(); + std::scoped_lock lock{m_mutex}; if (!m_trampoline) { return; } - execute_while_frozen( - [this] { - if (auto um = unprotect(m_target, m_original_bytes.size())) { - std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_trampoline.data() + i, m_target + i); - } - }); - m_trampoline.free(); } } // namespace safetyhook diff --git a/src/mid_hook.cpp b/src/mid_hook.cpp index 3284ff4..339bef8 100644 --- a/src/mid_hook.cpp +++ b/src/mid_hook.cpp @@ -68,12 +68,12 @@ constexpr std::array asm_data = {0xFF, 0x35, 0xA7, 0x00, 0x00, 0x0 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; #endif -std::expected MidHook::create(void* target, MidHookFn destination) { - return create(Allocator::global(), target, destination); +std::expected MidHook::create(void* target, MidHookFn destination, Flags flags) { + return create(Allocator::global(), target, destination, flags); } std::expected MidHook::create( - const std::shared_ptr& allocator, void* target, MidHookFn destination) { + const std::shared_ptr& allocator, void* target, MidHookFn destination, Flags flags) { MidHook hook{}; if (const auto setup_result = hook.setup(allocator, reinterpret_cast(target), destination); @@ -81,6 +81,12 @@ std::expected MidHook::create( return std::unexpected{setup_result.error()}; } + if (!(flags & StartDisabled)) { + if (auto enable_result = hook.enable(); !enable_result) { + return std::unexpected{enable_result.error()}; + } + } + return hook; } @@ -131,7 +137,7 @@ std::expected MidHook::setup( store(m_stub.data() + 0x59, m_stub.data() + m_stub.size() - 8); #endif - auto hook_result = InlineHook::create(allocator, m_target, m_stub.data()); + auto hook_result = InlineHook::create(allocator, m_target, m_stub.data(), InlineHook::StartDisabled); if (!hook_result) { m_stub.free(); @@ -148,4 +154,20 @@ std::expected MidHook::setup( return {}; } + +std::expected MidHook::enable() { + if (auto enable_result = m_hook.enable(); !enable_result) { + return std::unexpected{Error::bad_inline_hook(enable_result.error())}; + } + + return {}; +} + +std::expected MidHook::disable() { + if (auto disable_result = m_hook.disable(); !disable_result) { + return std::unexpected{Error::bad_inline_hook(disable_result.error())}; + } + + return {}; +} } // namespace safetyhook diff --git a/src/os.linux.cpp b/src/os.linux.cpp index 17ce09d..f6dcee1 100644 --- a/src/os.linux.cpp +++ b/src/os.linux.cpp @@ -185,9 +185,13 @@ SystemInfo system_info() { }; } -void execute_while_frozen(const std::function& run_fn, - [[maybe_unused]] const std::function& visit_fn) { +void trap_threads([[maybe_unused]] uint8_t* from, [[maybe_unused]] uint8_t* to, [[maybe_unused]] size_t len, + const std::function& run_fn) { + auto from_protect = vm_protect(from, len, VM_ACCESS_RWX).value_or(0); + auto to_protect = vm_protect(to, len, VM_ACCESS_RWX).value_or(0); run_fn(); + vm_protect(to, len, to_protect); + vm_protect(from, len, from_protect); } void fix_ip([[maybe_unused]] ThreadContext ctx, [[maybe_unused]] uint8_t* old_ip, [[maybe_unused]] uint8_t* new_ip) { diff --git a/src/os.windows.cpp b/src/os.windows.cpp index 8fd4f5b..a7a4975 100644 --- a/src/os.windows.cpp +++ b/src/os.windows.cpp @@ -1,4 +1,9 @@ +#include +#include +#include + #include "safetyhook/common.hpp" +#include "safetyhook/utility.hpp" #if SAFETYHOOK_OS_WINDOWS @@ -11,19 +16,8 @@ #error "Windows.h not found" #endif -#include - #include "safetyhook/os.hpp" -#pragma comment(lib, "ntdll") - -extern "C" { -NTSTATUS -NTAPI -NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAccess, ULONG HandleAttributes, - ULONG Flags, PHANDLE NewThreadHandle); -} - namespace safetyhook { std::expected vm_allocate(uint8_t* address, size_t size, VmAccess access) { DWORD protect = 0; @@ -158,104 +152,144 @@ SystemInfo system_info() { return info; } -void execute_while_frozen( - const std::function& run_fn, const std::function& visit_fn) { - // Freeze all threads. - int num_threads_frozen; - auto first_run = true; +struct TrapInfo { + uint8_t* from_page_start; + uint8_t* from_page_end; + uint8_t* from; + uint8_t* to_page_start; + uint8_t* to_page_end; + uint8_t* to; + size_t len; +}; + +class TrapManager final { +public: + static std::mutex mutex; + static std::unique_ptr instance; + + TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); } + ~TrapManager() { + if (m_trap_veh != nullptr) { + RemoveVectoredExceptionHandler(m_trap_veh); + } + } - do { - num_threads_frozen = 0; - HANDLE thread{}; + TrapInfo* find_trap(uint8_t* address) { + auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { + return address >= trap.second.from && address < trap.second.from + trap.second.len; + }); - while (true) { - HANDLE next_thread{}; - const auto status = NtGetNextThread(GetCurrentProcess(), thread, - THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, - 0, &next_thread); + if (search == m_traps.end()) { + return nullptr; + } - if (thread != nullptr) { - CloseHandle(thread); - } + return &search->second; + } - if (!NT_SUCCESS(status)) { - break; - } + TrapInfo* find_trap_page(uint8_t* address) { + auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { + return address >= trap.second.from_page_start && address < trap.second.from_page_end; + }); - thread = next_thread; + if (search != m_traps.end()) { + return &search->second; + } - const auto thread_id = GetThreadId(thread); + search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { + return address >= trap.second.to_page_start && address < trap.second.to_page_end; + }); - if (thread_id == 0 || thread_id == GetCurrentThreadId()) { - continue; - } + if (search != m_traps.end()) { + return &search->second; + } - const auto suspend_count = SuspendThread(thread); + return nullptr; + } - if (suspend_count == static_cast(-1)) { - continue; - } + void add_trap(uint8_t* from, uint8_t* to, size_t len) { + m_traps.insert_or_assign(from, TrapInfo{.from_page_start = align_down(from, 0x1000), + .from_page_end = align_up(from + len, 0x1000), + .from = from, + .to_page_start = align_down(to, 0x1000), + .to_page_end = align_up(to + len, 0x1000), + .to = to, + .len = len}); + } - // Check if the thread was already frozen. Only resume if the thread was already frozen, and it wasn't the - // first run of this freeze loop to account for threads that may have already been frozen for other reasons. - if (suspend_count != 0 && !first_run) { - ResumeThread(thread); - continue; - } +private: + std::map m_traps; + PVOID m_trap_veh{}; - CONTEXT thread_ctx{}; + static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) { + auto exception_code = exp->ExceptionRecord->ExceptionCode; - thread_ctx.ContextFlags = CONTEXT_FULL; + if (exception_code != EXCEPTION_ACCESS_VIOLATION) { + return EXCEPTION_CONTINUE_SEARCH; + } - if (GetThreadContext(thread, &thread_ctx) == FALSE) { - continue; - } + std::scoped_lock lock{mutex}; + auto* faulting_address = reinterpret_cast(exp->ExceptionRecord->ExceptionInformation[1]); + auto* trap = instance->find_trap(faulting_address); - if (visit_fn) { - visit_fn(static_cast(thread_id), static_cast(thread), - static_cast(&thread_ctx)); + if (trap == nullptr) { + if (instance->find_trap_page(faulting_address) != nullptr) { + return EXCEPTION_CONTINUE_EXECUTION; + } else { + return EXCEPTION_CONTINUE_SEARCH; } + } - SetThreadContext(thread, &thread_ctx); + auto* ctx = exp->ContextRecord; - ++num_threads_frozen; + for (size_t i = 0; i < trap->len; i++) { + fix_ip(ctx, trap->from + i, trap->to + i); } - first_run = false; - } while (num_threads_frozen != 0); - - // Run the function. - if (run_fn) { - run_fn(); + return EXCEPTION_CONTINUE_EXECUTION; } +}; - // Resume all threads. - HANDLE thread{}; +std::mutex TrapManager::mutex; +std::unique_ptr TrapManager::instance; - while (true) { - HANDLE next_thread{}; - const auto status = NtGetNextThread(GetCurrentProcess(), thread, - THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, 0, - &next_thread); +void find_me() { +} - if (thread != nullptr) { - CloseHandle(thread); - } +void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn) { + MEMORY_BASIC_INFORMATION find_me_mbi{}; + MEMORY_BASIC_INFORMATION from_mbi{}; + MEMORY_BASIC_INFORMATION to_mbi{}; - if (!NT_SUCCESS(status)) { - break; - } + VirtualQuery(reinterpret_cast(find_me), &find_me_mbi, sizeof(find_me_mbi)); + VirtualQuery(from, &from_mbi, sizeof(from_mbi)); + VirtualQuery(to, &to_mbi, sizeof(to_mbi)); + + auto new_protect = PAGE_READWRITE; - thread = next_thread; + if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) { + new_protect = PAGE_EXECUTE_READWRITE; + } - const auto thread_id = GetThreadId(thread); + std::scoped_lock lock{TrapManager::mutex}; - if (thread_id == 0 || thread_id == GetCurrentThreadId()) { - continue; - } + if (TrapManager::instance == nullptr) { + TrapManager::instance = std::make_unique(); + } + + TrapManager::instance->add_trap(from, to, len); + + DWORD from_protect; + DWORD to_protect; - ResumeThread(thread); + VirtualProtect(from, len, new_protect, &from_protect); + VirtualProtect(to, len, new_protect, &to_protect); + + if (run_fn) { + run_fn(); } + + VirtualProtect(to, len, to_protect, &to_protect); + VirtualProtect(from, len, from_protect, &from_protect); } void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) { diff --git a/src/vmt_hook.cpp b/src/vmt_hook.cpp index 7831703..5d7b01b 100644 --- a/src/vmt_hook.cpp +++ b/src/vmt_hook.cpp @@ -104,17 +104,17 @@ void VmtHook::remove(void* object) { const auto original_vmt = search->second; - execute_while_frozen([&] { - if (!vm_is_writable(reinterpret_cast(object), sizeof(void*))) { - return; - } + if (!vm_is_writable(reinterpret_cast(object), sizeof(void*))) { + m_objects.erase(search); + return; + } - if (*reinterpret_cast(object) != &m_new_vmt[1]) { - return; - } + if (*reinterpret_cast(object) != &m_new_vmt[1]) { + m_objects.erase(search); + return; + } - *reinterpret_cast(object) = original_vmt; - }); + *reinterpret_cast(object) = original_vmt; m_objects.erase(search); } @@ -124,22 +124,20 @@ void VmtHook::reset() { } void VmtHook::destroy() { - execute_while_frozen([this] { - for (const auto [object, original_vmt] : m_objects) { - if (!vm_is_writable(reinterpret_cast(object), sizeof(void*))) { - return; - } - - if (*reinterpret_cast(object) != &m_new_vmt[1]) { - return; - } + for (const auto [object, original_vmt] : m_objects) { + if (!vm_is_writable(reinterpret_cast(object), sizeof(void*))) { + continue; + } - *reinterpret_cast(object) = original_vmt; + if (*reinterpret_cast(object) != &m_new_vmt[1]) { + continue; } - }); + + *reinterpret_cast(object) = original_vmt; + } m_objects.clear(); m_new_vmt_allocation.reset(); m_new_vmt = nullptr; } -} // namespace safetyhook \ No newline at end of file +} // namespace safetyhook diff --git a/test/inline_hook.cpp b/test/inline_hook.cpp index 5a45af2..807c052 100644 --- a/test/inline_hook.cpp +++ b/test/inline_hook.cpp @@ -623,4 +623,51 @@ static suite<"inline hook"> inline_hook_tests = [] { expect(fn() == 42_i); }; + + "Function hook can be enable and disabled"_test = [] { + struct Target { + SAFETYHOOK_NOINLINE static int fn(int a) { + volatile int b = a; + return b * 2; + } + }; + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + + static SafetyHookInline hook; + + struct Hook { + static int fn(int a) { return hook.call(a + 1); } + }; + + auto hook0_result = SafetyHookInline::create(Target::fn, Hook::fn, SafetyHookInline::StartDisabled); + + expect(hook0_result.has_value()); + + hook = std::move(*hook0_result); + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + + expect(hook.enable().has_value()); + + expect(Target::fn(1) == 4_i); + expect(Target::fn(2) == 6_i); + expect(Target::fn(3) == 8_i); + + expect(hook.disable().has_value()); + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + + hook.reset(); + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + }; }; \ No newline at end of file diff --git a/test/mid_hook.cpp b/test/mid_hook.cpp index 8560f3f..5441d21 100644 --- a/test/mid_hook.cpp +++ b/test/mid_hook.cpp @@ -71,4 +71,65 @@ static suite<"mid hook"> mid_hook_tests = [] { expect(Target::add_42(2.0f) == 2.42_f); }; #endif + + "Mid hook enable and disable"_test = [] { + struct Target { + SAFETYHOOK_NOINLINE static int SAFETYHOOK_FASTCALL add_42(int a) { + volatile int b = a; + return b + 42; + } + }; + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + + static SafetyHookMid hook; + + struct Hook { + static void add_42(SafetyHookContext& ctx) { +#if SAFETYHOOK_OS_WINDOWS +#if SAFETYHOOK_ARCH_X86_64 + ctx.rcx = 1337 - 42; +#elif SAFETYHOOK_ARCH_X86_32 + ctx.ecx = 1337 - 42; +#endif +#elif SAFETYHOOK_OS_LINUX +#if SAFETYHOOK_ARCH_X86_64 + ctx.rdi = 1337 - 42; +#elif SAFETYHOOK_ARCH_X86_32 + ctx.edi = 1337 - 42; +#endif +#endif + } + }; + + auto hook_result = SafetyHookMid::create(Target::add_42, Hook::add_42, SafetyHookMid::StartDisabled); + + expect(hook_result.has_value()); + + hook = std::move(*hook_result); + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + + expect(hook.enable().has_value()); + + expect(Target::add_42(1) == 1337_i); + expect(Target::add_42(2) == 1337_i); + expect(Target::add_42(3) == 1337_i); + + expect(hook.disable().has_value()); + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + + hook.reset(); + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + }; }; \ No newline at end of file