diff --git a/cmd/mocker/main.go b/cmd/mocker/main.go index 4b53225..14829e2 100644 --- a/cmd/mocker/main.go +++ b/cmd/mocker/main.go @@ -33,11 +33,11 @@ func main() { m, err := mocker.New(src, pkg, iface, prefix, suffix, w) if err != nil { - log.Fatal("failed to instantiate mocker") + log.Fatal("mocker: failed to instantiate") } if err = m.Mock(); err != nil { - log.Fatalf("failed to mock: %v", err) + log.Fatalf("mocker: failed to mock: %v", err) } if out != nil { diff --git a/pkg/mocker/importer.go b/pkg/mocker/importer.go deleted file mode 100644 index ddca3e4..0000000 --- a/pkg/mocker/importer.go +++ /dev/null @@ -1,101 +0,0 @@ -package mocker - -import ( - "fmt" - "go/ast" - "go/parser" - "go/token" - "go/types" - "io/ioutil" - "os" - "path" - "path/filepath" - - "github.com/pkg/errors" -) - -type importer struct { - src string - base types.Importer - pkgs map[string]*types.Package -} - -func (i *importer) Import(path string) (*types.Package, error) { - var err error - if path == "" || path[0] == '.' { - path, err = filepath.Abs(filepath.Clean(path)) - if err != nil { - return nil, errors.Wrap(err, "importer: failed to get path") - } - } - if pkg, ok := i.pkgs[path]; ok { - return pkg, nil - } - pkg, err := i.pkg(path) - if err != nil { - return nil, errors.Wrap(err, "importer: failed to read pkg") - } - i.pkgs[path] = pkg - return pkg, nil -} - -func (i *importer) pkg(pkg string) (*types.Package, error) { - paths := []string{ - filepath.Join(i.src, "vendor", pkg), - filepath.Join(os.Getenv("GOPATH"), "src", pkg), - filepath.Join(os.Getenv("GOROOT"), "src", pkg), - } - var fpath string - var errs []error - for _, p := range paths { - abs, err := filepath.Abs(p) - if err != nil { - errs = append(errs, errors.Wrap(err, "importer: failed to get abs path")) - continue - } - if fi, err := os.Stat(abs); err != nil { - errs = append(errs, errors.Wrap(err, "importer: failed stat'ing path")) - continue - } else if !fi.IsDir() { - errs = append(errs, errors.Wrap(err, "importer: path not dir")) - continue - } - fpath = abs - } - if len(errs) == 3 { - return nil, fmt.Errorf("importer: failed to find pkg in vendor, GOPATH, or GOROOT:\n\t%v", errs) - } - f, err := ioutil.ReadDir(fpath) - if err != nil { - return nil, errors.Wrap(err, "importer: failed to read pkg dir") - } - fset := token.NewFileSet() - var files []*ast.File - for _, fi := range f { - if fi.IsDir() { - continue - } - n := fi.Name() - if path.Ext(n) != ".go" { - continue - } - p := path.Join(fpath, n) - src, err := ioutil.ReadFile(p) - if err != nil { - return nil, errors.Wrap(err, "importer: failed to read file") - } - f, err := parser.ParseFile(fset, p, src, 0) - if err != nil { - return nil, errors.Wrap(err, "importer: failed to parse file") - } - files = append(files, f) - } - cfg := types.Config{Importer: i} - p, err := cfg.Check(pkg, fset, files, nil) - if err != nil { - if p, err = i.base.Import(pkg); err != nil { - return nil, errors.Wrap(err, "importer: failed to import pkg") - } - } - return p, nil -} diff --git a/pkg/mocker/mocker.go b/pkg/mocker/mocker.go index 2cee4ee..045f60e 100644 --- a/pkg/mocker/mocker.go +++ b/pkg/mocker/mocker.go @@ -3,18 +3,18 @@ package mocker import ( "bytes" "fmt" - "go/ast" "go/format" - goimporter "go/importer" "go/parser" "go/token" "go/types" "io" "os" + "path/filepath" "strings" "text/template" "github.com/pkg/errors" + "golang.org/x/tools/go/loader" ) type mocker struct { @@ -51,57 +51,31 @@ func (m *mocker) Mock() error { } tmpl, err := template.New("mocker").Funcs(tmplFns).Parse(tmpl) if err != nil { - return errors.Wrap(err, "mocker: failed to parse template") + return errors.Wrap(err, "failed to parse template") } f := file{Pkg: *m.pkg, Imports: []iimport{{Path: "sync"}}} - for _, pkg := range pkgs { - i := 0 - files := make([]*ast.File, len(pkg.Files)) - for _, f := range pkg.Files { - files[i] = f - i++ - } - cfg := types.Config{Importer: &importer{src: *m.src, pkgs: make(map[string]*types.Package), base: goimporter.Default()}} - tpkg, err := cfg.Check(*m.src, fset, files, nil) - if err != nil { - return errors.Wrap(err, "mocker: failed to type check pkg") + + pkgInfo, err := m.pkgInfo(*m.src) + if err != nil { + return errors.Wrap(err, "failed to get pkg info") + } + for _, n := range *m.iface { + ifaceobj := pkgInfo.Pkg.Scope().Lookup(n) + if ifaceobj == nil { + return fmt.Errorf("failed to find interface: %s", n) } - for _, f := range files { - for _, d := range f.Decls { - gd, ok := d.(*ast.GenDecl) - if !ok { - continue - } - for _, s := range gd.Specs { - is, ok := s.(*ast.ImportSpec) - if !ok { - continue - } - if is.Name != nil { - i := iimport{Name: is.Name.Name, Path: strings.Replace(is.Path.Value, `"`, "", -1)} - m.imports.named[i.Path] = i - } - } - } + if !types.IsInterface(ifaceobj.Type()) { + return errors.Wrap(err, fmt.Sprintf("%s (%s) is not an interface", n, ifaceobj.Type().String())) } - for _, i := range *m.iface { - ifaceobj := tpkg.Scope().Lookup(i) - if ifaceobj == nil { - return fmt.Errorf("mocker: failed to find interface %s", i) - } - if !types.IsInterface(ifaceobj.Type()) { - return fmt.Errorf("mocker: not an interface %s", i) - } - tiface := ifaceobj.Type().Underlying().(*types.Interface).Complete() - iface := iface{Name: i, Suffix: *m.suffix, Prefix: *m.prefix} - for i := 0; i < tiface.NumMethods(); i++ { - met := tiface.Method(i) - sig := met.Type().(*types.Signature) - m := method{Name: met.Name(), Params: m.params(sig, sig.Params(), "in%d"), Returns: m.params(sig, sig.Results(), "out%d")} - iface.Methods = append(iface.Methods, m) - } - f.Ifaces = append(f.Ifaces, iface) + iiface := ifaceobj.Type().Underlying().(*types.Interface).Complete() + iface := iface{Name: n, Suffix: *m.suffix, Prefix: *m.prefix} + for i := 0; i < iiface.NumMethods(); i++ { + met := iiface.Method(i) + sig := met.Type().(*types.Signature) + m := method{Name: met.Name(), Params: m.params(sig, sig.Params(), "in%d"), Returns: m.params(sig, sig.Results(), "out%d")} + iface.Methods = append(iface.Methods, m) } + f.Ifaces = append(f.Ifaces, iface) } for p, n := range m.imports.named { if _, ok := m.imports.all[p]; ok { @@ -113,14 +87,14 @@ func (m *mocker) Mock() error { } var buf bytes.Buffer if err := tmpl.Execute(&buf, f); err != nil { - return errors.Wrap(err, "mocker: failed to execute template") + return errors.Wrap(err, "failed to execute template") } fmted, err := format.Source(buf.Bytes()) if err != nil { - return errors.Wrap(err, "mocker: failed to format file") + return errors.Wrap(err, "failed to format file") } if _, err := m.w.Write(fmted); err != nil { - return errors.Wrap(err, "mocker: failed to write file") + return errors.Wrap(err, "failed to write file") } return nil } @@ -247,3 +221,32 @@ func (m *mocker) params(sig *types.Signature, tuple *types.Tuple, format string) } return params } + +func (m *mocker) pkgInfo(src string) (*loader.PackageInfo, error) { + abs, err := filepath.Abs(src) + if err != nil { + return nil, errors.Wrap(err, "faild to get abs src path") + } + pkgPath := m.strip(abs) + conf := loader.Config{ + ParserMode: parser.SpuriousErrors, + Cwd: src, + } + conf.Import(pkgPath) + loader, err := conf.Load() + if err != nil { + return nil, errors.Wrap(err, "failed to load program") + } + pkgInfo := loader.Package(pkgPath) + if pkgInfo == nil { + return nil, errors.New("unable to load package") + } + return pkgInfo, nil +} + +func (m *mocker) strip(pkg string) string { + for _, path := range strings.Split(os.Getenv("GOPATH"), string(filepath.ListSeparator)) { + pkg = strings.TrimPrefix(pkg, filepath.Join(path, "src")+"/") + } + return pkg +}