diff --git a/internal/flag/set.go b/internal/flag/set.go index 8c951a2..d8536f4 100644 --- a/internal/flag/set.go +++ b/internal/flag/set.go @@ -13,19 +13,21 @@ import ( // Set is a set of command line flags. type Set struct { - // TODO(@FollowTheProcess): Figure out a way so that we don't have to store 2 maps - // as it's very memory wasteful - flags map[string]Entry // The actual stored flags, can lookup by name - shorthands map[rune]Entry // The flags by shorthand - args []string // Arguments minus flags or flag values - extra []string // Arguments after "--" was hit + // Note: flags and shorthands are different "views" to the same *Entry, the Entry + // is not duplicated, it's just two maps to the same pointer so we can look up + // using either + + flags map[string]*Entry // The actual stored flags, can lookup by name + shorthands map[rune]*Entry // The flags by shorthand + args []string // Arguments minus flags or flag values + extra []string // Arguments after "--" was hit } // NewSet builds and returns a new set of flags. func NewSet() *Set { return &Set{ - flags: make(map[string]Entry), - shorthands: make(map[rune]Entry), + flags: make(map[string]*Entry), + shorthands: make(map[rune]*Entry), } } @@ -56,7 +58,7 @@ func AddToSet[T Flaggable](set *Set, flag Flag[T]) error { defaultValueNoArg = "1" } - entry := Entry{ + entry := &Entry{ Value: flag, Name: flag.name, Usage: flag.usage, @@ -76,26 +78,26 @@ func AddToSet[T Flaggable](set *Set, flag Flag[T]) error { // Get gets a flag [Entry] from the Set by name and a boolean to indicate // whether it was present. -func (s *Set) Get(name string) (Entry, bool) { +func (s *Set) Get(name string) (*Entry, bool) { if s == nil { - return Entry{}, false + return nil, false } entry, ok := s.flags[name] if !ok { - return Entry{}, false + return nil, false } return entry, true } // GetShort gets a flag [Entry] from the Set by it's shorthand and a boolean to indicate // whether it was present. -func (s *Set) GetShort(short rune) (Entry, bool) { +func (s *Set) GetShort(short rune) (*Entry, bool) { if s == nil { - return Entry{}, false + return nil, false } entry, ok := s.shorthands[short] if !ok { - return Entry{}, false + return nil, false } return entry, true } @@ -206,6 +208,9 @@ func (s *Set) Usage() (string, error) { for _, name := range names { entry := s.flags[name] + if entry == nil { + return "", fmt.Errorf("*Entry stored against key %s was nil", name) // Should never happen + } var shorthand string if entry.Shorthand != NoShortHand { shorthand = fmt.Sprintf("-%s", string(entry.Shorthand)) diff --git a/internal/flag/set_test.go b/internal/flag/set_test.go index 0dafb67..9dc6284 100644 --- a/internal/flag/set_test.go +++ b/internal/flag/set_test.go @@ -36,7 +36,7 @@ func TestParse(t *testing.T) { t.Helper() f, exists := set.Get("something") test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) test.EqualFunc(t, set.Args(), nil, slices.Equal) }, @@ -53,7 +53,7 @@ func TestParse(t *testing.T) { t.Helper() f, exists := set.Get("something") test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) test.EqualFunc(t, set.Args(), []string{"some", "args", "here", "no", "flags"}, slices.Equal) }, @@ -70,7 +70,7 @@ func TestParse(t *testing.T) { t.Helper() f, exists := set.Get("something") test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) test.EqualFunc(t, set.Args(), []string{"some", "args", "here", "no", "flags", "extra", "args"}, slices.Equal) test.EqualFunc(t, set.ExtraArgs(), []string{"extra", "args"}, slices.Equal) @@ -88,7 +88,7 @@ func TestParse(t *testing.T) { t.Helper() f, exists := set.Get("flag") test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) test.EqualFunc(t, set.Args(), []string{"some", "args", "here"}, slices.Equal) }, @@ -827,6 +827,7 @@ func TestParse(t *testing.T) { // Get by short entry, exists = set.GetShort('c') test.False(t, exists) // Short shouldn't exist + test.Equal(t, entry, nil) }, args: []string{"--count", "1"}, wantErr: false, @@ -856,6 +857,7 @@ func TestParse(t *testing.T) { // Get by short entry, exists = set.GetShort('c') test.False(t, exists) // Short shouldn't exist + test.Equal(t, entry, nil) }, args: []string{"-c", "1"}, wantErr: true, @@ -968,7 +970,7 @@ func TestFlagSet(t *testing.T) { t.Helper() f, exists := set.Get("missing") test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) }, }, { @@ -981,7 +983,7 @@ func TestFlagSet(t *testing.T) { t.Helper() f, exists := set.GetShort('d') test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) }, }, { @@ -994,7 +996,7 @@ func TestFlagSet(t *testing.T) { t.Helper() f, exists := set.Get("missing") test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) }, }, { @@ -1007,7 +1009,7 @@ func TestFlagSet(t *testing.T) { t.Helper() f, exists := set.GetShort('m') test.False(t, exists) - test.Equal(t, f, flag.Entry{}) + test.Equal(t, f, nil) }, }, {