Skip to content

Commit

Permalink
ThreadFreezer: Replace with POC thread trapping implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cursey committed Feb 10, 2024
1 parent 3103d4b commit ca071f3
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 169 deletions.
10 changes: 1 addition & 9 deletions include/safetyhook/thread_freezer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@ using ThreadId = uint32_t;
using ThreadHandle = void*;
using ThreadContext = void*;

/// @brief Executes a function while all other threads are frozen. Also allows for visiting each frozen thread and
/// modifying it's context.
/// @param run_fn The function to run while all other threads are frozen.
/// @param visit_fn The function that will be called for each frozen thread.
/// @note The visit function will be called in the order that the threads were frozen.
/// @note The visit function will be called before the run function.
/// @note Keep the logic inside run_fn and visit_fn as simple as possible to avoid deadlocks.
void execute_while_frozen(const std::function<void()>& run_fn,
const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn = {});
void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function<void()>& run_fn);

/// @brief Will modify the context of a thread's IP to point to a new address if its IP is at the old address.
/// @param ctx The thread context to modify.
Expand Down
13 changes: 13 additions & 0 deletions include/safetyhook/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,17 @@ class UnprotectMemory {
};

[[nodiscard]] std::optional<UnprotectMemory> unprotect(uint8_t* address, size_t size);

template <typename T> constexpr T align_up(T address, size_t align) {
const auto unaligned_address = (uintptr_t)address;
const auto aligned_address = (unaligned_address + align - 1) & ~(align - 1);
return (T)aligned_address;
}

template <typename T> constexpr T align_down(T address, size_t align) {
const auto unaligned_address = (uintptr_t)address;
const auto aligned_address = unaligned_address & ~(align - 1);
return (T)aligned_address;
}

} // namespace safetyhook
14 changes: 2 additions & 12 deletions src/allocator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,11 @@
#error "Windows.h not found"
#endif

#include "safetyhook/utility.hpp"

#include "safetyhook/allocator.hpp"

namespace safetyhook {
template <typename T> constexpr T align_up(T address, size_t align) {
const auto unaligned_address = (uintptr_t)address;
const auto aligned_address = (unaligned_address + align - 1) & ~(align - 1);
return (T)aligned_address;
}

template <typename T> constexpr T align_down(T address, size_t align) {
const auto unaligned_address = (uintptr_t)address;
const auto aligned_address = unaligned_address & ~(align - 1);
return (T)aligned_address;
}

Allocation::Allocation(Allocation&& other) noexcept {
*this = std::move(other);
}
Expand Down
63 changes: 15 additions & 48 deletions src/inline_hook.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ static auto make_jmp_ff(uint8_t* src, uint8_t* dst, uint8_t* data) {
return std::unexpected{InlineHook::Error::not_enough_space(dst)};
}

auto um = unprotect(src, size);

if (!um) {
return std::unexpected{InlineHook::Error::failed_to_unprotect(src)};
}

if (size > sizeof(JmpFF)) {
std::fill_n(src, size, static_cast<uint8_t>(0x90));
}
Expand All @@ -102,12 +96,6 @@ constexpr auto make_jmp_e9(uint8_t* src, uint8_t* dst) {
return std::unexpected{InlineHook::Error::not_enough_space(dst)};
}

auto um = unprotect(src, size);

if (!um) {
return std::unexpected{InlineHook::Error::failed_to_unprotect(src)};
}

if (size > sizeof(JmpE9)) {
std::fill_n(src, size, static_cast<uint8_t>(0x90));
}
Expand Down Expand Up @@ -328,19 +316,13 @@ std::expected<void, InlineHook::Error> InlineHook::e9_hook(const std::shared_ptr
std::optional<Error> error;

// jmp from original to trampoline.
execute_while_frozen(
[this, &trampoline_epilogue, &error] {
if (auto result = emit_jmp_e9(m_target,
reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size());
!result) {
error = result.error();
}
},
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_target + i, m_trampoline.data() + i);
}
});
trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &trampoline_epilogue, &error] {
if (auto result = emit_jmp_e9(m_target, reinterpret_cast<uint8_t*>(&trampoline_epilogue->jmp_to_destination),
m_original_bytes.size());
!result) {
error = result.error();
}
});

