Skip to content

Commit

Permalink
On Windows, do not degenerate working directory through restarts caus…
Browse files Browse the repository at this point in the history
…ing command execution failure when APPDATA is a drive name-based path which maps to a UNC share
  • Loading branch information
MMulthaupt committed Apr 21, 2021
1 parent 57f9b20 commit 1ce3d7a
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 14 deletions.
14 changes: 1 addition & 13 deletions cmd/launcher/gui/gui_windows.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gui

import (
"unicode/utf16"
"unsafe"

"github.com/setlog/trivrost/pkg/system"
Expand Down Expand Up @@ -145,7 +144,7 @@ func applyWindowStyle(handle uintptr) {
}

func loadIcons() {
binaryPath := goStringToConstantUTF16WinApiString(system.GetBinaryPath())
binaryPath := C.LPCWSTR(system.StringToUTF16UnmanagedString(system.GetBinaryPath()))
extractedIconCount := C.loadIcons(binaryPath)
didLoadIcons = true
C.free(unsafe.Pointer(binaryPath))
Expand All @@ -160,17 +159,6 @@ func loadIcons() {
}
}

func goStringToConstantUTF16WinApiString(s string) C.LPCWSTR {
utf16String := utf16.Encode([]rune(s))
utf16StringPointer := (*uint16)(C.calloc(C.size_t(len(utf16String)+1), C.size_t(unsafe.Sizeof(uint16(0)))))
currentCharPointer := utf16StringPointer
for _, c := range utf16String {
*currentCharPointer = c
currentCharPointer = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(currentCharPointer)) + unsafe.Sizeof(uint16(0))))
}
return (C.LPCWSTR)(unsafe.Pointer(utf16StringPointer))
}

