diff --git a/cli/cmd/cmd.go b/cli/cmd/cmd.go index 506af8e..45f2bc6 100644 --- a/cli/cmd/cmd.go +++ b/cli/cmd/cmd.go @@ -5,6 +5,7 @@ import ( "os" "strings" + "github.com/gocomply/xsd2go/pkg/xsd" "github.com/gocomply/xsd2go/pkg/xsd2go" "github.com/urfave/cli" ) @@ -37,10 +38,28 @@ var convert = cli.Command{ 1) } } + + for _, override := range c.StringSlice("type-override") { + if !strings.Contains(override, "=") { + return cli.NewExitError( + fmt.Sprintf( + "Invalid type-override: '%s', expecting form of TYPE=GOTYPE or TYPE=GOTYPE:GOIMPORT", + override, + ), + 1, + ) + } + } + return nil }, Action: func(c *cli.Context) error { xsdFile, goModule, outputDir := c.Args()[0], c.Args()[1], c.Args()[2] + + for _, typeOverride := range c.StringSlice("type-override") { + xsd.AddStaticTypeOverride(typeOverride) + } + err := xsd2go.Convert(xsdFile, goModule, outputDir, c.StringSlice("xmlns-override")) if err != nil { return cli.NewExitError(err, 1) @@ -52,5 +71,9 @@ var convert = cli.Command{ Name: "xmlns-override", Usage: "Allows to explicitly set gopackage name for given XMLNS. Example: --xmlns-override='http://www.w3.org/2000/09/xmldsig#=xml_signatures'", }, + cli.StringSliceFlag{ + Name: "type-override", + Usage: "Allows to explicitly override a static simple type mapping. Example: --type-override='decimal=string' or --type-override='decimal=decimal:github.com/ericlagergren/decimal", + }, }, } diff --git a/pkg/xsd/schema.go b/pkg/xsd/schema.go index 4799aac..4b6e1c3 100644 --- a/pkg/xsd/schema.go +++ b/pkg/xsd/schema.go @@ -84,7 +84,6 @@ func (sch *Schema) findReferencedElement(ref reference) *Element { } if innerSchema != sch { sch.registerImportedModule(innerSchema) - } return innerSchema.GetElement(ref.Name()) } @@ -257,6 +256,9 @@ func (sch *Schema) GoImportsNeeded() []string { for _, importedMod := range sch.importedModules { imports = append(imports, fmt.Sprintf("%s/%s", sch.ModulesPath, importedMod.GoPackageName())) } + for _, importedMod := range GetStaticTypeImports() { + imports = append(imports, importedMod) + } sort.Strings(imports) return imports } @@ -297,8 +299,7 @@ type Import struct { func (i *Import) load(ws *Workspace, baseDir string) (err error) { if i.SchemaLocation != "" { - i.ImportedSchema, err = - ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), true) + i.ImportedSchema, err = ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), true) } return } @@ -312,8 +313,7 @@ type Include struct { func (i *Include) load(ws *Workspace, baseDir string) (err error) { if i.SchemaLocation != "" { - i.IncludedSchema, err = - ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), false) + i.IncludedSchema, err = ws.loadXsd(filepath.Join(baseDir, i.SchemaLocation), false) } return } diff --git a/pkg/xsd/types.go b/pkg/xsd/types.go index 59a78b2..1f55d6d 100644 --- a/pkg/xsd/types.go +++ b/pkg/xsd/types.go @@ -2,6 +2,8 @@ package xsd import ( "encoding/xml" + "sort" + "strings" "github.com/iancoleman/strcase" ) @@ -287,9 +289,43 @@ var staticTypes = map[string]staticType{ "hexBinary": "string", } +var ( + staticTypeImports = map[string]string{} + staticTypeUsed = map[string]struct{}{} +) + +func AddStaticTypeOverride(override string) { + parts := strings.SplitN(override, "=", 2) + typeParts := strings.SplitN(parts[1], ":", 2) + + typeName := parts[0] + + staticTypes[typeName] = staticType(typeParts[0]) + + if len(typeParts) == 2 { + staticTypeImports[typeName] = typeParts[1] + } +} + +func GetStaticTypeImports() []string { + imports := []string{} + + for name, mod := range staticTypeImports { + if _, found := staticTypeUsed[name]; found { + imports = append(imports, mod) + } + } + + sort.Strings(imports) + + return imports +} + func StaticType(name string) staticType { typ, found := staticTypes[name] if found { + staticTypeUsed[name] = struct{}{} + return typ } panic("Type xsd:" + name + " not implemented")