diff --git a/go.mod b/go.mod index 2e9babe96f..b35abd5d3f 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( github.com/hashicorp/cap v0.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/iancoleman/strcase v0.3.0 github.com/jackc/pgx/v5 v5.5.5 github.com/libsql/libsql-client-go v0.0.0-20230917132930-48c310b27e7b github.com/magefile/mage v1.15.0 diff --git a/go.sum b/go.sum index 0cb67132b6..6d46072be4 100644 --- a/go.sum +++ b/go.sum @@ -384,6 +384,8 @@ github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= +github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= diff --git a/internal/config/config_test.go b/internal/config/config_test.go index f1e94ee3e7..902d2f18a3 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -15,6 +15,7 @@ import ( "testing" "time" + "github.com/iancoleman/strcase" "github.com/santhosh-tekuri/jsonschema/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1475,3 +1476,95 @@ func TestGetConfigFile(t *testing.T) { }) } } + +var ( + // add any struct tags to match their camelCase equivalents here. + camelCaseMatchers = map[string]string{ + "requireTLS": "requireTLS", + "discoveryURL": "discoveryURL", + } +) + +func TestStructTags(t *testing.T) { + configType := reflect.TypeOf(Config{}) + configTags := getStructTags(configType) + + for k, v := range camelCaseMatchers { + strcase.ConfigureAcronym(k, v) + } + + // Validate the struct tags for the Config struct. + // recursively validate the struct tags for all sub-structs. + validateStructTags(t, configTags, configType) +} + +func validateStructTags(t *testing.T, tags map[string]map[string]string, tType reflect.Type) { + tName := tType.Name() + for fieldName, fieldTags := range tags { + fieldType, ok := tType.FieldByName(fieldName) + require.True(t, ok, "field %s not found in type %s", fieldName, tName) + + // Validate the `json` struct tag. + jsonTag, ok := fieldTags["json"] + if ok { + require.True(t, isCamelCase(jsonTag), "json tag for field '%s.%s' should be camelCase but is '%s'", tName, fieldName, jsonTag) + } + + // Validate the `mapstructure` struct tag. + mapstructureTag, ok := fieldTags["mapstructure"] + if ok { + require.True(t, isSnakeCase(mapstructureTag), "mapstructure tag for field '%s.%s' should be snake_case but is '%s'", tName, fieldName, mapstructureTag) + } + + // Validate the `yaml` struct tag. + yamlTag, ok := fieldTags["yaml"] + if ok { + require.True(t, isSnakeCase(yamlTag), "yaml tag for field '%s.%s' should be snake_case but is '%s'", tName, fieldName, yamlTag) + } + + // recursively validate the struct tags for all sub-structs. + if fieldType.Type.Kind() == reflect.Struct { + validateStructTags(t, getStructTags(fieldType.Type), fieldType.Type) + } + } +} + +func isCamelCase(s string) bool { + return s == strcase.ToLowerCamel(s) +} + +func isSnakeCase(s string) bool { + return s == strcase.ToSnake(s) +} + +func getStructTags(t reflect.Type) map[string]map[string]string { + tags := make(map[string]map[string]string) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Get the field name. + fieldName := field.Name + + // Get the field tags. + fieldTags := make(map[string]string) + for _, tag := range []string{"json", "mapstructure", "yaml"} { + tagValue := field.Tag.Get(tag) + if tagValue == "-" { + fieldTags[tag] = "skip" + continue + } + values := strings.Split(tagValue, ",") + if len(values) > 1 { + tagValue = values[0] + } + if tagValue != "" { + fieldTags[tag] = tagValue + } + } + + tags[fieldName] = fieldTags + } + + return tags +}