diff --git a/command.go b/command.go index 82d37d5..d0a1627 100644 --- a/command.go +++ b/command.go @@ -84,10 +84,10 @@ func (c *Command) predict(a Args) (options []string, only bool) { // if last completed word is a global flag that we need to complete if predictor, ok := c.GlobalFlags[a.LastCompleted]; ok && predictor != nil { Log("Predicting according to global flag %s", a.LastCompleted) - return predictor.Predict(a), true + return predictAndFilterPrefix(predictor, a), true } - options = append(options, c.GlobalFlags.Predict(a)...) + options = append(options, predictAndFilterPrefix(c.GlobalFlags, a)...) // if a sub command was entered, we won't add the parent command // completions and we return here. @@ -98,13 +98,13 @@ func (c *Command) predict(a Args) (options []string, only bool) { // if last completed word is a command flag that we need to complete if predictor, ok := c.Flags[a.LastCompleted]; ok && predictor != nil { Log("Predicting according to flag %s", a.LastCompleted) - return predictor.Predict(a), true + return predictAndFilterPrefix(predictor, a), true } - options = append(options, c.Sub.Predict(a)...) - options = append(options, c.Flags.Predict(a)...) + options = append(options, predictAndFilterPrefix(c.Sub, a)...) + options = append(options, predictAndFilterPrefix(c.Flags, a)...) if c.Args != nil { - options = append(options, c.Args.Predict(a)...) + options = append(options, predictAndFilterPrefix(c.Args, a)...) } return diff --git a/complete.go b/complete.go index 423cbec..e15b244 100644 --- a/complete.go +++ b/complete.go @@ -6,7 +6,6 @@ import ( "io" "os" "strconv" - "strings" "github.com/posener/complete/cmd" ) @@ -69,15 +68,7 @@ func (c *Complete) Complete() bool { options := c.Command.Predict(a) Log("Options: %s", options) - // filter only options that match the last argument - matches := []string{} - for _, option := range options { - if strings.HasPrefix(option, a.Last) { - matches = append(matches, option) - } - } - Log("Matches: %s", matches) - c.output(matches) + c.output(options) return true } diff --git a/complete_test.go b/complete_test.go index 45fa304..ceddcfb 100644 --- a/complete_test.go +++ b/complete_test.go @@ -13,7 +13,7 @@ import ( func TestCompleter_Complete(t *testing.T) { initTests() - c := Command{ + cmd := Command{ Sub: Commands{ "sub1": { Flags: Flags{ @@ -28,6 +28,18 @@ func TestCompleter_Complete(t *testing.T) { }, Args: PredictFiles("*.md"), }, + "permissiveSub": { + Args: &PrefixFilteringPredictor{ + Predictor: PredictSet("aaa", "bbb", "Aab"), + PrefixFilterFunc: PermissivePrefixFilter, + }, + }, + "caseInsensitiveSub": { + Args: &PrefixFilteringPredictor{ + Predictor: PredictSet("aaa", "bbb", "Aab", "åaa"), + PrefixFilterFunc: CaseInsensitivePrefixFilter, + }, + }, }, Flags: Flags{ "-o": PredictFiles("*.txt"), @@ -37,7 +49,8 @@ func TestCompleter_Complete(t *testing.T) { "-global1": PredictAnything, }, } - cmp := New("cmd", c) + + cmp := New("cmd", cmd) tests := []struct { line string @@ -47,7 +60,27 @@ func TestCompleter_Complete(t *testing.T) { { line: "cmd ", point: -1, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, + }, + { + line: "cmd permissiveSub ", + point: -1, + want: []string{"aaa", "bbb", "Aab"}, + }, + { + line: "cmd permissiveSub a", + point: -1, + want: []string{"aaa", "bbb", "Aab"}, + }, + { + line: "cmd caseInsensitiveSub ", + point: -1, + want: []string{"aaa", "bbb", "Aab", "åaa"}, + }, + { + line: "cmd caseInsensitiveSub a", + point: -1, + want: []string{"aaa", "Aab"}, }, { line: "cmd -", @@ -57,7 +90,7 @@ func TestCompleter_Complete(t *testing.T) { { line: "cmd -h ", point: -1, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, }, { line: "cmd -global1 ", // global1 is known follow flag @@ -142,7 +175,7 @@ func TestCompleter_Complete(t *testing.T) { { line: "cmd -no-such-flag ", point: -1, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, }, { line: "cmd -no-such-flag -", @@ -157,7 +190,7 @@ func TestCompleter_Complete(t *testing.T) { { line: "cmd no-such-command ", point: -1, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, }, { line: "cmd -o ", @@ -212,12 +245,12 @@ func TestCompleter_Complete(t *testing.T) { { line: "cmd -o ./readme.md ", point: -1, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, }, { line: "cmd -o=./readme.md ", point: -1, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, }, { line: "cmd -o sub2 -flag3 ", @@ -256,7 +289,7 @@ func TestCompleter_Complete(t *testing.T) { line: "cmd -o ", // ^ point: 4, - want: []string{"sub1", "sub2"}, + want: []string{"sub1", "sub2", "permissiveSub", "caseInsensitiveSub"}, }, } diff --git a/predict.go b/predict.go index 8207063..9646ee1 100644 --- a/predict.go +++ b/predict.go @@ -14,7 +14,7 @@ func PredictOr(predictors ...Predictor) Predictor { if p == nil { continue } - prediction = append(prediction, p.Predict(a)...) + prediction = append(prediction, predictAndFilterPrefix(p, a)...) } return }) @@ -39,3 +39,18 @@ var PredictNothing Predictor // PredictAnything expects something, but nothing particular, such as a number // or arbitrary name. var PredictAnything = PredictFunc(func(Args) []string { return nil }) + +func predictAndFilterPrefix(p Predictor, a Args) []string { + options := p.Predict(a) + prefixerFunc := defaultPrefixFilter + if prefixFilter, ok := p.(PrefixFilter); ok { + prefixerFunc = prefixFilter.FilterPrefix + } + matches := make([]string, 0, len(options)) + for _, option := range options { + if prefixerFunc(option, a.Last) { + matches = append(matches, option) + } + } + return matches +} diff --git a/predict_test.go b/predict_test.go index c376207..b85926c 100644 --- a/predict_test.go +++ b/predict_test.go @@ -3,11 +3,46 @@ package complete import ( "fmt" "os" + "reflect" "sort" "strings" "testing" ) +func Test_predictAndFilterPrefix(t *testing.T) { + t.Parallel() + initTests() + + t.Run("default prefix filter", func(t *testing.T) { + predictor := PredictSet("a", "ab", "b", "c") + args := Args{ + Last: "a", + } + want := []string{"a", "ab"} + got := predictAndFilterPrefix(predictor, args) + if !reflect.DeepEqual(want, got) { + t.Errorf("unexpected result\nwant: %v\ngot: %v", want, got) + } + }) + + t.Run("permissive filter", func(t *testing.T) { + predictor := PredictSet("a", "ab", "b", "c") + predictor = &PrefixFilteringPredictor{ + Predictor: predictor, + PrefixFilterFunc: PermissivePrefixFilter, + } + + args := Args{ + Last: "a", + } + want := []string{"a", "ab", "b", "c"} + got := predictAndFilterPrefix(predictor, args) + if !reflect.DeepEqual(want, got) { + t.Errorf("unexpected result\nwant: %v\ngot: %v", want, got) + } + }) +} + func TestPredicate(t *testing.T) { t.Parallel() initTests() diff --git a/prefixfilter.go b/prefixfilter.go new file mode 100644 index 0000000..dda345c --- /dev/null +++ b/prefixfilter.go @@ -0,0 +1,48 @@ +package complete + +import ( + "strings" +) + +// PrefixFilter filters a predictor's options based on the prefix +type PrefixFilter interface { + FilterPrefix(str, prefix string) bool +} + +// PrefixFilteringPredictor is a Predictor that also implements PrefixFilter +type PrefixFilteringPredictor struct { + Predictor Predictor + PrefixFilterFunc func(s, prefix string) bool +} + +func (p *PrefixFilteringPredictor) Predict(a Args) []string { + if p.Predictor == nil { + return []string{} + } + return p.Predictor.Predict(a) +} + +func (p *PrefixFilteringPredictor) FilterPrefix(str, prefix string) bool { + if p.PrefixFilterFunc == nil { + return defaultPrefixFilter(str, prefix) + } + return p.PrefixFilterFunc(str, prefix) +} + +// defaultPrefixFilter is the PrefixFilter used when none is set +func defaultPrefixFilter(s, prefix string) bool { + return strings.HasPrefix(s, prefix) +} + +// PermissivePrefixFilter always returns true +func PermissivePrefixFilter(_, _ string) bool { + return true +} + +// CaseInsensitivePrefixFilter ignores case differences between the prefix and tested string +func CaseInsensitivePrefixFilter(s, prefix string) bool { + if len(prefix) > len(s) { + return false + } + return strings.EqualFold(prefix, s[:len(prefix)]) +} diff --git a/prefixfilter_test.go b/prefixfilter_test.go new file mode 100644 index 0000000..27204ab --- /dev/null +++ b/prefixfilter_test.go @@ -0,0 +1,151 @@ +package complete + +import ( + "fmt" + "reflect" + "testing" +) + +func TestPrefixFilteringPredictor_Predict(t *testing.T) { + t.Parallel() + initTests() + + t.Run("defaults to empty list", func(t *testing.T) { + pfp := &PrefixFilteringPredictor{} + got := pfp.Predict(Args{}) + if len(got) != 0 { + t.Fail() + } + }) + + t.Run("passes request to Predictor", func(t *testing.T) { + args := Args{ + All: []string{"a"}, + } + want := []string{"b"} + predictFunc := PredictFunc(func(a Args) []string { + if !reflect.DeepEqual(a, args) { + t.Errorf("unexpected args: %v", a) + } + return want + }) + pfp := &PrefixFilteringPredictor{ + Predictor: predictFunc, + } + got := pfp.Predict(args) + if !reflect.DeepEqual(want, got) { + t.Errorf("unexpected result: %v", got) + } + }) +} + +func TestPrefixFilteringPredictor_FilterPrefix(t *testing.T) { + t.Parallel() + initTests() + + t.Run("default PrefixFilterFunc", func(t *testing.T) { + for _, td := range []struct { + s string + prefix string + want bool + }{ + { + s: "ohm", + prefix: "ohm", + want: true, + }, + { + s: "ohm", + prefix: "", + want: true, + }, + { + s: "ohm", + prefix: "O", + want: false, + }, + { + s: "ohm", + prefix: "q", + want: false, + }, + { + s: "öhm", + prefix: "o", + want: false, + }, + { + s: "ohm", + prefix: "ohmy", + want: false, + }, + } { + t.Run(fmt.Sprintf("%s %s", td.s, td.prefix), func(t *testing.T) { + pfp := &PrefixFilteringPredictor{} + got := pfp.FilterPrefix(td.s, td.prefix) + if td.want != got { + t.Errorf("failed %s\ngot: %v\nwant: %v", t.Name(), got, td.want) + } + }) + } + }) + + t.Run("CaseInsensitivePrefixFilter", func(t *testing.T) { + for _, td := range []struct { + s string + prefix string + want bool + }{ + { + s: "ohm", + prefix: "ohm", + want: true, + }, + { + s: "ohm", + prefix: "", + want: true, + }, + { + s: "ohm", + prefix: "O", + want: true, + }, + { + s: "ohm", + prefix: "q", + want: false, + }, + { + s: "öhm", + prefix: "o", + want: false, + }, + { + s: "ohm", + prefix: "ohmy", + want: false, + }, + } { + t.Run(fmt.Sprintf("%s %s", td.s, td.prefix), func(t *testing.T) { + pfp := &PrefixFilteringPredictor{ + PrefixFilterFunc: CaseInsensitivePrefixFilter, + } + got := pfp.FilterPrefix(td.s, td.prefix) + if td.want != got { + t.Errorf("failed %s\ngot: %v\nwant: %v", t.Name(), got, td.want) + } + }) + } + }) + + t.Run("PermissivePrefixFilter", func(t *testing.T) { + pfp := &PrefixFilteringPredictor{ + PrefixFilterFunc: PermissivePrefixFilter, + } + got := pfp.FilterPrefix("", "") + if !got { + t.Errorf("should have returned true, but didn't") + } + }) +}