diff --git a/Makefile b/Makefile index c6563db8..439e717f 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,8 @@ example/user/*.pb.go: example/user/*.proto example/postgres_arrays/*.pb.go: example/postgres_arrays/*.proto buf generate --template example/postgres_arrays/buf.gen.yaml --path example/postgres_arrays +install: + go install -v . gentool: docker build -f docker/Dockerfile -t $(GENTOOL_IMAGE) . diff --git a/plugin/plugin.go b/plugin/plugin.go index 1e7e1dd3..efb51a09 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -1,6 +1,7 @@ package plugin import ( + "errors" "fmt" "sort" "strconv" @@ -403,19 +404,31 @@ func (p *OrmPlugin) addIncludedField(ormable *OrmableType, field *gorm.ExtraFiel } func (p *OrmPlugin) isOrmable(typeName string) bool { - parts := strings.Split(typeName, ".") - _, ok := p.ormableTypes[strings.Trim(parts[len(parts)-1], "[]*")] - return ok + _, err := GetOrmable(p.ormableTypes, typeName) + return err == nil } func (p *OrmPlugin) getOrmable(typeName string) *OrmableType { - parts := strings.Split(typeName, ".") - if ormable, ok := p.ormableTypes[strings.TrimSuffix(strings.Trim(parts[len(parts)-1], "[]*"), "ORM")]; ok { - return ormable - } else { - p.Fail(typeName, "is not ormable.") + orm, err := GetOrmable(p.ormableTypes, typeName) + if err != nil { + p.Fail(typeName, ErrNotOrmable.Error()) return nil } + return orm +} + +var ( + ErrNotOrmable = errors.New("type is not ormable") +) + +func GetOrmable(ormableTypes map[string]*OrmableType, typeName string) (*OrmableType, error) { + parts := strings.Split(typeName, ".") + ormable, ok := ormableTypes[strings.TrimSuffix(strings.Trim(parts[len(parts)-1], "[]*"), "ORM")] + var err error + if !ok { + err = ErrNotOrmable + } + return ormable, err } func (p *OrmPlugin) getSortedFieldNames(fields map[string]*Field) []string { diff --git a/plugin/plugin_test.go b/plugin/plugin_test.go new file mode 100644 index 00000000..4cb63aaf --- /dev/null +++ b/plugin/plugin_test.go @@ -0,0 +1,41 @@ +package plugin + +import ( + "reflect" + "testing" +) + +func TestGetOrmable(t *testing.T) { + tests := []struct { + in string + m map[string]*OrmableType + e *OrmableType + err error + }{ + { + in: "IntPoint", + m: map[string]*OrmableType{ + "IntPoint": NewOrmableType("", "google.protobuf", nil), + }, + e: NewOrmableType("", "google.protobuf", nil), + }, + { + in: "Task", + m: map[string]*OrmableType{ + "Task": NewOrmableType("TaskORM", "google.protobuf", nil), + }, + e: NewOrmableType("TaskORM", "google.protobuf", nil), + }, + } + + for _, tt := range tests { + ot, err := GetOrmable(tt.m, tt.in) + if err != tt.err { + t.Errorf("got: %s wanted: %s", err, tt.err) + } + + if !reflect.DeepEqual(ot, tt.e) { + t.Errorf("got: %+v wanted: %+v", *ot, tt.e) + } + } +}