From 45cb131d22dcb5133a67e483a9a018336c4e1f3b Mon Sep 17 00:00:00 2001 From: shawn Date: Tue, 8 Oct 2024 21:34:45 +0800 Subject: [PATCH] feat: Improving the usability of templates with subcommands --- tool/cmd/kitex/args/args.go | 4 - tool/cmd/kitex/args/tpl_args.go | 491 ++++++++++ tool/cmd/kitex/main.go | 11 +- .../internal_pkg/generator/custom_template.go | 232 ++++- tool/internal_pkg/generator/generator.go | 9 + tool/internal_pkg/generator/generator_test.go | 12 +- tool/internal_pkg/generator/type.go | 9 + .../pluginmode/thriftgo/plugin.go | 24 + tool/internal_pkg/util/command.go | 244 +++++ tool/internal_pkg/util/command_test.go | 245 +++++ tool/internal_pkg/util/flag.go | 889 ++++++++++++++++++ tool/internal_pkg/util/flag_test.go | 137 +++ tool/internal_pkg/util/util.go | 2 +- 13 files changed, 2300 insertions(+), 9 deletions(-) create mode 100644 tool/cmd/kitex/args/tpl_args.go create mode 100644 tool/internal_pkg/util/command.go create mode 100644 tool/internal_pkg/util/command_test.go create mode 100644 tool/internal_pkg/util/flag.go create mode 100644 tool/internal_pkg/util/flag_test.go diff --git a/tool/cmd/kitex/args/args.go b/tool/cmd/kitex/args/args.go index f7f627225f..fd946d2a6b 100644 --- a/tool/cmd/kitex/args/args.go +++ b/tool/cmd/kitex/args/args.go @@ -314,10 +314,6 @@ func (a *Arguments) BuildCmd(out io.Writer) (*exec.Cmd, error) { Stderr: io.MultiWriter(out, os.Stderr), } - if err != nil { - return nil, err - } - if a.IDLType == "thrift" { os.Setenv(EnvPluginMode, thriftgo.PluginName) cmd.Args = append(cmd.Args, "thriftgo") diff --git a/tool/cmd/kitex/args/tpl_args.go b/tool/cmd/kitex/args/tpl_args.go new file mode 100644 index 0000000000..610a5c4fb3 --- /dev/null +++ b/tool/cmd/kitex/args/tpl_args.go @@ -0,0 +1,491 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package args + +import ( + "fmt" + "github.com/cloudwego/kitex/tool/internal_pkg/tpl" + "io/fs" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" + + "github.com/cloudwego/kitex/tool/internal_pkg/generator" + "github.com/cloudwego/kitex/tool/internal_pkg/log" + "github.com/cloudwego/kitex/tool/internal_pkg/util" +) + +// Constants . +const ( + KitexGenPath = "kitex_gen" + DefaultCodec = "thrift" + + BuildFileName = "build.sh" + BootstrapFileName = "bootstrap.sh" + ToolVersionFileName = "kitex_info.yaml" + HandlerFileName = "handler.go" + MainFileName = "main.go" + ClientFileName = "client.go" + ServerFileName = "server.go" + InvokerFileName = "invoker.go" + ServiceFileName = "*service.go" + ExtensionFilename = "extensions.yaml" + + MultipleServicesFileName = "multiple_services.go" +) + +var defaultTemplates = map[string]string{ + BuildFileName: tpl.BuildTpl, + BootstrapFileName: tpl.BootstrapTpl, + ToolVersionFileName: tpl.ToolVersionTpl, + HandlerFileName: tpl.HandlerTpl, + MainFileName: tpl.MainTpl, + ClientFileName: tpl.ClientTpl, + ServerFileName: tpl.ServerTpl, + InvokerFileName: tpl.InvokerTpl, + ServiceFileName: tpl.ServiceTpl, +} + +var multipleServicesTpl = map[string]string{ + MultipleServicesFileName: tpl.MainMultipleServicesTpl, +} + +const ( + DefaultType = "default" + MultipleServicesType = "multiple_services" +) + +type TemplateGenerator func(string) error + +var genTplMap = map[string]TemplateGenerator{ + DefaultType: GenTemplates, + MultipleServicesType: GenMultipleServicesTemplates, +} + +// GenTemplates is the entry for command kitex template, +// it will create the specified path +func GenTemplates(path string) error { + return InitTemplates(path, defaultTemplates) +} + +func GenMultipleServicesTemplates(path string) error { + return InitTemplates(path, multipleServicesTpl) +} + +// InitTemplates creates template files. +func InitTemplates(path string, templates map[string]string) error { + if err := MkdirIfNotExist(path); err != nil { + return err + } + + for name, content := range templates { + var dir string + if name == BootstrapFileName { + dir = filepath.Join(path, "script") + } else { + dir = path + } + if err := MkdirIfNotExist(dir); err != nil { + return err + } + filePath := filepath.Join(dir, fmt.Sprintf("%s.tpl", name)) + if err := createTemplate(filePath, content); err != nil { + return err + } + } + + return nil +} + +// GetTemplateDir returns the category path. +func GetTemplateDir(category string) (string, error) { + home, err := filepath.Abs(".") + if err != nil { + return "", err + } + return filepath.Join(home, category), nil +} + +// MkdirIfNotExist makes directories if the input path is not exists +func MkdirIfNotExist(dir string) error { + if len(dir) == 0 { + return nil + } + + if _, err := os.Stat(dir); os.IsNotExist(err) { + return os.MkdirAll(dir, os.ModePerm) + } + + return nil +} + +func createTemplate(file, content string) error { + if util.Exists(file) { + return nil + } + + f, err := os.Create(file) + if err != nil { + return err + } + defer f.Close() + + _, err = f.WriteString(content) + return err +} + +func (a *Arguments) Init(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + var path string + if InitOutputDir == "" { + path = curpath + } else { + path = InitOutputDir + } + if len(InitTypes) == 0 { + typ := DefaultType + if err := genTplMap[typ](path); err != nil { + return err + } + } else { + for _, typ := range InitTypes { + if _, ok := genTplMap[typ]; !ok { + return fmt.Errorf("invalid type: %s", typ) + } + if err := genTplMap[typ](path); err != nil { + return err + } + } + } + os.Exit(0) + return nil +} + +func (a *Arguments) checkTplArgs() error { + if a.TemplateDir != "" && a.RenderTplDir != "" { + return fmt.Errorf("template render --dir and -template-dir cannot be used at the same time") + } + if a.RenderTplDir != "" && len(a.TemplateFiles) > 0 { + return fmt.Errorf("template render --dir and --file option cannot be specified at the same time") + } + return nil +} + +func (a *Arguments) Root(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func (a *Arguments) Template(cmd *util.Command, args []string) error { + if len(args) == 0 { + return util.ErrHelp + } + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func parseYAMLFiles(directory string) ([]generator.Template, error) { + var templates []generator.Template + files, err := os.ReadDir(directory) + if err != nil { + return nil, err + } + for _, file := range files { + if filepath.Ext(file.Name()) == ".yaml" { + data, err := os.ReadFile(filepath.Join(directory, file.Name())) + if err != nil { + return nil, err + } + var template generator.Template + err = yaml.Unmarshal(data, &template) + if err != nil { + return nil, err + } + templates = append(templates, template) + } + } + return templates, nil +} + +func createFilesFromTemplates(templates []generator.Template, baseDirectory string) error { + for _, template := range templates { + fullPath := filepath.Join(baseDirectory, fmt.Sprintf("%s.tpl", template.Path)) + dir := filepath.Dir(fullPath) + err := os.MkdirAll(dir, os.ModePerm) + if err != nil { + return err + } + err = os.WriteFile(fullPath, []byte(template.Body), 0o644) + if err != nil { + return err + } + } + return nil +} + +func generateMetadata(templates []generator.Template, outputFile string) error { + var metadata generator.Meta + for _, template := range templates { + meta := generator.Template{ + Path: template.Path, + UpdateBehavior: template.UpdateBehavior, + LoopMethod: template.LoopMethod, + LoopService: template.LoopService, + } + metadata.Templates = append(metadata.Templates, meta) + } + data, err := yaml.Marshal(&metadata) + if err != nil { + return err + } + return os.WriteFile(outputFile, data, 0o644) +} + +func (a *Arguments) Render(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + if len(args) == 0 { + return util.ErrHelp + } + log.Verbose = a.Verbose + + for _, e := range a.extends { + err := e.Check(a) + if err != nil { + return err + } + } + + err = a.checkIDL(args) + if err != nil { + return err + } + err = a.checkServiceName() + if err != nil { + return err + } + err = a.checkTplArgs() + if err != nil { + return err + } + // todo finish protobuf + if a.IDLType != "thrift" { + a.GenPath = generator.KitexGenPath + } + return a.checkPath(curpath) +} + +func (a *Arguments) Clean(cmd *util.Command, args []string) error { + curpath, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("get current path failed: %s", err.Error()) + } + + magicString := "// Kitex template debug file. use template clean to delete it." + err = filepath.WalkDir(curpath, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() { + return nil + } + content, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("read file %s failed: %v", path, err) + } + if strings.Contains(string(content), magicString) { + if err := os.Remove(path); err != nil { + return fmt.Errorf("delete file %s failed: %v", path, err) + } + } + return nil + }) + if err != nil { + return fmt.Errorf("error cleaning debug template files: %v", err) + } + fmt.Println("clean debug template files successfully...") + os.Exit(0) + return nil +} + +var ( + InitOutputDir string // specify the location path of init subcommand + InitTypes []string // specify the type for init subcommand +) + +func (a *Arguments) TemplateArgs(version string) error { + kitexCmd := &util.Command{ + Use: "kitex", + Short: "Kitex command", + RunE: a.Root, + } + templateCmd := &util.Command{ + Use: "template", + Short: "Template command", + RunE: a.Template, + } + initCmd := &util.Command{ + Use: "init", + Short: "Init command", + RunE: a.Init, + } + renderCmd := &util.Command{ + Use: "render", + Short: "Render command", + RunE: a.Render, + } + cleanCmd := &util.Command{ + Use: "clean", + Short: "Clean command", + RunE: a.Clean, + } + kitexCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + templateCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + initCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + initCmd.Flags().StringVarP(&InitOutputDir, "output", "o", ".", "Specify template init path (default current directory)") + initCmd.Flags().StringArrayVar(&InitTypes, "type", []string{}, "Specify template init type") + renderCmd.Flags().StringVar(&a.RenderTplDir, "dir", "", "Use custom template to generate codes.") + renderCmd.Flags().StringVar(&a.ModuleName, "module", "", + "Specify the Go module name to generate go.mod.") + renderCmd.Flags().StringVar(&a.IDLType, "type", "unknown", "Specify the type of IDL: 'thrift' or 'protobuf'.") + renderCmd.Flags().StringVar(&a.GenPath, "gen-path", generator.KitexGenPath, + "Specify a code gen path.") + renderCmd.Flags().StringArrayVar(&a.TemplateFiles, "file", []string{}, "Specify single template path") + renderCmd.Flags().BoolVar(&a.DebugTpl, "debug", false, "turn on debug for template") + renderCmd.Flags().StringVarP(&a.IncludesTpl, "Includes", "I", "", "Add IDL search path and template search path for includes.") + renderCmd.Flags().StringVar(&a.MetaFlags, "meta", "", "Meta data in key=value format, keys separated by ';' values separated by ',' ") + templateCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Template operation + +Usage: + kitex template [command] + +Available Commands: + init Initialize the templates according to the type + render Render the template files + clean Clean the debug templates + `) + }) + initCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Initialize the templates according to the type + +Usage: + kitex template init [flags] + +Flags: + -o, --output string Output directory + -t, --type string The init type of the template + `) + }) + renderCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Render the template files + +Usage: + kitex template render [flags] + +Flags: + --dir string Output directory + --debug bool Turn on the debug mode + --file stringArray Specify multiple files for render + -I, --Includes string Add an template git search path for includes. + --meta string Specify meta data for render + --module string Specify the Go module name to generate go.mod. + -t, --type string The init type of the template + `) + }) + cleanCmd.SetHelpFunc(func(*util.Command, []string) { + fmt.Fprintln(os.Stderr, ` +Clean the debug templates + +Usage: + kitex template clean + `) + }) + templateCmd.AddCommand(initCmd, renderCmd, cleanCmd) + kitexCmd.AddCommand(templateCmd) + if _, err := kitexCmd.ExecuteC(); err != nil { + return err + } + return nil +} diff --git a/tool/cmd/kitex/main.go b/tool/cmd/kitex/main.go index ef6aa53b8e..e2fd8f26ee 100644 --- a/tool/cmd/kitex/main.go +++ b/tool/cmd/kitex/main.go @@ -17,6 +17,7 @@ package main import ( "bytes" "flag" + "fmt" "os" "path/filepath" "strings" @@ -78,8 +79,14 @@ func main() { log.Warn("Get current path failed:", err.Error()) os.Exit(1) } - // run as kitex - err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) + if len(os.Args) > 1 && os.Args[1] == "template" { + err = args.TemplateArgs(kitex.Version) + } else if len(os.Args) > 1 && !strings.HasPrefix(os.Args[1], "-") { + err = fmt.Errorf("unknown command %q", os.Args[1]) + } else { + // run as kitex + err = args.ParseArgs(kitex.Version, curpath, os.Args[1:]) + } if err != nil { if err.Error() != "flag: help requested" { log.Warn(err.Error()) diff --git a/tool/internal_pkg/generator/custom_template.go b/tool/internal_pkg/generator/custom_template.go index 3f364589db..b76b71eab5 100644 --- a/tool/internal_pkg/generator/custom_template.go +++ b/tool/internal_pkg/generator/custom_template.go @@ -15,6 +15,7 @@ package generator import ( + "errors" "fmt" "os" "path" @@ -206,7 +207,7 @@ func renderFile(pkg *PackageInfo, outputPath string, tpl *Template) (fs []*File, } else { err = cg.commonGenerate(tpl) } - if err == errNoNewMethod { + if errors.Is(err, errNoNewMethod) { err = nil } return cg.fs, err @@ -235,3 +236,232 @@ func readTemplates(dir string) ([]*Template, error) { return ts, nil } + +// parseMeta parses the meta flag and returns a map where the value is a slice of strings +func parseMeta(metaFlags string) (map[string][]string, error) { + meta := make(map[string][]string) + if metaFlags == "" { + return meta, nil + } + + // split for each key=value pairs + pairs := strings.Split(metaFlags, ";") + for _, pair := range pairs { + kv := strings.SplitN(pair, "=", 2) + if len(kv) == 2 { + key := kv[0] + values := strings.Split(kv[1], ",") + meta[key] = values + } else { + return nil, fmt.Errorf("invalid meta format: %s", pair) + } + } + return meta, nil +} + +func parseMiddlewares(middlewares []MiddlewareForResolve) ([]UserDefinedMiddleware, error) { + var mwList []UserDefinedMiddleware + + for _, mw := range middlewares { + content, err := os.ReadFile(mw.Path) + if err != nil { + return nil, fmt.Errorf("failed to read middleware file %s: %v", mw.Path, err) + } + mwList = append(mwList, UserDefinedMiddleware{ + Name: mw.Name, + Content: string(content), + }) + } + return mwList, nil +} + +func (g *generator) GenerateCustomPackageWithTpl(pkg *PackageInfo) (fs []*File, err error) { + g.updatePackageInfo(pkg) + + g.setImports(HandlerFileName, pkg) + pkg.ExtendMeta, err = parseMeta(g.MetaFlags) + if err != nil { + return nil, err + } + if g.Config.IncludesTpl != "" { + inc := g.Config.IncludesTpl + if strings.HasPrefix(inc, "git@") || strings.HasPrefix(inc, "http://") || strings.HasPrefix(inc, "https://") { + localGitPath, errMsg, gitErr := util.RunGitCommand(inc) + if gitErr != nil { + if errMsg == "" { + errMsg = gitErr.Error() + } + return nil, fmt.Errorf("failed to pull IDL from git:%s\nYou can execute 'rm -rf ~/.kitex' to clean the git cache and try again", errMsg) + } + if g.RenderTplDir != "" { + g.RenderTplDir = filepath.Join(localGitPath, g.RenderTplDir) + } else { + g.RenderTplDir = localGitPath + } + if util.Exists(g.RenderTplDir) { + return nil, fmt.Errorf("the render template directory path you specified does not exists int the git path") + } + } + } + var meta *Meta + metaPath := filepath.Join(g.RenderTplDir, KitexRenderMetaFile) + if util.Exists(metaPath) { + meta, err = readMetaFile(metaPath) + if err != nil { + return nil, err + } + middlewares, err := parseMiddlewares(meta.MWs) + if err != nil { + return nil, err + } + pkg.MWs = middlewares + } + tpls, err := readTpls(g.RenderTplDir, g.RenderTplDir, meta) + if err != nil { + return nil, err + } + for _, tpl := range tpls { + newPath := filepath.Join(g.OutputPath, tpl.Path) + dir := filepath.Dir(newPath) + if err := os.MkdirAll(dir, os.ModePerm); err != nil { + return nil, fmt.Errorf("failed to create directory %s: %v", dir, err) + } + if tpl.LoopService && g.CombineService { + svrInfo, cs := pkg.ServiceInfo, pkg.CombineServices + + for i := range cs { + pkg.ServiceInfo = cs[i] + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + pkg.ServiceInfo, pkg.CombineServices = svrInfo, cs + } else { + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + } + return fs, nil +} + +const KitexRenderMetaFile = "kitex_render_meta.yaml" + +// Meta represents the structure of the kitex_render_meta.yaml file. +type Meta struct { + Templates []Template `yaml:"templates"` + MWs []MiddlewareForResolve `yaml:"middlewares"` +} + +type MiddlewareForResolve struct { + // name of the middleware + Name string `yaml:"name"` + // path of the middleware + Path string `yaml:"path"` +} + +func readMetaFile(metaPath string) (*Meta, error) { + metaData, err := os.ReadFile(metaPath) + if err != nil { + return nil, fmt.Errorf("failed to read meta file from %s: %v", metaPath, err) + } + + var meta Meta + err = yaml.Unmarshal(metaData, &meta) + if err != nil { + return nil, fmt.Errorf("failed to parse yaml file %s: %v", metaPath, err) + } + + return &meta, nil +} + +func getMetadata(meta *Meta, relativePath string) *Template { + for i := range meta.Templates { + if meta.Templates[i].Path == relativePath { + return &meta.Templates[i] + } + } + return &Template{ + UpdateBehavior: &Update{Type: string(skip)}, + } +} + +func readTpls(rootDir, currentDir string, meta *Meta) (ts []*Template, error error) { + defaultMetadata := &Template{ + UpdateBehavior: &Update{Type: string(skip)}, + } + + files, _ := os.ReadDir(currentDir) + for _, f := range files { + // filter dir and non-tpl files + if f.IsDir() { + subDir := filepath.Join(currentDir, f.Name()) + subTemplates, err := readTpls(rootDir, subDir, meta) + if err != nil { + return nil, err + } + ts = append(ts, subTemplates...) + } else if strings.HasSuffix(f.Name(), ".tpl") { + p := filepath.Join(currentDir, f.Name()) + tplData, err := os.ReadFile(p) + if err != nil { + return nil, fmt.Errorf("read file from %s failed, err: %v", p, err.Error()) + } + // Remove the .tpl suffix from the Path and compute relative path + relativePath, err := filepath.Rel(rootDir, p) + if err != nil { + return nil, fmt.Errorf("failed to compute relative path for %s: %v", p, err) + } + trimmedPath := strings.TrimSuffix(relativePath, ".tpl") + // If kitex_render_meta.yaml exists, get the corresponding metadata; otherwise, use the default metadata + var metadata *Template + if meta != nil { + metadata = getMetadata(meta, relativePath) + } else { + metadata = defaultMetadata + } + t := &Template{ + Path: trimmedPath, + Body: string(tplData), + UpdateBehavior: metadata.UpdateBehavior, + LoopMethod: metadata.LoopMethod, + LoopService: metadata.LoopService, + } + ts = append(ts, t) + } + } + + return ts, nil +} + +func (g *generator) RenderWithMultipleFiles(pkg *PackageInfo) (fs []*File, err error) { + for _, file := range g.Config.TemplateFiles { + content, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("read file from %s failed, err: %v", file, err.Error()) + } + var updatedContent string + if g.Config.DebugTpl { + // when --debug is enabled, add a magic string at the top of the template content for distinction. + updatedContent = "// Kitex template debug file. use template clean to delete it.\n\n" + string(content) + } else { + updatedContent = string(content) + } + filename := filepath.Base(strings.TrimSuffix(file, ".tpl")) + tpl := &Template{ + Path: filename, + Body: updatedContent, + UpdateBehavior: &Update{Type: string(skip)}, + } + f, err := renderFile(pkg, g.OutputPath, tpl) + if err != nil { + return nil, err + } + fs = append(fs, f...) + } + return +} diff --git a/tool/internal_pkg/generator/generator.go b/tool/internal_pkg/generator/generator.go index 8291ba123a..b28da9a938 100644 --- a/tool/internal_pkg/generator/generator.go +++ b/tool/internal_pkg/generator/generator.go @@ -99,6 +99,8 @@ type Generator interface { GenerateService(pkg *PackageInfo) ([]*File, error) GenerateMainPackage(pkg *PackageInfo) ([]*File, error) GenerateCustomPackage(pkg *PackageInfo) ([]*File, error) + GenerateCustomPackageWithTpl(pkg *PackageInfo) ([]*File, error) + RenderWithMultipleFiles(pkg *PackageInfo) ([]*File, error) } // Config . @@ -136,6 +138,13 @@ type Config struct { TemplateDir string + // subcommand template + RenderTplDir string // specify the path of template directory for render subcommand + TemplateFiles []string // specify the path of single file or multiple file to render + DebugTpl bool // turn on the debug mode + IncludesTpl string // specify the path of remote template repository for render subcommand + MetaFlags string // Metadata in key=value format, keys separated by ';' values separated by ',' + GenPath string DeepCopyAPI bool diff --git a/tool/internal_pkg/generator/generator_test.go b/tool/internal_pkg/generator/generator_test.go index 4c208583a4..72e120e15b 100644 --- a/tool/internal_pkg/generator/generator_test.go +++ b/tool/internal_pkg/generator/generator_test.go @@ -56,6 +56,11 @@ func TestConfig_Pack(t *testing.T) { RecordCmd string ThriftPluginTimeLimit time.Duration TemplateDir string + RenderTplDir string + TemplateFiles []string + DebugTpl bool + IncludesTpl string + MetaFlags string Protocol string HandlerReturnKeepResp bool } @@ -69,7 +74,7 @@ func TestConfig_Pack(t *testing.T) { { name: "some", fields: fields{Features: []feature{feature(999)}, ThriftPluginTimeLimit: 30 * time.Second}, - wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "BuiltinTpl="}, + wantRes: []string{"Verbose=false", "GenerateMain=false", "GenerateInvoker=false", "Version=", "NoFastAPI=false", "ModuleName=", "ServiceName=", "Use=", "IDLType=", "Includes=", "ThriftOptions=", "ProtobufOptions=", "Hessian2Options=", "IDL=", "OutputPath=", "PackagePrefix=", "CombineService=false", "CopyIDL=false", "ProtobufPlugins=", "Features=999", "FrugalPretouch=false", "ThriftPluginTimeLimit=30s", "CompilerPath=", "ExtensionFile=", "Record=false", "RecordCmd=", "TemplateDir=", "RenderTplDir=", "TemplateFiles=", "DebugTpl=false", "IncludesTpl=", "MetaFlags=", "GenPath=", "DeepCopyAPI=false", "Protocol=", "HandlerReturnKeepResp=false", "NoDependencyCheck=false", "Rapid=false", "BuiltinTpl="}, }, } for _, tt := range tests { @@ -97,6 +102,11 @@ func TestConfig_Pack(t *testing.T) { FrugalPretouch: tt.fields.FrugalPretouch, ThriftPluginTimeLimit: tt.fields.ThriftPluginTimeLimit, TemplateDir: tt.fields.TemplateDir, + RenderTplDir: tt.fields.RenderTplDir, + TemplateFiles: tt.fields.TemplateFiles, + DebugTpl: tt.fields.DebugTpl, + IncludesTpl: tt.fields.IncludesTpl, + MetaFlags: tt.fields.MetaFlags, Protocol: tt.fields.Protocol, } if gotRes := c.Pack(); !reflect.DeepEqual(gotRes, tt.wantRes) { diff --git a/tool/internal_pkg/generator/type.go b/tool/internal_pkg/generator/type.go index bdbb92ebf3..8b683062ff 100644 --- a/tool/internal_pkg/generator/type.go +++ b/tool/internal_pkg/generator/type.go @@ -52,6 +52,15 @@ type PackageInfo struct { Protocol transport.Protocol IDLName string ServerPkg string + ExtendMeta map[string][]string // key-value metadata for render + MWs []UserDefinedMiddleware +} + +type UserDefinedMiddleware struct { + // the name of the middleware + Name string + // the content of the middleware + Content string } // AddImport . diff --git a/tool/internal_pkg/pluginmode/thriftgo/plugin.go b/tool/internal_pkg/pluginmode/thriftgo/plugin.go index 13afd349bc..2721cc9b26 100644 --- a/tool/internal_pkg/pluginmode/thriftgo/plugin.go +++ b/tool/internal_pkg/pluginmode/thriftgo/plugin.go @@ -119,6 +119,30 @@ func HandleRequest(req *plugin.Request) *plugin.Response { files = append(files, fs...) } + if len(conv.Config.TemplateFiles) > 0 { + if len(conv.Services) == 0 { + return conv.failResp(errors.New("no service defined in the IDL")) + } + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + fs, err := gen.RenderWithMultipleFiles(&conv.Package) + if err != nil { + return conv.failResp(err) + } + files = append(files, fs...) + } + + if conv.Config.RenderTplDir != "" { + if len(conv.Services) == 0 { + return conv.failResp(errors.New("no service defined in the IDL")) + } + conv.Package.ServiceInfo = conv.Services[len(conv.Services)-1] + fs, err := gen.GenerateCustomPackageWithTpl(&conv.Package) + if err != nil { + return conv.failResp(err) + } + files = append(files, fs...) + } + res := &plugin.Response{ Warnings: conv.Warnings, } diff --git a/tool/internal_pkg/util/command.go b/tool/internal_pkg/util/command.go new file mode 100644 index 0000000000..ee1274af14 --- /dev/null +++ b/tool/internal_pkg/util/command.go @@ -0,0 +1,244 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "errors" + "fmt" + "io" + "os" + "strings" +) + +type Command struct { + Use string + Short string + Long string + RunE func(cmd *Command, args []string) error + commands []*Command + parent *Command + flags *FlagSet + // helpFunc is help func defined by user. + helpFunc func(*Command, []string) + // for debug + args []string +} + +// SetArgs sets arguments for the command. It is set to os.Args[1:] by default, if desired, can be overridden +// particularly useful when testing. +func (c *Command) SetArgs(a []string) { + c.args = a +} + +func (c *Command) AddCommand(cmds ...*Command) error { + for i, x := range cmds { + if cmds[i] == c { + return fmt.Errorf("command can't be a child of itself") + } + cmds[i].parent = c + c.commands = append(c.commands, x) + } + return nil +} + +// Flags returns the FlagSet of the Command +func (c *Command) Flags() *FlagSet { + if c.flags == nil { + c.flags = NewFlagSet(c.Use, ContinueOnError) + } + return c.flags +} + +// HasParent determines if the command is a child command. +func (c *Command) HasParent() bool { + return c.parent != nil +} + +// HasSubCommands determines if the command has children commands. +func (c *Command) HasSubCommands() bool { + return len(c.commands) > 0 +} + +func stripFlags(args []string) []string { + commands := make([]string, 0) + for len(args) > 0 { + s := args[0] + args = args[1:] + if strings.HasPrefix(s, "-") { + // handle "-f child child" args + if len(args) <= 1 { + break + } else { + args = args[1:] + continue + } + } else if s != "" && !strings.HasPrefix(s, "-") { + commands = append(commands, s) + } + } + return commands +} + +func (c *Command) findNext(next string) *Command { + for _, cmd := range c.commands { + if cmd.Use == next { + return cmd + } + } + return nil +} + +func nextArgs(args []string, x string) []string { + if len(args) == 0 { + return args + } + for pos := 0; pos < len(args); pos++ { + s := args[pos] + switch { + case strings.HasPrefix(s, "-"): + pos++ + continue + case !strings.HasPrefix(s, "-"): + if s == x { + // cannot use var ret []string cause it return nil + ret := make([]string, 0) + ret = append(ret, args[:pos]...) + ret = append(ret, args[pos+1:]...) + return ret + } + } + } + return args +} + +func validateArgs(cmd *Command, args []string) error { + // no subcommand, always take args + if !cmd.HasSubCommands() { + return nil + } + + // root command with subcommands, do subcommand checking. + if !cmd.HasParent() && len(args) > 0 { + return fmt.Errorf("unknown command %q", args[0]) + } + return nil +} + +// Find the target command given the args and command tree +func (c *Command) Find(args []string) (*Command, []string, error) { + var innerFind func(*Command, []string) (*Command, []string) + + innerFind = func(c *Command, innerArgs []string) (*Command, []string) { + argsWithoutFlags := stripFlags(innerArgs) + if len(argsWithoutFlags) == 0 { + return c, innerArgs + } + nextSubCmd := argsWithoutFlags[0] + + cmd := c.findNext(nextSubCmd) + if cmd != nil { + return innerFind(cmd, nextArgs(innerArgs, nextSubCmd)) + } + return c, innerArgs + } + commandFound, a := innerFind(c, args) + return commandFound, a, validateArgs(commandFound, stripFlags(a)) +} + +// ParseFlags parses persistent flag tree and local flags. +func (c *Command) ParseFlags(args []string) error { + err := c.Flags().Parse(args) + return err +} + +// SetHelpFunc sets help function. Can be defined by Application. +func (c *Command) SetHelpFunc(f func(*Command, []string)) { + c.helpFunc = f +} + +// HelpFunc returns either the function set by SetHelpFunc for this command +// or a parent, or it returns a function with default help behavior. +func (c *Command) HelpFunc() func(*Command, []string) { + if c.helpFunc != nil { + return c.helpFunc + } + if c.HasParent() { + return c.parent.HelpFunc() + } + return nil +} + +// PrintErrln is a convenience method to Println to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErrln(i ...interface{}) { + c.PrintErr(fmt.Sprintln(i...)) +} + +// PrintErr is a convenience method to Print to the defined Err output, fallback to Stderr if not set. +func (c *Command) PrintErr(i ...interface{}) { + fmt.Fprint(c.ErrOrStderr(), i...) +} + +// ErrOrStderr returns output to stderr +func (c *Command) ErrOrStderr() io.Writer { + return c.getErr(os.Stderr) +} + +func (c *Command) getErr(def io.Writer) io.Writer { + if c.HasParent() { + return c.parent.getErr(def) + } + return def +} + +// ExecuteC executes the command. +func (c *Command) ExecuteC() (cmd *Command, err error) { + args := c.args + if c.args == nil { + args = os.Args[1:] + } + cmd, flags, err := c.Find(args) + if err != nil { + return c, err + } + err = cmd.execute(flags) + if err != nil { + // Always show help if requested, even if SilenceErrors is in + // effect + if errors.Is(err, ErrHelp) { + cmd.HelpFunc()(cmd, args) + return cmd, err + } + } + + return cmd, err +} + +func (c *Command) execute(a []string) error { + if c == nil { + return fmt.Errorf("called Execute() on a nil Command") + } + err := c.ParseFlags(a) + if err != nil { + return err + } + argWoFlags := c.Flags().Args() + if c.RunE != nil { + err := c.RunE(c, argWoFlags) + if err != nil { + return err + } + } + return nil +} diff --git a/tool/internal_pkg/util/command_test.go b/tool/internal_pkg/util/command_test.go new file mode 100644 index 0000000000..ae83875b52 --- /dev/null +++ b/tool/internal_pkg/util/command_test.go @@ -0,0 +1,245 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + "reflect" + "strings" + "testing" +) + +func emptyRun(*Command, []string) error { return nil } + +func executeCommand(root *Command, args ...string) (err error) { + _, err = executeCommandC(root, args...) + return err +} + +func executeCommandC(root *Command, args ...string) (c *Command, err error) { + root.SetArgs(args) + c, err = root.ExecuteC() + return c, err +} + +const onetwo = "one two" + +func TestSingleCommand(t *testing.T) { + rootCmd := &Command{ + Use: "root", + RunE: func(_ *Command, args []string) error { return nil }, + } + aCmd := &Command{Use: "a", RunE: emptyRun} + bCmd := &Command{Use: "b", RunE: emptyRun} + rootCmd.AddCommand(aCmd, bCmd) + + _ = executeCommand(rootCmd, "one", "two") +} + +func TestChildCommand(t *testing.T) { + var child1CmdArgs []string + rootCmd := &Command{Use: "root", RunE: emptyRun} + child1Cmd := &Command{ + Use: "child1", + RunE: func(_ *Command, args []string) error { child1CmdArgs = args; return nil }, + } + child2Cmd := &Command{Use: "child2", RunE: emptyRun} + rootCmd.AddCommand(child1Cmd, child2Cmd) + + err := executeCommand(rootCmd, "child1", "one", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + got := strings.Join(child1CmdArgs, " ") + if got != onetwo { + t.Errorf("child1CmdArgs expected: %q, got: %q", onetwo, got) + } +} + +func TestCallCommandWithoutSubcommands(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + err := executeCommand(rootCmd) + if err != nil { + t.Errorf("Calling command without subcommands should not have error: %v", err) + } +} + +func TestRootExecuteUnknownCommand(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + rootCmd.AddCommand(&Command{Use: "child", RunE: emptyRun}) + + _ = executeCommand(rootCmd, "unknown") +} + +func TestSubcommandExecuteC(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + childCmd := &Command{Use: "child", RunE: emptyRun} + rootCmd.AddCommand(childCmd) + + _, err := executeCommandC(rootCmd, "child") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } +} + +func TestFind(t *testing.T) { + var foo, bar string + root := &Command{ + Use: "root", + } + root.Flags().StringVarP(&foo, "foo", "f", "", "") + root.Flags().StringVarP(&bar, "bar", "b", "something", "") + + child := &Command{ + Use: "child", + } + root.AddCommand(child) + + testCases := []struct { + args []string + expectedFoundArgs []string + }{ + { + []string{"child"}, + []string{}, + }, + { + []string{"child", "child"}, + []string{"child"}, + }, + { + []string{"child", "foo", "child", "bar", "child", "baz", "child"}, + []string{"foo", "child", "bar", "child", "baz", "child"}, + }, + { + []string{"-f", "child", "child"}, + []string{"-f", "child"}, + }, + { + []string{"child", "-f", "child"}, + []string{"-f", "child"}, + }, + { + []string{"-b", "child", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b"}, + []string{"-b"}, + }, + { + []string{"-b", "-f", "child", "child"}, + []string{"-b", "-f", "child"}, + }, + { + []string{"-f", "child", "-b", "something", "child"}, + []string{"-f", "child", "-b", "something"}, + }, + { + []string{"-f", "child", "child", "-b"}, + []string{"-f", "child", "-b"}, + }, + { + []string{"-f=child", "-b=something", "child"}, + []string{"-f=child", "-b=something"}, + }, + { + []string{"--foo", "child", "--bar", "something", "child"}, + []string{"--foo", "child", "--bar", "something"}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) { + cmd, foundArgs, err := root.Find(tc.args) + if err != nil { + t.Fatal(err) + } + + if cmd != child { + t.Fatal("Expected cmd to be child, but it was not") + } + + if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) { + t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs) + } + }) + } +} + +func TestFlagLong(t *testing.T) { + var cArgs []string + c := &Command{ + Use: "c", + RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, + } + + var stringFlagValue string + c.Flags().StringVar(&stringFlagValue, "sf", "", "") + + err := executeCommand(c, "--sf=abc", "one", "--", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if stringFlagValue != "abc" { + t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) + } + + got := strings.Join(cArgs, " ") + if got != onetwo { + t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) + } +} + +func TestFlagShort(t *testing.T) { + var cArgs []string + c := &Command{ + Use: "c", + RunE: func(_ *Command, args []string) error { cArgs = args; return nil }, + } + + var stringFlagValue string + c.Flags().StringVarP(&stringFlagValue, "sf", "s", "", "") + + err := executeCommand(c, "-sabc", "one", "two") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if stringFlagValue != "abc" { + t.Errorf("Expected stringFlagValue: %q, got %q", "abc", stringFlagValue) + } + + got := strings.Join(cArgs, " ") + if got != onetwo { + t.Errorf("rootCmdArgs expected: %q, got: %q", onetwo, got) + } +} + +func TestChildFlag(t *testing.T) { + rootCmd := &Command{Use: "root", RunE: emptyRun} + childCmd := &Command{Use: "child", RunE: emptyRun} + rootCmd.AddCommand(childCmd) + err := executeCommand(rootCmd, "child") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } +} diff --git a/tool/internal_pkg/util/flag.go b/tool/internal_pkg/util/flag.go new file mode 100644 index 0000000000..77d8f8069a --- /dev/null +++ b/tool/internal_pkg/util/flag.go @@ -0,0 +1,889 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "bytes" + "encoding/csv" + "errors" + goflag "flag" + "fmt" + "io" + "os" + "sort" + "strconv" + "strings" +) + +// ErrHelp is the error returned if the flag -help is invoked but no such flag is defined. +var ErrHelp = errors.New("flag: help requested") + +// ErrorHandling defines how to handle flag parsing errors. +type ErrorHandling int + +const ( + // ContinueOnError will return an err from Parse() if an error is found + ContinueOnError ErrorHandling = iota + // ExitOnError will call os.Exit(2) if an error is found when parsing + ExitOnError + // PanicOnError will panic() if an error is found when parsing flags + PanicOnError +) + +// ParseErrorsWhitelist defines the parsing errors that can be ignored +type ParseErrorsWhitelist struct { + // UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags + UnknownFlags bool +} + +// NormalizedName is a flag name that has been normalized according to rules +// for the FlagSet (e.g. making '-' and '_' equivalent). +type NormalizedName string + +// A FlagSet represents a set of defined flags. +type FlagSet struct { + // Usage is the function called when an error occurs while parsing flags. + // The field is a function (not a method) that may be changed to point to + // a custom error handler. + Usage func() + + // SortFlags is used to indicate, if user wants to have sorted flags in + // help/usage messages. + SortFlags bool + + // ParseErrorsWhitelist is used to configure a whitelist of errors + ParseErrorsWhitelist ParseErrorsWhitelist + + name string + parsed bool + actual map[NormalizedName]*Flag + orderedActual []*Flag + formal map[NormalizedName]*Flag + orderedFormal []*Flag + sortedFormal []*Flag + shorthands map[byte]*Flag + args []string // arguments after flags + argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- + errorHandling ErrorHandling + output io.Writer // nil means stderr; use out() accessor + interspersed bool // allow interspersed option/non-option args + normalizeNameFunc func(f *FlagSet, name string) NormalizedName + + addedGoFlagSets []*goflag.FlagSet +} + +// A Flag represents the state of a flag. +type Flag struct { + Name string // name as it appears on command line + Shorthand string // one-letter abbreviated flag + Usage string // help message + Value Value // value as set + DefValue string // default value (as text); for usage message + Changed bool // If the user set the value (or if left to default) + NoOptDefVal string // default value (as text); if the flag is on the command line without any options + Deprecated string // If this flag is deprecated, this string is the new or now thing to use + ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use +} + +// Value is the interface to the dynamic value stored in a flag. +// (The default value is represented as a string.) +type Value interface { + String() string + Set(string) error + Type() string +} + +// SliceValue is a secondary interface to all flags which hold a list +// of values. This allows full control over the value of list flags, +// and avoids complicated marshalling and unmarshalling to csv. +type SliceValue interface { + // Append adds the specified value to the end of the flag value list. + Append(string) error + // Replace will fully overwrite any data currently in the flag value list. + Replace([]string) error + // GetSlice returns the flag value list as an array of strings. + GetSlice() []string +} + +// sortFlags returns the flags as a slice in lexicographical sorted order. +func sortFlags(flags map[NormalizedName]*Flag) []*Flag { + list := make(sort.StringSlice, len(flags)) + i := 0 + for k := range flags { + list[i] = string(k) + i++ + } + list.Sort() + result := make([]*Flag, len(list)) + for i, name := range list { + result[i] = flags[NormalizedName(name)] + } + return result +} + +// GetNormalizeFunc returns the previously set NormalizeFunc of a function which +// does no translation, if not set previously. +func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { + if f.normalizeNameFunc != nil { + return f.normalizeNameFunc + } + return func(f *FlagSet, name string) NormalizedName { return NormalizedName(name) } +} + +func (f *FlagSet) normalizeFlagName(name string) NormalizedName { + n := f.GetNormalizeFunc() + return n(f, name) +} + +// Lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) Lookup(name string) *Flag { + return f.lookup(f.normalizeFlagName(name)) +} + +// lookup returns the Flag structure of the named flag, returning nil if none exists. +func (f *FlagSet) lookup(name NormalizedName) *Flag { + return f.formal[name] +} + +func (f *FlagSet) out() io.Writer { + if f.output == nil { + return os.Stderr + } + return f.output +} + +// SetOutput sets the destination for usage and error messages. +// If output is nil, os.Stderr is used. +func (f *FlagSet) SetOutput(output io.Writer) { + f.output = output +} + +// Set sets the value of the named flag. +func (f *FlagSet) Set(name, value string) error { + normalName := f.normalizeFlagName(name) + flag, ok := f.formal[normalName] + if !ok { + return fmt.Errorf("no such flag -%v", name) + } + + err := flag.Value.Set(value) + if err != nil { + var flagName string + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + flagName = fmt.Sprintf("-%s, --%s", flag.Shorthand, flag.Name) + } else { + flagName = fmt.Sprintf("--%s", flag.Name) + } + return fmt.Errorf("invalid argument %q for %q flag: %v", value, flagName, err) + } + + if !flag.Changed { + if f.actual == nil { + f.actual = make(map[NormalizedName]*Flag) + } + f.actual[normalName] = flag + f.orderedActual = append(f.orderedActual, flag) + + flag.Changed = true + } + + if flag.Deprecated != "" { + fmt.Fprintf(f.out(), "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated) + } + return nil +} + +func (f *FlagSet) VisitAll(fn func(*Flag)) { + if len(f.formal) == 0 { + return + } + + var flags []*Flag + if f.SortFlags { + if len(f.formal) != len(f.sortedFormal) { + f.sortedFormal = sortFlags(f.formal) + } + flags = f.sortedFormal + } else { + flags = f.orderedFormal + } + + for _, flag := range flags { + fn(flag) + } +} + +func UnquoteUsage(flag *Flag) (name, usage string) { + // Look for a back-quoted name, but avoid the strings package. + usage = flag.Usage + for i := 0; i < len(usage); i++ { + if usage[i] == '`' { + for j := i + 1; j < len(usage); j++ { + if usage[j] == '`' { + name = usage[i+1 : j] + usage = usage[:i] + name + usage[j+1:] + return name, usage + } + } + break // Only one back quote; use type name. + } + } + + name = flag.Value.Type() + switch name { + case "bool": + name = "" + case "float64": + name = "float" + case "int64": + name = "int" + case "uint64": + name = "uint" + case "stringSlice": + name = "strings" + case "intSlice": + name = "ints" + case "uintSlice": + name = "uints" + case "boolSlice": + name = "bools" + } + + return +} + +func (f *FlagSet) FlagUsagesWrapped(cols int) string { + buf := new(bytes.Buffer) + + lines := make([]string, 0, len(f.formal)) + + maxlen := 0 + f.VisitAll(func(flag *Flag) { + line := "" + if flag.Shorthand != "" && flag.ShorthandDeprecated == "" { + line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) + } else { + line = fmt.Sprintf(" --%s", flag.Name) + } + + varname, usage := UnquoteUsage(flag) + if varname != "" { + line += " " + varname + } + if flag.NoOptDefVal != "" { + switch flag.Value.Type() { + case "string": + line += fmt.Sprintf("[=\"%s\"]", flag.NoOptDefVal) + case "bool": + if flag.NoOptDefVal != "true" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + case "count": + if flag.NoOptDefVal != "+1" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + default: + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + } + + // This special character will be replaced with spacing once the + // correct alignment is calculated + line += "\x00" + if len(line) > maxlen { + maxlen = len(line) + } + + line += usage + if len(flag.Deprecated) != 0 { + line += fmt.Sprintf(" (DEPRECATED: %s)", flag.Deprecated) + } + + lines = append(lines, line) + }) + + for _, line := range lines { + sidx := strings.Index(line, "\x00") + spacing := strings.Repeat(" ", maxlen-sidx) + // maxlen + 2 comes from + 1 for the \x00 and + 1 for the (deliberate) off-by-one in maxlen-sidx + fmt.Fprintln(buf, line[:sidx], spacing, wrap(maxlen+2, cols, line[sidx+1:])) + } + + return buf.String() +} + +func wrap(i, w int, s string) string { + if w == 0 { + return strings.Replace(s, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + // space between indent i and end of line width w into which + // we should wrap the text. + wrap := w - i + + var r, l string + + // Not enough space for sensible wrapping. Wrap as a block on + // the next line instead. + if wrap < 24 { + i = 16 + wrap = w - i + r += "\n" + strings.Repeat(" ", i) + } + // If still not enough space then don't even try to wrap. + if wrap < 24 { + return strings.Replace(s, "\n", r, -1) + } + + // Try to avoid short orphan words on the final line, by + // allowing wrapN to go a bit over if that would fit in the + // remainder of the line. + slop := 5 + wrap = wrap - slop + + // Handle first line, which is indented by the caller (or the + // special case above) + l, s = wrapN(wrap, slop, s) + r = r + strings.Replace(l, "\n", "\n"+strings.Repeat(" ", i), -1) + + // Now wrap the rest + for s != "" { + var t string + + t, s = wrapN(wrap, slop, s) + r = r + "\n" + strings.Repeat(" ", i) + strings.Replace(t, "\n", "\n"+strings.Repeat(" ", i), -1) + } + + return r +} + +func wrapN(i, slop int, s string) (string, string) { + if i+slop > len(s) { + return s, "" + } + + w := strings.LastIndexAny(s[:i], " \t\n") + if w <= 0 { + return s, "" + } + nlPos := strings.LastIndex(s[:i], "\n") + if nlPos > 0 && nlPos < w { + return s[:nlPos], s[nlPos+1:] + } + return s[:w], s[w+1:] +} + +func (f *FlagSet) FlagUsages() string { + return f.FlagUsagesWrapped(0) +} + +func (f *FlagSet) PrintDefaults() { + usages := f.FlagUsages() + fmt.Fprint(f.out(), usages) +} + +// defaultUsage is the default function to print a usage message. +func defaultUsage(f *FlagSet) { + fmt.Fprintf(f.out(), "Usage of %s:\n", f.name) + f.PrintDefaults() +} + +// Args returns the non-flag arguments. +func (f *FlagSet) Args() []string { return f.args } + +// VarPF is like VarP, but returns the flag created +func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) (*Flag, error) { + // Remember the default value as a string; it won't change. + flag := &Flag{ + Name: name, + Shorthand: shorthand, + Usage: usage, + Value: value, + DefValue: value.String(), + } + err := f.AddFlag(flag) + if err != nil { + return nil, err + } + return flag, nil +} + +// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { + f.VarPF(value, name, shorthand, usage) +} + +// AddFlag will add the flag to the FlagSet +func (f *FlagSet) AddFlag(flag *Flag) error { + normalizedFlagName := f.normalizeFlagName(flag.Name) + + _, alreadyThere := f.formal[normalizedFlagName] + if alreadyThere { + msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name) + fmt.Fprintln(f.out(), msg) + return fmt.Errorf("%s flag redefined: %s", f.name, flag.Name) + } + if f.formal == nil { + f.formal = make(map[NormalizedName]*Flag) + } + + flag.Name = string(normalizedFlagName) + f.formal[normalizedFlagName] = flag + f.orderedFormal = append(f.orderedFormal, flag) + + if flag.Shorthand == "" { + return nil + } + if len(flag.Shorthand) > 1 { + fmt.Fprintf(f.out(), "%q shorthand is more than one ASCII character", flag.Shorthand) + return fmt.Errorf("%q shorthand is more than one ASCII character", flag.Shorthand) + } + if f.shorthands == nil { + f.shorthands = make(map[byte]*Flag) + } + c := flag.Shorthand[0] + used, alreadyThere := f.shorthands[c] + if alreadyThere { + fmt.Fprintf(f.out(), "unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) + return fmt.Errorf("unable to redefine %q shorthand in %q flagset: it's already used for %q flag", c, f.name, used.Name) + } + f.shorthands[c] = flag + return nil +} + +// failf prints to standard error a formatted error and usage message and +// returns the error. +func (f *FlagSet) failf(format string, a ...interface{}) error { + err := fmt.Errorf(format, a...) + if f.errorHandling != ContinueOnError { + fmt.Fprintln(f.out(), err) + f.usage() + } + return err +} + +// --unknown (args will be empty) +// --unknown --next-flag ... (args will be --next-flag ...) +// --unknown arg ... (args will be arg ...) +func stripUnknownFlagValue(args []string) []string { + if len(args) == 0 { + //--unknown + return args + } + + first := args[0] + if len(first) > 0 && first[0] == '-' { + //--unknown --next-flag ... + return args + } + + //--unknown arg ... (args will be arg ...) + if len(args) > 1 { + return args[1:] + } + return nil +} + +func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { + a = args + name := s[2:] + if len(name) == 0 || name[0] == '-' || name[0] == '=' { + err = f.failf("bad flag syntax: %s", s) + return + } + + split := strings.SplitN(name, "=", 2) + name = split[0] + flag, exists := f.formal[f.normalizeFlagName(name)] + + if !exists { + switch { + case name == "help": + f.usage() + return a, ErrHelp + case f.ParseErrorsWhitelist.UnknownFlags: + // --unknown=unknownval arg ... + // we do not want to lose arg in this case + if len(split) >= 2 { + return a, nil + } + + return stripUnknownFlagValue(a), nil + default: + err = f.failf("unknown flag: --%s", name) + return + } + } + + var value string + if len(split) == 2 { + // '--flag=arg' + value = split[1] + } else if flag.NoOptDefVal != "" { + // '--flag' (arg was optional) + value = flag.NoOptDefVal + } else if len(a) > 0 { + // '--flag arg' + value = a[0] + a = a[1:] + } else { + // '--flag' (arg was required) + err = f.failf("flag needs an argument: %s", s) + return + } + + err = fn(flag, value) + if err != nil { + f.failf(err.Error()) + } + return +} + +func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parseFunc) (outShorts string, outArgs []string, err error) { + outArgs = args + + if strings.HasPrefix(shorthands, "test.") { + return + } + + outShorts = shorthands[1:] + c := shorthands[0] + + flag, exists := f.shorthands[c] + if !exists { + switch { + case c == 'h': + f.usage() + err = ErrHelp + return + case f.ParseErrorsWhitelist.UnknownFlags: + // '-f=arg arg ...' + // we do not want to lose arg in this case + if len(shorthands) > 2 && shorthands[1] == '=' { + outShorts = "" + return + } + + outArgs = stripUnknownFlagValue(outArgs) + return + default: + err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) + return + } + } + + var value string + if len(shorthands) > 2 && shorthands[1] == '=' { + // '-f=arg' + value = shorthands[2:] + outShorts = "" + } else if flag.NoOptDefVal != "" { + // '-f' (arg was optional) + value = flag.NoOptDefVal + } else if len(shorthands) > 1 { + // '-farg' + value = shorthands[1:] + outShorts = "" + } else if len(args) > 0 { + // '-f arg' + value = args[0] + outArgs = args[1:] + } else { + // '-f' (arg was required) + err = f.failf("flag needs an argument: %q in -%s", c, shorthands) + return + } + + if flag.ShorthandDeprecated != "" { + fmt.Fprintf(f.out(), "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated) + } + + err = fn(flag, value) + if err != nil { + f.failf(err.Error()) + } + return +} + +func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) { + a = args + shorthands := s[1:] + + // "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv"). + for len(shorthands) > 0 { + shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn) + if err != nil { + return + } + } + + return +} + +func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) { + for len(args) > 0 { + s := args[0] + args = args[1:] + if len(s) == 0 || s[0] != '-' || len(s) == 1 { + if !f.interspersed { + f.args = append(f.args, s) + f.args = append(f.args, args...) + return nil + } + f.args = append(f.args, s) + continue + } + + if s[1] == '-' { + if len(s) == 2 { // "--" terminates the flags + f.argsLenAtDash = len(f.args) + f.args = append(f.args, args...) + break + } + args, err = f.parseLongArg(s, args, fn) + } else { + args, err = f.parseShortArg(s, args, fn) + } + if err != nil { + return + } + } + return +} + +// Parse parses flag definitions from the argument list, which should not +// include the command name. Must be called after all flags in the FlagSet +// are defined and before flags are accessed by the program. +// The return value will be ErrHelp if -help was set but not defined. +func (f *FlagSet) Parse(arguments []string) error { + if f.addedGoFlagSets != nil { + for _, goFlagSet := range f.addedGoFlagSets { + goFlagSet.Parse(nil) + } + } + f.parsed = true + + f.args = make([]string, 0, len(arguments)) + + set := func(flag *Flag, value string) error { + return f.Set(flag.Name, value) + } + + err := f.parseArgs(arguments, set) + if err != nil { + switch f.errorHandling { + case ContinueOnError: + return err + case ExitOnError: + fmt.Println(err) + os.Exit(2) + case PanicOnError: + panic(err) + } + } + return nil +} + +type parseFunc func(flag *Flag, value string) error + +// Parsed reports whether f.Parse has been called. +func (f *FlagSet) Parsed() bool { + return f.parsed +} + +var CommandLine = NewFlagSet(os.Args[0], ExitOnError) + +// NewFlagSet returns a new, empty flag set with the specified name, +// error handling property and SortFlags set to true. +func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { + f := &FlagSet{ + name: name, + errorHandling: errorHandling, + argsLenAtDash: -1, + interspersed: true, + SortFlags: true, + } + return f +} + +// PrintDefaults prints to standard error the default values of all defined command-line flags. +func PrintDefaults() { + CommandLine.PrintDefaults() +} + +var Usage = func() { + fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) + PrintDefaults() +} + +// usage calls the Usage method for the flag set, or the usage function if +// the flag set is CommandLine. +func (f *FlagSet) usage() { + if f == CommandLine { + Usage() + } else if f.Usage == nil { + defaultUsage(f) + } else { + f.Usage() + } +} + +// -- string Value +type stringValue string + +func newStringValue(val string, p *string) *stringValue { + *p = val + return (*stringValue)(p) +} + +func (s *stringValue) Set(val string) error { + *s = stringValue(val) + return nil +} + +func (s *stringValue) String() string { return string(*s) } + +func (s *stringValue) Type() string { + return "string" +} + +// StringVar defines a string flag with specified name, default value, and usage string. +func (f *FlagSet) StringVar(p *string, name, value, usage string) { + f.VarP(newStringValue(value, p), name, "", usage) +} + +// StringVarP is like StringVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringVarP(p *string, name, shorthand, value, usage string) { + f.VarP(newStringValue(value, p), name, shorthand, usage) +} + +// String defines a string flag with specified name, default value, and usage string. +// The return value is the address of a string variable that stores the value of the flag. +func (f *FlagSet) String(name, value, usage string) *string { + p := new(string) + f.StringVarP(p, name, "", value, usage) + return p +} + +// StringP is like String, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringP(name, shorthand, value, usage string) *string { + p := new(string) + f.StringVarP(p, name, shorthand, value, usage) + return p +} + +// -- bool Value +type boolValue bool + +func newBoolValue(val bool, p *bool) *boolValue { + *p = val + return (*boolValue)(p) +} + +func (b *boolValue) Set(s string) error { + v, err := strconv.ParseBool(s) + *b = boolValue(v) + return err +} + +func (b *boolValue) String() string { return strconv.FormatBool(bool(*b)) } + +func (b *boolValue) Type() string { + return "bool" +} + +func (b *boolValue) IsBoolFlag() bool { return true } + +// BoolVar defines a bool flag with specified name, default value, and usage string. +// The argument p points to a bool variable in which to store the value of the flag. +func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) error { + err := f.BoolVarP(p, name, "", value, usage) + if err != nil { + return err + } + return nil +} + +// BoolVarP is like BoolVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage string) error { + flag, err := f.VarPF(newBoolValue(value, p), name, shorthand, usage) + if err != nil { + return err + } + flag.NoOptDefVal = "true" + return nil +} + +// Bool defines a bool flag with specified name, default value, and usage string. +// The return value is the address of a bool variable that stores the value of the flag. +func (f *FlagSet) Bool(name string, value bool, usage string) *bool { + return f.BoolP(name, "", value, usage) +} + +// BoolP is like Bool, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool { + p := new(bool) + f.BoolVarP(p, name, shorthand, value, usage) + return p +} + +// -- stringArray Value +type stringArrayValue struct { + value *[]string + changed bool +} + +func newStringArrayValue(val []string, p *[]string) *stringArrayValue { + ssv := new(stringArrayValue) + ssv.value = p + *ssv.value = val + return ssv +} + +func (s *stringArrayValue) Set(val string) error { + if !s.changed { + *s.value = []string{val} + s.changed = true + } else { + *s.value = append(*s.value, val) + } + return nil +} + +func writeAsCSV(vals []string) (string, error) { + b := &bytes.Buffer{} + w := csv.NewWriter(b) + err := w.Write(vals) + if err != nil { + return "", err + } + w.Flush() + return strings.TrimSuffix(b.String(), "\n"), nil +} + +func (s *stringArrayValue) String() string { + str, _ := writeAsCSV(*s.value) + return "[" + str + "]" +} + +func (s *stringArrayValue) Type() string { + return "stringArray" +} + +// StringArrayVar defines a string flag with specified name, default value, and usage string. +// The argument p points to a []string variable in which to store the values of the multiple flags. +func (f *FlagSet) StringArrayVar(p *[]string, name string, value []string, usage string) { + f.VarP(newStringArrayValue(value, p), name, "", usage) +} diff --git a/tool/internal_pkg/util/flag_test.go b/tool/internal_pkg/util/flag_test.go new file mode 100644 index 0000000000..916f906df5 --- /dev/null +++ b/tool/internal_pkg/util/flag_test.go @@ -0,0 +1,137 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "io" + "os" + "testing" +) + +func ResetForTesting() { + CommandLine = &FlagSet{ + name: os.Args[0], + errorHandling: ContinueOnError, + output: io.Discard, + } +} + +// GetCommandLine returns the default FlagSet. +func GetCommandLine() *FlagSet { + return CommandLine +} + +func TestParse(t *testing.T) { + ResetForTesting() + testParse(GetCommandLine(), t) +} + +func testParse(f *FlagSet, t *testing.T) { + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + boolFlag := f.Bool("bool", false, "bool value") + bool2Flag := f.Bool("bool2", false, "bool2 value") + bool3Flag := f.Bool("bool3", false, "bool3 value") + stringFlag := f.String("string", "0", "string value") + extra := "one-extra-argument" + args := []string{ + "--bool", + "--bool2=true", + "--bool3=false", + "--string=hello", + extra, + } + if err := f.Parse(args); err != nil { + t.Fatal(err) + } + if !f.Parsed() { + t.Error("f.Parse() = false after Parse") + } + if *boolFlag != true { + t.Error("bool flag should be true, is ", *boolFlag) + } + if *bool2Flag != true { + t.Error("bool2 flag should be true, is ", *bool2Flag) + } + if *bool3Flag != false { + t.Error("bool3 flag should be false, is ", *bool2Flag) + } + if *stringFlag != "hello" { + t.Error("string flag should be `hello`, is ", *stringFlag) + } + if len(f.Args()) != 1 { + t.Error("expected one argument, got", len(f.Args())) + } else if f.Args()[0] != extra { + t.Errorf("expected argument %q got %q", extra, f.Args()[0]) + } +} + +func TestShorthand(t *testing.T) { + f := NewFlagSet("shorthand", ContinueOnError) + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + boolaFlag := f.BoolP("boola", "a", false, "bool value") + boolbFlag := f.BoolP("boolb", "b", false, "bool2 value") + boolcFlag := f.BoolP("boolc", "c", false, "bool3 value") + booldFlag := f.BoolP("boold", "d", false, "bool4 value") + stringaFlag := f.StringP("stringa", "s", "0", "string value") + stringzFlag := f.StringP("stringz", "z", "0", "string value") + extra := "interspersed-argument" + notaflag := "--i-look-like-a-flag" + args := []string{ + "-ab", + extra, + "-cs", + "hello", + "-z=something", + "-d=true", + "--", + notaflag, + } + f.SetOutput(io.Discard) + if err := f.Parse(args); err != nil { + t.Error("expected no error, got ", err) + } + if !f.Parsed() { + t.Error("f.Parse() = false after Parse") + } + if *boolaFlag != true { + t.Error("boola flag should be true, is ", *boolaFlag) + } + if *boolbFlag != true { + t.Error("boolb flag should be true, is ", *boolbFlag) + } + if *boolcFlag != true { + t.Error("boolc flag should be true, is ", *boolcFlag) + } + if *booldFlag != true { + t.Error("boold flag should be true, is ", *booldFlag) + } + if *stringaFlag != "hello" { + t.Error("stringa flag should be `hello`, is ", *stringaFlag) + } + if *stringzFlag != "something" { + t.Error("stringz flag should be `something`, is ", *stringzFlag) + } + if len(f.Args()) != 2 { + t.Error("expected one argument, got", len(f.Args())) + } else if f.Args()[0] != extra { + t.Errorf("expected argument %q got %q", extra, f.Args()[0]) + } else if f.Args()[1] != notaflag { + t.Errorf("expected argument %q got %q", notaflag, f.Args()[1]) + } +} diff --git a/tool/internal_pkg/util/util.go b/tool/internal_pkg/util/util.go index 196f5b29bc..99549d1733 100644 --- a/tool/internal_pkg/util/util.go +++ b/tool/internal_pkg/util/util.go @@ -88,7 +88,7 @@ func Exists(path string) bool { return !fi.IsDir() } -// LowerFirst converts the first letter to upper case for the given string. +// LowerFirst converts the first letter to lower case for the given string. func LowerFirst(s string) string { rs := []rune(s) rs[0] = unicode.ToLower(rs[0])