From 51501044374fb92d1bda8ddf243f3a70d3cd38bc Mon Sep 17 00:00:00 2001
From: Axel Wagner <mero@merovius.de>
Date: Fri, 3 Feb 2023 22:19:25 +0100
Subject: [PATCH] testscript: suggest misspelled commands

If a command is not found, we go through the list of defined commands
and check if any of them are sufficiently close to the one used.
"Sufficiently close" is defined by having a Damerau-Levenshtein distance
of 1, which feels like it hits the sweet spot between usefulness and
ease of implementation.

The negation case is still special-cased, as negation is not in the set
of defined commands.

Fixes #190
---
 internal/misspell/misspell.go                 |  68 ++++++++++++
 internal/misspell/misspell_test.go            | 100 ++++++++++++++++++
 .../fuzz/FuzzAlmostEqual/295b316649ae86dd     |   3 +
 .../fuzz/FuzzAlmostEqual/5bd9cd4e8c887808     |   3 +
 .../fuzz/FuzzAlmostEqual/b323cef1fc26e507     |   3 +
 .../fuzz/FuzzAlmostEqual/c6edde4256d6f5eb     |   3 +
 testscript/testdata/testscript_notfound.txt   |  17 +++
 testscript/testscript.go                      |  49 ++++++++-
 8 files changed, 245 insertions(+), 1 deletion(-)
 create mode 100644 internal/misspell/misspell.go
 create mode 100644 internal/misspell/misspell_test.go
 create mode 100644 internal/misspell/testdata/fuzz/FuzzAlmostEqual/295b316649ae86dd
 create mode 100644 internal/misspell/testdata/fuzz/FuzzAlmostEqual/5bd9cd4e8c887808
 create mode 100644 internal/misspell/testdata/fuzz/FuzzAlmostEqual/b323cef1fc26e507
 create mode 100644 internal/misspell/testdata/fuzz/FuzzAlmostEqual/c6edde4256d6f5eb
 create mode 100644 testscript/testdata/testscript_notfound.txt