if (error) {
return std::unexpected{*error};
Expand Down Expand Up @@ -396,18 +378,12 @@ std::expected<void, InlineHook::Error> InlineHook::ff_hook(const std::shared_ptr
std::optional<Error> error;

// jmp from original to trampoline.
execute_while_frozen(
[this, &error] {
if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size());
!result) {
error = result.error();
}
},
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_target + i, m_trampoline.data() + i);
}
});
trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &error] {
if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size());
!result) {
error = result.error();
}
});

if (error) {
return std::unexpected{*error};
Expand All @@ -424,17 +400,8 @@ void InlineHook::destroy() {
return;
}

execute_while_frozen(
[this] {
if (auto um = unprotect(m_target, m_original_bytes.size())) {
std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target);
}
},
[this](auto, auto, auto ctx) {
for (size_t i = 0; i < m_original_bytes.size(); ++i) {
fix_ip(ctx, m_trampoline.data() + i, m_target + i);
}
});
trap_threads(m_trampoline.data(), m_target, m_original_bytes.size(),
[this] { std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); });

m_trampoline.free();
}
Expand Down
183 changes: 103 additions & 80 deletions src/thread_freezer.cpp
Original file line number Diff line number Diff line change
@@ -1,122 +1,145 @@
#include <map>
#include <mutex>

#if __has_include(<Windows.h>)
#include <Windows.h>
#elif __has_include(<windows.h>)
#include <windows.h>
#else
#error "Windows.h not found"
#endif
#include <winternl.h>

#include "safetyhook/common.hpp"
#include "safetyhook/utility.hpp"

#include "safetyhook/thread_freezer.hpp"

