diff --git a/config/config.go b/config/config.go index be67b01e..b538822f 100644 --- a/config/config.go +++ b/config/config.go @@ -14,19 +14,18 @@ import ( // ErrInvalidArgument is the error returned by [ParseFlags] or [FromYAMLFile] if // its parsing result cannot be stored in the value pointed to by the designated passed argument which -// must be a non-nil pointer. +// must be a non-nil struct pointer. var ErrInvalidArgument = stderrors.New("invalid argument") // FromYAMLFile parses the given YAML file and stores the result -// in the value pointed to by v. If v is nil or not a pointer, +// in the value pointed to by v. If v is nil or not a struct pointer, // FromYAMLFile returns an [ErrInvalidArgument] error. func FromYAMLFile(name string, v Validator) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { - return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v) + if err := validateNonNilStructPointer(v); err != nil { + return errors.WithStack(err) } - // #nosec G304 -- Potential file inclusion via variable - Its purpose is to load any file name that is passed to it, so doesn't need to validate anything. + // #nosec G304 -- Accept user-controlled input for config file. f, err := os.Open(name) if err != nil { return errors.Wrap(err, "can't open YAML file "+name) @@ -55,11 +54,10 @@ func FromYAMLFile(name string, v Validator) error { type EnvOptions = env.Options // FromEnv parses environment variables and stores the result in the value pointed to by v. -// If v is nil or not a pointer, FromEnv returns an [ErrInvalidArgument] error. +// If v is nil or not a struct pointer, FromEnv returns an [ErrInvalidArgument] error. func FromEnv(v Validator, options EnvOptions) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v) + if err := validateNonNilStructPointer(v); err != nil { + return errors.WithStack(err) } if err := defaults.Set(v); err != nil { @@ -78,7 +76,7 @@ func FromEnv(v Validator, options EnvOptions) error { } // ParseFlags parses CLI flags and stores the result -// in the value pointed to by v. If v is nil or not a pointer, +// in the value pointed to by v. If v is nil or not a struct pointer, // ParseFlags returns an [ErrInvalidArgument] error. // ParseFlags adds a default Help Options group, // which contains the options -h and --help. @@ -87,17 +85,16 @@ func FromEnv(v Validator, options EnvOptions) error { // Note that errors are not printed automatically, // so error handling is the sole responsibility of the caller. func ParseFlags(v any) error { - rv := reflect.ValueOf(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { - return errors.Wrapf(ErrInvalidArgument, "non-nil pointer expected, got %T", v) + if err := validateNonNilStructPointer(v); err != nil { + return errors.WithStack(err) } parser := flags.NewParser(v, flags.Default^flags.PrintErrors) if _, err := parser.Parse(); err != nil { var flagErr *flags.Error - if errors.As(err, &flagErr) && flagErr.Type == flags.ErrHelp { - fmt.Fprintln(os.Stdout, flagErr) + if errors.As(err, &flagErr) && errors.Is(flagErr.Type, flags.ErrHelp) { + _, _ = fmt.Fprintln(os.Stdout, flagErr) os.Exit(0) } @@ -106,3 +103,14 @@ func ParseFlags(v any) error { return nil } + +// validateNonNilStructPointer checks if the provided value is a non-nil pointer to a struct. +// It returns an error if the value is not a pointer, is nil, or does not point to a struct. +func validateNonNilStructPointer(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() || rv.Elem().Kind() != reflect.Struct { + return errors.Wrapf(ErrInvalidArgument, "non-nil struct pointer expected, got %T", v) + } + + return nil +} diff --git a/config/config_test.go b/config/config_test.go index 2a42d658..c7c857fd 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -220,10 +220,7 @@ func TestFromEnv(t *testing.T) { var config nonStructValidator err := FromEnv(&config, EnvOptions{}) - // Struct pointer assertion is done in the defaults library, - // so we must ensure that the error returned is not one of our own errors. - require.NotErrorIs(t, err, ErrInvalidArgument) - require.NotErrorIs(t, err, errInvalidConfiguration) + require.ErrorIs(t, err, ErrInvalidArgument) }) } @@ -299,11 +296,7 @@ func TestFromYAMLFile(t *testing.T) { var config nonStructValidator err := FromYAMLFile(file.Name(), &config) - require.Error(t, err) - // Struct pointer assertion is done in the defaults library, - // so we must ensure that the error returned is not one of our own errors. - require.NotErrorIs(t, err, ErrInvalidArgument) - require.NotErrorIs(t, err, errInvalidConfiguration) + require.ErrorIs(t, err, ErrInvalidArgument) }) }) @@ -362,6 +355,13 @@ func TestParseFlags(t *testing.T) { require.ErrorIs(t, err, ErrInvalidArgument) }) + t.Run("Non-struct pointer argument", func(t *testing.T) { + var flags int + + err := ParseFlags(&flags) + require.ErrorIs(t, err, ErrInvalidArgument) + }) + t.Run("Exit on help flag", func(t *testing.T) { // This test case checks the behavior of ParseFlags() when the help flag (e.g. -h) is provided. // Since ParseFlags() calls os.Exit() upon encountering the help flag, we need to run this