From d94358100a27d0bdef94ac44d65f59da6244503a Mon Sep 17 00:00:00 2001 From: cursey Date: Tue, 14 May 2024 21:04:00 -0700 Subject: [PATCH] Feature/enable disable (#73) --- include/safetyhook/easy.hpp | 17 ++++-- include/safetyhook/inline_hook.hpp | 41 ++++++++++++-- include/safetyhook/mid_hook.hpp | 35 +++++++++--- src/easy.cpp | 8 +-- src/inline_hook.cpp | 88 ++++++++++++++++++++++-------- src/mid_hook.cpp | 30 ++++++++-- test/inline_hook.cpp | 47 ++++++++++++++++ test/mid_hook.cpp | 61 +++++++++++++++++++++ 8 files changed, 276 insertions(+), 51 deletions(-) diff --git a/include/safetyhook/easy.hpp b/include/safetyhook/easy.hpp index bb4bb28..080a2a8 100644 --- a/include/safetyhook/easy.hpp +++ b/include/safetyhook/easy.hpp @@ -12,29 +12,34 @@ namespace safetyhook { /// @brief Easy to use API for creating an InlineHook. /// @param target The address of the function to hook. /// @param destination The address of the destination function. +/// @param flags The flags to use. /// @return The InlineHook object. -[[nodiscard]] InlineHook create_inline(void* target, void* destination); +[[nodiscard]] InlineHook create_inline(void* target, void* destination, InlineHook::Flags flags = InlineHook::Default); /// @brief Easy to use API for creating an InlineHook. /// @param target The address of the function to hook. /// @param destination The address of the destination function. +/// @param flags The flags to use. /// @return The InlineHook object. -[[nodiscard]] InlineHook create_inline(FnPtr auto target, FnPtr auto destination) { - return create_inline(reinterpret_cast(target), reinterpret_cast(destination)); +[[nodiscard]] InlineHook create_inline( + FnPtr auto target, FnPtr auto destination, InlineHook::Flags flags = InlineHook::Default) { + return create_inline(reinterpret_cast(target), reinterpret_cast(destination), flags); } /// @brief Easy to use API for creating a MidHook. /// @param target the address of the function to hook. /// @param destination The destination function. +/// @param flags The flags to use. /// @return The MidHook object. -[[nodiscard]] MidHook create_mid(void* target, MidHookFn destination); +[[nodiscard]] MidHook create_mid(void* target, MidHookFn destination, MidHook::Flags = MidHook::Default); /// @brief Easy to use API for creating a MidHook. /// @param target the address of the function to hook. /// @param destination The destination function. +/// @param flags The flags to use. /// @return The MidHook object. -[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination) { - return create_mid(reinterpret_cast(target), destination); +[[nodiscard]] MidHook create_mid(FnPtr auto target, MidHookFn destination, MidHook::Flags flags = MidHook::Default) { + return create_mid(reinterpret_cast(target), destination, flags); } /// @brief Easy to use API for creating a VmtHook. diff --git a/include/safetyhook/inline_hook.hpp b/include/safetyhook/inline_hook.hpp index 0ddac53..cb7e2be 100644 --- a/include/safetyhook/inline_hook.hpp +++ b/include/safetyhook/inline_hook.hpp @@ -87,42 +87,54 @@ class InlineHook final { [[nodiscard]] static Error not_enough_space(uint8_t* ip) { return {.type = NOT_ENOUGH_SPACE, .ip = ip}; } }; + /// @brief Flags for InlineHook. + enum Flags : int { + Default = 0, ///< Default flags. + StartDisabled = 1 << 0, ///< Start the hook disabled. + }; + /// @brief Create an inline hook. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). - [[nodiscard]] static std::expected create(void* target, void* destination); + [[nodiscard]] static std::expected create( + void* target, void* destination, Flags flags = Default); /// @brief Create an inline hook. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). - [[nodiscard]] static std::expected create(FnPtr auto target, FnPtr auto destination) { - return create(reinterpret_cast(target), reinterpret_cast(destination)); + [[nodiscard]] static std::expected create( + FnPtr auto target, FnPtr auto destination, Flags flags = Default) { + return create(reinterpret_cast(target), reinterpret_cast(destination), flags); } /// @brief Create an inline hook with a given Allocator. /// @param allocator The allocator to use. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, void* target, void* destination); + const std::shared_ptr& allocator, void* target, void* destination, Flags flags = Default); /// @brief Create an inline hook with a given Allocator. /// @param allocator The allocator to use. /// @param target The address of the function to hook. /// @param destination The destination address. + /// @param flags The flags to use. /// @return The InlineHook or an InlineHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_inline). [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, FnPtr auto target, FnPtr auto destination) { - return create(allocator, reinterpret_cast(target), reinterpret_cast(destination)); + const std::shared_ptr& allocator, FnPtr auto target, FnPtr auto destination, Flags flags = Default) { + return create(allocator, reinterpret_cast(target), reinterpret_cast(destination), flags); } InlineHook() = default; @@ -285,15 +297,32 @@ class InlineHook final { return original()(args...); } + /// @brief Enable the hook. + [[nodiscard]] std::expected enable(); + + /// @brief Disable the hook. + [[nodiscard]] std::expected disable(); + + /// @brief Check if the hook is enabled. + [[nodiscard]] bool enabled() const { return m_enabled; } + private: friend class MidHook; + enum class Type { + Unset, + E9, + FF, + }; + uint8_t* m_target{}; uint8_t* m_destination{}; Allocation m_trampoline{}; std::vector m_original_bytes{}; uintptr_t m_trampoline_size{}; std::recursive_mutex m_mutex{}; + bool m_enabled{}; + Type m_type{Type::Unset}; std::expected setup( const std::shared_ptr& allocator, uint8_t* target, uint8_t* destination); diff --git a/include/safetyhook/mid_hook.hpp b/include/safetyhook/mid_hook.hpp index a5ad4f8..7b2fb79 100644 --- a/include/safetyhook/mid_hook.hpp +++ b/include/safetyhook/mid_hook.hpp @@ -52,43 +52,55 @@ class MidHook final { } }; + /// @brief Flags for MidHook. + enum Flags : int { + Default = 0, ///< Default flags. + StartDisabled = 1, ///< Start the hook disabled. + }; + /// @brief Creates a new MidHook object. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - [[nodiscard]] static std::expected create(void* target, MidHookFn destination_fn); + [[nodiscard]] static std::expected create( + void* target, MidHookFn destination_fn, Flags flags = Default); /// @brief Creates a new MidHook object. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note This will use the default global Allocator. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - [[nodiscard]] static std::expected create(FnPtr auto target, MidHookFn destination_fn) { - return create(reinterpret_cast(target), destination_fn); + [[nodiscard]] static std::expected create( + FnPtr auto target, MidHookFn destination_fn, Flags flags = Default) { + return create(reinterpret_cast(target), destination_fn, flags); } /// @brief Creates a new MidHook object with a given Allocator. /// @param allocator The Allocator to use. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, void* target, MidHookFn destination_fn); + const std::shared_ptr& allocator, void* target, MidHookFn destination_fn, Flags flags = Default); /// @brief Creates a new MidHook object with a given Allocator. /// @tparam T The type of the function to hook. /// @param allocator The Allocator to use. /// @param target The address of the function to hook. /// @param destination_fn The destination function. + /// @param flags The flags to use. /// @return The MidHook object or a MidHook::Error if an error occurred. /// @note If you don't care about error handling, use the easy API (safetyhook::create_mid). - [[nodiscard]] static std::expected create( - const std::shared_ptr& allocator, FnPtr auto target, MidHookFn destination_fn) { - return create(allocator, reinterpret_cast(target), destination_fn); + [[nodiscard]] static std::expected create(const std::shared_ptr& allocator, + FnPtr auto target, MidHookFn destination_fn, Flags flags = Default) { + return create(allocator, reinterpret_cast(target), destination_fn, flags); } MidHook() = default; @@ -123,6 +135,15 @@ class MidHook final { /// @return true if the hook is valid, false otherwise. explicit operator bool() const { return static_cast(m_stub); } + /// @brief Enable the hook. + [[nodiscard]] std::expected enable(); + + /// @brief Disable the hook. + [[nodiscard]] std::expected disable(); + + /// @brief Check if the hook is enabled. + [[nodiscard]] bool enabled() const { return m_hook.enabled(); } + private: InlineHook m_hook{}; uint8_t* m_target{}; diff --git a/src/easy.cpp b/src/easy.cpp index 65efb06..1f92b7c 100644 --- a/src/easy.cpp +++ b/src/easy.cpp @@ -1,16 +1,16 @@ #include "safetyhook/easy.hpp" namespace safetyhook { -InlineHook create_inline(void* target, void* destination) { - if (auto hook = InlineHook::create(target, destination)) { +InlineHook create_inline(void* target, void* destination, InlineHook::Flags flags) { + if (auto hook = InlineHook::create(target, destination, flags)) { return std::move(*hook); } else { return {}; } } -MidHook create_mid(void* target, MidHookFn destination) { - if (auto hook = MidHook::create(target, destination)) { +MidHook create_mid(void* target, MidHookFn destination, MidHook::Flags flags) { + if (auto hook = MidHook::create(target, destination, flags)) { return std::move(*hook); } else { return {}; diff --git a/src/inline_hook.cpp b/src/inline_hook.cpp index cd7d4cb..dea6a81 100644 --- a/src/inline_hook.cpp +++ b/src/inline_hook.cpp @@ -114,12 +114,12 @@ static bool decode(ZydisDecodedInstruction* ix, uint8_t* ip) { return ZYAN_SUCCESS(ZydisDecoderDecodeInstruction(&decoder, nullptr, ip, 15, ix)); } -std::expected InlineHook::create(void* target, void* destination) { - return create(Allocator::global(), target, destination); +std::expected InlineHook::create(void* target, void* destination, Flags flags) { + return create(Allocator::global(), target, destination, flags); } std::expected InlineHook::create( - const std::shared_ptr& allocator, void* target, void* destination) { + const std::shared_ptr& allocator, void* target, void* destination, Flags flags) { InlineHook hook{}; if (const auto setup_result = @@ -128,6 +128,12 @@ std::expected InlineHook::create( return std::unexpected{setup_result.error()}; } + if (!(flags & StartDisabled)) { + if (auto enable_result = hook.enable(); !enable_result) { + return std::unexpected{enable_result.error()}; + } + } + return hook; } @@ -146,10 +152,14 @@ InlineHook& InlineHook::operator=(InlineHook&& other) noexcept { m_trampoline = std::move(other.m_trampoline); m_trampoline_size = other.m_trampoline_size; m_original_bytes = std::move(other.m_original_bytes); + m_enabled = other.m_enabled; + m_type = other.m_type; other.m_target = nullptr; other.m_destination = nullptr; other.m_trampoline_size = 0; + other.m_enabled = false; + other.m_type = Type::Unset; } return *this; @@ -305,20 +315,7 @@ std::expected InlineHook::e9_hook(const std::shared_ptr } #endif - std::optional error; - - // jmp from original to trampoline. - trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &trampoline_epilogue, &error] { - if (auto result = emit_jmp_e9(m_target, reinterpret_cast(&trampoline_epilogue->jmp_to_destination), - m_original_bytes.size()); - !result) { - error = result.error(); - } - }); - - if (error) { - return std::unexpected{*error}; - } + m_type = Type::E9; return {}; } @@ -367,34 +364,77 @@ std::expected InlineHook::ff_hook(const std::shared_ptr return std::unexpected{result.error()}; } + m_type = Type::FF; + + return {}; +} +#endif + +std::expected InlineHook::enable() { + std::scoped_lock lock{m_mutex}; + + if (m_enabled) { + return {}; + } + std::optional error; // jmp from original to trampoline. trap_threads(m_target, m_trampoline.data(), m_original_bytes.size(), [this, &error] { - if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); - !result) { - error = result.error(); + if (m_type == Type::E9) { + auto trampoline_epilogue = reinterpret_cast( + m_trampoline.address() + m_trampoline_size - sizeof(TrampolineEpilogueE9)); + + if (auto result = emit_jmp_e9(m_target, + reinterpret_cast(&trampoline_epilogue->jmp_to_destination), m_original_bytes.size()); + !result) { + error = result.error(); + } } + +#if SAFETYHOOK_ARCH_X86_64 + if (m_type == Type::FF) { + if (auto result = emit_jmp_ff(m_target, m_destination, m_target + sizeof(JmpFF), m_original_bytes.size()); + !result) { + error = result.error(); + } + } +#endif }); if (error) { return std::unexpected{*error}; } + m_enabled = true; + return {}; } -#endif -void InlineHook::destroy() { +std::expected InlineHook::disable() { std::scoped_lock lock{m_mutex}; - if (!m_trampoline) { - return; + if (!m_enabled) { + return {}; } trap_threads(m_trampoline.data(), m_target, m_original_bytes.size(), [this] { std::copy(m_original_bytes.begin(), m_original_bytes.end(), m_target); }); + m_enabled = false; + + return {}; +} + +void InlineHook::destroy() { + [[maybe_unused]] auto disable_result = disable(); + + std::scoped_lock lock{m_mutex}; + + if (!m_trampoline) { + return; + } + m_trampoline.free(); } } // namespace safetyhook diff --git a/src/mid_hook.cpp b/src/mid_hook.cpp index 3284ff4..339bef8 100644 --- a/src/mid_hook.cpp +++ b/src/mid_hook.cpp @@ -68,12 +68,12 @@ constexpr std::array asm_data = {0xFF, 0x35, 0xA7, 0x00, 0x00, 0x0 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; #endif -std::expected MidHook::create(void* target, MidHookFn destination) { - return create(Allocator::global(), target, destination); +std::expected MidHook::create(void* target, MidHookFn destination, Flags flags) { + return create(Allocator::global(), target, destination, flags); } std::expected MidHook::create( - const std::shared_ptr& allocator, void* target, MidHookFn destination) { + const std::shared_ptr& allocator, void* target, MidHookFn destination, Flags flags) { MidHook hook{}; if (const auto setup_result = hook.setup(allocator, reinterpret_cast(target), destination); @@ -81,6 +81,12 @@ std::expected MidHook::create( return std::unexpected{setup_result.error()}; } + if (!(flags & StartDisabled)) { + if (auto enable_result = hook.enable(); !enable_result) { + return std::unexpected{enable_result.error()}; + } + } + return hook; } @@ -131,7 +137,7 @@ std::expected MidHook::setup( store(m_stub.data() + 0x59, m_stub.data() + m_stub.size() - 8); #endif - auto hook_result = InlineHook::create(allocator, m_target, m_stub.data()); + auto hook_result = InlineHook::create(allocator, m_target, m_stub.data(), InlineHook::StartDisabled); if (!hook_result) { m_stub.free(); @@ -148,4 +154,20 @@ std::expected MidHook::setup( return {}; } + +std::expected MidHook::enable() { + if (auto enable_result = m_hook.enable(); !enable_result) { + return std::unexpected{Error::bad_inline_hook(enable_result.error())}; + } + + return {}; +} + +std::expected MidHook::disable() { + if (auto disable_result = m_hook.disable(); !disable_result) { + return std::unexpected{Error::bad_inline_hook(disable_result.error())}; + } + + return {}; +} } // namespace safetyhook diff --git a/test/inline_hook.cpp b/test/inline_hook.cpp index 5a45af2..807c052 100644 --- a/test/inline_hook.cpp +++ b/test/inline_hook.cpp @@ -623,4 +623,51 @@ static suite<"inline hook"> inline_hook_tests = [] { expect(fn() == 42_i); }; + + "Function hook can be enable and disabled"_test = [] { + struct Target { + SAFETYHOOK_NOINLINE static int fn(int a) { + volatile int b = a; + return b * 2; + } + }; + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + + static SafetyHookInline hook; + + struct Hook { + static int fn(int a) { return hook.call(a + 1); } + }; + + auto hook0_result = SafetyHookInline::create(Target::fn, Hook::fn, SafetyHookInline::StartDisabled); + + expect(hook0_result.has_value()); + + hook = std::move(*hook0_result); + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + + expect(hook.enable().has_value()); + + expect(Target::fn(1) == 4_i); + expect(Target::fn(2) == 6_i); + expect(Target::fn(3) == 8_i); + + expect(hook.disable().has_value()); + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + + hook.reset(); + + expect(Target::fn(1) == 2_i); + expect(Target::fn(2) == 4_i); + expect(Target::fn(3) == 6_i); + }; }; \ No newline at end of file diff --git a/test/mid_hook.cpp b/test/mid_hook.cpp index 8560f3f..5441d21 100644 --- a/test/mid_hook.cpp +++ b/test/mid_hook.cpp @@ -71,4 +71,65 @@ static suite<"mid hook"> mid_hook_tests = [] { expect(Target::add_42(2.0f) == 2.42_f); }; #endif + + "Mid hook enable and disable"_test = [] { + struct Target { + SAFETYHOOK_NOINLINE static int SAFETYHOOK_FASTCALL add_42(int a) { + volatile int b = a; + return b + 42; + } + }; + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + + static SafetyHookMid hook; + + struct Hook { + static void add_42(SafetyHookContext& ctx) { +#if SAFETYHOOK_OS_WINDOWS +#if SAFETYHOOK_ARCH_X86_64 + ctx.rcx = 1337 - 42; +#elif SAFETYHOOK_ARCH_X86_32 + ctx.ecx = 1337 - 42; +#endif +#elif SAFETYHOOK_OS_LINUX +#if SAFETYHOOK_ARCH_X86_64 + ctx.rdi = 1337 - 42; +#elif SAFETYHOOK_ARCH_X86_32 + ctx.edi = 1337 - 42; +#endif +#endif + } + }; + + auto hook_result = SafetyHookMid::create(Target::add_42, Hook::add_42, SafetyHookMid::StartDisabled); + + expect(hook_result.has_value()); + + hook = std::move(*hook_result); + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + + expect(hook.enable().has_value()); + + expect(Target::add_42(1) == 1337_i); + expect(Target::add_42(2) == 1337_i); + expect(Target::add_42(3) == 1337_i); + + expect(hook.disable().has_value()); + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + + hook.reset(); + + expect(Target::add_42(0) == 42_i); + expect(Target::add_42(1) == 43_i); + expect(Target::add_42(2) == 44_i); + }; }; \ No newline at end of file