diff --git a/viofs/svc/virtiofs.cpp b/viofs/svc/virtiofs.cpp index 2f245b56..bf782960 100644 --- a/viofs/svc/virtiofs.cpp +++ b/viofs/svc/virtiofs.cpp @@ -29,7 +29,7 @@ #pragma warning(push) // '' : structure was padded due to alignment specifier -#pragma warning(disable: 4324) +#pragma warning(disable: 4324) // nonstandard extension used : nameless struct / union #pragma warning(disable: 4201) // potentially uninitialized local variable 'Size' used @@ -46,6 +46,7 @@ #include #include #include +#include #include #include @@ -129,6 +130,8 @@ struct VIRTFS UINT32 OwnerUid{ 0 }; UINT32 OwnerGid{ 0 }; + BOOL IsRunningAsLocalSystem{ false }; + // Maps NodeId to its Nlookup counter. std::map LookupMap{}; @@ -486,39 +489,91 @@ static DWORD FindDeviceInterface(PHANDLE Device) return ERROR_SUCCESS; } -static VOID UpdateLocalUidGid(VIRTFS *VirtFs, DWORD SessionId) +static NTSTATUS FspToolGetTokenInfo(HANDLE Token, + TOKEN_INFORMATION_CLASS TokenInformationClass, PVOID *PInfo) { - PWSTR UserName = NULL; - LPUSER_INFO_3 UserInfo = NULL; - DWORD BytesReturned; - NET_API_STATUS Status; - BOOL Result; + PVOID Info{ NULL }; + DWORD Size; + NTSTATUS Result; - Result = WTSQuerySessionInformation(WTS_CURRENT_SERVER_HANDLE, SessionId, - WTSUserName, &UserName, &BytesReturned); + if (GetTokenInformation(Token, TokenInformationClass, 0, 0, &Size)) + { + Result = STATUS_INVALID_PARAMETER; + goto exit; + } - if (Result == TRUE) + if (ERROR_INSUFFICIENT_BUFFER != GetLastError()) { - Status = NetUserGetInfo(NULL, UserName, 3, (LPBYTE *)&UserInfo); + Result = FspNtStatusFromWin32(GetLastError()); + goto exit; + } - if (Status == NERR_Success) - { - // Use an account from local machine's user DB as the file's - // owner (0x30000 + RID). - VirtFs->LocalUid = UserInfo->usri3_user_id + 0x30000; - VirtFs->LocalGid = UserInfo->usri3_primary_group_id + 0x30000; - } + Info = HeapAlloc(GetProcessHeap(), 0, Size); + if (NULL == Info) + { + Result = STATUS_INSUFFICIENT_RESOURCES; + goto exit; + } + + if (!GetTokenInformation(Token, TokenInformationClass, Info, Size, &Size)) + { + Result = FspNtStatusFromWin32(GetLastError()); + goto exit; + } + + *PInfo = Info; + Result = STATUS_SUCCESS; + +exit: + if (!NT_SUCCESS(Result)) + SafeHeapFree(Info); + + return Result; +} + +static VOID UpdateLocalUidGid(VIRTFS *VirtFs, DWORD SessionId) +{ + NTSTATUS Result; + + HANDLE Token{ NULL }; + TOKEN_USER *Uinfo{ NULL }; + TOKEN_PRIMARY_GROUP *Ginfo{ NULL }; - if (UserInfo != NULL) + if (VirtFs->IsRunningAsLocalSystem) + { + if (!WTSQueryUserToken(SessionId, &Token)) { - NetApiBufferFree(UserInfo); + VirtFs->LocalUid = 0; + VirtFs->LocalGid = 0; + DBG("Failed to open a process token as LocalSystem; Error=%d", GetLastError()); + return; } } - - if (UserName != NULL) + else if (!OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &Token)) { - WTSFreeMemory(UserName); + VirtFs->LocalUid = 0; + VirtFs->LocalGid = 0; + DBG("Failed to open a process token; Error=%d", GetLastError()); + return; } + + Result = FspToolGetTokenInfo(Token, TokenUser, (PVOID *)&Uinfo); + if (!NT_SUCCESS(Result)) + goto exit; + + Result = FspToolGetTokenInfo(Token, TokenPrimaryGroup, (PVOID *)&Ginfo); + if (!NT_SUCCESS(Result)) + goto exit; + + Result = FspPosixMapSidToUid(Uinfo->User.Sid, &(VirtFs->LocalUid)); + Result = FspPosixMapSidToUid(Ginfo->PrimaryGroup, &(VirtFs->LocalGid)); + + DBG("Local UID=%u, local GID=%u", VirtFs->LocalUid, VirtFs->LocalGid); + +exit: + SafeHeapFree(Uinfo); + SafeHeapFree(Ginfo); + CloseHandle(Token); } static UINT32 PosixUnixModeToAttributes(VIRTFS *VirtFs, uint64_t nodeid, @@ -2535,6 +2590,7 @@ NTSTATUS VIRTFS::SubmitDestroyRequest() NTSTATUS VIRTFS::Start() { + HANDLE Token{ NULL }; NTSTATUS Status; DWORD SessionId; FILETIME FileTime; @@ -2549,8 +2605,24 @@ NTSTATUS VIRTFS::Start() SessionId = WTSGetActiveConsoleSessionId(); if (SessionId != 0xFFFFFFFF) { + // Will fail if we are not the LocalSystem user. + IsRunningAsLocalSystem = WTSQueryUserToken(SessionId, &Token); + + if (IsRunningAsLocalSystem) + { + CloseHandle(Token); + } + else + { + DBG("The service %s was not run as the LocalSytem user", FS_SERVICE_NAME); + } + UpdateLocalUidGid(this, SessionId); } + else + { + DBG("Failed to get SessionID!"); + } GetSystemTimeAsFileTime(&FileTime); @@ -2768,10 +2840,10 @@ static NTSTATUS DebugLogSet(const std::wstring& DebugLogFile) static NTSTATUS SvcStart(FSP_SERVICE* Service, ULONG argc, PWSTR* argv) { std::wstring DebugLogFile{}; - ULONG DebugFlags = 0; + ULONG DebugFlags{ 0 }; std::wstring MountPoint{ L"*" }; VIRTFS *VirtFs; - NTSTATUS Status = STATUS_SUCCESS; + NTSTATUS Status{ STATUS_SUCCESS }; DWORD Error; if (argc > 1)