diff --git a/src/driver/hypervisor.cpp b/src/driver/hypervisor.cpp index e93e88d..1cfef86 100644 --- a/src/driver/hypervisor.cpp +++ b/src/driver/hypervisor.cpp @@ -12,19 +12,13 @@ #define DPL_USER 3 #define DPL_SYSTEM 0 -typedef struct _KPROCESS -{ - DISPATCHER_HEADER Header; - LIST_ENTRY ProfileListHead; - ULONG DirectoryTableBase; - // ... -} KPROCESS, *PKPROCESS; - typedef struct _EPROCESS { - KPROCESS Pcb; - // ... -} EPROCESS, *PEPROCESS; + DISPATCHER_HEADER Header; + LIST_ENTRY ProfileListHead; + ULONG_PTR DirectoryTableBase; + UCHAR Data[1]; +} EPROCESS, * PEPROCESS; namespace { @@ -500,7 +494,7 @@ void inject_invalid_opcode(vmx::guest_context& guest_context) cr3 get_current_process_cr3() { cr3 guest_cr3{}; - guest_cr3.flags = PsGetCurrentProcess()->Pcb.DirectoryTableBase; + guest_cr3.flags = PsGetCurrentProcess()->DirectoryTableBase; return guest_cr3; } @@ -519,6 +513,52 @@ bool is_mem_equal(const uint8_t* ptr, const uint8_t (&array)[Length]) return true; } +enum class syscall_state +{ + is_sysret, + is_syscall, + none, +}; + +syscall_state get_syscall_state(const vmx::guest_context& guest_context) +{ + cr3 orignal_cr3{}; + orignal_cr3.flags = __readcr3(); + + const auto _ = utils::finally([&] + { + __writecr3(orignal_cr3.flags); + }); + + constexpr auto PCID_NONE = 0x000; + constexpr auto PCID_MASK = 0x003; + + const auto guest_cr3 = read_vmx(VMCS_GUEST_CR3); + if ((guest_cr3 & PCID_MASK) != PCID_NONE) + { + const auto process_cr3 = get_current_process_cr3(); + __writecr3(process_cr3.flags); + } + + // TODO: Check for potential page fault + const auto* rip = reinterpret_cast(guest_context.guest_rip); + + constexpr uint8_t syscall_bytes[] = { 0x0F, 0x05 }; + constexpr uint8_t sysret_bytes[] = {0x48, 0x0F, 0x07}; + + if (is_mem_equal(rip, syscall_bytes)) + { + return syscall_state::is_syscall; + } + + if (is_mem_equal(rip, sysret_bytes)) + { + return syscall_state::is_sysret; + } + + return syscall_state::none; +} + void vmx_handle_exception(vmx::guest_context& guest_context) { vmexit_interrupt_information interrupt{}; @@ -533,26 +573,14 @@ void vmx_handle_exception(vmx::guest_context& guest_context) if (interrupt.vector == invalid_opcode) { - auto* rip = reinterpret_cast(guest_context.guest_rip); - - cr3 orignal_cr3{}; - orignal_cr3.flags = __readcr3(); - - const auto guest_cr3 = get_current_process_cr3(); - - __writecr3(guest_cr3.flags); - - // TODO: Check for potential page fault - - constexpr uint8_t sysret_bytes[] = {0x48, 0x05, 0x07}; - constexpr uint8_t syscall_bytes[] = {0x0F, 0x05}; + const auto state = get_syscall_state(guest_context); - if (is_mem_equal(rip, syscall_bytes)) + if (state == syscall_state::is_syscall) { guest_context.increment_rip = false; rflags rflags{}; - rflags.flags = read_vmx(VMCS_GUEST_RFLAGS); + rflags.flags = read_vmx(VMCS_GUEST_RFLAGS); const auto instruction_length = read_vmx(VMCS_VMEXIT_INSTRUCTION_LENGTH); @@ -587,7 +615,7 @@ void vmx_handle_exception(vmx::guest_context& guest_context) __vmx_vmwrite(VMCS_GUEST_SS_ACCESS_RIGHTS, gdt_entry.access_rights.flags); __vmx_vmwrite(VMCS_GUEST_SS_BASE, gdt_entry.base); } - else if (is_mem_equal(rip, sysret_bytes)) + else if (state == syscall_state::is_sysret) { guest_context.increment_rip = false; @@ -952,7 +980,8 @@ void setup_vmcs_for_cpu(vmx::state& vm_state) __vmx_vmwrite(VMCS_GUEST_DEBUGCTL, state->debug_control); __vmx_vmwrite(VMCS_GUEST_DR7, state->kernel_dr7); - const auto stack_pointer = reinterpret_cast(vm_state.stack_buffer) + KERNEL_STACK_SIZE - sizeof(CONTEXT); + const auto stack_pointer = reinterpret_cast(vm_state.stack_buffer) + KERNEL_STACK_SIZE - sizeof( + CONTEXT); __vmx_vmwrite(VMCS_GUEST_RSP, stack_pointer); __vmx_vmwrite(VMCS_GUEST_RIP, reinterpret_cast(vm_launch));