diff --git a/backend/flags/flags.go b/backend/flags/flags.go index 68f3e66..a255115 100644 --- a/backend/flags/flags.go +++ b/backend/flags/flags.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "os" "reflect" "time" @@ -12,11 +13,15 @@ import ( ) // Backend that loads configuration from the command line flags. -type Backend struct{} +type Backend struct { + flags *flag.FlagSet +} // NewBackend creates a flags backend. func NewBackend() *Backend { - return new(Backend) + return &Backend{ + flags: flag.CommandLine, + } } // LoadStruct takes a struct config, define flags based on it and parse the command line args. @@ -34,80 +39,90 @@ func (b *Backend) LoadStruct(ctx context.Context, cfg *confita.StructConfig) err switch { case f.Value.Type().String() == "time.Duration": var val time.Duration - flag.DurationVar(&val, f.Key, time.Duration(f.Default.Int()), f.Description) + b.flags.DurationVar(&val, f.Key, time.Duration(f.Default.Int()), f.Description) if f.Short != "" { - flag.DurationVar(&val, f.Short, time.Duration(f.Default.Int()), shortDesc(f.Description)) + b.flags.DurationVar(&val, f.Short, time.Duration(f.Default.Int()), shortDesc(f.Description)) } // this function must be executed after the flag.Parse call. defer func() { // if the user has set the flag, save the value in the field. - if isFlagSet(f) { + if b.isFlagSet(f) { f.Value.SetInt(int64(val)) } }() case k == reflect.Bool: var val bool - flag.BoolVar(&val, f.Key, f.Default.Bool(), f.Description) + b.flags.BoolVar(&val, f.Key, f.Default.Bool(), f.Description) if f.Short != "" { - flag.BoolVar(&val, f.Short, f.Default.Bool(), shortDesc(f.Description)) + b.flags.BoolVar(&val, f.Short, f.Default.Bool(), shortDesc(f.Description)) } defer func() { - if isFlagSet(f) { + if b.isFlagSet(f) { f.Value.SetBool(val) } }() case k >= reflect.Int && k <= reflect.Int64: var val int - flag.IntVar(&val, f.Key, int(f.Default.Int()), f.Description) + b.flags.IntVar(&val, f.Key, int(f.Default.Int()), f.Description) if f.Short != "" { - flag.IntVar(&val, f.Short, int(f.Default.Int()), shortDesc(f.Description)) + b.flags.IntVar(&val, f.Short, int(f.Default.Int()), shortDesc(f.Description)) } defer func() { - if isFlagSet(f) { + if b.isFlagSet(f) { f.Value.SetInt(int64(val)) } }() case k >= reflect.Uint && k <= reflect.Uint64: var val uint64 - flag.Uint64Var(&val, f.Key, f.Default.Uint(), f.Description) + b.flags.Uint64Var(&val, f.Key, f.Default.Uint(), f.Description) if f.Short != "" { - flag.Uint64Var(&val, f.Short, f.Default.Uint(), shortDesc(f.Description)) + b.flags.Uint64Var(&val, f.Short, f.Default.Uint(), shortDesc(f.Description)) } defer func() { - if isFlagSet(f) { + if b.isFlagSet(f) { f.Value.SetUint(val) } }() case k >= reflect.Float32 && k <= reflect.Float64: var val float64 - flag.Float64Var(&val, f.Key, f.Default.Float(), f.Description) + b.flags.Float64Var(&val, f.Key, f.Default.Float(), f.Description) if f.Short != "" { - flag.Float64Var(&val, f.Short, f.Default.Float(), shortDesc(f.Description)) + b.flags.Float64Var(&val, f.Short, f.Default.Float(), shortDesc(f.Description)) } defer func() { - if isFlagSet(f) { + if b.isFlagSet(f) { f.Value.SetFloat(val) } }() case k == reflect.String: var val string - flag.StringVar(&val, f.Key, f.Default.String(), f.Description) + b.flags.StringVar(&val, f.Key, f.Default.String(), f.Description) if f.Short != "" { - flag.StringVar(&val, f.Short, f.Default.String(), shortDesc(f.Description)) + b.flags.StringVar(&val, f.Short, f.Default.String(), shortDesc(f.Description)) } defer func() { - if isFlagSet(f) { + if b.isFlagSet(f) { f.Value.SetString(val) } }() default: - flag.Var(&flagValue{f}, f.Key, f.Description) + b.flags.Var(&flagValue{f}, f.Key, f.Description) } } - flag.Parse() + // Note: in the usual case, when b.flags is flag.CommandLine, this will exit + // rather than returning an error. + return b.flags.Parse(os.Args[1:]) +} - return nil +func (b *Backend) isFlagSet(config *confita.FieldConfig) bool { + ok := false + b.flags.Visit(func(f *flag.Flag) { + if f.Name == config.Key || f.Name == config.Short { + ok = true + } + }) + return ok } type flagValue struct { @@ -139,11 +154,3 @@ func (b *Backend) Name() string { func shortDesc(description string) string { return fmt.Sprintf("%s (short)", description) } - -func isFlagSet(config *confita.FieldConfig) bool { - flagset := make(map[*confita.FieldConfig]bool) - flag.Visit(func(f *flag.Flag) { flagset[config] = true }) - - _, ok := flagset[config] - return ok -} diff --git a/backend/flags/flags_test.go b/backend/flags/flags_test.go index ea23f78..c4c3ea1 100644 --- a/backend/flags/flags_test.go +++ b/backend/flags/flags_test.go @@ -1,12 +1,9 @@ package flags import ( - "bytes" "context" - "encoding/json" - "fmt" + "flag" "os" - "os/exec" "testing" "time" @@ -15,158 +12,160 @@ import ( "github.com/stretchr/testify/require" ) -type Config struct { - A string `config:"a"` - Adef string `config:"a-def,short=ad"` - B bool `config:"b"` - Bdef bool `config:"b-def,short=bd"` - C time.Duration `config:"c"` - Cdef time.Duration `config:"c-def,short=cd"` - D int `config:"d"` - Ddef int `config:"d-def,short=dd"` - E uint `config:"e"` - Edef uint `config:"e-def,short=ed"` - F float32 `config:"f"` - Fdef float32 `config:"f-def,short=fd"` -} - -func runHelper(t *testing.T, args ...string) *Config { +func runHelper(t *testing.T, cfg interface{}, args ...string) { t.Helper() - var output bytes.Buffer - - cs := []string{"-test.run=TestHelperProcess", "--"} - cs = append(cs, args...) - cmd := exec.Command(os.Args[0], cs...) - cmd.Stderr = &output - cmd.Env = []string{"GO_HELPER_PROCESS=1"} - err := cmd.Run() - require.NoError(t, err) - - var cfg Config - - err = json.NewDecoder(&output).Decode(&cfg) + flags := flag.NewFlagSet("test", flag.ContinueOnError) + os.Args = append([]string{"a.out"}, args...) + err := confita.NewLoader(&Backend{flags}).Load(context.Background(), cfg) require.NoError(t, err) - - return &cfg } func TestFlags(t *testing.T) { t.Run("Use defaults", func(t *testing.T) { - cfg := runHelper(t, "-a=hello", "-b=true", "-c=10s", "-d=-100", "-e=1", "-f=100.01") - require.Equal(t, "hello", cfg.A) - require.Equal(t, true, cfg.B) - require.Equal(t, 10*time.Second, cfg.C) - require.Equal(t, -100, cfg.D) - require.Equal(t, uint(1), cfg.E) - require.Equal(t, float32(100.01), cfg.F) + type config struct { + A string `config:"a"` + B bool `config:"b"` + C time.Duration `config:"c"` + D int `config:"d"` + E uint `config:"e"` + F float32 `config:"f"` + } + var cfg config + runHelper(t, &cfg, "-a=hello", "-b=true", "-c=10s", "-d=-100", "-e=1", "-f=100.01") + require.Equal(t, config{ + A: "hello", + B: true, + C: 10 * time.Second, + D: -100, + E: 1, + F: 100.01, + }, cfg) }) t.Run("Override defaults", func(t *testing.T) { - cfg := runHelper(t, "-a-def=bye", "-b-def=false", "-c-def=15s", "-d-def=-200", "-e-def=400", "-f-def=2.33") - require.Equal(t, "bye", cfg.Adef) - require.Equal(t, false, cfg.Bdef) - require.Equal(t, 15*time.Second, cfg.Cdef) - require.Equal(t, -200, cfg.Ddef) - require.Equal(t, uint(400), cfg.Edef) - require.Equal(t, float32(2.33), cfg.Fdef) + type config struct { + Adef string `config:"a-def,short=ad"` + Bdef bool `config:"b-def,short=bd"` + Cdef time.Duration `config:"c-def,short=cd"` + Ddef int `config:"d-def,short=dd"` + Edef uint `config:"e-def,short=ed"` + Fdef float32 `config:"f-def,short=fd"` + } + cfg := &config{ + Adef: "hello", + Bdef: true, + Cdef: 10 * time.Second, + Ddef: -100, + } + runHelper(t, cfg, "-a-def=bye", "-b-def=false", "-c-def=15s", "-d-def=-200", "-e-def=400", "-f-def=2.33") + + require.Equal(t, &config{ + Adef: "bye", + Bdef: false, + Cdef: 15 * time.Second, + Ddef: -200, + Edef: 400, + Fdef: 2.33, + }, cfg) }) } func TestFlagsShort(t *testing.T) { - cfg := runHelper(t, "-ad=hello", "-bd=true", "-cd=20s", "-dd=500", "-ed=700", "-fd=333.33") - require.Equal(t, "hello", cfg.Adef) - require.Equal(t, true, cfg.Bdef) - require.Equal(t, 20*time.Second, cfg.Cdef) - require.Equal(t, 500, cfg.Ddef) - require.Equal(t, uint(700), cfg.Edef) - require.Equal(t, float32(333.33), cfg.Fdef) -} - -func TestFlagsMixed(t *testing.T) { - cfg := runHelper(t, "-ad=hello", "-b-def=true", "-cd=20s", "-d-def=500", "-ed=600", "-f-def=42.42") - require.Equal(t, "hello", cfg.Adef) - require.Equal(t, true, cfg.Bdef) - require.Equal(t, 20*time.Second, cfg.Cdef) - require.Equal(t, 500, cfg.Ddef) - require.Equal(t, uint(600), cfg.Edef) - require.Equal(t, float32(42.42), cfg.Fdef) -} - -func TestHelperProcess(t *testing.T) { - if os.Getenv("GO_HELPER_PROCESS") != "1" { - return - } - - args := os.Args - for len(args) > 0 { - if args[0] == "--" { - args = args[1:] - break - } - args = args[1:] + type config struct { + Adef string `config:"a-def,short=ad"` + Bdef bool `config:"b-def,short=bd"` + Cdef time.Duration `config:"c-def,short=cd"` + Ddef int `config:"d-def,short=dd"` + Edef uint `config:"e-def,short=ed"` + Fdef float32 `config:"f-def,short=fd"` } - if len(args) == 0 { - fmt.Fprintf(os.Stderr, "No args\n") - os.Exit(2) - } - - os.Args = append(os.Args[:1], args...) - - cfg := Config{ + cfg := &config{ Adef: "hello", Bdef: true, Cdef: 10 * time.Second, Ddef: -100, } - - err := confita.NewLoader(NewBackend()).Load(context.Background(), &cfg) - require.NoError(t, err) - err = json.NewEncoder(os.Stderr).Encode(&cfg) - require.NoError(t, err) - os.Exit(0) + runHelper(t, cfg, "-ad=hello", "-bd=true", "-cd=20s", "-dd=500", "-ed=700", "-fd=333.33") + require.Equal(t, &config{ + Adef: "hello", + Bdef: true, + Cdef: 20 * time.Second, + Ddef: 500, + Edef: 700, + Fdef: 333.33, + }, cfg) } -type store map[string]string - -func (s store) Get(ctx context.Context, key string) ([]byte, error) { - data, ok := s[key] - if !ok { - return nil, backend.ErrNotFound +func TestFlagsMixed(t *testing.T) { + type config struct { + Adef string `config:"a-def,short=ad"` + Bdef bool `config:"b-def,short=bd"` + Cdef time.Duration `config:"c-def,short=cd"` + Ddef int `config:"d-def,short=dd"` + Edef uint `config:"e-def,short=ed"` + Fdef float32 `config:"f-def,short=fd"` } - - return []byte(data), nil -} - -func (store) Name() string { - return "store" + cfg := &config{ + Adef: "hello", + Bdef: true, + Cdef: 10 * time.Second, + Ddef: -100, + } + runHelper(t, cfg, "-ad=hello", "-b-def=true", "-cd=20s", "-d-def=500", "-ed=600", "-f-def=42.42") + require.Equal(t, &config{ + Adef: "hello", + Bdef: true, + Cdef: 20 * time.Second, + Ddef: 500, + Edef: 600, + Fdef: 42.42, + }, cfg) } func TestWithAnotherBackend(t *testing.T) { - s := struct { + type config struct { String string `config:"string,required"` Bool bool `config:"bool,required"` Int int `config:"int,required"` Uint uint `config:"uint,required"` Float float64 `config:"float,required"` Duration time.Duration `config:"duration,required"` - }{} + } + + var cfg config st := store{ - "string": "string", "bool": "true", - "int": "42", "uint": "42", "float": "42.42", "duration": "1ns", } - err := confita.NewLoader(st, NewBackend()).Load(context.Background(), &s) + flags := flag.NewFlagSet("test", flag.ContinueOnError) + os.Args = append([]string{"a.out"}, "-int=42", "-string=string", "-float=99.5") + err := confita.NewLoader(st, &Backend{flags}).Load(context.Background(), &cfg) require.NoError(t, err) - require.Equal(t, "string", s.String) - require.Equal(t, true, s.Bool) - require.Equal(t, 42, s.Int) - require.EqualValues(t, 42, s.Uint) - require.Equal(t, 42.42, s.Float) - require.Equal(t, time.Duration(1), s.Duration) + + require.Equal(t, "string", cfg.String) + require.Equal(t, true, cfg.Bool) + require.Equal(t, 42, cfg.Int) + require.EqualValues(t, 42, cfg.Uint) + require.Equal(t, 99.5, cfg.Float) + require.Equal(t, time.Duration(1), cfg.Duration) +} + +type store map[string]string + +func (s store) Get(ctx context.Context, key string) ([]byte, error) { + data, ok := s[key] + if !ok { + return nil, backend.ErrNotFound + } + + return []byte(data), nil +} + +func (store) Name() string { + return "store" }