From ca071f3a238712b89072628c979ce668248dd3b4 Mon Sep 17 00:00:00 2001 From: cursey Date: Sat, 10 Feb 2024 01:10:38 -0800 Subject: [PATCH] ThreadFreezer: Replace with POC thread trapping implementation --- include/safetyhook/thread_freezer.hpp | 10 +- include/safetyhook/utility.hpp | 13 ++ src/allocator.cpp | 14 +- src/inline_hook.cpp | 63 +++------ src/thread_freezer.cpp | 183 +++++++++++++++----------- src/vmt_hook.cpp | 38 +++--- 6 files changed, 152 insertions(+), 169 deletions(-) diff --git a/include/safetyhook/thread_freezer.hpp b/include/safetyhook/thread_freezer.hpp index 74ea23d..7493a82 100644 --- a/include/safetyhook/thread_freezer.hpp +++ b/include/safetyhook/thread_freezer.hpp @@ -11,15 +11,7 @@ using ThreadId = uint32_t; using ThreadHandle = void*; using ThreadContext = void*; -/// @brief Executes a function while all other threads are frozen. Also allows for visiting each frozen thread and -/// modifying it's context. -/// @param run_fn The function to run while all other threads are frozen. -/// @param visit_fn The function that will be called for each frozen thread. -/// @note The visit function will be called in the order that the threads were frozen. -/// @note The visit function will be called before the run function. -/// @note Keep the logic inside run_fn and visit_fn as simple as possible to avoid deadlocks. -void execute_while_frozen(const std::function& run_fn, - const std::function& visit_fn = {}); +void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn); /// @brief Will modify the context of a thread's IP to point to a new address if its IP is at the old address. /// @param ctx The thread context to modify. diff --git a/include/safetyhook/utility.hpp b/include/safetyhook/utility.hpp index 933b338..cfabb0e 100644 --- a/include/safetyhook/utility.hpp +++ b/include/safetyhook/utility.hpp @@ -35,4 +35,17 @@ class UnprotectMemory { }; [[nodiscard]] std::optional unprotect(uint8_t* address, size_t size); + +template constexpr T align_up(T address, size_t align) { + const auto unaligned_address = (uintptr_t)address; + const auto aligned_address = (unaligned_address + align - 1) & ~(align - 1); + return (T)aligned_address; +} + +template constexpr T align_down(T address, size_t align) { + const auto unaligned_address = (uintptr_t)address; + const auto aligned_address = unaligned_address & ~(align - 1); + return (T)aligned_address; +} + } // namespace safetyhook diff --git a/src/allocator.cpp b/src/allocator.cpp index df2d6ce..98ecc21 100644 --- a/src/allocator.cpp +++ b/src/allocator.cpp @@ -11,21 +11,11 @@ #error "Windows.h not found" #endif +#include "safetyhook/utility.hpp" + #include "safetyhook/allocator.hpp" namespace safetyhook { -template constexpr T align_up(T address, size_t align) { - const auto unaligned_address = (uintptr_t)address; - const auto aligned_address = (unaligned_address + align - 1) & ~(align - 1); - return (T)aligned_address; -} - -template constexpr T align_down(T address, size_t align) { - const auto unaligned_address = (uintptr_t)address; - const auto aligned_address = unaligned_address & ~(align - 1); - return (T)aligned_address; -} - Allocation::Allocation(Allocation&& other) noexcept { *this = std::move(other); } diff --git a/src/inline_hook.cpp b/src/inline_hook.cpp index 35de774..c46382c 100644 --- a/src/inline_hook.cpp +++ b/src/inline_hook.cpp @@ -72,12 +72,6 @@ static auto make_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data) { return std::unexpected{InlineHook::Error::not_enough_space(dst)}; } - auto um = unprotect(src, size); - - if (!um) { - return std::unexpected{InlineHook::Error::failed_to_unprotect(src)}; - } - if (size > sizeof(JmpFF)) { std::fill_n(src, size, static_cast(0x90)); } @@ -102,12 +96,6 @@ constexpr auto make_jmp_e9(uint8_t* src, uint8_t* dst) { return std::unexpected{InlineHook::Error::not_enough_space(dst)}; } - auto um = unprotect(src, size); - - if (!um) { - return std::unexpected{InlineHook::Error::failed_to_unprotect(src)}; - } - if (size > sizeof(JmpE9)) { std::fill_n(src, size, static_cast(0x90)); } @@ -328,19 +316,13 @@ std::expected InlineHook::e9_hook(const std::shared_ptr std::optional error; // jmp from original to trampoline. - execute_while_frozen( - [this, &trampoline_epilogue, &error] { - if (auto result = emit_jmp_e9(m_target, - reinterpret_cast(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size()); - !result) { - error = result.error(); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_target + i, m_trampoline.data() + i); - } - }); + trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &trampoline_epilogue, &error] { + if (auto result = emit_jmp_e9(m_target, reinterpret_cast(&trampoline_epilogue->jmp_to_destination), + m_original_bytes.size()); + !result) { + error = result.error(); + } + }); if (error) { return std::unexpected{*error}; @@ -396,18 +378,12 @@ std::expected InlineHook::ff_hook(const std::shared_ptr std::optional error; // jmp from original to trampoline. - execute_while_frozen( - [this, &error] { - if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); - !result) { - error = result.error(); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_target + i, m_trampoline.data() + i); - } - }); + trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &error] { + if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); + !result) { + error = result.error(); + } + }); if (error) { return std::unexpected{*error}; @@ -424,17 +400,8 @@ void InlineHook::destroy() { return; } - execute_while_frozen( - [this] { - if (auto um = unprotect(m_target, m_original_bytes.size())) { - std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); - } - }, - [this](auto, auto, auto ctx) { - for (size_t i = 0; i < m_original_bytes.size(); ++i) { - fix_ip(ctx, m_trampoline.data() + i, m_target + i); - } - }); + trap_threads(m_trampoline.data(), m_target, m_original_bytes.size(), + [this] { std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); }); m_trampoline.free(); } diff --git a/src/thread_freezer.cpp b/src/thread_freezer.cpp index b9b3ae0..5e24389 100644 --- a/src/thread_freezer.cpp +++ b/src/thread_freezer.cpp @@ -1,3 +1,6 @@ +#include +#include + #if __has_include() #include #elif __has_include() @@ -5,118 +8,138 @@ #else #error "Windows.h not found" #endif -#include #include "safetyhook/common.hpp" +#include "safetyhook/utility.hpp" #include "safetyhook/thread_freezer.hpp" -#pragma comment(lib, "ntdll") +namespace safetyhook { +struct TrapInfo { + uint8_t* page_start; + uint8_t* page_end; + uint8_t* from; + uint8_t* to; + size_t len; +}; + +class TrapManager { +public: + static std::mutex mutex; + static std::unique_ptr instance; + + TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); } + ~TrapManager() { + if (m_trap_veh != nullptr) { + RemoveVectoredExceptionHandler(m_trap_veh); + } + } -extern "C" { -NTSTATUS -NTAPI -NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAccess, ULONG HandleAttributes, - ULONG Flags, PHANDLE NewThreadHandle); -} + TrapInfo* find_trap(uint8_t* address) { + auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) { + return address >= trap.second.from && address < trap.second.from + trap.second.len; + }); -namespace safetyhook { -void execute_while_frozen( - const std::function& run_fn, const std::function& visit_fn) { - // Freeze all threads. - int num_threads_frozen; - auto first_run = true; - - do { - num_threads_frozen = 0; - HANDLE thread{}; - - while (true) { - HANDLE next_thread{}; - const auto status = NtGetNextThread(GetCurrentProcess(), thread, - THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, - 0, &next_thread); - - if (thread != nullptr) { - CloseHandle(thread); - } + if (search == m_traps.end()) { + return nullptr; + } - if (!NT_SUCCESS(status)) { - break; - } + return &search->second; + } - thread = next_thread; + TrapInfo* find_trap_page(uint8_t* address) { + auto search = std::find_if(m_traps.begin(), m_traps.end(), + [address](auto& trap) { return address >= trap.second.page_start && address < trap.second.page_end; }); - const auto thread_id = GetThreadId(thread); + if (search == m_traps.end()) { + return nullptr; + } - if (thread_id == 0 || thread_id == GetCurrentThreadId()) { - continue; - } + return &search->second; + } - const auto suspend_count = SuspendThread(thread); + void add_trap(uint8_t* from, uint8_t* to, size_t len) { + m_traps.insert_or_assign(from, TrapInfo{.page_start = align_down(from, 0x1000), + .page_end = align_up(from + len, 0x1000), + .from = from, + .to = to, + .len = len}); + } - if (suspend_count == static_cast(-1)) { - continue; - } +private: + std::map m_traps; + PVOID m_trap_veh{}; - // Check if the thread was already frozen. Only resume if the thread was already frozen, and it wasn't the - // first run of this freeze loop to account for threads that may have already been frozen for other reasons. - if (suspend_count != 0 && !first_run) { - ResumeThread(thread); - continue; - } + static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) { + auto exception_code = exp->ExceptionRecord->ExceptionCode; - CONTEXT thread_ctx{}; + if (exception_code != EXCEPTION_ACCESS_VIOLATION) { + return EXCEPTION_CONTINUE_SEARCH; + } - thread_ctx.ContextFlags = CONTEXT_FULL; + std::scoped_lock lock{mutex}; + auto* faulting_address = reinterpret_cast(exp->ExceptionRecord->ExceptionInformation[1]); + auto* trap = instance->find_trap(faulting_address); - if (GetThreadContext(thread, &thread_ctx) == FALSE) { - continue; + if (trap == nullptr) { + if (instance->find_trap_page(faulting_address) != nullptr) { + return EXCEPTION_CONTINUE_EXECUTION; + } else { + return EXCEPTION_CONTINUE_SEARCH; } + } - if (visit_fn) { - visit_fn(static_cast(thread_id), static_cast(thread), - static_cast(&thread_ctx)); - } + auto* ctx = exp->ContextRecord; - ++num_threads_frozen; + for (size_t i = 0; i < trap->len; i++) { + fix_ip(ctx, trap->from + i, trap->to + i); } - first_run = false; - } while (num_threads_frozen != 0); - - // Run the function. - if (run_fn) { - run_fn(); + return EXCEPTION_CONTINUE_EXECUTION; } +}; - // Resume all threads. - HANDLE thread{}; +std::mutex TrapManager::mutex; +std::unique_ptr TrapManager::instance; - while (true) { - HANDLE next_thread{}; - const auto status = NtGetNextThread(GetCurrentProcess(), thread, - THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, 0, - &next_thread); +void find_me() { +} - if (thread != nullptr) { - CloseHandle(thread); - } +void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function& run_fn) { + MEMORY_BASIC_INFORMATION find_me_mbi{}; + MEMORY_BASIC_INFORMATION from_mbi{}; + MEMORY_BASIC_INFORMATION to_mbi{}; - if (!NT_SUCCESS(status)) { - break; - } + VirtualQuery(reinterpret_cast(find_me), &find_me_mbi, sizeof(find_me_mbi)); + VirtualQuery(from, &from_mbi, sizeof(from_mbi)); + VirtualQuery(to, &to_mbi, sizeof(to_mbi)); - thread = next_thread; + auto new_protect = PAGE_READWRITE; - const auto thread_id = GetThreadId(thread); + if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) { + new_protect = PAGE_EXECUTE_READWRITE; + } - if (thread_id == 0 || thread_id == GetCurrentThreadId()) { - continue; - } + std::scoped_lock lock{TrapManager::mutex}; - ResumeThread(thread); + if (TrapManager::instance == nullptr) { + TrapManager::instance = std::make_unique(); } + + TrapManager::instance->add_trap(from, to, len); + + DWORD from_protect; + DWORD to_protect; + + VirtualProtect(from, len, new_protect, &from_protect); + VirtualProtect(to, len, new_protect, &to_protect); + + if (run_fn) { + run_fn(); + } + + VirtualProtect(to, len, to_protect, &to_protect); + VirtualProtect(from, len, from_protect, &from_protect); } void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) { diff --git a/src/vmt_hook.cpp b/src/vmt_hook.cpp index d62d74e..84f9b5c 100644 --- a/src/vmt_hook.cpp +++ b/src/vmt_hook.cpp @@ -112,17 +112,17 @@ void VmtHook::remove(void* object) { const auto original_vmt = search->second; - execute_while_frozen([&] { - if (IsBadWritePtr(object, sizeof(void*))) { - return; - } + if (IsBadWritePtr(object, sizeof(void*))) { + m_objects.erase(search); + return; + } - if (*reinterpret_cast(object) != &m_new_vmt[1]) { - return; - } + if (*reinterpret_cast(object) != &m_new_vmt[1]) { + m_objects.erase(search); + return; + } - *reinterpret_cast(object) = original_vmt; - }); + *reinterpret_cast(object) = original_vmt; m_objects.erase(search); } @@ -132,19 +132,17 @@ void VmtHook::reset() { } void VmtHook::destroy() { - execute_while_frozen([this] { - for (const auto [object, original_vmt] : m_objects) { - if (IsBadWritePtr(object, sizeof(void*))) { - return; - } - - if (*reinterpret_cast(object) != &m_new_vmt[1]) { - return; - } + for (const auto [object, original_vmt] : m_objects) { + if (IsBadWritePtr(object, sizeof(void*))) { + continue; + } - *reinterpret_cast(object) = original_vmt; + if (*reinterpret_cast(object) != &m_new_vmt[1]) { + continue; } - }); + + *reinterpret_cast(object) = original_vmt; + } m_objects.clear(); m_new_vmt_allocation.reset();