diff --git a/fixtures/dup_packages/go.mod b/fixtures/dup_packages/go.mod index e5fc549..b49ab2a 100644 --- a/fixtures/dup_packages/go.mod +++ b/fixtures/dup_packages/go.mod @@ -1,3 +1,3 @@ module github.com/maxbrunsfeld/counterfeiter/v6/fixtures/dup_packages -go 1.12 +go 1.22 diff --git a/fixtures/type_aliases/extra/m.go b/fixtures/type_aliases/extra/m.go new file mode 100644 index 0000000..bc0623b --- /dev/null +++ b/fixtures/type_aliases/extra/m.go @@ -0,0 +1,5 @@ +package extra // import "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases/extra" + +import "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases/extra/primitive" + +type M = primitive.M diff --git a/fixtures/type_aliases/extra/primitive/m.go b/fixtures/type_aliases/extra/primitive/m.go new file mode 100644 index 0000000..9cebc01 --- /dev/null +++ b/fixtures/type_aliases/extra/primitive/m.go @@ -0,0 +1,3 @@ +package primitive // import "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases/primitive" + +type M map[string]interface{} diff --git a/fixtures/type_aliases/go.mod b/fixtures/type_aliases/go.mod new file mode 100644 index 0000000..b9bd99b --- /dev/null +++ b/fixtures/type_aliases/go.mod @@ -0,0 +1,3 @@ +module github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases + +go 1.22 diff --git a/fixtures/type_aliases/interface.go b/fixtures/type_aliases/interface.go new file mode 100644 index 0000000..d75bd9e --- /dev/null +++ b/fixtures/type_aliases/interface.go @@ -0,0 +1,14 @@ +package type_aliases // import "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases" + +import ( + "context" + + "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases/extra" +) + +//go:generate go run github.com/maxbrunsfeld/counterfeiter/v6 -generate + +//counterfeiter:generate . WithAliasedType +type WithAliasedType interface { + FindExample(ctx context.Context, filter extra.M) ([]string, error) +} diff --git a/generator/loader.go b/generator/loader.go index 4a8695b..4ffb873 100644 --- a/generator/loader.go +++ b/generator/loader.go @@ -130,14 +130,10 @@ func (f *Fake) addImportsFor(typ types.Type) { f.addImportsFor(t.Elem()) case *types.Chan: f.addImportsFor(t.Elem()) + case *types.Alias: + f.addImportsForNamedType(t) case *types.Named: - if t.Obj() != nil && t.Obj().Pkg() != nil { - typeArgs := t.TypeArgs() - for i := 0; i < typeArgs.Len(); i++ { - f.addImportsFor(typeArgs.At(i)) - } - f.Imports.Add(t.Obj().Pkg().Name(), t.Obj().Pkg().Path()) - } + f.addImportsForNamedType(t) case *types.Slice: f.addImportsFor(t.Elem()) case *types.Array: @@ -154,3 +150,16 @@ func (f *Fake) addImportsFor(typ types.Type) { log.Printf("!!! WARNING: Missing case for type %s\n", reflect.TypeOf(typ).String()) } } + +func (f *Fake) addImportsForNamedType(t interface { + Obj() *types.TypeName + TypeArgs() *types.TypeList +}) { + if t.Obj() != nil && t.Obj().Pkg() != nil { + typeArgs := t.TypeArgs() + for i := 0; i < typeArgs.Len(); i++ { + f.addImportsFor(typeArgs.At(i)) + } + f.Imports.Add(t.Obj().Pkg().Name(), t.Obj().Pkg().Path()) + } +} diff --git a/go.mod b/go.mod index 7806358..15dcc4b 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,6 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.22.0 +go 1.23 toolchain go1.23.1 diff --git a/integration/roundtrip_test.go b/integration/roundtrip_test.go index 45939f3..5dafaa0 100644 --- a/integration/roundtrip_test.go +++ b/integration/roundtrip_test.go @@ -121,6 +121,25 @@ func runTests(t *testing.T, when spec.G, it spec.S) { }) }) + when("generating interfaces using type aliases", func() { + it.Before(func() { + relativeDir = filepath.Join(relativeDir, "type_aliases") + copyDirFunc() + }) + it("imports the aliased type, not the underlying type", func() { + cache := &generator.FakeCache{} + pkgPath := "github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases" + interfaceName := "WithAliasedType" + fakePackageName := "type_aliasesfakes" + f, err := generator.NewFake(generator.InterfaceOrFunction, interfaceName, pkgPath, "Fake"+interfaceName, fakePackageName, "", baseDir, cache) + Expect(err).NotTo(HaveOccurred()) + b, err := f.Generate(false) + Expect(err).NotTo(HaveOccurred()) + Expect(string(b)).NotTo(ContainSubstring("primitive")) + Expect(string(b)).To(ContainSubstring(`"github.com/maxbrunsfeld/counterfeiter/v6/fixtures/type_aliases/extra"`)) + }) + }) + when(name, func() { t := func(interfaceName string, filename string, subDir string, files ...string) { when("working with "+filename, func() {