diff --git a/internal/misspell/misspell.go b/internal/misspell/misspell.go
new file mode 100644
index 00000000..6e1ae95e
--- /dev/null
+++ b/internal/misspell/misspell.go
@@ -0,0 +1,68 @@
+// Package misspell impements utilities for basic spelling correction.
+package misspell
+
+import (
+	"unicode/utf8"
+)
+
+// AlmostEqual reports whether a and b have Damerau-Levenshtein distance of at
+// most 1. That is, it reports whether a can be transformed into b by adding,
+// removing or substituting a single rune, or by swapping two adjacent runes.
+// Invalid runes are considered equal.
+//
+// It runs in O(len(a)+len(b)) time.
+func AlmostEqual(a, b string) bool {
+	for len(a) > 0 && len(b) > 0 {
+		ra, tailA := shiftRune(a)
+		rb, tailB := shiftRune(b)
+		if ra == rb {
+			a, b = tailA, tailB
+			continue
+		}
+		// check for addition/deletion/substitution
+		if equalValid(a, tailB) || equalValid(tailA, b) || equalValid(tailA, tailB) {
+			return true
+		}
+		if len(tailA) == 0 || len(tailB) == 0 {
+			return false
+		}
+		// check for swap
+		a, b = tailA, tailB
+		Ra, tailA := shiftRune(tailA)
+		Rb, tailB := shiftRune(tailB)
+		return ra == Rb && Ra == rb && equalValid(tailA, tailB)
+	}
+	if len(a) == 0 {
+		return len(b) == 0 || singleRune(b)
+	}
+	return singleRune(a)
+}
+
+// singleRune reports whether s consists of a single UTF-8 codepoint.
+func singleRune(s string) bool {
+	_, n := utf8.DecodeRuneInString(s)
+	return n == len(s)
+}
+
+// shiftRune splits off the first UTF-8 codepoint from s and returns it and the
+// rest of the string. It panics if s is empty.
+func shiftRune(s string) (rune, string) {
+	if len(s) == 0 {
+		panic(s)
+	}
+	r, n := utf8.DecodeRuneInString(s)
+	return r, s[n:]
+}
+
+// equalValid reports whether a and b are equal, if invalid code points are considered equal.
+func equalValid(a, b string) bool {
+	var ra, rb rune
+	for len(a) > 0 && len(b) > 0 {
+		ra, a = shiftRune(a)
+		rb, b = shiftRune(b)
+		if ra != rb {
+			return false
+		}
+	}
+	return len(a) == 0 && len(b) == 0
+}
diff --git a/internal/misspell/misspell_test.go b/internal/misspell/misspell_test.go
new file mode 100644
index 00000000..690006e7
--- /dev/null
+++ b/internal/misspell/misspell_test.go
@@ -0,0 +1,100 @@
+package misspell
+
+import (
+	"math"
+	"testing"
+)
+
+func TestAlmostEqual(t *testing.T) {
+	t.Parallel()
+
+	tcs := []struct {
+		inA  string
+		inB  string
+		want bool
+	}{
+		{"", "", true},
+		{"", "a", true},
+		{"a", "a", true},
+		{"a", "b", true},
+		{"hello", "hell", true},
+		{"hello", "jello", true},
+		{"hello", "helol", true},
+		{"hello", "jelol", false},
+	}
+	for _, tc := range tcs {
+		got := AlmostEqual(tc.inA, tc.inB)
+		if got != tc.want {
+			t.Errorf("AlmostEqual(%q, %q) = %v, want %v", tc.inA, tc.inB, got, tc.want)
+		}
+		// two tests for the price of one \o/
+		if got != AlmostEqual(tc.inB, tc.inA) {
+			t.Errorf("AlmostEqual(%q, %q) == %v != AlmostEqual(%q, %q)", tc.inA, tc.inB, got, tc.inB, tc.inA)
+		}
+	}
+}
+
+func FuzzAlmostEqual(f *testing.F) {
+	f.Add("", "")
+	f.Add("", "a")
+	f.Add("a", "a")
+	f.Add("a", "b")
+	f.Add("hello", "hell")
+	f.Add("hello", "jello")
+	f.Add("hello", "helol")
+	f.Add("hello", "jelol")
+	f.Fuzz(func(t *testing.T, a, b string) {
+		if len(a) > 10 || len(b) > 10 {
+			// longer strings won't add coverage, but take longer to check
+			return
+		}
+		d := editDistance([]rune(a), []rune(b))
+		got := AlmostEqual(a, b)
+		if want := d <= 1; got != want {
+			t.Errorf("AlmostEqual(%q, %q) = %v, editDistance(%q, %q) = %d", a, b, got, a, b, d)
+		}
+		if got != AlmostEqual(b, a) {
+			t.Errorf("AlmostEqual(%q, %q) == %v != AlmostEqual(%q, %q)", a, b, got, b, a)
+		}
+	})
+}
+
+// editDistance returns the Damerau-Levenshtein distance between a and b. It is
+// inefficient, but by keeping almost verbatim to the recursive definition from
+// Wikipedia, hopefully "obviously correct" and thus suitable for the fuzzing
+// test of AlmostEqual.
+func editDistance(a, b []rune) int {
+	i, j := len(a), len(b)
+	m := math.MaxInt
+	if i == 0 && j == 0 {
+		return 0
+	}
+	if i > 0 {
+		m = min(m, editDistance(a[1:], b)+1)
+	}
+	if j > 0 {
+		m = min(m, editDistance(a, b[1:])+1)
+	}
+	if i > 0 && j > 0 {
+		d := editDistance(a[1:], b[1:])
+		if a[0] != b[0] {
+			d += 1
+		}
+		m = min(m, d)
+	}
+	if i > 1 && j > 1 && a[0] == b[1] && a[1] == b[0] {
+		d := editDistance(a[2:], b[2:])
+		if a[0] != b[0] {
+			d += 1
+		}
+		m = min(m, d)
+	}
+	return m
+}
+
+func min(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
diff --git a/internal/misspell/testdata/fuzz/FuzzAlmostEqual/295b316649ae86dd b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/295b316649ae86dd
new file mode 100644
index 00000000..ce1515b2
--- /dev/null
+++ b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/295b316649ae86dd
@@ -0,0 +1,3 @@
+go test fuzz v1
+string("")
+string("00")
diff --git a/internal/misspell/testdata/fuzz/FuzzAlmostEqual/5bd9cd4e8c887808 b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/5bd9cd4e8c887808
new file mode 100644
index 00000000..8aaf0cf7
--- /dev/null
+++ b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/5bd9cd4e8c887808
@@ -0,0 +1,3 @@
+go test fuzz v1
+string("\x980")
+string("0\xb70")
diff --git a/internal/misspell/testdata/fuzz/FuzzAlmostEqual/b323cef1fc26e507 b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/b323cef1fc26e507
new file mode 100644
index 00000000..e2e74e13
--- /dev/null
+++ b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/b323cef1fc26e507
@@ -0,0 +1,3 @@
+go test fuzz v1
+string("OOOOOOOO000")
+string("0000000000000")
diff --git a/internal/misspell/testdata/fuzz/FuzzAlmostEqual/c6edde4256d6f5eb b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/c6edde4256d6f5eb
new file mode 100644
index 00000000..4fcd6936
--- /dev/null
+++ b/internal/misspell/testdata/fuzz/FuzzAlmostEqual/c6edde4256d6f5eb
@@ -0,0 +1,3 @@
+go test fuzz v1
+string("OOOOOOOO000")
+string("0000000000\x1000")
diff --git a/testscript/testdata/testscript_notfound.txt b/testscript/testdata/testscript_notfound.txt
new file mode 100644
index 00000000..4248e681
--- /dev/null
+++ b/testscript/testdata/testscript_notfound.txt
@@ -0,0 +1,17 @@
+# Check that unknown commands output a useful error message
+
+! testscript notfound
+stdout 'unknown command "notexist"'
+
+! testscript negation
+stdout 'unknown command "!exists" \(did you mean "! exists"\?\)'
+
+! testscript misspelled
+stdout 'unknown command "exits" \(did you mean "exists"\?\)'
+
+-- notfound/script.txt --
+notexist
+-- negation/script.txt --
+!exists file
+-- misspelled/script.txt --
+exits file
diff --git a/testscript/testscript.go b/testscript/testscript.go
index 8c68b9c5..dc40a321 100644
--- a/testscript/testscript.go
+++ b/testscript/testscript.go
@@ -22,6 +22,7 @@ import (
 	"path/filepath"
 	"regexp"
 	"runtime"
+	"sort"
 	"strings"
 	"sync/atomic"
 	"syscall"
@@ -29,6 +30,7 @@ import (
 	"time"
 
 	"github.com/rogpeppe/go-internal/imports"
+	"github.com/rogpeppe/go-internal/internal/misspell"
 	"github.com/rogpeppe/go-internal/internal/os/execpath"
 	"github.com/rogpeppe/go-internal/par"
 	"github.com/rogpeppe/go-internal/testenv"
@@ -669,7 +671,16 @@ func (ts *TestScript) runLine(line string) (runOK bool) {
 		cmd = ts.params.Cmds[args[0]]
 	}
 	if cmd == nil {
-		ts.Fatalf("unknown command %q", args[0])
+		// try to find spelling corrections. We arbitrarily limit the number of
+		// corrections, to not be too noisy.
+		switch c := ts.cmdSuggestions(args[0]); len(c) {
+		case 1:
+			ts.Fatalf("unknown command %q (did you mean %q?)", args[0], c[0])
+		case 2, 3, 4:
+			ts.Fatalf("unknown command %q (did you mean one of %q?)", args[0], c)
+		default:
+			ts.Fatalf("unknown command %q", args[0])
+		}
 	}
 	ts.callBuiltinCmd(args[0], func() {
 		cmd(ts, neg, args[1:])
@@ -694,6 +705,42 @@ func (ts *TestScript) callBuiltinCmd(cmd string, runCmd func()) {
 	runCmd()
 }
 
+func (ts *TestScript) cmdSuggestions(name string) []string {
+	// special case: spell-correct `!cmd` to `! cmd`
+	if strings.HasPrefix(name, "!") {
+		if _, ok := scriptCmds[name[1:]]; ok {
+			return []string{"! " + name[1:]}
+		}
+		if _, ok := ts.params.Cmds[name[1:]]; ok {
+			return []string{"! " + name[1:]}
+		}
+	}
+	var candidates []string
+	for c := range scriptCmds {
+		if misspell.AlmostEqual(name, c) {
+			candidates = append(candidates, c)
+		}
+	}
+	for c := range ts.params.Cmds {
+		if misspell.AlmostEqual(name, c) {
+			candidates = append(candidates, c)
+		}
+	}
+	if len(candidates) == 0 {
+		return nil
+	}
+	// deduplicate candidates
+	// TODO: Use slices.Compact (and maybe slices.Sort) once we can use Go 1.21
+	sort.Strings(candidates)
+	out := candidates[:1]
+	for _, c := range candidates[1:] {
+		if out[len(out)-1] == c {
+			out = append(out, c)
+		}
+	}
+	return out
+}
+
 func (ts *TestScript) applyScriptUpdates() {
 	if len(ts.scriptUpdates) == 0 {
 		return