diff --git a/generator/fake.go b/generator/fake.go index e983cea..d017ac1 100644 --- a/generator/fake.go +++ b/generator/fake.go @@ -34,7 +34,7 @@ type Fake struct { TargetAlias string TargetName string TargetPackage string - Imports []Import + Imports Imports Methods []Method Function Method } @@ -55,10 +55,10 @@ func NewFake(fakeMode FakeMode, targetName string, packagePath string, fakeName Name: fakeName, Mode: fakeMode, DestinationPackage: destinationPackage, - Imports: []Import{}, + Imports: newImports(), } - f.AddImport("sync", "sync") + f.Imports.Add("sync", "sync") err := f.loadPackages(workingDir) if err != nil { return nil, err diff --git a/generator/function_loader.go b/generator/function_loader.go index 52218e3..8b9aa21 100644 --- a/generator/function_loader.go +++ b/generator/function_loader.go @@ -15,7 +15,6 @@ func (f *Fake) loadMethodForFunction() error { return errors.New("target does not have an underlying function signature") } f.addTypesForMethod(sig) - importsMap := f.importsMap() - f.Function = methodForSignature(sig, f.TargetName, importsMap) + f.Function = methodForSignature(sig, f.TargetName, f.Imports) return nil } diff --git a/generator/function_template.go b/generator/function_template.go index 7df7b99..0d5c0d9 100644 --- a/generator/function_template.go +++ b/generator/function_template.go @@ -16,8 +16,8 @@ const functionTemplate string = `// Code generated by counterfeiter. DO NOT EDIT package {{.DestinationPackage}} import ( - {{- range .Imports}} - {{.Alias}} "{{.Path}}" + {{- range $index, $import := .Imports.ByAlias}} + {{$import.Alias}} "{{$import.PkgPath}}" {{- end}} ) diff --git a/generator/generator_internals_test.go b/generator/generator_internals_test.go index 925ea57..df3a8ee 100644 --- a/generator/generator_internals_test.go +++ b/generator/generator_internals_test.go @@ -54,12 +54,18 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { Expect(f.Name).To(Equal("FakeFileInfo")) Expect(f.Mode).To(Equal(InterfaceOrFunction)) Expect(f.DestinationPackage).To(Equal("osfakes")) - Expect(f.Imports).To(HaveLen(3)) - Expect(f.Imports).To(ConsistOf( - Import{Alias: "os", Path: "os"}, - Import{Alias: "sync", Path: "sync"}, - Import{Alias: "time", Path: "time"}, - )) + Expect(f.Imports).To(BeEquivalentTo(Imports{ + ByAlias: map[string]Import{ + "os": {Alias: "os", PkgPath: "os"}, + "sync": {Alias: "sync", PkgPath: "sync"}, + "time": {Alias: "time", PkgPath: "time"}, + }, + ByPkgPath: map[string]Import{ + "os": {Alias: "os", PkgPath: "os"}, + "sync": {Alias: "sync", PkgPath: "sync"}, + "time": {Alias: "time", PkgPath: "time"}, + }, + })) Expect(f.Function).To(BeZero()) Expect(f.Packages).NotTo(BeNil()) Expect(f.Package).NotTo(BeNil()) @@ -79,11 +85,16 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { Expect(f.Name).To(Equal("FakeHandlerFunc")) Expect(f.Mode).To(Equal(InterfaceOrFunction)) Expect(f.DestinationPackage).To(Equal("httpfakes")) - Expect(f.Imports).To(HaveLen(2)) - Expect(f.Imports).To(ConsistOf( - Import{Alias: "http", Path: "net/http"}, - Import{Alias: "sync", Path: "sync"}, - )) + Expect(f.Imports).To(BeEquivalentTo(Imports{ + ByAlias: map[string]Import{ + "http": {Alias: "http", PkgPath: "net/http"}, + "sync": {Alias: "sync", PkgPath: "sync"}, + }, + ByPkgPath: map[string]Import{ + "net/http": {Alias: "http", PkgPath: "net/http"}, + "sync": {Alias: "sync", PkgPath: "sync"}, + }, + })) Expect(f.Function).NotTo(BeZero()) Expect(f.Packages).NotTo(BeNil()) Expect(f.Package).NotTo(BeNil()) @@ -97,20 +108,29 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { when("manually constructing a fake", func() { it.Before(func() { - f = &Fake{} + f = &Fake{Imports: newImports()} }) - when("there are imports", func() { + when("duplicate import package names are added", func() { it.Before(func() { - f.AddImport("sync", "sync") - f.AddImport("sync", "github.com/maxbrunsfeld/counterfeiter/fixtures/sync") - f.AddImport("sync", "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync") + f.Imports.Add("sync", "sync") + f.Imports.Add("sync", "github.com/maxbrunsfeld/counterfeiter/fixtures/sync") + f.Imports.Add("sync", "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync") }) - it("always leaves the built-in sync in position 0", func() { - f.sortImports() - Expect(f.Imports[0].Alias).To(Equal("sync")) - Expect(f.Imports[0].Path).To(Equal("sync")) + it("all packages have unique aliases", func() { + Expect(f.Imports).To(BeEquivalentTo(Imports{ + ByAlias: map[string]Import{ + "sync": {Alias: "sync", PkgPath: "sync"}, + "synca": {Alias: "synca", PkgPath: "github.com/maxbrunsfeld/counterfeiter/fixtures/sync"}, + "syncb": {Alias: "syncb", PkgPath: "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync"}, + }, + ByPkgPath: map[string]Import{ + "sync": {Alias: "sync", PkgPath: "sync"}, + "github.com/maxbrunsfeld/counterfeiter/fixtures/sync": {Alias: "synca", PkgPath: "github.com/maxbrunsfeld/counterfeiter/fixtures/sync"}, + "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync": {Alias: "syncb", PkgPath: "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync"}, + }, + })) }) }) @@ -259,7 +279,7 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { f.loadMethods() Expect(len(f.Methods)).To(BeNumerically(">=", 51)) // yes, this is crazy because go 1.11 added a function Expect(len(f.Methods)).To(BeNumerically("<=", 53)) - Expect(len(f.Imports)).To(Equal(2)) + Expect(len(f.Imports.ByAlias)).To(Equal(2)) }) }) }) @@ -267,133 +287,49 @@ func testGenerator(t *testing.T, when spec.G, it spec.S) { when("working with imports", func() { when("there are no imports", func() { it("returns an empty alias map", func() { - m := f.aliasMap() - Expect(m).To(BeEmpty()) + Expect(f.Imports.ByAlias).To(BeEmpty()) }) it("turns a vendor path into the correct import", func() { - i := f.AddImport("apackage", "github.com/maxbrunsfeld/counterfeiter/fixtures/vendored/vendor/apackage") + i := f.Imports.Add("apackage", "github.com/maxbrunsfeld/counterfeiter/fixtures/vendored/vendor/apackage") Expect(i.Alias).To(Equal("apackage")) - Expect(i.Path).To(Equal("apackage")) + Expect(i.PkgPath).To(Equal("apackage")) - i = f.AddImport("anotherpackage", "vendor/anotherpackage") + i = f.Imports.Add("anotherpackage", "vendor/anotherpackage") Expect(i.Alias).To(Equal("anotherpackage")) - Expect(i.Path).To(Equal("anotherpackage")) + Expect(i.PkgPath).To(Equal("anotherpackage")) }) }) when("there is a single import", func() { it.Before(func() { - f.AddImport("os", "os") + f.Imports.Add("os", "os") }) it("is present in the map", func() { - expected := Import{Alias: "os", Path: "os"} - m := f.aliasMap() - Expect(m).To(HaveLen(1)) - Expect(m).To(HaveKeyWithValue("os", []Import{expected})) - }) - - it("returns the existing imports if there is a path match", func() { - i := f.AddImport("aliasedos", "os") - Expect(i.Alias).To(Equal("os")) - Expect(i.Path).To(Equal("os")) - Expect(f.Imports).To(HaveLen(1)) - Expect(f.Imports[0].Alias).To(Equal("os")) - Expect(f.Imports[0].Path).To(Equal("os")) - }) - }) - - when("there are imports", func() { - it.Before(func() { - f.Imports = []Import{ - Import{ - Alias: "dup_packages", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/dup_packages", - }, - Import{ - Alias: "foo", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/dup_packages/a/foo", + Expect(f.Imports).To(BeEquivalentTo(Imports{ + ByAlias: map[string]Import{ + "os": {Alias: "os", PkgPath: "os"}, }, - Import{ - Alias: "foo", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/dup_packages/b/foo", + ByPkgPath: map[string]Import{ + "os": {Alias: "os", PkgPath: "os"}, }, - Import{ - Alias: "sync", - Path: "sync", - }, - } + })) }) - it("collects duplicates", func() { - m := f.aliasMap() - Expect(m).To(HaveLen(3)) - Expect(m).To(HaveKey("dup_packages")) - Expect(m).To(HaveKey("sync")) - Expect(m).To(HaveKey("foo")) - Expect(m["foo"]).To(ConsistOf( - Import{ - Alias: "foo", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/dup_packages/a/foo", + it("returns the existing imports if there is a path match", func() { + i := f.Imports.Add("aliasedos", "os") + Expect(i.Alias).To(Equal("os")) + Expect(i.PkgPath).To(Equal("os")) + Expect(f.Imports).To(BeEquivalentTo(Imports{ + ByAlias: map[string]Import{ + "os": {Alias: "os", PkgPath: "os"}, }, - Import{ - Alias: "foo", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/dup_packages/b/foo", + ByPkgPath: map[string]Import{ + "os": {Alias: "os", PkgPath: "os"}, }, - )) - }) - - it("disambiguates aliases", func() { - m := f.aliasMap() - Expect(m).To(HaveLen(3)) - f.disambiguateAliases() - m = f.aliasMap() - Expect(m).To(HaveLen(4)) - Expect(m["fooa"]).To(ConsistOf(Import{ - Alias: "fooa", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/dup_packages/b/foo", })) }) - - when("there is a package named sync", func() { - it.Before(func() { - f.Imports = []Import{ - Import{ - Alias: "sync", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync", - }, - Import{ - Alias: "sync", - Path: "sync", - }, - Import{ - Alias: "sync", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/sync", - }, - } - }) - - it("preserves the stdlib sync alias", func() { - m := f.aliasMap() - Expect(m).To(HaveLen(1)) - f.disambiguateAliases() - m = f.aliasMap() - Expect(m).To(HaveLen(3)) - Expect(m["sync"]).To(ConsistOf(Import{ - Alias: "sync", - Path: "sync", - })) - Expect(m["syncb"]).To(ConsistOf(Import{ - Alias: "syncb", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/sync", - })) - Expect(m["synca"]).To(ConsistOf(Import{ - Alias: "synca", - Path: "github.com/maxbrunsfeld/counterfeiter/fixtures/othersync", - })) - }) - }) }) }) }) diff --git a/generator/import.go b/generator/import.go index 26b3e19..15f6700 100644 --- a/generator/import.go +++ b/generator/import.go @@ -1,127 +1,65 @@ package generator import ( - "log" - "sort" + "go/types" "strings" + + "golang.org/x/tools/imports" ) +// Imports indexes imports by package path and alias so that all imports have a +// unique alias, and no package is included twice. +type Imports struct { + ByAlias map[string]Import + ByPkgPath map[string]Import +} + +func newImports() Imports { + return Imports{ + ByAlias: make(map[string]Import), + ByPkgPath: make(map[string]Import), + } +} + // Import is a package import with the associated alias for that package. type Import struct { - Alias string - Path string + Alias string + PkgPath string } // AddImport creates an import with the given alias and path, and adds it to // Fake.Imports. -func (f *Fake) AddImport(alias string, path string) Import { - path = unvendor(strings.TrimSpace(path)) +func (i *Imports) Add(alias string, path string) Import { + // TODO: why is there extra whitespace on these args? + path = imports.VendorlessPath(strings.TrimSpace(path)) alias = strings.TrimSpace(alias) - for i := range f.Imports { - if f.Imports[i].Path == path { - return f.Imports[i] - } - } - log.Printf("Adding import: %s > %s\n", alias, path) - result := Import{ - Alias: alias, - Path: path, - } - f.Imports = append(f.Imports, result) - return result -} - -// SortImports sorts imports alphabetically. -func (f *Fake) sortImports() { - sort.SliceStable(f.Imports, func(i, j int) bool { - if f.Imports[i].Path == "sync" { - return true - } - if f.Imports[j].Path == "sync" { - return false - } - return f.Imports[i].Path < f.Imports[j].Path - }) -} -func unvendor(s string) string { - // Devendorize for use in import statement. - if i := strings.LastIndex(s, "/vendor/"); i >= 0 { - return s[i+len("/vendor/"):] + imp, exists := i.ByPkgPath[path] + if exists { + return imp } - if strings.HasPrefix(s, "vendor/") { - return s[len("vendor/"):] - } - return s -} -func (f *Fake) hasDuplicateAliases() bool { - for _, imports := range f.aliasMap() { - if len(imports) > 1 { - return true - } + imp, exists = i.ByAlias[alias] + if exists { + alias = uniqueAliasForImport(alias, i.ByAlias) } - return false -} -func (f *Fake) printAliases() { - for i := range f.Imports { - log.Printf("- %s > %s\n", f.Imports[i].Alias, f.Imports[i].Path) - } + result := Import{Alias: alias, PkgPath: path} + i.ByPkgPath[path] = result + i.ByAlias[alias] = result + return result } -// disambiguateAliases ensures that all imports are aliased uniquely. -func (f *Fake) disambiguateAliases() { - f.sortImports() - if !f.hasDuplicateAliases() { - return - } - - log.Printf("!!! Duplicate import aliases found,...") - log.Printf("aliases before disambiguation:\n") - f.printAliases() - var byAlias map[string][]Import - for { - byAlias = f.aliasMap() - if !f.hasDuplicateAliases() { - break - } - - for i := range f.Imports { - imports := byAlias[f.Imports[i].Alias] - if len(imports) == 1 { - continue - } - - for j := 0; j < len(imports); j++ { - if imports[j].Path == f.Imports[i].Path && j > 0 { - f.Imports[i].Alias = f.Imports[i].Alias + string('a'+byte(j-1)) - if f.Imports[i].Path == f.TargetPackage { - f.TargetAlias = f.Imports[i].Alias - } - } - } +func uniqueAliasForImport(alias string, imports map[string]Import) string { + for i := 0; ; i++ { + newAlias := alias + string('a'+byte(i)) + if _, exists := imports[newAlias]; !exists { + return newAlias } } - - log.Println("aliases after disambiguation:") - f.printAliases() -} - -func (f *Fake) aliasMap() map[string][]Import { - result := map[string][]Import{} - for i := range f.Imports { - imports := result[f.Imports[i].Alias] - result[f.Imports[i].Alias] = append(imports, f.Imports[i]) - } - return result } -func (f *Fake) importsMap() map[string]Import { - f.disambiguateAliases() - result := map[string]Import{} - for i := range f.Imports { - result[f.Imports[i].Path] = f.Imports[i] - } - return result +// AliasForPackage returns a package alias for the package. +func (i *Imports) AliasForPackage(p *types.Package) string { + return i.ByPkgPath[imports.VendorlessPath(p.Path())].Alias } diff --git a/generator/interface_loader.go b/generator/interface_loader.go index 4a92d6b..f81c252 100644 --- a/generator/interface_loader.go +++ b/generator/interface_loader.go @@ -19,12 +19,12 @@ func (f *Fake) addTypesForMethod(sig *types.Signature) { } } -func methodForSignature(sig *types.Signature, methodName string, importsMap map[string]Import) Method { +func methodForSignature(sig *types.Signature, methodName string, imports Imports) Method { params := []Param{} for i := 0; i < sig.Params().Len(); i++ { param := sig.Params().At(i) isVariadic := i == sig.Params().Len()-1 && sig.Variadic() - typ := typeFor(param.Type(), importsMap) + typ := types.TypeString(param.Type(), imports.AliasForPackage) if isVariadic { typ = "..." + typ[2:] // Change []string to ...string } @@ -41,7 +41,7 @@ func methodForSignature(sig *types.Signature, methodName string, importsMap map[ ret := sig.Results().At(i) r := Return{ Name: fmt.Sprintf("result%v", i+1), - Type: typeFor(ret.Type(), importsMap), + Type: types.TypeString(ret.Type(), imports.AliasForPackage), } returns = append(returns, r) } @@ -96,9 +96,8 @@ func (f *Fake) loadMethods() { f.addTypesForMethod(methods[i].Signature) } - importsMap := f.importsMap() for i := range methods { - method := methodForSignature(methods[i].Signature, methods[i].Func.Name(), importsMap) + method := methodForSignature(methods[i].Signature, methods[i].Func.Name(), f.Imports) f.Methods = append(f.Methods, method) } } diff --git a/generator/interface_template.go b/generator/interface_template.go index e1b19c7..714ed4c 100644 --- a/generator/interface_template.go +++ b/generator/interface_template.go @@ -16,8 +16,8 @@ const interfaceTemplate string = `// Code generated by counterfeiter. DO NOT EDI package {{.DestinationPackage}} import ( - {{- range .Imports}} - {{.Alias}} "{{.Path}}" + {{- range $index, $import := .Imports.ByAlias}} + {{$import.Alias}} "{{$import.PkgPath}}" {{- end}} ) diff --git a/generator/loader.go b/generator/loader.go index 132291c..195a6fa 100644 --- a/generator/loader.go +++ b/generator/loader.go @@ -8,6 +8,7 @@ import ( "strings" "golang.org/x/tools/go/packages" + "golang.org/x/tools/imports" ) func (f *Fake) loadPackages(workingDir string) error { @@ -69,8 +70,8 @@ func (f *Fake) findPackage() error { } f.Target = target f.Package = pkg - f.TargetPackage = unvendor(pkg.PkgPath) - t := f.AddImport(pkg.Name, f.TargetPackage) + f.TargetPackage = imports.VendorlessPath(pkg.PkgPath) + t := f.Imports.Add(pkg.Name, f.TargetPackage) f.TargetAlias = t.Alias if f.Mode != Package { f.TargetName = target.Name() @@ -113,7 +114,7 @@ func (f *Fake) addImportsFor(typ types.Type) { f.addImportsFor(t.Elem()) case *types.Named: if t.Obj() != nil && t.Obj().Pkg() != nil { - f.AddImport(t.Obj().Pkg().Name(), t.Obj().Pkg().Path()) + f.Imports.Add(t.Obj().Pkg().Name(), t.Obj().Pkg().Path()) } case *types.Slice: f.addImportsFor(t.Elem()) @@ -127,16 +128,3 @@ func (f *Fake) addImportsFor(typ types.Type) { log.Printf("!!! WARNING: Missing case for type %s\n", reflect.TypeOf(typ).String()) } } - -func typeFor(typ types.Type, importsMap map[string]Import) string { - if typ == nil { - return "" - } - return types.TypeString(typ, func(p *types.Package) string { - imp, ok := importsMap[unvendor(p.Path())] - if ok { - return imp.Alias - } - return "" - }) -} diff --git a/generator/package_template.go b/generator/package_template.go index bb27f00..e1029e1 100644 --- a/generator/package_template.go +++ b/generator/package_template.go @@ -16,8 +16,8 @@ const packageTemplate string = `// Code generated by counterfeiter. DO NOT EDIT. package {{.DestinationPackage}} import ( - {{- range .Imports}} - {{.Alias}} "{{.Path}}" + {{- range $index, $import := .Imports.ByAlias}} + {{$import.Alias}} "{{$import.PkgPath}}" {{- end}} )