func setProgressState(s progressState) {
C.setProgressBarState(C.ULONG_PTR(panelDownloadStatus.barTotalProgress.Handle()), C.int(s))
}
2 changes: 1 addition & 1 deletion cmd/launcher/launcher/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func IsInstanceInstalledInSystemMode() bool {

// IsInstanceInstalledForCurrentUser returns true iff the launcher's desired path under user files is occupied by the program running this code.
func IsInstanceInstalledForCurrentUser() bool {
return system.GetProgramPath() == getTargetProgramPath()
return system.FilepathsEquivalent(system.GetProgramPath(), getTargetProgramPath())
}

// IsInstallationOutdated returns true if the time the installed launcher binary was built
Expand Down
4 changes: 4 additions & 0 deletions pkg/system/api_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ func showLocalFileInFileManager(path string) error {
func isProcessRunning(p *os.Process) bool {
return p.Signal(unix.Signal(0)) == nil
}

func universalPathName(p string) (string, error) {
return p, nil
}
131 changes: 131 additions & 0 deletions pkg/system/api_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ import (
"os/exec"
"runtime"
"strings"
"unicode/utf16"
"unsafe"

"golang.org/x/sys/windows"
)

// #cgo LDFLAGS: -lMpr
//#include <windows.h>
//#include <winnetwk.h>
import "C"

func mustDetectArchitecture() {
Expand Down Expand Up @@ -59,3 +63,130 @@ func isProcessRunning(p *os.Process) bool {
result := C.GetExitCodeProcess(handle, &lpExitCode)
return (result != 0) && (lpExitCode == C.STILL_ACTIVE)
}

func universalPathName(p string) (string, error) {
s, lpBufferSize, err := universalPathNameWithBufferSize(p, 1000)
if err != nil && err.(*universalNameRetrievalError).ErrorType() == errorMoreData {
s, _, err = universalPathNameWithBufferSize(s, lpBufferSize)
}
if err != nil {
return p, err
}
return s, err
}

func universalPathNameWithBufferSize(p string, lpBufferSizeUse C.DWORD) (universalPath string, lpBufferSize C.DWORD, err error) {
cp := C.LPCWSTR(StringToUTF16UnmanagedString(p))
defer C.free(unsafe.Pointer(cp))

// The possible data written to infoStruct (we request a UNIVERSAL_NAME_INFO below) not only consists of the struct, but also of the data (strings)
// pointed to by pointer-members within the struct. That's why this allocation needs to be much larger than just large enough to hold the struct itself.
infoStruct := C.LPVOID(C.calloc(C.size_t(lpBufferSizeUse), 1))
defer C.free(unsafe.Pointer(infoStruct))

lpBufferSize = lpBufferSizeUse
errorCode := C.WNetGetUniversalNameW(cp, C.UNIVERSAL_NAME_INFO_LEVEL, infoStruct, &lpBufferSize)
err = getErrorOfWNetGetUniversalNameW(errorCode)
if err == nil {
lpUniversalName := unsafe.Pointer(*(*C.LPWSTR)(infoStruct))
universalPath = UTF16StringToString(lpUniversalName)
}
return universalPath, lpBufferSize, err
}

func getErrorOfWNetGetUniversalNameW(returnCode C.DWORD) error {
if returnCode == C.NO_ERROR {
return nil
}
if returnCode == C.ERROR_BAD_DEVICE {
return &universalNameRetrievalError{errorType: errorBadDevice,
message: `the string pointed to by the lpLocalPath parameter is invalid`}
}
if returnCode == C.ERROR_CONNECTION_UNAVAIL {
return &universalNameRetrievalError{errorType: errorConnectionUnavailable,
message: `there is no current connection to the remote device, but there is a remembered (persistent) connection to it`}
}
if returnCode == C.ERROR_EXTENDED_ERROR {
errorMessage, providerName, err := getLastWNetError()
if err != nil {
return &universalNameRetrievalError{errorType: errorExtendedError,
message: `a network-specific error occurred; getting extended error information failed: ` + err.Error()}
}
return &universalNameRetrievalError{errorType: errorExtendedError,
message: `a network-specific error occurred; Network provider "` + providerName + `" reports: ` + errorMessage}
}
if returnCode == C.ERROR_MORE_DATA {
return &universalNameRetrievalError{errorType: errorMoreData,
message: `despite trying to query with the requested buffer size, the buffer pointed to by the lpBuffer parameter was too small`}
}
if returnCode == C.ERROR_NOT_SUPPORTED {
return &universalNameRetrievalError{errorType: errorNotSupported,
message: `the dwInfoLevel parameter is set to UNIVERSAL_NAME_INFO_LEVEL, but the network provider does not support UNC names. (None of the network providers support this function)`}
}
if returnCode == C.ERROR_NO_NET_OR_BAD_PATH {
return &universalNameRetrievalError{errorType: errorNoNetOrBadPath,
message: `none of the network providers recognize the local name as having a connection. However, the network is not available for at least one provider to whom the connection may belong`}
}
if returnCode == C.ERROR_NO_NETWORK {
return &universalNameRetrievalError{errorType: errorNoNetwork,
message: `the network is unavailable`}
}
if returnCode == C.ERROR_NOT_CONNECTED {
return &universalNameRetrievalError{errorType: errorNotConnected,
message: `the device specified by the path is not redirected`}
}
return &universalNameRetrievalError{errorType: errorUndocumented,
message: fmt.Sprintf(`undocumented error code %d`, returnCode)}
}

func getLastWNetError() (errorMessage, providerName string, err error) {
var lpError C.DWORD

const errorBufferSize = 5000
const nErrorBufSize C.DWORD = errorBufferSize
lpErrorBuf := (C.LPWSTR)(C.calloc(C.size_t(errorBufferSize+1), C.size_t(unsafe.Sizeof(uint16(0)))))
defer C.free(unsafe.Pointer(lpErrorBuf))

const nameBufferSize = 1000
const nNameBufSize C.DWORD = nameBufferSize
lpNameBuf := (C.LPWSTR)(C.calloc(C.size_t(nameBufferSize+1), C.size_t(unsafe.Sizeof(uint16(0)))))
defer C.free(unsafe.Pointer(lpNameBuf))

returnCode := C.WNetGetLastErrorW(&lpError, lpErrorBuf, nErrorBufSize, lpNameBuf, nNameBufSize)
if returnCode == C.NO_ERROR {
return UTF16StringToString(unsafe.Pointer(lpErrorBuf)), UTF16StringToString(unsafe.Pointer(lpNameBuf)), nil
}
if returnCode == C.ERROR_INVALID_ADDRESS {
return "", "", fmt.Errorf("could not get last WNet error: ERROR_INVALID_ADDRESS")
}
return "", "", fmt.Errorf("could not get last WNet error: undocumented extended error code %d", returnCode)
}

// StringToUTF16UnmanagedString returns an unmanaged, null-terminated UTF16 string for given string.
// The caller is responsible for freeing the returned pointer.
func StringToUTF16UnmanagedString(s string) unsafe.Pointer {
utf16String := utf16.Encode([]rune(s))
utf16StringPointer := (*uint16)(C.calloc(C.size_t(len(utf16String)+1), C.size_t(unsafe.Sizeof(uint16(0)))))
currentCharPointer := utf16StringPointer
for _, c := range utf16String {
*currentCharPointer = c
currentCharPointer = (*uint16)(unsafe.Pointer(uintptr(unsafe.Pointer(currentCharPointer)) + unsafe.Sizeof(uint16(0))))
}
return unsafe.Pointer(utf16StringPointer)
}

// UTF16StringToString returns a string for a given null-terminated UTF16 string.
// This function does not call free on the parameter.
func UTF16StringToString(lpwString unsafe.Pointer) string {
ptr := (*uint16)(lpwString)
data := make([]uint16, 0, 0)
for {
if *ptr == 0 {
break
}
data = append(data, *ptr)
ptr = (*uint16)(unsafe.Pointer(((uintptr)(unsafe.Pointer(ptr))) + unsafe.Sizeof(uint16(0))))
}
s := utf16.Decode(data)
return string(s)
}
22 changes: 22 additions & 0 deletions pkg/system/file_system_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,25 @@ func CleanUpFileOperation(file *os.File, returnError *error) {
}
}
}

// FilepathsEquivalent returns true if the filepaths a and b are semantically equivalent (exceptions may exist).
func FilepathsEquivalent(a, b string) bool {
a = filepath.Clean(a)
b = filepath.Clean(b)
if a == b {
return true
}
aResolved, aErr := universalPathName(a)
if aErr != nil && aErr.(*universalNameRetrievalError).ErrorType() != errorNotConnected {
log.Warnf(`could not determine UNC path for filepath "%s": %v\n`, a, aErr)
}
bResolved, bErr := universalPathName(b)
if bErr != nil && bErr.(*universalNameRetrievalError).ErrorType() != errorNotConnected {
log.Warnf(`could not determine UNC path for filepath "%s": %v\n`, b, bErr)
}
aResolved = filepath.Clean(aResolved)
bResolved = filepath.Clean(bResolved)
return (a == bResolved && bErr == nil) ||
(b == aResolved && aErr == nil) ||
(aResolved == bResolved && aErr == nil && bErr == nil)
}
30 changes: 30 additions & 0 deletions pkg/system/universal_path_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package system

type universalNameRetrievalErrorType int

const errorBadDevice universalNameRetrievalErrorType = 1
const errorConnectionUnavailable universalNameRetrievalErrorType = 2
const errorExtendedError universalNameRetrievalErrorType = 3
const errorMoreData universalNameRetrievalErrorType = 4
const errorNotSupported universalNameRetrievalErrorType = 5
const errorNoNetOrBadPath universalNameRetrievalErrorType = 6
const errorNoNetwork universalNameRetrievalErrorType = 7
const errorNotConnected universalNameRetrievalErrorType = 8
const errorUndocumented universalNameRetrievalErrorType = 9

type universalNameRetrievalError struct {
message string
errorType universalNameRetrievalErrorType
}

func (err *universalNameRetrievalError) Error() string {
if err == nil {
return "<nil>"
}
return err.message
}

// ErrorType returns the corresponsing WINAPI error type of the WNetGetUniversalNameW function call which generated the error.
func (err *universalNameRetrievalError) ErrorType() universalNameRetrievalErrorType {
return err.errorType
}

0 comments on commit 1ce3d7a

Please sign in to comment.