#pragma comment(lib, "ntdll")
namespace safetyhook {
struct TrapInfo {
uint8_t* page_start;
uint8_t* page_end;
uint8_t* from;
uint8_t* to;
size_t len;
};

class TrapManager {
public:
static std::mutex mutex;
static std::unique_ptr<TrapManager> instance;

TrapManager() { m_trap_veh = AddVectoredExceptionHandler(1, trap_handler); }
~TrapManager() {
if (m_trap_veh != nullptr) {
RemoveVectoredExceptionHandler(m_trap_veh);
}
}

extern "C" {
NTSTATUS
NTAPI
NtGetNextThread(HANDLE ProcessHandle, HANDLE ThreadHandle, ACCESS_MASK DesiredAccess, ULONG HandleAttributes,
ULONG Flags, PHANDLE NewThreadHandle);
}
TrapInfo* find_trap(uint8_t* address) {
auto search = std::find_if(m_traps.begin(), m_traps.end(), [address](auto& trap) {
return address >= trap.second.from && address < trap.second.from + trap.second.len;
});

namespace safetyhook {
void execute_while_frozen(
const std::function<void()>& run_fn, const std::function<void(ThreadId, ThreadHandle, ThreadContext)>& visit_fn) {
// Freeze all threads.
int num_threads_frozen;
auto first_run = true;

do {
num_threads_frozen = 0;
HANDLE thread{};

while (true) {
HANDLE next_thread{};
const auto status = NtGetNextThread(GetCurrentProcess(), thread,
THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0,
0, &next_thread);

if (thread != nullptr) {
CloseHandle(thread);
}
if (search == m_traps.end()) {
return nullptr;
}

if (!NT_SUCCESS(status)) {
break;
}
return &search->second;
}

thread = next_thread;
TrapInfo* find_trap_page(uint8_t* address) {
auto search = std::find_if(m_traps.begin(), m_traps.end(),
[address](auto& trap) { return address >= trap.second.page_start && address < trap.second.page_end; });

const auto thread_id = GetThreadId(thread);
if (search == m_traps.end()) {
return nullptr;
}

if (thread_id == 0 || thread_id == GetCurrentThreadId()) {
continue;
}
return &search->second;
}

const auto suspend_count = SuspendThread(thread);
void add_trap(uint8_t* from, uint8_t* to, size_t len) {
m_traps.insert_or_assign(from, TrapInfo{.page_start = align_down(from, 0x1000),
.page_end = align_up(from + len, 0x1000),
.from = from,
.to = to,
.len = len});
}

if (suspend_count == static_cast<DWORD>(-1)) {
continue;
}
private:
std::map<uint8_t*, TrapInfo> m_traps;
PVOID m_trap_veh{};

// Check if the thread was already frozen. Only resume if the thread was already frozen, and it wasn't the
// first run of this freeze loop to account for threads that may have already been frozen for other reasons.
if (suspend_count != 0 && !first_run) {
ResumeThread(thread);
continue;
}
static LONG CALLBACK trap_handler(PEXCEPTION_POINTERS exp) {
auto exception_code = exp->ExceptionRecord->ExceptionCode;

CONTEXT thread_ctx{};
if (exception_code != EXCEPTION_ACCESS_VIOLATION) {
return EXCEPTION_CONTINUE_SEARCH;
}

thread_ctx.ContextFlags = CONTEXT_FULL;
std::scoped_lock lock{mutex};
auto* faulting_address = reinterpret_cast<uint8_t*>(exp->ExceptionRecord->ExceptionInformation[1]);
auto* trap = instance->find_trap(faulting_address);

if (GetThreadContext(thread, &thread_ctx) == FALSE) {
continue;
if (trap == nullptr) {
if (instance->find_trap_page(faulting_address) != nullptr) {
return EXCEPTION_CONTINUE_EXECUTION;
} else {
return EXCEPTION_CONTINUE_SEARCH;
}
}

if (visit_fn) {
visit_fn(static_cast<ThreadId>(thread_id), static_cast<ThreadHandle>(thread),
static_cast<ThreadContext>(&thread_ctx));
}
auto* ctx = exp->ContextRecord;

++num_threads_frozen;
for (size_t i = 0; i < trap->len; i++) {
fix_ip(ctx, trap->from + i, trap->to + i);
}

first_run = false;
} while (num_threads_frozen != 0);

// Run the function.
if (run_fn) {
run_fn();
return EXCEPTION_CONTINUE_EXECUTION;
}
};

// Resume all threads.
HANDLE thread{};
std::mutex TrapManager::mutex;
std::unique_ptr<TrapManager> TrapManager::instance;

while (true) {
HANDLE next_thread{};
const auto status = NtGetNextThread(GetCurrentProcess(), thread,
THREAD_QUERY_LIMITED_INFORMATION | THREAD_SUSPEND_RESUME | THREAD_GET_CONTEXT | THREAD_SET_CONTEXT, 0, 0,
&next_thread);
void find_me() {
}

if (thread != nullptr) {
CloseHandle(thread);
}
void trap_threads(uint8_t* from, uint8_t* to, size_t len, const std::function<void()>& run_fn) {
MEMORY_BASIC_INFORMATION find_me_mbi{};
MEMORY_BASIC_INFORMATION from_mbi{};
MEMORY_BASIC_INFORMATION to_mbi{};

if (!NT_SUCCESS(status)) {
break;
}
VirtualQuery(reinterpret_cast<void*>(find_me), &find_me_mbi, sizeof(find_me_mbi));
VirtualQuery(from, &from_mbi, sizeof(from_mbi));
VirtualQuery(to, &to_mbi, sizeof(to_mbi));

thread = next_thread;
auto new_protect = PAGE_READWRITE;

const auto thread_id = GetThreadId(thread);
if (from_mbi.AllocationBase == find_me_mbi.AllocationBase || to_mbi.AllocationBase == find_me_mbi.AllocationBase) {
new_protect = PAGE_EXECUTE_READWRITE;
}

if (thread_id == 0 || thread_id == GetCurrentThreadId()) {
continue;
}
std::scoped_lock lock{TrapManager::mutex};

ResumeThread(thread);
if (TrapManager::instance == nullptr) {
TrapManager::instance = std::make_unique<TrapManager>();
}

TrapManager::instance->add_trap(from, to, len);

DWORD from_protect;
DWORD to_protect;

VirtualProtect(from, len, new_protect, &from_protect);
VirtualProtect(to, len, new_protect, &to_protect);

if (run_fn) {
run_fn();
}

VirtualProtect(to, len, to_protect, &to_protect);
VirtualProtect(from, len, from_protect, &from_protect);
}

void fix_ip(ThreadContext thread_ctx, uint8_t* old_ip, uint8_t* new_ip) {
Expand Down
Loading

0 comments on commit ca071f3

Please sign in to comment.