diff --git a/dgw.go b/dgw.go index 4f7f8c9..db2d08d 100644 --- a/dgw.go +++ b/dgw.go @@ -6,8 +6,10 @@ import ( "context" "database/sql" "fmt" + "github.com/lib/pq" "go/format" "io/ioutil" + "regexp" "sort" "strings" "text/template" @@ -37,6 +39,34 @@ func OpenDB(connStr string) (*sql.DB, error) { return conn, nil } +const pgLoadEnumDef = ` +SELECT n.nspname AS schema, + pg_catalog.format_type ( t.oid, NULL ) AS name, + ARRAY( SELECT e.enumlabel + FROM pg_catalog.pg_enum e + WHERE e.enumtypid = t.oid + ORDER BY e.oid ) + AS elements +FROM pg_catalog.pg_type t + LEFT JOIN pg_catalog.pg_namespace n + ON n.oid = t.typnamespace +WHERE ( t.typrelid = 0 + OR ( SELECT c.relkind = 'c' + FROM pg_catalog.pg_class c + WHERE c.oid = t.typrelid + ) + ) + AND NOT EXISTS + ( SELECT 1 + FROM pg_catalog.pg_type el + WHERE el.oid = t.typelem + AND el.typarray = t.oid + ) + AND n.nspname = $1 + AND pg_catalog.pg_type_is_visible ( t.oid ) +ORDER BY 1, 2; +` + const queryInterface = ` // Queryer database/sql compatible query interface type Queryer interface { @@ -106,6 +136,28 @@ type TypeMap struct { DBTypes []string `toml:"db_types"` NotNullGoType string `toml:"notnull_go_type"` NullableGoType string `toml:"nullable_go_type"` + + compiled bool + rePatterns []*regexp.Regexp +} + +func (t *TypeMap) Match(s string) bool { + if !t.compiled { + for _, v := range t.DBTypes { + if strings.HasPrefix(v, "re/") { + t.rePatterns = append(t.rePatterns, regexp.MustCompile(v[3:])) + } + } + } + if contains(s, t.DBTypes) { + return true + } + for _, v := range t.rePatterns { + if v.MatchString(s) { + return true + } + } + return false } // AutoKeyMap auto generating key config @@ -114,7 +166,7 @@ type AutoKeyMap struct { } // PgTypeMapConfig go/db type map struct toml config -type PgTypeMapConfig map[string]TypeMap +type PgTypeMapConfig map[string]*TypeMap // PgTable postgres table type PgTable struct { @@ -186,6 +238,49 @@ func PgLoadTypeMapFromFile(filePath string) (*PgTypeMapConfig, error) { return &conf, nil } +type PgEnum struct { + Schema string + Name string + Values []string +} + +type EnumValue struct { + Type *EnumType + Name string + Value string +} + +type EnumType struct { + Name string + Enum *PgEnum + Comment string + Values []EnumValue +} + +func PgLoadEnumDef(db Queryer, schema string) ([]*PgEnum, error) { + enumDefs, err := db.Query(pgLoadEnumDef, schema) + if err != nil { + return nil, errors.Wrap(err, "failed to load enum def") + } + + enums := []*PgEnum{} + for enumDefs.Next() { + e := &PgEnum{} + var vals pq.StringArray + err := enumDefs.Scan( + &e.Schema, + &e.Name, + &vals, + ) + e.Values = vals + if err != nil { + return nil, errors.Wrap(err, "failed to scan") + } + enums = append(enums, e) + } + return enums, nil +} + // PgLoadColumnDef load Postgres column definition func PgLoadColumnDef(db Queryer, schema string, table string) ([]*PgColumn, error) { colDefs, err := db.Query(pgLoadColumnDef, schema, table) @@ -256,11 +351,10 @@ func contains(v string, l []string) bool { } // PgConvertType converts type -func PgConvertType(col *PgColumn, typeCfg *PgTypeMapConfig) string { - cfg := map[string]TypeMap(*typeCfg) - typ := cfg["default"].NotNullGoType - for _, v := range cfg { - if contains(col.DataType, v.DBTypes) { +func PgConvertType(col *PgColumn, typeCfg PgTypeMapConfig) string { + typ := typeCfg["default"].NotNullGoType + for _, v := range typeCfg { + if v.Match(col.DataType) { if col.NotNull { return v.NotNullGoType } @@ -271,7 +365,7 @@ func PgConvertType(col *PgColumn, typeCfg *PgTypeMapConfig) string { } // PgColToField converts pg column to go struct field -func PgColToField(col *PgColumn, typeCfg *PgTypeMapConfig) (*StructField, error) { +func PgColToField(col *PgColumn, typeCfg PgTypeMapConfig) (*StructField, error) { stfType := PgConvertType(col, typeCfg) stf := &StructField{ Name: varfmt.PublicVarName(col.Name), @@ -282,7 +376,7 @@ func PgColToField(col *PgColumn, typeCfg *PgTypeMapConfig) (*StructField, error) } // PgTableToStruct converts table def to go struct -func PgTableToStruct(t *PgTable, typeCfg *PgTypeMapConfig, keyConfig *AutoKeyMap) (*Struct, error) { +func PgTableToStruct(t *PgTable, typeCfg PgTypeMapConfig, keyConfig *AutoKeyMap) (*Struct, error) { t.setPrimaryKeyInfo(keyConfig) s := &Struct{ Name: varfmt.PublicVarName(t.Name), @@ -292,7 +386,7 @@ func PgTableToStruct(t *PgTable, typeCfg *PgTypeMapConfig, keyConfig *AutoKeyMap for _, c := range t.Columns { f, err := PgColToField(c, typeCfg) if err != nil { - return nil, errors.Wrap(err, "faield to convert col to field") + return nil, errors.Wrap(err, "failed to convert col to field") } fs = append(fs, f) } @@ -301,7 +395,7 @@ func PgTableToStruct(t *PgTable, typeCfg *PgTypeMapConfig, keyConfig *AutoKeyMap } // PgExecuteDefaultTmpl execute struct template with *Struct -func PgExecuteDefaultTmpl(st *StructTmpl, path string) ([]byte, error) { +func PgExecuteDefaultTmpl(st interface{}, path string) ([]byte, error) { var src []byte d, err := Asset(path) if err != nil { @@ -323,7 +417,7 @@ func PgExecuteDefaultTmpl(st *StructTmpl, path string) ([]byte, error) { } // PgExecuteCustomTmpl execute custom template -func PgExecuteCustomTmpl(st *StructTmpl, customTmpl string) ([]byte, error) { +func PgExecuteCustomTmpl(st interface{}, customTmpl string) ([]byte, error) { var src []byte tpl, err := template.New("struct").Funcs(tmplFuncMap).Parse(customTmpl) if err != nil { @@ -340,34 +434,100 @@ func PgExecuteCustomTmpl(st *StructTmpl, customTmpl string) ([]byte, error) { return src, nil } +func getPgTypeMapConfig(typeMapPath string) (PgTypeMapConfig, error) { + cfg := make(PgTypeMapConfig) + if typeMapPath == "" { + if _, err := toml.Decode(typeMap, &cfg); err != nil { + return nil, errors.Wrap(err, "failed to read type map") + } + } else { + if _, err := toml.DecodeFile(typeMapPath, &cfg); err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("failed to decode type map file %s", typeMapPath)) + } + } + return cfg, nil +} + +func PgEnumToType(e *PgEnum, typeCfg PgTypeMapConfig, keyConfig *AutoKeyMap) (*EnumType, error) { + en := &EnumType{ + Name: varfmt.PublicVarName(e.Name), + Enum: e, + } + for _, v := range e.Values { + en.Values = append(en.Values, EnumValue{ + Type: en, + Name: en.Name + "_" + varfmt.PublicVarName(v), + Value: v, + }) + } + if _,ok := typeCfg[e.Name]; !ok { + typeCfg[e.Name] = &TypeMap{ + DBTypes: []string{e.Name}, + NotNullGoType: en.Name, + NullableGoType: "Null"+en.Name, + + compiled: true, + rePatterns: nil, + } + } + + return en, nil +} + +func PgCreateEnums(db Queryer, schema string, cfg PgTypeMapConfig, customTmpl string) ([]byte, error) { + var src []byte + + enums, err := PgLoadEnumDef(db, schema) + if err != nil { + return src, errors.Wrap(err, "failed to load enum definitions") + } + + for _, pgEnum := range enums { + enum, err := PgEnumToType(pgEnum, cfg, autoGenKeyCfg) + if err != nil { + return src, errors.Wrap(err, "failed to convert enum definition to type") + } + + if customTmpl != "" { + tmpl, err := ioutil.ReadFile(customTmpl) + if err != nil { + return nil, err + } + s, err := PgExecuteCustomTmpl(enum, string(tmpl)) + if err != nil { + return nil, errors.Wrap(err, "PgExecuteCustomTmpl failed") + } + src = append(src, s...) + } else { + s, err := PgExecuteDefaultTmpl(enum, "template/enum.tmpl") + if err != nil { + return src, errors.Wrap(err, "failed to execute template") + } + src = append(src, s...) + } + } + return src, nil +} + // PgCreateStruct creates struct from given schema func PgCreateStruct( - db Queryer, schema, typeMapPath, pkgName, customTmpl string, exTbls []string) ([]byte, error) { + db Queryer, schema string, cfg PgTypeMapConfig, pkgName, customTmpl string, exTbls []string) ([]byte, error) { var src []byte pkgDef := []byte(fmt.Sprintf("package %s\n\n", pkgName)) src = append(src, pkgDef...) tbls, err := PgLoadTableDef(db, schema) if err != nil { - return src, errors.Wrap(err, "faield to load table definitions") - } - cfg := &PgTypeMapConfig{} - if typeMapPath == "" { - if _, err := toml.Decode(typeMap, cfg); err != nil { - return src, errors.Wrap(err, "faield to read type map") - } - } else { - if _, err := toml.DecodeFile(typeMapPath, cfg); err != nil { - return src, errors.Wrap(err, fmt.Sprintf("failed to decode type map file %s", typeMapPath)) - } + return src, errors.Wrap(err, "failed to load table definitions") } + for _, tbl := range tbls { if contains(tbl.Name, exTbls) { continue } st, err := PgTableToStruct(tbl, cfg, autoGenKeyCfg) if err != nil { - return src, errors.Wrap(err, "faield to convert table definition to struct") + return src, errors.Wrap(err, "failed to convert table definition to struct") } if customTmpl != "" { tmpl, err := ioutil.ReadFile(customTmpl) @@ -382,11 +542,11 @@ func PgCreateStruct( } else { s, err := PgExecuteDefaultTmpl(&StructTmpl{Struct: st}, "template/struct.tmpl") if err != nil { - return src, errors.Wrap(err, "faield to execute template") + return src, errors.Wrap(err, "failed to execute template") } m, err := PgExecuteDefaultTmpl(&StructTmpl{Struct: st}, "template/method.tmpl") if err != nil { - return src, errors.Wrap(err, "faield to execute template") + return src, errors.Wrap(err, "failed to execute template") } src = append(src, s...) src = append(src, m...) diff --git a/dgw_test.go b/dgw_test.go index 160f7ec..e795842 100644 --- a/dgw_test.go +++ b/dgw_test.go @@ -40,7 +40,7 @@ func testSetupStruct(t *testing.T, conn *sql.DB) []*Struct { var sts []*Struct for _, tbl := range tbls { - st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg) + st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg) if err != nil { t.Fatal(err) } @@ -90,7 +90,7 @@ func TestPgColToField(t *testing.T) { } for _, c := range cols { - f, err := PgColToField(c, &defaultTypeMapCfg) + f, err := PgColToField(c, defaultTypeMapCfg) if err != nil { t.Fatal(err) } @@ -120,7 +120,7 @@ func TestPgTableToStruct(t *testing.T) { } for _, tbl := range tbls { - st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg) + st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg) if err != nil { t.Fatal(err) } @@ -142,7 +142,7 @@ func TestPgTableToMethod(t *testing.T) { t.Fatal(err) } for _, tbl := range tbls { - st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg) + st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg) if err != nil { t.Fatal(err) } @@ -173,7 +173,7 @@ func TestPgExecuteCustomTemplate(t *testing.T) { t.Fatal(err) } for _, tbl := range tbls { - st, err := PgTableToStruct(tbl, &defaultTypeMapCfg, autoGenKeyCfg) + st, err := PgTableToStruct(tbl, defaultTypeMapCfg, autoGenKeyCfg) if err != nil { t.Fatal(err) } diff --git a/main.go b/main.go index 452034c..ed4fcbf 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ var ( typeMapFilePath = kingpin.Flag("typemap", "column type and go type map file path").Short('t').String() exTbls = kingpin.Flag("exclude", "table names to exclude").Short('x').Strings() customTmpl = kingpin.Flag("template", "custom template path").String() + customEnumTmpl = kingpin.Flag("enum-template", "custom enum template").String() outFile = kingpin.Flag("output", "output file path").Short('o').String() noQueryInterface = kingpin.Flag("no-interface", "output without Queryer interface").Bool() ) @@ -30,11 +31,23 @@ func main() { log.Fatal(err) } - st, err := PgCreateStruct(conn, *schema, *typeMapFilePath, *pkgName, *customTmpl, *exTbls) + cfg, err := getPgTypeMapConfig(*typeMapFilePath) if err != nil { log.Fatal(err) } + en, err := PgCreateEnums(conn, *schema, cfg, *customEnumTmpl) + if err != nil { + log.Fatal(err) + } + + st, err := PgCreateStruct(conn, *schema, cfg, *pkgName, *customTmpl, *exTbls) + if err != nil { + log.Fatal(err) + } + + st = append(st, en...) + var src []byte if *noQueryInterface { src = st diff --git a/template/enum.tmpl b/template/enum.tmpl new file mode 100644 index 0000000..a27e0f6 --- /dev/null +++ b/template/enum.tmpl @@ -0,0 +1,24 @@ +type {{.Name}} string + +type Null{{.Name}} struct { + Value {{.Name}} + Valid bool +} + +const ( + {{range .Values}} + {{.Name}} ({{.Type.Name}}) = "{{.Value}}"{{end}} +) + +var {{.Name}}Values = []{{.Name}} { + {{range .Values}} {{.Name}}, {{end}} +} + +func Valid{{.Name}}(s string) bool { + for _,v := range {{.Name}}Values { + if string(v) == s { + return true + } + } + return false +}