Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanbagabas committed Feb 14, 2024
1 parent 42ffa6f commit 7f3fb42
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 151 deletions.
221 changes: 176 additions & 45 deletions exp/term/input/ansi/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@ package ansi
import (
"bufio"
"fmt"
"io"
"os"
"unicode/utf8"

"github.com/charmbracelet/x/exp/term/ansi"
"github.com/charmbracelet/x/exp/term/input"
)

// ErrUnsupportedReader is returned when the reader is not a *bufio.Reader.
var ErrUnsupportedReader = fmt.Errorf("unsupported reader")

// Flags to control the behavior of the driver.
const (
Fctrlsp = 1 << iota // treat NUL as ctrl+space, otherwise ctrl+@
Expand All @@ -25,172 +30,298 @@ const (

// driver represents a terminal ANSI input driver.
type driver struct {
r *bufio.Reader
table map[string]input.Key
rd *bufio.Reader
term string
flags int
}

var _ input.Driver = &driver{}

// NewDriver returns a new ANSI input driver.
// This driver uses ANSI control codes compatible with VT100/VT200 terminals.
func NewDriver(r *bufio.Reader, flags int) input.Driver {
func NewDriver(r io.Reader, term string, flags int) input.Driver {
if r == nil {
r = os.Stdin
}
if term == "" {
term = os.Getenv("TERM")
}
d := &driver{
r: bufio.NewReaderSize(r, 256),
rd: bufio.NewReaderSize(r, 256),
flags: flags,
term: term,
}
// Populate the key sequences table.
d.registerKeys(flags)
return d
}

// ReadInput implements input.Driver.
func (d *driver) ReadInput() (int, input.Event, error) {
n, e, err := d.PeekInput()
func (d *driver) ReadInput() ([]input.Event, error) {
nb, ne, err := d.peekInput()
if err != nil {
return n, e, err
return nil, err
}

// Consume the event
p := make([]byte, n)
if _, err := d.r.Read(p); err != nil {
return 0, nil, err
if _, err := d.rd.Discard(nb); err != nil {
return nil, err
}

return n, e, nil
return ne, nil
}

const esc = string(ansi.ESC)

// PeekInput implements input.Driver.
func (d *driver) PeekInput() (int, input.Event, error) {
p, err := d.r.Peek(1)
// FIXME: This method is not implemented correctly.
// Should block when n is not filled from a stream or return an error.
func (d *driver) PeekInput() ([]input.Event, error) {
_, ne, err := d.peekInput()
if err != nil {
return nil, err
}

return ne, err
}

func (d *driver) peekInput() (int, []input.Event, error) {
ev := make([]input.Event, 0)
p, err := d.rd.Peek(1)
if err != nil {
return 0, nil, err
}

// The number of bytes buffered.
nb := d.r.Buffered()
bufferedBytes := d.rd.Buffered()
// Peek more bytes if needed.
if nb > len(p) {
p, err = d.r.Peek(nb)
if bufferedBytes > len(p) {
p, err = d.rd.Peek(bufferedBytes)
if err != nil {
return 0, nil, err
}
}

for {
peekedBytes := 0
i := 0 // index of the current byte

addEvent := func(n int, e input.Event) {
peekedBytes += n
i += n
ev = append(ev, e)
}

for i < len(p) {
var alt bool
i := 0
b := p[i]

begin:
switch b {
case ansi.ESC:
if nb == 1 {
return 1, input.KeyEvent(d.table[esc]), nil
if bufferedBytes == 1 {
// Special case for Esc
addEvent(1, input.KeyEvent(d.table[esc]))
continue
}

if i+1 >= len(p) {
// Not enough bytes to peek
break
}

i++ // i = 1
i++ // we know there's at least one more byte
peekedBytes++
switch p[i] {
case 'O': // Esc-prefixed SS3
return d.parseSs3(i+1, p)
nb, e, err := d.parseSs3(i, p, alt)
if err != nil {
return peekedBytes, ev, err
}

addEvent(nb, e)
continue
case 'P': // Esc-prefixed DCS
case '[': // Esc-prefixed CSI
return d.parseCsi(i+1, p)
nb, e, err := d.parseCsi(i, p, alt)
if err != nil {
return peekedBytes, ev, err
}

addEvent(nb, e)
continue
case ']': // Esc-prefixed OSC
nb, e, err := d.parseOsc(i, p, alt)
if err != nil {
return peekedBytes, ev, err
}

addEvent(nb, e)
continue
}

alt = true
b = p[i]

goto begin
case ansi.SS3:
return d.parseSs3(i+1, p)
nb, e, err := d.parseSs3(i, p, alt)
if err != nil {
return peekedBytes, ev, err
}

addEvent(nb, e)
continue
case ansi.DCS:
case ansi.CSI:
return d.parseCsi(i+1, p)
nb, e, err := d.parseCsi(i, p, alt)
if err != nil {
return peekedBytes, ev, err
}

addEvent(nb, e)
continue
case ansi.OSC:
nb, e, err := d.parseOsc(i, p, alt)
if err != nil {
return peekedBytes, ev, err
}

addEvent(nb, e)
continue
}

// Single byte control code or printable ASCII/UTF-8
if b <= ansi.US || b == ansi.DEL || b == ansi.SP {
k := input.KeyEvent(d.table[string(b)])
l := 1
nb := 1
if alt {
k.Mod |= input.Alt
l++
}
return l, k, nil
addEvent(nb, k)
continue
} else if utf8.RuneStart(b) { // Printable ASCII/UTF-8
ul := utf8ByteLen(b)
if ul == -1 || ul > nb {
return 0, nil, fmt.Errorf("invalid UTF-8 sequence: %x", p)
nb := utf8ByteLen(b)
if nb == -1 || nb > bufferedBytes {
return peekedBytes, ev, fmt.Errorf("invalid UTF-8 sequence: %x", p)
}

r := rune(b)
if ul > 1 {
r, _ = utf8.DecodeRune(p[i : i+ul])
if nb > 1 {
r, _ = utf8.DecodeRune(p[i : i+nb])
}

k := input.KeyEvent{Rune: r}
if alt {
k.Mod |= input.Alt
ul++
}

return ul, k, nil
addEvent(nb, k)
continue
}

return nb, nil, input.ErrUnknownEvent
}

return peekedBytes, ev, nil
}

func (d *driver) parseCsi(i int, p []byte) (int, input.Event, error) {
start := i
func (d *driver) parseCsi(i int, p []byte, alt bool) (n int, e input.Event, err error) {
if p[i] == '[' {
n++
}

i++
seq := "\x1b["

// Scan parameter bytes in the range 0x30-0x3F
for ; p[i] >= 0x30 && p[i] <= 0x3F; i++ {
n++
seq += string(p[i])
}
// Scan intermediate bytes in the range 0x20-0x2F
for ; p[i] >= 0x20 && p[i] <= 0x2F; i++ {
n++
seq += string(p[i])
}
// Scan final byte in the range 0x40-0x7E
if p[i] < 0x40 || p[i] > 0x7E {
return i, nil, fmt.Errorf("%w: invalid CSI sequence: %q", input.ErrUnknownEvent, p[start:i+1])
return n, nil, fmt.Errorf("%w: invalid CSI sequence: %q", input.ErrUnknownEvent, seq[2:])
}
n++
seq += string(p[i])

// Handle X10 mouse
if seq == "\x1b[M" && i+3 < len(p) {
btn := int(p[i+1] - 32)
x := int(p[i+2] - 32)
y := int(p[i+3] - 32)
return n + 3, input.MouseEvent{X: x, Y: y, Btn: input.Button(btn)}, nil
}

k, ok := d.table[seq]
if ok {
return i + 1, input.KeyEvent(k), nil
if alt {
k.Mod |= input.Alt
}
return n, input.KeyEvent(k), nil
}

return i + 1, nil, fmt.Errorf("%w: unknown CSI sequence: %q (%q)", input.ErrUnknownEvent, seq, p[start:i+1])
return n, csiSequence(seq), nil
}

// parseSs3 parses a SS3 sequence.
// See https://vt100.net/docs/vt220-rm/chapter4.html#S4.4.4.2
func (d *driver) parseSs3(i int, p []byte) (int, input.Event, error) {
func (d *driver) parseSs3(i int, p []byte, alt bool) (n int, e input.Event, err error) {
if p[i] == 'O' {
n++
}

i++
seq := "\x1bO"

// Scan a GL character
// A GL character is a single byte in the range 0x21-0x7E
// See https://vt100.net/docs/vt220-rm/chapter2.html#S2.3.2
if p[i] < 0x21 || p[i] > 0x7E {
return i, nil, fmt.Errorf("%w: invalid SS3 sequence: %q", input.ErrUnknownEvent, p[i])
return n, nil, fmt.Errorf("%w: invalid SS3 sequence: %q", input.ErrUnknownEvent, p[i])
}
n++
seq += string(p[i])

k, ok := d.table[seq]
if ok {
return i + 1, input.KeyEvent(k), nil
if alt {
k.Mod |= input.Alt
}
return n, input.KeyEvent(k), nil
}

return n, ss3Sequence(seq), nil
}

func (d *driver) parseOsc(i int, p []byte, _ bool) (n int, e input.Event, err error) {
if p[i] == ']' {
n++
}

i++
seq := "\x1b]"

// Scan a OSC sequence
// An OSC sequence is terminated by a BEL, ESC, or ST character
for ; p[i] != ansi.BEL && p[i] != ansi.ESC && p[i] != ansi.ST; i++ {
n++
seq += string(p[i])
}
n++
seq += string(p[i])

// Check 7-bit ST (string terminator) character
if len(p) > i+1 && p[i] == ansi.ESC && p[i+1] == '\\' {
seq += string(p[i+1])
n++
}

return i + 1, nil, fmt.Errorf("%w: unknown SS3 sequence: %q", input.ErrUnknownEvent, seq)
return n, oscSequence(seq), nil
}

func utf8ByteLen(b byte) int {
Expand Down
Loading

0 comments on commit 7f3fb42

Please sign in to comment.