Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(term): add cross-platform console state-aware support #29

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions term/term.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package term

// State contains platform-specific state of a terminal.
type State struct {
state
}

// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd uintptr) bool {
return isTerminal(fd)
}

// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd uintptr) (*State, error) {
return makeRaw(fd)
}

// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd uintptr) (*State, error) {
return getState(fd)
}

// SetState sets the given state of the terminal.
func SetState(fd uintptr, state *State) error {
return setState(fd, state)
}

// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd uintptr, oldState *State) error {
return restore(fd, oldState)
}

// GetSize returns the visible dimensions of the given terminal.
//
// These dimensions don't include any scrollback buffer height.
func GetSize(fd uintptr) (width, height int, err error) {
return getSize(fd)
}

// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd uintptr) ([]byte, error) {
return readPassword(fd)
}
39 changes: 39 additions & 0 deletions term/term_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !netbsd && !openbsd && !zos && !windows && !solaris && !plan9
// +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!zos,!windows,!solaris,!plan9

package term

import (
"fmt"
"runtime"
)

type state struct{}

func isTerminal(fd uintptr) bool {
return false
}

func makeRaw(fd uintptr) (*State, error) {
return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func getState(fd uintptr) (*State, error) {
return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func restore(fd uintptr, state *State) error {
return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func getSize(fd uintptr) (width, height int, err error) {
return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func setState(fd uintptr, state *State) error {
return fmt.Errorf("terminal: SetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

func readPassword(fd uintptr) ([]byte, error) {
return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
37 changes: 37 additions & 0 deletions term/term_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package term_test

import (
"os"
"runtime"
"testing"

"github.com/charmbracelet/x/term"
)

func TestIsTerminalTempFile(t *testing.T) {
file, err := os.CreateTemp("", "TestIsTerminalTempFile")
if err != nil {
t.Fatal(err)
}
defer os.Remove(file.Name())
defer file.Close()

if term.IsTerminal(file.Fd()) {
t.Fatalf("IsTerminal unexpectedly returned true for temporary file %s", file.Name())
}
}

func TestIsTerminalTerm(t *testing.T) {
if runtime.GOOS != "linux" {
t.Skipf("unknown terminal path for GOOS %v", runtime.GOOS)
}
file, err := os.OpenFile("/dev/ptmx", os.O_RDWR, 0)
if err != nil {
t.Fatal(err)
}
defer file.Close()

if !term.IsTerminal(file.Fd()) {
t.Fatalf("IsTerminal unexpectedly returned false for terminal file %s", file.Name())
}
}
96 changes: 96 additions & 0 deletions term/term_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris zos

package term

import (
"golang.org/x/sys/unix"
)

type state struct {
unix.Termios
}

func isTerminal(fd uintptr) bool {
_, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
return err == nil
}

func makeRaw(fd uintptr) (*State, error) {
termios, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
if err != nil {
return nil, err
}

oldState := State{state{Termios: *termios}}

// This attempts to replicate the behaviour documented for cfmakeraw in
// the termios(3) manpage.
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(int(fd), ioctlWriteTermios, termios); err != nil {
return nil, err
}

return &oldState, nil
}

func setState(fd uintptr, state *State) error {
var termios *unix.Termios
if state != nil {
termios = &state.Termios
}
return unix.IoctlSetTermios(int(fd), ioctlWriteTermios, termios)
}

func getState(fd uintptr) (*State, error) {
termios, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
if err != nil {
return nil, err
}

return &State{state{Termios: *termios}}, nil
}

func restore(fd uintptr, state *State) error {
return unix.IoctlSetTermios(int(fd), ioctlWriteTermios, &state.Termios)
}

func getSize(fd uintptr) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(int(fd), unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

// passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int

func (r passwordReader) Read(buf []byte) (int, error) {
return unix.Read(int(r), buf)
}

func readPassword(fd uintptr) ([]byte, error) {
termios, err := unix.IoctlGetTermios(int(fd), ioctlReadTermios)
if err != nil {
return nil, err
}

newState := *termios
newState.Lflag &^= unix.ECHO
newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= unix.ICRNL
if err := unix.IoctlSetTermios(int(fd), ioctlWriteTermios, &newState); err != nil {
return nil, err
}

defer unix.IoctlSetTermios(int(fd), ioctlWriteTermios, termios)

return readPasswordLine(passwordReader(fd))
}
11 changes: 11 additions & 0 deletions term/term_unix_bsd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
// +build darwin dragonfly freebsd netbsd openbsd

package term

import "golang.org/x/sys/unix"

const (
ioctlReadTermios = unix.TIOCGETA
ioctlWriteTermios = unix.TIOCSETA
)
11 changes: 11 additions & 0 deletions term/term_unix_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//go:build aix || linux || solaris || zos
// +build aix linux solaris zos

package term

import "golang.org/x/sys/unix"

const (
ioctlReadTermios = unix.TCGETS
ioctlWriteTermios = unix.TCSETS
)
86 changes: 86 additions & 0 deletions term/term_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
//go:build windows
// +build windows

package term

import (
"os"

"golang.org/x/sys/windows"
)

type state struct {
Mode uint32
}

func isTerminal(fd uintptr) bool {
var st uint32
err := windows.GetConsoleMode(windows.Handle(fd), &st)
return err == nil
}

func makeRaw(fd uintptr) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil {
return nil, err
}
return &State{state{st}}, nil
}

func setState(fd uintptr, state *State) error {
var mode uint32
if state != nil {
mode = state.Mode
}
return windows.SetConsoleMode(windows.Handle(fd), mode)
}

func getState(fd uintptr) (*State, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
return &State{state{st}}, nil
}

func restore(fd uintptr, state *State) error {
return windows.SetConsoleMode(windows.Handle(fd), state.Mode)
}

func getSize(fd uintptr) (width, height int, err error) {
var info windows.ConsoleScreenBufferInfo
if err := windows.GetConsoleScreenBufferInfo(windows.Handle(fd), &info); err != nil {
return 0, 0, err
}
return int(info.Window.Right - info.Window.Left + 1), int(info.Window.Bottom - info.Window.Top + 1), nil
}

func readPassword(fd uintptr) ([]byte, error) {
var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
old := st

st &^= (windows.ENABLE_ECHO_INPUT | windows.ENABLE_LINE_INPUT)
st |= (windows.ENABLE_PROCESSED_OUTPUT | windows.ENABLE_PROCESSED_INPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), st); err != nil {
return nil, err
}

defer windows.SetConsoleMode(windows.Handle(fd), old)

var h windows.Handle
p, _ := windows.GetCurrentProcess()
if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}

f := os.NewFile(uintptr(h), "stdin")
defer f.Close()
return readPasswordLine(f)
}
47 changes: 47 additions & 0 deletions term/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package term

import (
"io"
"runtime"
)

// readPasswordLine reads from reader until it finds \n or io.EOF.
// The slice returned does not include the \n.
// readPasswordLine also ignores any \r it finds.
// Windows uses \r as end of line. So, on Windows, readPasswordLine
// reads until it finds \r and ignores any \n it finds during processing.
func readPasswordLine(reader io.Reader) ([]byte, error) {
var buf [1]byte
var ret []byte

for {
n, err := reader.Read(buf[:])
if n > 0 {
switch buf[0] {
case '\b':
if len(ret) > 0 {
ret = ret[:len(ret)-1]
}
case '\n':
if runtime.GOOS != "windows" {
return ret, nil
}
// otherwise ignore \n
case '\r':
if runtime.GOOS == "windows" {
return ret, nil
}
// otherwise ignore \r
default:
ret = append(ret, buf[0])
}
continue
}
if err != nil {
if err == io.EOF && len(ret) > 0 {
return ret, nil
}
return ret, err
}
}
}
Loading