diff --git a/README.md b/README.md index f902f1e0d..e9c04a4cb 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Join the chat at https://gitter.im/yandex/pandora](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/yandex/pandora?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) [![Build Status](https://travis-ci.org/yandex/pandora.svg)](https://travis-ci.org/yandex/pandora) -[![Coverage Status](https://coveralls.io/repos/yandex/pandora/badge.svg?branch=master&service=github)](https://coveralls.io/github/yandex/pandora?branch=master) +[![Coverage Status](https://coveralls.io/repos/yandex/pandora/badge.svg?branch=develop&service=github)](https://coveralls.io/github/yandex/pandora?branch=develop) A load generator in Go language. diff --git a/components/example/import/import_suite_test.go b/components/example/import/import_suite_test.go new file mode 100644 index 000000000..06002b770 --- /dev/null +++ b/components/example/import/import_suite_test.go @@ -0,0 +1,20 @@ +package example + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/yandex/pandora/lib/testutil" +) + +func TestImport(t *testing.T) { + testutil.RunSuite(t, "Import Suite") +} + +var _ = Describe("import", func() { + It("not panics", func() { + Expect(Import).NotTo(Panic()) + }) +}) diff --git a/components/phttp/client.go b/components/phttp/client.go index 79c8c86f1..7ee02d409 100644 --- a/components/phttp/client.go +++ b/components/phttp/client.go @@ -6,7 +6,6 @@ package phttp import ( - "context" "crypto/tls" "net" "net/http" @@ -17,6 +16,7 @@ import ( "github.com/pkg/errors" "github.com/yandex/pandora/core/config" + "github.com/yandex/pandora/lib/netutil" ) //go:generate mockery -name=Client -case=underscore -inpkg -testonly @@ -43,6 +43,8 @@ func NewDefaultClientConfig() ClientConfig { // DialerConfig can be mapped on net.Dialer. // Set net.Dialer for details. type DialerConfig struct { + DNSCache bool `config:"dns-cache" map:"-"` + Timeout time.Duration `config:"timeout"` DualStack bool `config:"dual-stack"` @@ -54,28 +56,20 @@ type DialerConfig struct { func NewDefaultDialerConfig() DialerConfig { return DialerConfig{ + DNSCache: true, + DualStack: true, Timeout: 3 * time.Second, KeepAlive: 120 * time.Second, } } -type Dialer interface { - DialContext(ctx context.Context, network, address string) (net.Conn, error) -} - -type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error) - -func (f DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return f(ctx, network, address) -} - -func NewDialer(conf DialerConfig) *net.Dialer { +func NewDialer(conf DialerConfig) netutil.Dialer { d := &net.Dialer{} - err := config.Map(d, conf) - if err != nil { - zap.L().Panic("Dialer config map fail", zap.Error(err)) + config.Map(d, conf) + if !conf.DNSCache { + return d } - return d + return netutil.NewDNSCachingDialer(d, netutil.DefaultDNSCache) } // TransportConfig can be mapped on http.Transport. @@ -101,21 +95,18 @@ func NewDefaultTransportConfig() TransportConfig { } } -func NewTransport(conf TransportConfig, dial DialerFunc) *http.Transport { +func NewTransport(conf TransportConfig, dial netutil.DialerFunc) *http.Transport { tr := &http.Transport{} tr.TLSClientConfig = &tls.Config{ InsecureSkipVerify: true, // We should not spend time for this stuff. NextProtos: []string{"http/1.1"}, // Disable HTTP/2. Use HTTP/2 transport explicitly, if needed. } - err := config.Map(tr, conf) - if err != nil { - zap.L().Panic("Transport config map fail", zap.Error(err)) - } + config.Map(tr, conf) tr.DialContext = dial return tr } -func NewHTTP2Transport(conf TransportConfig, dial DialerFunc) *http.Transport { +func NewHTTP2Transport(conf TransportConfig, dial netutil.DialerFunc) *http.Transport { tr := NewTransport(conf, dial) err := http2.ConfigureTransport(tr) if err != nil { diff --git a/components/phttp/connect.go b/components/phttp/connect.go index 4ea48b289..a05a3d8ec 100644 --- a/components/phttp/connect.go +++ b/components/phttp/connect.go @@ -15,6 +15,7 @@ import ( "net/url" "github.com/pkg/errors" + "github.com/yandex/pandora/lib/netutil" ) type ConnectGunConfig struct { @@ -78,7 +79,7 @@ func newConnectClient(conf ConnectGunConfig) Client { return newClient(transport, conf.Client.Redirect) } -func newConnectDialFunc(target string, connectSSL bool, dialer Dialer) DialerFunc { +func newConnectDialFunc(target string, connectSSL bool, dialer netutil.Dialer) netutil.DialerFunc { return func(ctx context.Context, network, address string) (conn net.Conn, err error) { // TODO(skipor): make connect sample. // TODO(skipor): make httptrace callbacks called correctly. diff --git a/components/phttp/http.go b/components/phttp/http.go index 1b33c313a..f043a6d86 100644 --- a/components/phttp/http.go +++ b/components/phttp/http.go @@ -36,7 +36,7 @@ func NewHTTPGun(conf HTTPGunConfig) *HTTPGun { // NewHTTP2Gun return simple HTTP/2 gun that can shoot sequentially through one connection. func NewHTTP2Gun(conf HTTP2GunConfig) (*HTTPGun, error) { if !conf.Gun.SSL { - // Open issue on github if you need this feature. + // Open issue on github if you really need this feature. return nil, errors.New("HTTP/2.0 over TCP is not supported. Please leave SSL option true by default.") } transport := NewHTTP2Transport(conf.Client.Transport, NewDialer(conf.Client.Dialer).DialContext) diff --git a/components/phttp/import/import.go b/components/phttp/import/import.go index 5ea87f375..9733cf5bf 100644 --- a/components/phttp/import/import.go +++ b/components/phttp/import/import.go @@ -6,7 +6,10 @@ package phttp import ( + "net" + "github.com/spf13/afero" + "go.uber.org/zap" . "github.com/yandex/pandora/components/phttp" "github.com/yandex/pandora/components/phttp/ammo/simple/jsonline" @@ -14,6 +17,7 @@ import ( "github.com/yandex/pandora/components/phttp/ammo/simple/uri" "github.com/yandex/pandora/core" "github.com/yandex/pandora/core/register" + "github.com/yandex/pandora/lib/netutil" ) func Import(fs afero.Fs) { @@ -29,16 +33,56 @@ func Import(fs afero.Fs) { return raw.NewProvider(fs, conf) }) - register.Gun("http", func(conf HTTPGunConfig) core.Gun { - return WrapGun(NewHTTPGun(conf)) + register.Gun("http", func(conf HTTPGunConfig) func() core.Gun { + preResolveTargetAddr(&conf.Client, &conf.Gun.Target) + return func() core.Gun { return WrapGun(NewHTTPGun(conf)) } }, NewDefaultHTTPGunConfig) - register.Gun("http2", func(conf HTTP2GunConfig) (core.Gun, error) { - gun, err := NewHTTP2Gun(conf) - return WrapGun(gun), err + register.Gun("http2", func(conf HTTP2GunConfig) func() (core.Gun, error) { + preResolveTargetAddr(&conf.Client, &conf.Gun.Target) + return func() (core.Gun, error) { + gun, err := NewHTTP2Gun(conf) + return WrapGun(gun), err + } }, NewDefaultHTTP2GunConfig) - register.Gun("connect", func(conf ConnectGunConfig) core.Gun { - return WrapGun(NewConnectGun(conf)) + register.Gun("connect", func(conf ConnectGunConfig) func() core.Gun { + preResolveTargetAddr(&conf.Client, &conf.Target) + return func() core.Gun { + return WrapGun(NewConnectGun(conf)) + } }, NewDefaultConnectGunConfig) } + +// DNS resolve optimisation. +// When DNSCache turned off - do nothing extra, host will be resolved on every shoot. +// When using resolved target, don't use DNS caching logic - it is useless. +// If we can resolve accessible target addr - use it as target, not use caching. +// Otherwise just use DNS cache - we should not fail shooting, we should try to +// connect on every shoot. DNS cache will save resolved addr after first successful connect. +func preResolveTargetAddr(clientConf *ClientConfig, target *string) (err error) { + if !clientConf.Dialer.DNSCache { + return + } + if endpointIsResolved(*target) { + clientConf.Dialer.DNSCache = false + return + } + resolved, err := netutil.LookupReachable(*target) + if err != nil { + zap.L().Warn("DNS target pre resolve failed", + zap.String("target", *target), zap.Error(err)) + return + } + clientConf.Dialer.DNSCache = false + *target = resolved + return +} + +func endpointIsResolved(endpoint string) bool { + host, _, err := net.SplitHostPort(endpoint) + if err != nil { + return false + } + return net.ParseIP(host) != nil +} diff --git a/components/phttp/import/import_suite_test.go b/components/phttp/import/import_suite_test.go new file mode 100644 index 000000000..0c7f0f60d --- /dev/null +++ b/components/phttp/import/import_suite_test.go @@ -0,0 +1,72 @@ +package phttp + +import ( + "net" + "strconv" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + "github.com/spf13/afero" + + . "github.com/yandex/pandora/components/phttp" + "github.com/yandex/pandora/lib/testutil" +) + +func TestImport(t *testing.T) { + testutil.RunSuite(t, "phttp Import Suite") +} + +var _ = Describe("import", func() { + It("not panics", func() { + Expect(func() { + Import(afero.NewOsFs()) + }).NotTo(Panic()) + }) +}) + +var _ = Describe("preResolveTargetAddr", func() { + It("host target", func() { + conf := &ClientConfig{} + conf.Dialer.DNSCache = true + + listener, err := net.ListenTCP("tcp4", nil) + defer listener.Close() + Expect(err).NotTo(HaveOccurred()) + + port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + target := "localhost:" + port + expectedResolved := "127.0.0.1:" + port + + err = preResolveTargetAddr(conf, &target) + Expect(err).NotTo(HaveOccurred()) + Expect(conf.Dialer.DNSCache).To(BeFalse()) + + Expect(target).To(Equal(expectedResolved)) + }) + + It("ip target", func() { + conf := &ClientConfig{} + conf.Dialer.DNSCache = true + + const addr = "127.0.0.1:80" + target := addr + err := preResolveTargetAddr(conf, &target) + Expect(err).NotTo(HaveOccurred()) + Expect(conf.Dialer.DNSCache).To(BeFalse()) + Expect(target).To(Equal(addr)) + }) + + It("failed", func() { + conf := &ClientConfig{} + conf.Dialer.DNSCache = true + + const addr = "localhost:54321" + target := addr + err := preResolveTargetAddr(conf, &target) + Expect(err).To(HaveOccurred()) + Expect(conf.Dialer.DNSCache).To(BeTrue()) + Expect(target).To(Equal(addr)) + }) + +}) diff --git a/core/config/config.go b/core/config/config.go index 2cd507f89..3f2203fc5 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -54,7 +54,7 @@ func AddKindHook(hook KindHook) (_ struct{}) { // Example: you need to configure only some subset fields of struct Multi, // in such case you can from this subset of fields struct Single, decode config // into it, and map it on Multi. -func Map(dst, src interface{}) error { +func Map(dst, src interface{}) { conf := &mapstructure.DecoderConfig{ ErrorUnused: true, ZeroFields: true, @@ -62,11 +62,14 @@ func Map(dst, src interface{}) error { } d, err := mapstructure.NewDecoder(conf) if err != nil { - return err + panic(err) } s := structs.New(src) s.TagName = "map" - return d.Decode(s.Map()) + err = d.Decode(s.Map()) + if err != nil { + panic(err) + } } func newDecoderConfig(result interface{}) *mapstructure.DecoderConfig { diff --git a/core/config/config_test.go b/core/config/config_test.go index 3ec92160b..a602c6bc1 100644 --- a/core/config/config_test.go +++ b/core/config/config_test.go @@ -151,13 +151,11 @@ type SingleString struct { func TestMapFlat(t *testing.T) { a := &MultiStrings{} - err := Map(a, &SingleString{B: "b"}) - require.NoError(t, err) + Map(a, &SingleString{B: "b"}) assert.Equal(t, &MultiStrings{B: "b"}, a) a = &MultiStrings{A: "a", B: "not b"} - err = Map(a, &SingleString{B: "b"}) - require.NoError(t, err) + Map(a, &SingleString{B: "b"}) assert.Equal(t, &MultiStrings{A: "a", B: "b"}, a) } @@ -170,8 +168,7 @@ func TestMapRecursive(t *testing.T) { MultiStrings } n := &N{MultiStrings: MultiStrings{B: "b"}, A: "a"} - err := Map(n, &M{MultiStrings: MultiStrings{A: "a"}}) - require.NoError(t, err) + Map(n, &M{MultiStrings: MultiStrings{A: "a"}}) assert.Equal(t, &N{A: "a", MultiStrings: MultiStrings{A: "a"}}, n) } @@ -184,8 +181,7 @@ func TestMapTagged(t *testing.T) { SomeOtherFieldName MultiStrings `map:"MultiStrings"` } n := &N{MultiStrings: MultiStrings{B: "b"}, A: "a"} - err := Map(n, &M{SomeOtherFieldName: MultiStrings{A: "a"}}) - require.NoError(t, err) + Map(n, &M{SomeOtherFieldName: MultiStrings{A: "a"}}) assert.Equal(t, &N{A: "a", MultiStrings: MultiStrings{A: "a"}}, n) } diff --git a/core/import/import_suite_test.go b/core/import/import_suite_test.go index cb41a2952..77a71b008 100644 --- a/core/import/import_suite_test.go +++ b/core/import/import_suite_test.go @@ -6,7 +6,6 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" - "github.com/onsi/gomega/format" "github.com/spf13/afero" "github.com/yandex/pandora/core" @@ -16,13 +15,8 @@ import ( ) func TestImport(t *testing.T) { - format.UseStringerRepresentation = true - RegisterFailHandler(Fail) - - testutil.ReplaceGlobalLogger() Import(afero.NewOsFs()) - - RunSpecs(t, "Import Suite") + testutil.RunSuite(t, "Import Suite") } var _ = Describe("plugin decode", func() { diff --git a/core/plugin/constructor.go b/core/plugin/constructor.go new file mode 100644 index 000000000..399befd7d --- /dev/null +++ b/core/plugin/constructor.go @@ -0,0 +1,192 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package plugin + +import ( + "fmt" + "reflect" +) + +// implConstructor interface representing ability to create some plugin interface +// implementations and is't factory functions. Usually it wraps some newPlugin or newFactory function, and +// use is as implementation creator. +// implConstructor expects, that caller pass correct maybeConf value, that can be +// passed to underlying implementation creator. +type implConstructor interface { + // NewPlugin constructs plugin implementation. + NewPlugin(maybeConf []reflect.Value) (plugin interface{}, err error) + // getMaybeConf may be nil, if no config required. + // Underlying implementation creator may require new config for every instance create. + // If so, then getMaybeConf will be called on every instance create. Otherwise, only once. + NewFactory(factoryType reflect.Type, getMaybeConf func() ([]reflect.Value, error)) (pluginFactory interface{}, err error) +} + +func newImplConstructor(pluginType reflect.Type, constructor interface{}) implConstructor { + constructorType := reflect.TypeOf(constructor) + expect(constructorType.Kind() == reflect.Func, "plugin constructor should be func") + expect(constructorType.NumOut() >= 1, + "plugin constructor should return plugin implementation as first output parameter") + if constructorType.Out(0).Kind() == reflect.Func { + return newFactoryConstructor(pluginType, constructor) + } + return newPluginConstructor(pluginType, constructor) +} + +func newPluginConstructor(pluginType reflect.Type, newPlugin interface{}) *pluginConstructor { + expectPluginConstructor(pluginType, reflect.TypeOf(newPlugin), true) + return &pluginConstructor{pluginType, reflect.ValueOf(newPlugin)} +} + +// pluginConstructor use newPlugin func([config ]) ( [, error]) +// to construct plugin implementations +type pluginConstructor struct { + pluginType reflect.Type + newPlugin reflect.Value +} + +func (c *pluginConstructor) NewPlugin(maybeConf []reflect.Value) (plugin interface{}, err error) { + out := c.newPlugin.Call(maybeConf) + plugin = out[0].Interface() + if len(out) > 1 { + err, _ = out[1].Interface().(error) + } + return +} + +func (c *pluginConstructor) NewFactory(factoryType reflect.Type, getMaybeConf func() ([]reflect.Value, error)) (interface{}, error) { + if c.newPlugin.Type() == factoryType { + return c.newPlugin.Interface(), nil + } + return reflect.MakeFunc(factoryType, func(in []reflect.Value) []reflect.Value { + var maybeConf []reflect.Value + if getMaybeConf != nil { + var err error + maybeConf, err = getMaybeConf() + if err != nil { + switch factoryType.NumOut() { + case 1: + panic(err) + case 2: + return []reflect.Value{reflect.Zero(c.pluginType), reflect.ValueOf(&err).Elem()} + default: + panic(fmt.Sprintf(" out params num expeced to be 1 or 2, but have: %v", factoryType.NumOut())) + } + } + } + out := c.newPlugin.Call(maybeConf) + return convertFactoryOutParams(c.pluginType, factoryType.NumOut(), out) + }).Interface(), nil +} + +// factoryConstructor use newFactory func([config ]) (func() ([, error])[, error) +// to construct plugin implementations. +type factoryConstructor struct { + pluginType reflect.Type + newFactory reflect.Value +} + +func newFactoryConstructor(pluginType reflect.Type, newFactory interface{}) *factoryConstructor { + newFactoryType := reflect.TypeOf(newFactory) + expect(newFactoryType.Kind() == reflect.Func, "factory constructor should be func") + expect(newFactoryType.NumIn() <= 1, "factory constructor should accept config or nothing") + + expect(1 <= newFactoryType.NumOut() && newFactoryType.NumOut() <= 2, + "factory constructor should return factory, and optionally error") + if newFactoryType.NumOut() == 2 { + expect(newFactoryType.Out(1) == errorType, "factory constructor should have no second return value, or it should be error") + } + factoryType := newFactoryType.Out(0) + expectPluginConstructor(pluginType, factoryType, false) + return &factoryConstructor{pluginType, reflect.ValueOf(newFactory)} +} + +func (c *factoryConstructor) NewPlugin(maybeConf []reflect.Value) (plugin interface{}, err error) { + factory, err := c.callNewFactory(maybeConf) + if err != nil { + return nil, err + } + out := factory.Call(nil) + plugin = out[0].Interface() + if len(out) > 1 { + err, _ = out[1].Interface().(error) + } + return +} + +func (c *factoryConstructor) NewFactory(factoryType reflect.Type, getMaybeConf func() ([]reflect.Value, error)) (interface{}, error) { + var maybeConf []reflect.Value + if getMaybeConf != nil { + var err error + maybeConf, err = getMaybeConf() + if err != nil { + return nil, err + } + } + factory, err := c.callNewFactory(maybeConf) + if err != nil { + return nil, err + } + if factory.Type() == factoryType { + return factory.Interface(), nil + } + return reflect.MakeFunc(factoryType, func(in []reflect.Value) []reflect.Value { + out := factory.Call(nil) + return convertFactoryOutParams(c.pluginType, factoryType.NumOut(), out) + }).Interface(), nil +} + +func (c *factoryConstructor) callNewFactory(maybeConf []reflect.Value) (factory reflect.Value, err error) { + factoryAndMaybeErr := c.newFactory.Call(maybeConf) + if len(factoryAndMaybeErr) > 1 { + err, _ = factoryAndMaybeErr[1].Interface().(error) + } + return factoryAndMaybeErr[0], err +} + +// expectPluginConstructor checks type expectations common for newPlugin, and factory, returned from newFactory. +func expectPluginConstructor(pluginType, factoryType reflect.Type, configAllowed bool) { + expect(factoryType.Kind() == reflect.Func, "plugin constructor should be func") + if configAllowed { + expect(factoryType.NumIn() <= 1, "plugin constructor should accept config or nothing") + } else { + expect(factoryType.NumIn() == 0, "plugin constructor returned from newFactory, shouldn't accept any arguments") + } + expect(1 <= factoryType.NumOut() && factoryType.NumOut() <= 2, + "plugin constructor should return plugin implementation, and optionally error") + pluginImplType := factoryType.Out(0) + expect(pluginImplType.Implements(pluginType), "plugin constructor should implement plugin interface") + if factoryType.NumOut() == 2 { + expect(factoryType.Out(1) == errorType, "plugin constructor should have no second return value, or it should be error") + } +} + +// convertFactoryOutParams converts output params of some factory (newPlugin) call to required. +func convertFactoryOutParams(pluginType reflect.Type, numOut int, out []reflect.Value) []reflect.Value { + switch numOut { + case 1, 2: + // OK. + default: + panic(fmt.Sprintf("unexpeced out params num: %v; 1 or 2 expected", numOut)) + } + if out[0].Type() != pluginType { + // Not plugin, but its implementation. + impl := out[0] + out[0] = reflect.New(pluginType).Elem() + out[0].Set(impl) + } + if len(out) < numOut { + // Registered factory returns no error, but we should. + out = append(out, reflect.Zero(errorType)) + } + if numOut < len(out) { + // Registered factory returns error, but we should not. + if !out[1].IsNil() { + panic(out[1].Interface()) + } + out = out[:1] + } + return out +} diff --git a/core/plugin/constructor_test.go b/core/plugin/constructor_test.go new file mode 100644 index 000000000..f51b94401 --- /dev/null +++ b/core/plugin/constructor_test.go @@ -0,0 +1,360 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package plugin + +import ( + "errors" + "fmt" + "reflect" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" +) + +var _ = Describe("plugin constructor", func() { + DescribeTable("expectations failed", + func(newPlugin interface{}) { + defer recoverExpectationFail() + newPluginConstructor(ptestType(), newPlugin) + }, + Entry("not func", + errors.New("that is not constructor")), + Entry("not implements", + func() struct{} { panic("") }), + Entry("too many args", + func(_, _ ptestConfig) ptestPlugin { panic("") }), + Entry("too many return valued", + func() (_ ptestPlugin, _, _ error) { panic("") }), + Entry("second return value is not error", + func() (_, _ ptestPlugin) { panic("") }), + ) + + Context("new plugin", func() { + newPlugin := func(newPlugin interface{}, maybeConf []reflect.Value) (interface{}, error) { + testee := newPluginConstructor(ptestType(), newPlugin) + return testee.NewPlugin(maybeConf) + } + + It("", func() { + plugin, err := newPlugin(ptestNew, nil) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("more that plugin", func() { + plugin, err := newPlugin(ptestNewMoreThan, nil) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("config", func() { + plugin, err := newPlugin(ptestNewConf, confToMaybe(ptestDefaultConf())) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestDefaultValue) + }) + + It("failed", func() { + plugin, err := newPlugin(ptestNewErrFailing, nil) + Expect(err).To(Equal(ptestCreateFailedErr)) + Expect(plugin).To(BeNil()) + }) + }) + + Context("new factory", func() { + newFactoryOK := func(newPlugin interface{}, factoryType reflect.Type, getMaybeConf func() ([]reflect.Value, error)) interface{} { + testee := newPluginConstructor(ptestType(), newPlugin) + factory, err := testee.NewFactory(factoryType, getMaybeConf) + Expect(err).NotTo(HaveOccurred()) + return factory + } + + It("same type - no wrap", func() { + factory := newFactoryOK(ptestNew, ptestNewType(), nil) + expectSameFunc(factory, ptestNew) + }) + + It(" new impl", func() { + factory := newFactoryOK(ptestNewImpl, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("more than", func() { + factory := newFactoryOK(ptestNewMoreThan, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("add err", func() { + factory := newFactoryOK(ptestNew, ptestNewErrType(), nil) + f, ok := factory.(func() (ptestPlugin, error)) + Expect(ok).To(BeTrue()) + plugin, err := f() + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("trim nil err", func() { + factory := newFactoryOK(ptestNewErr, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("config", func() { + factory := newFactoryOK(ptestNewConf, ptestNewType(), confToGetMaybe(ptestDefaultConf())) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestDefaultValue) + }) + + It("new factory, get config failed", func() { + factory := newFactoryOK(ptestNewConf, ptestNewErrType(), errToGetMaybe(ptestConfigurationFailedErr)) + f, ok := factory.(func() (ptestPlugin, error)) + Expect(ok).To(BeTrue()) + plugin, err := f() + Expect(err).To(Equal(ptestConfigurationFailedErr)) + Expect(plugin).To(BeNil()) + }) + + It("no err, get config failed, throw panic", func() { + factory := newFactoryOK(ptestNewConf, ptestNewType(), errToGetMaybe(ptestConfigurationFailedErr)) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + func() { + defer func() { + r := recover() + Expect(r).To(Equal(ptestConfigurationFailedErr)) + }() + f() + }() + }) + + It("panic on trim non nil err", func() { + factory := newFactoryOK(ptestNewErrFailing, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + func() { + defer func() { + r := recover() + Expect(r).To(Equal(ptestCreateFailedErr)) + }() + f() + }() + }) + + }) +}) + +var _ = Describe("factory constructor", func() { + DescribeTable("expectations failed", + func(newPlugin interface{}) { + defer recoverExpectationFail() + newFactoryConstructor(ptestType(), newPlugin) + }, + Entry("not func", + errors.New("that is not constructor")), + Entry("returned not func", + func() error { panic("") }), + Entry("too many args", + func(_, _ ptestConfig) func() ptestPlugin { panic("") }), + Entry("too many return valued", + func() (func() ptestPlugin, error, error) { panic("") }), + Entry("second return value is not error", + func() (func() ptestPlugin, ptestPlugin) { panic("") }), + Entry("factory accepts conf", + func() func(config ptestConfig) ptestPlugin { panic("") }), + Entry("not implements", + func() func() struct{} { panic("") }), + Entry("factory too many args", + func() func(_, _ ptestConfig) ptestPlugin { panic("") }), + Entry("factory too many return valued", + func() func() (_ ptestPlugin, _, _ error) { panic("") }), + Entry("factory second return value is not error", + func() func() (_, _ ptestPlugin) { panic("") }), + ) + + Context("new plugin", func() { + newPlugin := func(newFactory interface{}, maybeConf []reflect.Value) (interface{}, error) { + testee := newFactoryConstructor(ptestType(), newFactory) + return testee.NewPlugin(maybeConf) + } + + It("", func() { + plugin, err := newPlugin(ptestNewFactory, nil) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("impl", func() { + plugin, err := newPlugin(ptestNewFactoryImpl, nil) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("impl more than", func() { + plugin, err := newPlugin(ptestNewFactoryMoreThan, nil) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("config", func() { + plugin, err := newPlugin(ptestNewFactoryConf, confToMaybe(ptestDefaultConf())) + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestDefaultValue) + }) + + It("failed", func() { + plugin, err := newPlugin(ptestNewFactoryErrFailing, nil) + Expect(err).To(Equal(ptestCreateFailedErr)) + Expect(plugin).To(BeNil()) + }) + + It("factory failed", func() { + plugin, err := newPlugin(ptestNewFactoryFactoryErrFailing, nil) + Expect(err).To(Equal(ptestCreateFailedErr)) + Expect(plugin).To(BeNil()) + }) + }) + + Context("new factory", func() { + newFactory := func(newFactory interface{}, factoryType reflect.Type, getMaybeConf func() ([]reflect.Value, error)) (interface{}, error) { + testee := newFactoryConstructor(ptestType(), newFactory) + return testee.NewFactory(factoryType, getMaybeConf) + } + newFactoryOK := func(newF interface{}, factoryType reflect.Type, getMaybeConf func() ([]reflect.Value, error)) interface{} { + factory, err := newFactory(newF, factoryType, getMaybeConf) + Expect(err).NotTo(HaveOccurred()) + return factory + } + + It("no err, same type - no wrap", func() { + factory := newFactoryOK(ptestNewFactory, ptestNewType(), nil) + expectSameFunc(factory, ptestNew) + }) + + It("has err, same type - no wrap", func() { + factory := newFactoryOK(ptestNewFactoryFactoryErr, ptestNewErrType(), nil) + expectSameFunc(factory, ptestNewErr) + }) + + It("from new impl", func() { + factory := newFactoryOK(ptestNewFactoryImpl, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("from new impl", func() { + factory := newFactoryOK(ptestNewFactoryMoreThan, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("add err", func() { + factory := newFactoryOK(ptestNewFactory, ptestNewErrType(), nil) + f, ok := factory.(func() (ptestPlugin, error)) + Expect(ok).To(BeTrue()) + plugin, err := f() + Expect(err).NotTo(HaveOccurred()) + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("factory construction not failed", func() { + factory := newFactoryOK(ptestNewFactoryErr, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("trim nil err", func() { + factory := newFactoryOK(ptestNewFactoryFactoryErr, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestInitValue) + }) + + It("config", func() { + factory := newFactoryOK(ptestNewFactoryConf, ptestNewType(), confToGetMaybe(ptestDefaultConf())) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + plugin := f() + ptestExpectConfigValue(plugin, ptestDefaultValue) + }) + + It("get config failed", func() { + factory, err := newFactory(ptestNewFactoryConf, ptestNewErrType(), errToGetMaybe(ptestConfigurationFailedErr)) + Expect(err).To(Equal(ptestConfigurationFailedErr)) + Expect(factory).To(BeNil()) + }) + + It("factory create failed", func() { + factory, err := newFactory(ptestNewFactoryErrFailing, ptestNewErrType(), nil) + Expect(err).To(Equal(ptestCreateFailedErr)) + Expect(factory).To(BeNil()) + }) + + It("plugin create failed", func() { + factory := newFactoryOK(ptestNewFactoryFactoryErrFailing, ptestNewErrType(), nil) + f, ok := factory.(func() (ptestPlugin, error)) + Expect(ok).To(BeTrue()) + plugin, err := f() + Expect(err).To(Equal(ptestCreateFailedErr)) + Expect(plugin).To(BeNil()) + }) + + It("panic on trim non nil err", func() { + factory := newFactoryOK(ptestNewFactoryFactoryErrFailing, ptestNewType(), nil) + f, ok := factory.(func() ptestPlugin) + Expect(ok).To(BeTrue()) + func() { + defer func() { + r := recover() + Expect(r).To(Equal(ptestCreateFailedErr)) + }() + f() + }() + }) + + }) +}) + +func confToMaybe(conf interface{}) []reflect.Value { + if conf != nil { + return []reflect.Value{reflect.ValueOf(conf)} + } + return nil +} + +func confToGetMaybe(conf interface{}) func() ([]reflect.Value, error) { + return func() ([]reflect.Value, error) { + return confToMaybe(conf), nil + } +} + +func errToGetMaybe(err error) func() ([]reflect.Value, error) { + return func() ([]reflect.Value, error) { + return nil, err + } +} + +func expectSameFunc(f1, f2 interface{}) { + s1 := fmt.Sprint(f1) + s2 := fmt.Sprint(f2) + Expect(s1).To(Equal(s2)) +} diff --git a/core/plugin/doc.go b/core/plugin/doc.go new file mode 100644 index 000000000..59b90d8df --- /dev/null +++ b/core/plugin/doc.go @@ -0,0 +1,30 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +// Package plugin provides a generic inversion of control model for making +// extensible Go packages, libraries, and applications. Like +// github.com/progrium/go-extpoints, but reflect based: doesn't require code +// generation, but have more overhead; provide more flexibility, but less type +// safety. It allows to register constructor for some plugin interface, and create +// new plugin instances or plugin instance factories. +// Main feature is flexible plugin configuration: plugin factory can +// accept config struct, that could be filled by passed hook. Config default +// values could be provided by registering default config factory. +// Such flexibility can be used to decode structured text (json/yaml/etc) into +// struct. +// +// Type expectations. +// Here and bellow we mean by some type expectations. +// [some type signature part] means that this part of type signature is optional. +// +// Plugin type, let's label it as , should be interface. +// Registered plugin constructor should be one of: or . +// should have type func([config ]) ([, error]). +// should have type func([config [, error])[, error]). +// should be assignable to . +// That is, methods should be subset methods. In other words, should be +// some implementation, or interface, that contains methods as subset. +// type should be struct or struct pointer. +package plugin diff --git a/core/plugin/plugin.go b/core/plugin/plugin.go index 2882c7e0a..dcaa2ddfc 100644 --- a/core/plugin/plugin.go +++ b/core/plugin/plugin.go @@ -3,84 +3,44 @@ // license that can be found in the LICENSE file. // Author: Vladimir Skipor -// Package plugin provides a generic inversion of control model for making -// extensible Go packages, libraries, and applications. Like -// github.com/progrium/go-extpoints, but reflect based: doesn't require code -// generation, but have more overhead; provide more flexibility, but less type -// safety. It allows to register factory for some plugin interface, and create -// new plugin instances by registered factory. -// Main feature is flexible plugin configuration: plugin factory can -// accept config struct, that could be filled by passed hook. Config default -// values could be provided by registering default config factory. -// Such flexibility can be used to decode structured text (json/yaml/etc) into -// struct with plugin interface fields. -// -// Type expectations. -// Plugin factory type should be: -// func([config ]) ( [, error]) -// where configType kind is struct or struct pointer, and pluginImpl implements -// plugin interface. Plugin factory will never receive nil config, even there -// are no registered default config factory, or default config is nil. Config -// will be pointer to zero config in such case. -// If plugin factory receive config argument, default config factory can be -// registered. Default config factory type should be: is func() . -// Default config factory is optional. If no default config factory has been -// registered, than plugin factory will receive zero config (zero struct or -// pointer to zero struct). -// -// Note, that plugin interface type could be taken as reflect.TypeOf((*PluginInterface)(nil)).Elem(). package plugin import ( "fmt" "reflect" - - "github.com/pkg/errors" ) -// Register registers plugin factory and optional default config factory, -// for given plugin interface type and plugin name. -// See package doc for type expectations details. -// Register designed to be called in package init func, so it panics if type -// expectations were failed. Register is thread unsafe. +// DefaultRegistry returns default Registry used for package Registry like functions. +func DefaultRegistry() *Registry { + return defaultRegistry +} + +// Register is DefaultRegistry().Register shortcut. func Register( pluginType reflect.Type, name string, newPluginImpl interface{}, newDefaultConfigOptional ...interface{}, ) { - defaultRegistry.Register(pluginType, name, newPluginImpl, newDefaultConfigOptional...) + DefaultRegistry().Register(pluginType, name, newPluginImpl, newDefaultConfigOptional...) } -// Lookup returns true if any plugin factory has been registered for given -// type. +// Lookup is DefaultRegistry().Lookup shortcut. func Lookup(pluginType reflect.Type) bool { - return defaultRegistry.Lookup(pluginType) + return DefaultRegistry().Lookup(pluginType) } -func LookupFactory(pluginType reflect.Type) bool { - return defaultRegistry.LookupFactory(pluginType) +// LookupFactory is DefaultRegistry().LookupFactory shortcut. +func LookupFactory(factoryType reflect.Type) bool { + return DefaultRegistry().LookupFactory(factoryType) } -// New creates plugin by registered plugin factory. Returns error if creation -// failed or no plugin were registered for given type and name. -// Passed fillConf called on created config before calling plugin factory. -// fillConf argument is always valid struct pointer, even if plugin factory -// receives no config: fillConf is called on empty struct pointer in such case. -// fillConf error fails plugin creation. -// New is thread safe, if there is no concurrent Register calls. +// New is DefaultRegistry().New shortcut. func New(pluginType reflect.Type, name string, fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { return defaultRegistry.New(pluginType, name, fillConfOptional...) } -// TODO (skipor): add support for `func() PluginInterface` factories, that -// panics on error. -// TODO (skipor): add NewSharedConfigsFactory that decodes config once and use -// it to create all plugin instances. - -// NewFactory behaves like New, but creates factory func() (PluginInterface, error), that on call -// creates New plugin by registered factory. -// New config is created filled for every factory call. +// NewFactory is DefaultRegistry().NewFactory shortcut. func NewFactory(factoryType reflect.Type, name string, fillConfOptional ...func(conf interface{}) error) (factory interface{}, err error) { return defaultRegistry.NewFactory(factoryType, name, fillConfOptional...) } @@ -96,226 +56,38 @@ func PtrType(ptr interface{}) reflect.Type { return t.Elem() } -func IsFactoryType(t reflect.Type) bool { - return t.Kind() == reflect.Func && - t.NumIn() == 0 && - t.NumOut() == 2 && - t.Out(0).Kind() == reflect.Interface && - t.Out(1) == errorType -} - -func FactoryPluginType(factory reflect.Type) (plugin reflect.Type, ok bool) { - if IsFactoryType(factory) { - return factory.Out(0), true +// FactoryPluginType returns (SomeInterface, true) if factoryType looks like func() (SomeInterface[, error]) +// or (nil, false) otherwise. +func FactoryPluginType(factoryType reflect.Type) (plugin reflect.Type, ok bool) { + if isFactoryType(factoryType) { + return factoryType.Out(0), true } return } -type nameRegistryEntry struct { - // newPluginImpl type is func([config ]) ( [, error]), - // where configType kind is struct or struct pointer. - newPluginImpl reflect.Value - // newDefaultConfig type is func() . Zero if newPluginImpl accepts no arguments. - newDefaultConfig reflect.Value -} - -type nameRegistry map[string]nameRegistryEntry - -func newNameRegistry() nameRegistry { return make(nameRegistry) } - -type typeRegistry map[reflect.Type]nameRegistry - -var defaultRegistry = newTypeRegistry() - -func newTypeRegistry() typeRegistry { return make(typeRegistry) } - -func (r typeRegistry) Register( - pluginType reflect.Type, // plugin interface type - name string, - newPluginImpl interface{}, - newDefaultConfigOptional ...interface{}, -) { - expect(pluginType.Kind() == reflect.Interface, "plugin type should be interface, but have: %T", pluginType) - expect(name != "", "empty name") - pluginReg := r[pluginType] - if pluginReg == nil { - pluginReg = newNameRegistry() - r[pluginType] = pluginReg - } - _, ok := pluginReg[name] - expect(!ok, "plugin %s with name %q had been already registered", pluginType, name) - pluginReg[name] = newNameRegistryEntry(pluginType, newPluginImpl, newDefaultConfigOptional...) -} - -func (r typeRegistry) Lookup(pluginType reflect.Type) bool { - _, ok := r[pluginType] - return ok -} - -func (r typeRegistry) LookupFactory(factoryType reflect.Type) bool { - return IsFactoryType(factoryType) && r.Lookup(factoryType.Out(0)) -} - -func (r typeRegistry) New(pluginType reflect.Type, name string, fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { - expect(pluginType.Kind() == reflect.Interface, "plugin type should be interface, but have: %T", pluginType) - expect(name != "", "empty name") - fillConf := getFillConf(fillConfOptional) - registered, err := r.get(pluginType, name) - if err != nil { - return - } - confOptional, fillAddr := registered.NewDefaultConfig() - if fillConf != nil { - err = fillConf(fillAddr) - if err != nil { - return - } - } - return registered.NewPlugin(confOptional) -} - -func (r typeRegistry) NewFactory(factoryType reflect.Type, name string, fillConfOptional ...func(conf interface{}) error) (factory interface{}, err error) { - expect(IsFactoryType(factoryType), "plugin factory type should be like `func() (PluginInterface, error)`, but have: %T", factoryType) - expect(name != "", "empty name") - fillConf := getFillConf(fillConfOptional) - pluginType := factoryType.Out(0) - registered, err := r.get(pluginType, name) - if err != nil { - return - } - factory = reflect.MakeFunc(factoryType, func(in []reflect.Value) (out []reflect.Value) { - conf, fillAddr := registered.NewDefaultConfig() - if fillConf != nil { - // Check that config is correct. - err := fillConf(fillAddr) - if err != nil { - return []reflect.Value{reflect.Zero(pluginType), reflect.ValueOf(&err).Elem()} - } - } - out = registered.newPluginImpl.Call(conf) - if out[0].Type() != pluginType { - // Not plugin, but its implementation. - impl := out[0] - out[0] = reflect.New(pluginType).Elem() - out[0].Set(impl) - } - - if len(out) < 2 { - // Registered newPluginImpl can return no error, but we should. - out = append(out, reflect.Zero(errorType)) - } - return - }).Interface() - return -} - -func getFillConf(fillConfOptional []func(conf interface{}) error) func(interface{}) error { - expect(len(fillConfOptional) <= 1, "only fill config parameter could be passed") - if len(fillConfOptional) == 0 { - return nil - } - return fillConfOptional[0] -} - -func (e nameRegistryEntry) NewDefaultConfig() (confOptional []reflect.Value, fillAddr interface{}) { - if e.newPluginImpl.Type().NumIn() == 0 { - var emptyStruct struct{} - fillAddr = &emptyStruct // No fields to fill. - return +// isFactoryType returns true, if type looks like func() (SomeInterface[, error]) +func isFactoryType(t reflect.Type) bool { + hasProperParamsNum := t.Kind() == reflect.Func && + t.NumIn() == 0 && + (t.NumOut() == 1 || t.NumOut() == 2) + if !hasProperParamsNum { + return false } - conf := e.newDefaultConfig.Call(nil)[0] - switch conf.Kind() { - case reflect.Struct: - // Config can be filled only by pointer. - if !conf.CanAddr() { - // Can't address to pass pointer into decoder. Let's make New addressable! - newArg := reflect.New(conf.Type()).Elem() - newArg.Set(conf) - conf = newArg - } - fillAddr = conf.Addr().Interface() - case reflect.Ptr: - if conf.IsNil() { - // Can't fill nil config. Init with zero. - conf = reflect.New(conf.Type().Elem()) - } - fillAddr = conf.Interface() - default: - panic("unexpected type " + conf.String()) + if t.Out(0).Kind() != reflect.Interface { + return false } - confOptional = []reflect.Value{conf} - return -} - -func (e nameRegistryEntry) NewPlugin(confOptional []reflect.Value) (plugin interface{}, err error) { - out := e.newPluginImpl.Call(confOptional) - plugin = out[0].Interface() - if len(out) > 1 { - err, _ = out[1].Interface().(error) + if t.NumOut() == 1 { + return true } - return + return t.Out(1) == errorType } -func newNameRegistryEntry(pluginType reflect.Type, newPluginImpl interface{}, newDefaultConfigOptional ...interface{}) nameRegistryEntry { - newPluginImplType := reflect.TypeOf(newPluginImpl) - expect(newPluginImplType.Kind() == reflect.Func, "newPluginImpl should be func") - expect(newPluginImplType.NumIn() <= 1, "newPluginImple should accept config or nothing") - expect(1 <= newPluginImplType.NumOut() && newPluginImplType.NumOut() <= 2, - "newPluginImple should return plugin implementation, and optionally error") - pluginImplType := newPluginImplType.Out(0) - expect(pluginImplType.Implements(pluginType), "pluginImpl should implement plugin interface") - if newPluginImplType.NumOut() == 2 { - expect(newPluginImplType.Out(1) == errorType, "pluginImpl should have no second return value, or it should be error") - } - - if newPluginImplType.NumIn() == 0 { - expect(len(newDefaultConfigOptional) == 0, "newPluginImpl accept no config, but newDefaultConfig passed") - return nameRegistryEntry{ - newPluginImpl: reflect.ValueOf(newPluginImpl), - } - } - - expect(len(newDefaultConfigOptional) <= 1, "only one default config newPluginImpl could be passed") - configType := newPluginImplType.In(0) - expect(configType.Kind() == reflect.Struct || - configType.Kind() == reflect.Ptr && configType.Elem().Kind() == reflect.Struct, - "unexpected config kind: %s; should be struct or struct pointer or map") - - newDefaultConfigType := reflect.FuncOf(nil, []reflect.Type{configType}, false) - var newDefaultConfig interface{} - if len(newDefaultConfigOptional) != 0 { - newDefaultConfig = newDefaultConfigOptional[0] - expect(reflect.TypeOf(newDefaultConfig) == newDefaultConfigType, - "newDefaultConfig should be func that accepst nothing, and returns newPluiginImpl argument, but have type %T", newDefaultConfig) - } else { - newDefaultConfig = reflect.MakeFunc(newDefaultConfigType, - func(_ []reflect.Value) (results []reflect.Value) { - return []reflect.Value{reflect.Zero(configType)} - }).Interface() - } - return nameRegistryEntry{ - newPluginImpl: reflect.ValueOf(newPluginImpl), - newDefaultConfig: reflect.ValueOf(newDefaultConfig), - } -} +var defaultRegistry = NewRegistry() -func (r typeRegistry) get(pluginType reflect.Type, name string) (factory nameRegistryEntry, err error) { - pluginReg, ok := r[pluginType] - if !ok { - err = errors.Errorf("no plugins for type %s has been registered", pluginType) - return - } - factory, ok = pluginReg[name] - if !ok { - err = errors.Errorf("no plugins of type %s has been registered for name %s", pluginType, name) - } - return -} +var errorType = reflect.TypeOf((*error)(nil)).Elem() func expect(b bool, msg string, args ...interface{}) { if !b { panic(fmt.Sprintf("expectation failed: "+msg, args...)) } } - -var errorType = reflect.TypeOf((*error)(nil)).Elem() diff --git a/core/plugin/plugin_suite_test.go b/core/plugin/plugin_suite_test.go new file mode 100644 index 000000000..6eda14d0a --- /dev/null +++ b/core/plugin/plugin_suite_test.go @@ -0,0 +1,21 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package plugin + +import ( + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/yandex/pandora/lib/testutil" +) + +func TestPlugin(t *testing.T) { + RegisterFailHandler(Fail) + testutil.ReplaceGlobalLogger() + RunSpecs(t, "Plugin Suite") +} diff --git a/core/plugin/plugin_test.go b/core/plugin/plugin_test.go index 27228b81f..a3c9e37b5 100644 --- a/core/plugin/plugin_test.go +++ b/core/plugin/plugin_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2016 Yandex LLC. All rights reserved. +// Copyright (c) 2017 Yandex LLC. All rights reserved. // Use of this source code is governed by a MPL 2.0 // license that can be found in the LICENSE file. // Author: Vladimir Skipor @@ -6,281 +6,43 @@ package plugin import ( - stderrors "errors" - "fmt" - "io" - "reflect" - "testing" - - "github.com/mitchellh/mapstructure" - "github.com/pkg/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRegisterValid(t *testing.T) { - testCases := []struct { - description string - newPluginImpl interface{} - newDefaultConfigOptional []interface{} - }{ - {"return impl", func() *testPluginImpl { return nil }, nil}, - {"return interface", func() TestPlugin { return nil }, nil}, - {"super interface", func() interface { - io.Writer - TestPlugin - } { - return nil - }, nil}, - {"struct config", func(pluginImplConfig) TestPlugin { return nil }, nil}, - {"struct ptr config", func(*pluginImplConfig) TestPlugin { return nil }, nil}, - {"default config", func(*pluginImplConfig) TestPlugin { return nil }, []interface{}{func() *pluginImplConfig { return nil }}}, - } - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - assert.NotPanics(t, func() { - newTypeRegistry().testRegister(tc.newPluginImpl, tc.newDefaultConfigOptional...) - }) - }) - } -} - -func TestRegisterInvalid(t *testing.T) { - testCases := []struct { - description string - newPluginImpl interface{} - newDefaultConfigOptional []interface{} - }{ - {"return not impl", func() testPluginImpl { panic("") }, nil}, - {"invalid config type", func(int) TestPlugin { return nil }, nil}, - {"invalid config ptr type", func(*int) TestPlugin { return nil }, nil}, - {"to many args", func(_, _ pluginImplConfig) TestPlugin { return nil }, nil}, - {"default without config", func() TestPlugin { return nil }, []interface{}{func() *pluginImplConfig { return nil }}}, - {"extra deafult config", func(*pluginImplConfig) TestPlugin { return nil }, []interface{}{func() *pluginImplConfig { return nil }, 0}}, - {"invalid default config", func(pluginImplConfig) TestPlugin { return nil }, []interface{}{func() *pluginImplConfig { return nil }}}, - {"default config accepts args", func(*pluginImplConfig) TestPlugin { return nil }, []interface{}{func(int) *pluginImplConfig { return nil }}}, - } - for _, tc := range testCases { - t.Run(tc.description, func(t *testing.T) { - defer assertExpectationFailed(t) - newTypeRegistry().testRegister(tc.newPluginImpl, tc.newDefaultConfigOptional...) - }) - } -} - -func TestRegisterNameCollisionPanics(t *testing.T) { - r := newTypeRegistry() - r.testRegister(newPlugin) - defer assertExpectationFailed(t) - r.testRegister(newPlugin) -} - -func TestLookup(t *testing.T) { - r := newTypeRegistry() - r.testRegister(newPlugin) - assert.True(t, r.Lookup(pluginType())) - assert.False(t, r.Lookup(reflect.TypeOf(0))) - assert.False(t, r.Lookup(reflect.TypeOf(&testPluginImpl{}))) - assert.False(t, r.Lookup(reflect.TypeOf((*io.Writer)(nil)).Elem())) -} - -func TestNew(t *testing.T) { - var r typeRegistry - - type New func(r typeRegistry, fillConfOptional ...func(conf interface{}) error) (interface{}, error) - var testNew New - testNewOk := func(fillConfOptional ...func(conf interface{}) error) (pluginVal string) { - plugin, err := testNew(r, fillConfOptional...) - require.NoError(t, err) - return plugin.(*testPluginImpl).Value - } - - tests := []struct { - desc string - fn func(t *testing.T) - }{ - {"no conf", func(t *testing.T) { - r.testRegister(newPlugin) - assert.Equal(t, testNewOk(), testInitValue) - }}, - {"nil error", func(t *testing.T) { - r.testRegister(func() (TestPlugin, error) { - return newPlugin(), nil - }) - assert.Equal(t, testNewOk(), testInitValue) - }}, - {"non-nil error", func(t *testing.T) { - expectedErr := stderrors.New("fill conf err") - r.testRegister(func() (TestPlugin, error) { - return nil, expectedErr - }) - _, err := testNew(r) - require.Error(t, err) - err = errors.Cause(err) - assert.Equal(t, expectedErr, err) - }}, - {"no conf, fill conf error", func(t *testing.T) { - r.testRegister(newPlugin) - expectedErr := stderrors.New("fill conf err") - _, err := testNew(r, func(_ interface{}) error { return expectedErr }) - assert.Equal(t, expectedErr, err) - }}, - {"no default", func(t *testing.T) { - r.testRegister(func(c pluginImplConfig) *testPluginImpl { return &testPluginImpl{c.Value} }) - assert.Equal(t, testNewOk(), "") - }}, - {"default", func(t *testing.T) { - r.testRegister(newPluginConf, newPluginDefaultConf) - assert.Equal(t, testNewOk(), testDefaultValue) - }}, - {"fill conf default", func(t *testing.T) { - r.testRegister(newPluginConf, newPluginDefaultConf) - assert.Equal(t, "conf", testNewOk(fillConf)) - }}, - {"fill conf no default", func(t *testing.T) { - r.testRegister(newPluginConf) - assert.Equal(t, "conf", testNewOk(fillConf)) - }}, - {"fill ptr conf no default", func(t *testing.T) { - r.testRegister(newPluginPtrConf) - assert.Equal(t, "conf", testNewOk(fillConf)) - }}, - {"no default ptr conf not nil", func(t *testing.T) { - r.testRegister(newPluginPtrConf) - assert.Equal(t, "", testNewOk()) - }}, - {"nil default, conf not nil", func(t *testing.T) { - r.testRegister(newPluginPtrConf, func() *pluginImplConfig { return nil }) - assert.Equal(t, "", testNewOk()) - }}, - {"fill nil default", func(t *testing.T) { - r.testRegister(newPluginPtrConf, func() *pluginImplConfig { return nil }) - assert.Equal(t, "conf", testNewOk(fillConf)) - }}, - {"more than one fill conf panics", func(t *testing.T) { - r.testRegister(newPluginPtrConf) - defer assertExpectationFailed(t) - testNew(r, fillConf, fillConf) - }}, - } - - for _, suite := range []struct { - new New - desc string - }{ - {typeRegistry.testNew, "New"}, - {typeRegistry.testNewFactory, "NewFactory"}, - } { - testNew = suite.new - for _, test := range tests { - r = newTypeRegistry() - t.Run(fmt.Sprintf("%s %s", suite.desc, test.desc), test.fn) - } - } - -} - -// Test typical usage. -func TestMapstructureDecode(t *testing.T) { - r := newTypeRegistry() - const nameKey = "type" - - var hook mapstructure.DecodeHookFunc - decode := func(input, result interface{}) error { - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - DecodeHook: hook, - ErrorUnused: true, - Result: result, - }) - if err != nil { - return err - } - return decoder.Decode(input) - } - hook = mapstructure.ComposeDecodeHookFunc( - func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { - if !r.Lookup(to) { - return data, nil - } - // NOTE: could be map[interface{}]interface{} here. - input := data.(map[string]interface{}) - // NOTE: should be case insensitive behaviour. - pluginName := input[nameKey].(string) - delete(input, nameKey) - return r.New(to, pluginName, func(conf interface{}) error { - // NOTE: should error, if conf has "type" field. - return decode(input, conf) - }) - }) - - r.Register(pluginType(), "my-plugin", newPluginConf, newPluginDefaultConf) - input := map[string]interface{}{ - "plugin": map[string]interface{}{ - nameKey: "my-plugin", - "value": testConfValue, - }, - } - type Config struct { - Plugin TestPlugin - } - var conf Config - err := decode(input, &conf) - require.NoError(t, err) - assert.Equal(t, testConfValue, conf.Plugin.(*testPluginImpl).Value) -} - -const ( - testPluginName = "test_name" - testConfValue = "conf" - testDefaultValue = "default" - testInitValue = "init" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" ) -func (r typeRegistry) testRegister(newPluginImpl interface{}, newDefaultConfigOptional ...interface{}) { - r.Register(pluginType(), testPluginName, newPluginImpl, newDefaultConfigOptional...) -} - -func (r typeRegistry) testNew(fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { - return r.New(pluginType(), testPluginName, fillConfOptional...) -} - -func (r typeRegistry) testNewFactory(fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { - factory, err := r.NewFactory(pluginFactoryType(), testPluginName, fillConfOptional...) - if err != nil { - return - } - typedFactory := factory.(func() (TestPlugin, error)) - return typedFactory() -} - -type TestPlugin interface { - DoSomething() -} - -func pluginType() reflect.Type { return reflect.TypeOf((*TestPlugin)(nil)).Elem() } -func pluginFactoryType() reflect.Type { return reflect.TypeOf(func() (TestPlugin, error) { panic("") }) } -func newPlugin() *testPluginImpl { return &testPluginImpl{Value: testInitValue} } - -type testPluginImpl struct{ Value string } - -func (p *testPluginImpl) DoSomething() {} - -var _ TestPlugin = (*testPluginImpl)(nil) - -type pluginImplConfig struct{ Value string } - -func newPluginConf(c pluginImplConfig) *testPluginImpl { return &testPluginImpl{c.Value} } -func newPluginDefaultConf() pluginImplConfig { return pluginImplConfig{testDefaultValue} } -func newPluginPtrConf(c *pluginImplConfig) *testPluginImpl { - return &testPluginImpl{c.Value} -} - -func fillConf(conf interface{}) error { - return mapstructure.Decode(map[string]interface{}{"Value": "conf"}, conf) -} - -func assertExpectationFailed(t *testing.T) { - r := recover() - require.NotNil(t, r) - assert.Contains(t, r, "expectation failed") -} +var _ = Describe("Default registry", func() { + BeforeEach(func() { + Register(ptestType(), ptestPluginName, ptestNewImpl) + }) + AfterEach(func() { + defaultRegistry = NewRegistry() + }) + It("lookup", func() { + Expect(Lookup(ptestType())).To(BeTrue()) + }) + It("lookup factory", func() { + Expect(LookupFactory(ptestNewErrType())).To(BeTrue()) + }) + It("new", func() { + plugin, err := New(ptestType(), ptestPluginName) + Expect(err).NotTo(HaveOccurred()) + Expect(plugin).NotTo(BeNil()) + }) + It("new factory", func() { + pluginFactory, err := NewFactory(ptestNewErrType(), ptestPluginName) + Expect(err).NotTo(HaveOccurred()) + Expect(pluginFactory).NotTo(BeNil()) + }) +}) + +var _ = Describe("type helpers", func() { + It("ptr type", func() { + var plugin ptestPlugin + Expect(PtrType(&plugin)).To(Equal(ptestType())) + }) + It("factory plugin type ok", func() { + factoryPlugin, ok := FactoryPluginType(ptestNewErrType()) + Expect(ok).To(BeTrue()) + Expect(factoryPlugin).To(Equal(ptestType())) + }) +}) diff --git a/core/plugin/ptest_test.go b/core/plugin/ptest_test.go new file mode 100644 index 000000000..cfae16f84 --- /dev/null +++ b/core/plugin/ptest_test.go @@ -0,0 +1,109 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package plugin + +import ( + "reflect" + + . "github.com/onsi/gomega" + "github.com/pkg/errors" + + "github.com/yandex/pandora/core/config" +) + +// ptest contains examples and utils for testing plugin pkg + +const ( + ptestPluginName = "ptest_name" + + ptestInitValue = "ptest_INITIAL" + ptestDefaultValue = "ptest_DEFAULT_CONFIG" + ptestFilledValue = "ptest_FILLED" +) + +func (r *Registry) ptestRegister(constructor interface{}, newDefaultConfigOptional ...interface{}) { + r.Register(ptestType(), ptestPluginName, constructor, newDefaultConfigOptional...) +} +func (r *Registry) ptestNew(fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { + return r.New(ptestType(), ptestPluginName, fillConfOptional...) +} +func (r *Registry) ptestNewFactory(fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { + factory, err := r.NewFactory(ptestNewErrType(), ptestPluginName, fillConfOptional...) + if err != nil { + return + } + typedFactory := factory.(func() (ptestPlugin, error)) + return typedFactory() +} + +var ( + ptestCreateFailedErr = errors.New("test plugin create failed") + ptestConfigurationFailedErr = errors.New("test plugin configuration failed") +) + +type ptestPlugin interface { + DoSomething() +} +type ptestMoreThanPlugin interface { + ptestPlugin + DoSomethingElse() +} +type ptestImpl struct{ Value string } +type ptestConfig struct{ Value string } + +func (p *ptestImpl) DoSomething() {} +func (p *ptestImpl) DoSomethingElse() {} + +func ptestNew() ptestPlugin { return ptestNewImpl() } +func ptestNewMoreThan() ptestMoreThanPlugin { return ptestNewImpl() } +func ptestNewImpl() *ptestImpl { return &ptestImpl{Value: ptestInitValue} } +func ptestNewConf(c ptestConfig) ptestPlugin { return &ptestImpl{c.Value} } +func ptestNewPtrConf(c *ptestConfig) ptestPlugin { return &ptestImpl{c.Value} } +func ptestNewErr() (ptestPlugin, error) { return &ptestImpl{Value: ptestInitValue}, nil } +func ptestNewErrFailing() (ptestPlugin, error) { return nil, ptestCreateFailedErr } + +func ptestNewFactory() func() ptestPlugin { return ptestNew } +func ptestNewFactoryMoreThan() func() ptestMoreThanPlugin { return ptestNewMoreThan } +func ptestNewFactoryImpl() func() *ptestImpl { return ptestNewImpl } +func ptestNewFactoryConf(c ptestConfig) func() ptestPlugin { + return func() ptestPlugin { + return ptestNewConf(c) + } +} +func ptestNewFactoryPtrConf(c *ptestConfig) func() ptestPlugin { + return func() ptestPlugin { + return ptestNewPtrConf(c) + } +} +func ptestNewFactoryErr() (func() ptestPlugin, error) { return ptestNew, nil } +func ptestNewFactoryErrFailing() (func() ptestPlugin, error) { return nil, ptestCreateFailedErr } +func ptestNewFactoryFactoryErr() func() (ptestPlugin, error) { return ptestNewErr } +func ptestNewFactoryFactoryErrFailing() func() (ptestPlugin, error) { return ptestNewErrFailing } + +func ptestDefaultConf() ptestConfig { return ptestConfig{ptestDefaultValue} } +func ptestNewDefaultPtrConf() *ptestConfig { return &ptestConfig{ptestDefaultValue} } + +func ptestType() reflect.Type { return PtrType((*ptestPlugin)(nil)) } +func ptestNewErrType() reflect.Type { return reflect.TypeOf(ptestNewErr) } +func ptestNewType() reflect.Type { return reflect.TypeOf(ptestNew) } + +func ptestFillConf(conf interface{}) error { + return config.Decode(map[string]interface{}{"Value": ptestFilledValue}, conf) +} + +func ptestExpectConfigValue(conf interface{}, val string) { + conf.(ptestConfChecker).expectConfValue(val) +} + +type ptestConfChecker interface { + expectConfValue(string) +} + +var _ ptestConfChecker = ptestConfig{} +var _ ptestConfChecker = &ptestImpl{} + +func (c ptestConfig) expectConfValue(val string) { Expect(c.Value).To(Equal(val)) } +func (p *ptestImpl) expectConfValue(val string) { Expect(p.Value).To(Equal(val)) } diff --git a/core/plugin/registry.go b/core/plugin/registry.go new file mode 100644 index 000000000..96c14022d --- /dev/null +++ b/core/plugin/registry.go @@ -0,0 +1,246 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package plugin + +import ( + "reflect" + + "github.com/pkg/errors" +) + +func NewRegistry() *Registry { + return &Registry{make(map[reflect.Type]nameRegistry)} +} + +type Registry struct { + typeToNameReg map[reflect.Type]nameRegistry +} + +func newNameRegistry() nameRegistry { return make(nameRegistry) } + +type nameRegistry map[string]nameRegistryEntry + +type nameRegistryEntry struct { + constructor implConstructor + defaultConfig defaultConfigContainer +} + +// Register registers plugin constructor and optional default config factory, +// for given plugin interface type and plugin name. +// See package doc for type expectations details. +// Register designed to be called in package init func, so it panics if something go wrong. +// Panics if type expectations are violated. +// Panics if some constructor have been already registered for this (pluginType, name) pair. +// Register is thread unsafe. +// +// If constructor receive config argument, default config factory can be +// registered. Default config factory type should be: is func() . +// Default config factory is optional. If no default config factory has been +// registered, than plugin factory will receive zero config (zero struct or +// pointer to zero struct). +// Registered constructor will never receive nil config, even there +// are no registered default config factory, or default config is nil. Config +// will be pointer to zero config in such case. +func (r *Registry) Register( + pluginType reflect.Type, + name string, + constructor interface{}, + newDefaultConfigOptional ...interface{}, // default config factory, or nothing. +) { + expect(pluginType.Kind() == reflect.Interface, "plugin type should be interface, but have: %T", pluginType) + expect(name != "", "empty name") + nameReg := r.typeToNameReg[pluginType] + if nameReg == nil { + nameReg = newNameRegistry() + r.typeToNameReg[pluginType] = nameReg + } + _, ok := nameReg[name] + expect(!ok, "plugin %s with name %q had been already registered", pluginType, name) + newDefaultConfig := getNewDefaultConfig(newDefaultConfigOptional) + nameReg[name] = newNameRegistryEntry(pluginType, constructor, newDefaultConfig) +} + +// Lookup returns true if any plugin constructor has been registered for given +// type. +func (r *Registry) Lookup(pluginType reflect.Type) bool { + _, ok := r.typeToNameReg[pluginType] + return ok +} + +// LookupFactory returns true if factoryType looks like func() (SomeInterface[, error]) +// and any plugin constructor has been registered for SomeInterface. +// That is, you may create instance of this factoryType using this registry. +func (r *Registry) LookupFactory(factoryType reflect.Type) bool { + return isFactoryType(factoryType) && r.Lookup(factoryType.Out(0)) +} + +// New creates plugin using registered plugin constructor. Returns error if creation +// failed or no plugin were registered for given type and name. +// Passed fillConf called on created config before calling plugin factory. +// fillConf argument is always valid struct pointer, even if plugin factory +// receives no config: fillConf is called on empty struct pointer in such case. +// fillConf error fails plugin creation. +// New is thread safe, if there is no concurrent Register calls. +func (r *Registry) New(pluginType reflect.Type, name string, fillConfOptional ...func(conf interface{}) error) (plugin interface{}, err error) { + expect(pluginType.Kind() == reflect.Interface, "plugin type should be interface, but have: %T", pluginType) + expect(name != "", "empty name") + fillConf := getFillConf(fillConfOptional) + registered, err := r.get(pluginType, name) + if err != nil { + return + } + conf, err := registered.defaultConfig.Get(fillConf) + if err != nil { + return nil, err + } + return registered.constructor.NewPlugin(conf) +} + +// NewFactory behaves like New, but creates factory func() (PluginInterface[, error]), that on call +// creates New plugin by registered factory. +// If registered constructor is config is created filled for every factory call, +// if . + newValue reflect.Value +} + +func newDefaultConfigContainer(constructorType reflect.Type, newDefaultConfig interface{}) defaultConfigContainer { + if constructorType.NumIn() == 0 { + expect(newDefaultConfig == nil, "constructor accept no config, but newDefaultConfig passed") + return defaultConfigContainer{} + } + expect(constructorType.NumIn() == 1, "constructor should accept zero or one argument") + configType := constructorType.In(0) + expect(configType.Kind() == reflect.Struct || + configType.Kind() == reflect.Ptr && configType.Elem().Kind() == reflect.Struct, + "unexpected config kind: %s; should be struct or struct pointer") + newDefaultConfigType := reflect.FuncOf(nil, []reflect.Type{configType}, false) + if newDefaultConfig == nil { + value := reflect.MakeFunc(newDefaultConfigType, + func(_ []reflect.Value) (results []reflect.Value) { + // OPTIMIZE: create addressable. + return []reflect.Value{reflect.Zero(configType)} + }) + return defaultConfigContainer{value} + } + value := reflect.ValueOf(newDefaultConfig) + expect(value.Type() == newDefaultConfigType, + "newDefaultConfig should be func that accepts nothing, and returns constructor argument, but have type %T", newDefaultConfig) + return defaultConfigContainer{value} +} + +// In reflect pkg []Value used to call functions. It's easier to return it, that convert from pointer when needed. +func (e defaultConfigContainer) Get(fillConf func(fillAddr interface{}) error) (maybeConf []reflect.Value, err error) { + var fillAddr interface{} + if e.configRequired() { + maybeConf, fillAddr = e.new() + } else { + var emptyStruct struct{} + fillAddr = &emptyStruct // No fields to fill. + } + if fillConf != nil { + err = fillConf(fillAddr) + if err != nil { + return nil, err + } + } + return +} + +func (e defaultConfigContainer) new() (maybeConf []reflect.Value, fillAddr interface{}) { + if !e.configRequired() { + panic("try to create config when not required") + } + conf := e.newValue.Call(nil)[0] + switch conf.Kind() { + case reflect.Struct: + // Config can be filled only by pointer. + if !conf.CanAddr() { + // Can't address to pass pointer into decoder. Let's make New addressable! + newArg := reflect.New(conf.Type()).Elem() + newArg.Set(conf) + conf = newArg + } + fillAddr = conf.Addr().Interface() + case reflect.Ptr: + if conf.IsNil() { + // Can't fill nil config. Init with zero. + conf = reflect.New(conf.Type().Elem()) + } + fillAddr = conf.Interface() + default: + panic("unexpected type " + conf.String()) + } + maybeConf = []reflect.Value{conf} + return +} + +func (e defaultConfigContainer) configRequired() bool { + return e.newValue.IsValid() +} + +func getFillConf(fillConfOptional []func(conf interface{}) error) func(interface{}) error { + expect(len(fillConfOptional) <= 1, "only fill config parameter could be passed") + if len(fillConfOptional) == 0 { + return nil + } + return fillConfOptional[0] +} + +func getNewDefaultConfig(newDefaultConfigOptional []interface{}) interface{} { + expect(len(newDefaultConfigOptional) <= 1, "too many arguments passed") + if len(newDefaultConfigOptional) == 0 { + return nil + } + return newDefaultConfigOptional[0] +} diff --git a/core/plugin/registry_test.go b/core/plugin/registry_test.go new file mode 100644 index 000000000..5eb1da323 --- /dev/null +++ b/core/plugin/registry_test.go @@ -0,0 +1,301 @@ +// Copyright (c) 2016 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package plugin + +import ( + "io" + "reflect" + + "github.com/mitchellh/mapstructure" + . "github.com/onsi/ginkgo" + . "github.com/onsi/ginkgo/extensions/table" + . "github.com/onsi/gomega" + "github.com/pkg/errors" +) + +var _ = Describe("new default config container", func() { + DescribeTable("expectation fail", + func(constructor interface{}, newDefaultConfigOptional ...interface{}) { + newDefaultConfig := getNewDefaultConfig(newDefaultConfigOptional) + defer recoverExpectationFail() + newDefaultConfigContainer(reflect.TypeOf(constructor), newDefaultConfig) + }, + Entry("invalid type", + func(int) ptestPlugin { return nil }), + Entry("invalid ptr type", + func(*int) ptestPlugin { return nil }), + Entry("to many args", + func(_, _ ptestConfig) ptestPlugin { return nil }), + Entry("default without config", + func() ptestPlugin { return nil }, func() *ptestConfig { return nil }), + Entry("invalid default config", + func(ptestConfig) ptestPlugin { return nil }, func() *ptestConfig { return nil }), + Entry("default config accepts args", + func(*ptestConfig) ptestPlugin { return nil }, func(int) *ptestConfig { return nil }), + ) + + DescribeTable("expectation ok", + func(constructor interface{}, newDefaultConfigOptional ...interface{}) { + newDefaultConfig := getNewDefaultConfig(newDefaultConfigOptional) + container := newDefaultConfigContainer(reflect.TypeOf(constructor), newDefaultConfig) + conf, err := container.Get(ptestFillConf) + Expect(err).NotTo(HaveOccurred()) + Expect(conf).To(HaveLen(1)) + ptestExpectConfigValue(conf[0].Interface(), ptestFilledValue) + }, + Entry("no default config", + ptestNewConf), + Entry("no default ptr config", + ptestNewPtrConf), + Entry("default config", + ptestNewConf, ptestDefaultConf), + Entry("default ptr config", + ptestNewPtrConf, ptestNewDefaultPtrConf), + ) + + It("fill no config failed", func() { + container := newDefaultConfigContainer(ptestNewErrType(), nil) + _, err := container.Get(ptestFillConf) + Expect(err).To(HaveOccurred()) + }) +}) + +var _ = Describe("registry", func() { + It("register name collision panics", func() { + r := NewRegistry() + r.ptestRegister(ptestNewImpl) + defer recoverExpectationFail() + r.ptestRegister(ptestNewImpl) + }) + + It("lookup", func() { + r := NewRegistry() + r.ptestRegister(ptestNewImpl) + Expect(r.Lookup(ptestType())).To(BeTrue()) + Expect(r.Lookup(reflect.TypeOf(0))).To(BeFalse()) + Expect(r.Lookup(reflect.TypeOf(&ptestImpl{}))).To(BeFalse()) + Expect(r.Lookup(reflect.TypeOf((*io.Writer)(nil)).Elem())).To(BeFalse()) + }) + + It("lookup factory", func() { + r := NewRegistry() + r.ptestRegister(ptestNewImpl) + Expect(r.LookupFactory(ptestNewType())).To(BeTrue()) + Expect(r.LookupFactory(ptestNewErrType())).To(BeTrue()) + + Expect(r.LookupFactory(reflect.TypeOf(0))).To(BeFalse()) + Expect(r.LookupFactory(reflect.TypeOf(&ptestImpl{}))).To(BeFalse()) + Expect(r.LookupFactory(reflect.TypeOf((*io.Writer)(nil)).Elem())).To(BeFalse()) + }) + +}) + +var _ = Describe("new", func() { + type New func(r *Registry, fillConfOptional ...func(conf interface{}) error) (interface{}, error) + var ( + r *Registry + testNew New + testNewOk = func(fillConfOptional ...func(conf interface{}) error) (pluginVal string) { + plugin, err := testNew(r, fillConfOptional...) + Expect(err).NotTo(HaveOccurred()) + return plugin.(*ptestImpl).Value + } + ) + BeforeEach(func() { r = NewRegistry() }) + runTestCases := func() { + Context("plugin constructor", func() { + It("no conf", func() { + r.ptestRegister(ptestNewImpl) + Expect(testNewOk()).To(Equal(ptestInitValue)) + }) + It("nil error", func() { + r.ptestRegister(ptestNewErr) + Expect(testNewOk()).To(Equal(ptestInitValue)) + }) + It("non-nil error", func() { + r.ptestRegister(ptestNewErrFailing) + _, err := testNew(r) + Expect(err).To(HaveOccurred()) + err = errors.Cause(err) + Expect(ptestCreateFailedErr).To(Equal(err)) + }) + It("no conf, fill conf error", func() { + r.ptestRegister(ptestNewImpl) + expectedErr := errors.New("fill conf err") + _, err := testNew(r, func(_ interface{}) error { return expectedErr }) + Expect(expectedErr).To(Equal(err)) + }) + It("no default", func() { + r.ptestRegister(ptestNewConf) + Expect(testNewOk()).To(Equal("")) + }) + It("default", func() { + r.ptestRegister(ptestNewConf, ptestDefaultConf) + Expect(testNewOk()).To(Equal(ptestDefaultValue)) + }) + It("fill conf default", func() { + r.ptestRegister(ptestNewConf, ptestDefaultConf) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("fill conf no default", func() { + r.ptestRegister(ptestNewConf) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("fill ptr conf no default", func() { + r.ptestRegister(ptestNewPtrConf) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("no default ptr conf not nil", func() { + r.ptestRegister(ptestNewPtrConf) + Expect("").To(Equal(testNewOk())) + }) + It("nil default, conf not nil", func() { + r.ptestRegister(ptestNewPtrConf, func() *ptestConfig { return nil }) + Expect(testNewOk()).To(Equal("")) + }) + It("fill nil default", func() { + r.ptestRegister(ptestNewPtrConf, func() *ptestConfig { return nil }) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("more than one fill conf panics", func() { + r.ptestRegister(ptestNewPtrConf) + defer recoverExpectationFail() + testNew(r, ptestFillConf, ptestFillConf) + }) + }) + + Context("factory constructor", func() { + It("no conf", func() { + r.ptestRegister(ptestNewFactory) + Expect(testNewOk()).To(Equal(ptestInitValue)) + }) + It("nil error", func() { + r.ptestRegister(func() (ptestPlugin, error) { + return ptestNewImpl(), nil + }) + Expect(testNewOk()).To(Equal(ptestInitValue)) + }) + It("non-nil error", func() { + r.ptestRegister(ptestNewFactoryFactoryErrFailing) + _, err := testNew(r) + Expect(err).To(HaveOccurred()) + err = errors.Cause(err) + Expect(ptestCreateFailedErr).To(Equal(err)) + }) + It("no conf, fill conf error", func() { + r.ptestRegister(ptestNewFactory) + expectedErr := errors.New("fill conf err") + _, err := testNew(r, func(_ interface{}) error { return expectedErr }) + Expect(expectedErr).To(Equal(err)) + }) + It("no default", func() { + r.ptestRegister(ptestNewFactoryConf) + Expect(testNewOk()).To(Equal("")) + }) + It("default", func() { + r.ptestRegister(ptestNewFactoryConf, ptestDefaultConf) + Expect(testNewOk()).To(Equal(ptestDefaultValue)) + }) + It("fill conf default", func() { + r.ptestRegister(ptestNewFactoryConf, ptestDefaultConf) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("fill conf no default", func() { + r.ptestRegister(ptestNewFactoryConf) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("fill ptr conf no default", func() { + r.ptestRegister(ptestNewFactoryPtrConf) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("no default ptr conf not nil", func() { + r.ptestRegister(ptestNewFactoryPtrConf) + Expect("").To(Equal(testNewOk())) + }) + It("nil default, conf not nil", func() { + r.ptestRegister(ptestNewFactoryPtrConf, func() *ptestConfig { return nil }) + Expect(testNewOk()).To(Equal("")) + }) + It("fill nil default", func() { + r.ptestRegister(ptestNewFactoryPtrConf, func() *ptestConfig { return nil }) + Expect(testNewOk(ptestFillConf)).To(Equal(ptestFilledValue)) + }) + It("more than one fill conf panics", func() { + r.ptestRegister(ptestNewFactoryPtrConf) + defer recoverExpectationFail() + testNew(r, ptestFillConf, ptestFillConf) + }) + }) + } + Context("use New", func() { + BeforeEach(func() { testNew = (*Registry).ptestNew }) + runTestCases() + + }) + Context("use NewFactory", func() { + BeforeEach(func() { testNew = (*Registry).ptestNewFactory }) + runTestCases() + }) + +}) + +var _ = Describe("decode", func() { + It("ok", func() { + r := NewRegistry() + const nameKey = "type" + + var hook mapstructure.DecodeHookFunc + decode := func(input, result interface{}) error { + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: hook, + ErrorUnused: true, + Result: result, + }) + if err != nil { + return err + } + return decoder.Decode(input) + } + hook = mapstructure.ComposeDecodeHookFunc( + func(from reflect.Type, to reflect.Type, data interface{}) (interface{}, error) { + if !r.Lookup(to) { + return data, nil + } + // NOTE: could be map[interface{}]interface{} here. + input := data.(map[string]interface{}) + // NOTE: should be case insensitive behaviour. + pluginName := input[nameKey].(string) + delete(input, nameKey) + return r.New(to, pluginName, func(conf interface{}) error { + // NOTE: should error, if conf has "type" field. + return decode(input, conf) + }) + }) + + r.Register(ptestType(), "my-plugin", ptestNewConf, ptestDefaultConf) + input := map[string]interface{}{ + "plugin": map[string]interface{}{ + nameKey: "my-plugin", + "value": ptestFilledValue, + }, + } + type Config struct { + Plugin ptestPlugin + } + var conf Config + err := decode(input, &conf) + Expect(err).NotTo(HaveOccurred()) + actualValue := conf.Plugin.(*ptestImpl).Value + Expect(actualValue).To(Equal(ptestFilledValue)) + }) + +}) + +func recoverExpectationFail() { + r := recover() + Expect(r).NotTo(BeNil()) + Expect(r).To(ContainSubstring("expectation failed")) +} diff --git a/glide.lock b/glide.lock index 9bf34c575..0735d8eef 100644 --- a/glide.lock +++ b/glide.lock @@ -1,10 +1,10 @@ -hash: b06d35c760b8a60202cc52695b1dd3d3ac96acdbbb9a8cd52c0cba9a05880733 -updated: 2017-08-25T10:25:48.281437192+03:00 +hash: e9c64f4ab9d84f20f3ea357eb84c21becc99b134765591aec67efddbae7d46f0 +updated: 2017-09-30T22:54:19.407523708+03:00 imports: - name: github.com/amahi/spdy version: 31da8b754faf6833fa95192c69bdb10518e9fb7b - name: github.com/asaskevich/govalidator - version: 15028e809df8c71964e8efa6c11e81d5c0262302 + version: 6fcd5b427f532a5d13738b27415e00a49e36ceef - name: github.com/c2h5oh/datasize version: 54516c931ae99c3c74637b9ea2390cf9a6327f26 - name: github.com/davecgh/go-spew @@ -20,7 +20,7 @@ imports: - name: github.com/fsnotify/fsnotify version: 4da3e2cfbabc9f751898f250b49f2439785783a1 - name: github.com/hashicorp/hcl - version: 392dba7d905ed5d04a5794ba89f558b27e2ba1ca + version: 68e816d1c783414e79bc65b3994d9ab6b0a722ab subpackages: - hcl/ast - hcl/parser @@ -31,12 +31,12 @@ imports: - json/scanner - json/token - name: github.com/magiconair/properties - version: be5ece7dd465ab0765a9682137865547526d1dfb + version: 8d7837e64d3c1ee4e54a880c5a920ab4316fc90a - name: github.com/mitchellh/mapstructure version: e88fb6b4946b282e0d5196ac35b09d256e09e9d2 repo: https://github.com/skipor/mapstructure - name: github.com/onsi/ginkgo - version: 8382b23d18dbaaff8e5f7e83784c53ebb8ec2f47 + version: 11459a886d9cd66b319dac7ef1e917ee221372c9 subpackages: - config - extensions/table @@ -57,7 +57,7 @@ imports: - reporters/stenographer/support/go-isatty - types - name: github.com/onsi/gomega - version: c893efa28eb45626cdaa76c9f653b62488858837 + version: dcabb60a477c2b6f456df65037cb6708210fbb02 subpackages: - format - gbytes @@ -75,30 +75,30 @@ imports: - matchers/support/goraph/util - types - name: github.com/pelletier/go-toml - version: 9c1b4e331f1e3d98e72600677699fbe212cd6d16 + version: 16398bac157da96aa88f98a2df640c7f32af1da2 - name: github.com/pkg/errors - version: c605e284fe17294bda444b34710735b29d1a9d90 + version: 2b3a18b5f0fb6b4f9190549597d3f962c02bc5eb - name: github.com/pmezard/go-difflib version: d8ed2627bdf02c080bf22230dbb337003b7aba2d subpackages: - difflib - name: github.com/pquerna/ffjson - version: 9a5203b7a07166f217f5f8177d5b16177acad3b2 + version: 619064c2092f8ed31957bf7af846f1af0529b737 subpackages: - fflib/v1 - fflib/v1/internal - name: github.com/spf13/afero - version: 9be650865eab0c12963d8753212f4f9c66cdcf12 + version: ee1bd8ee15a1306d1f9201acc41ef39cd9f99a1b subpackages: - mem - name: github.com/spf13/cast version: acbeb36b902d72a7a4c18e8f3241075e7ab763e4 - name: github.com/spf13/jwalterweatherman - version: 0efa5202c04663c757d84f90f5219c1250baf94f + version: 12bd96e66386c1960ab0f74ced1362f66f552f7b - name: github.com/spf13/pflag - version: e57e3eeb33f795204c1ca35f56c44f83227c6e66 + version: 7aff26db30c1be810f9de5038ec5ef96ac41fd7c - name: github.com/spf13/viper - version: 25b30aa063fc18e48662b86996252eabdcf2f0c7 + version: d9cca5ef33035202efb1586825bdbb15ff9ec3ba - name: github.com/stretchr/objx version: cbeaeb16a013161a98496fad62933b1d21786672 - name: github.com/stretchr/testify @@ -108,13 +108,13 @@ imports: - mock - require - name: github.com/uber-go/atomic - version: 70bd1261d36be490ebd22a62b385a3c5d23b6240 + version: 54f72d32435d760d5604f17a82e2435b28dc4ba5 - name: go.uber.org/atomic - version: 70bd1261d36be490ebd22a62b385a3c5d23b6240 + version: 4e336646b2ef9fc6e47be8e21594178f98e5ebcf - name: go.uber.org/multierr version: 3c4937480c32f4c13a875a1829af76c98ca3d40a - name: go.uber.org/zap - version: 416e66ad83ebde35df0e09f02e65ac149e193b0e + version: 35aad584952c3e7020db7b839f6b102de6271f89 subpackages: - buffer - internal/bufferpool @@ -122,7 +122,7 @@ imports: - internal/exit - zapcore - name: golang.org/x/net - version: 57efc9c3d9f91fb3277f8da1cff370539c4d3dc5 + version: 0a9397675ba34b2845f758fe3cd68828369c6517 subpackages: - html - html/atom @@ -132,11 +132,11 @@ imports: - idna - lex/httplex - name: golang.org/x/sys - version: 07c182904dbd53199946ba614a412c61d3c548f5 + version: 314a259e304ff91bd6985da2a7149bbf91237993 subpackages: - unix - name: golang.org/x/text - version: ac87088df8ef557f1e32cd00ed0b6fbc3f7ddafb + version: 1cbadb444a806fd9430d14ad08967ed91da4fa0a subpackages: - encoding - encoding/charmap diff --git a/lib/netutil/dial.go b/lib/netutil/dial.go new file mode 100644 index 000000000..44c94b0e1 --- /dev/null +++ b/lib/netutil/dial.go @@ -0,0 +1,115 @@ +// Copyright (c) 2017 Yandex LLC. All rights reserved. +// Use of this source code is governed by a MPL 2.0 +// license that can be found in the LICENSE file. +// Author: Vladimir Skipor + +package netutil + +import ( + "context" + "net" + "sync" + + "github.com/pkg/errors" +) + +//go:generate mockery -name=Dialer -case=underscore -outpkg=netmock + +type Dialer interface { + DialContext(ctx context.Context, net, addr string) (net.Conn, error) +} + +var _ Dialer = &net.Dialer{} + +type DialerFunc func(ctx context.Context, network, address string) (net.Conn, error) + +func (f DialerFunc) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return f(ctx, network, address) +} + +// NewDNSCachingDialer returns dialer with primitive DNS caching logic +// that remembers remote address on first try, and use it in future. +func NewDNSCachingDialer(dialer Dialer, cache DNSCache) DialerFunc { + return func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + resolved, ok := cache.Get(addr) + if ok { + return dialer.DialContext(ctx, network, resolved) + } + conn, err = dialer.DialContext(ctx, network, addr) + if err != nil { + return + } + remoteAddr := conn.RemoteAddr().(*net.TCPAddr) + _, port, err := net.SplitHostPort(addr) + if err != nil { + conn.Close() + return nil, errors.Wrap(err, "invalid address, but successful dial - should not happen") + } + cache.Add(addr, net.JoinHostPort(remoteAddr.IP.String(), port)) + return + } +} + +var DefaultDNSCache = &SimpleDNSCache{} + +// LookupReachable tries to resolve addr via connecting to it. +// This method has much more overhead, but get guaranteed reachable resolved addr. +// Example: host is resolved to IPv4 and IPv6, but IPv4 is not working on machine. +// LookupAccessible will return IPv6 in that case. +func LookupReachable(addr string) (string, error) { + d := net.Dialer{DualStack: true} + conn, err := d.Dial("tcp", addr) + if err != nil { + return "", err + } + defer conn.Close() + _, port, err := net.SplitHostPort(addr) + if err != nil { + return "", err + } + remoteAddr := conn.RemoteAddr().(*net.TCPAddr) + return net.JoinHostPort(remoteAddr.IP.String(), port), nil +} + +// WarmDNSCache tries connect to addr, and adds conn remote ip + addr port to cache. +func WarmDNSCache(c DNSCache, addr string) error { + var d net.Dialer + conn, err := NewDNSCachingDialer(&d, c).DialContext(context.Background(), "tcp", addr) + if err != nil { + return err + } + conn.Close() + return nil +} + +//go:generate mockery -name=DNSCache -case=underscore -outpkg=netmock + +type DNSCache interface { + Get(addr string) (string, bool) + Add(addr, resolved string) +} + +type SimpleDNSCache struct { + rw sync.RWMutex + hostToAddr map[string]string +} + +func (c *SimpleDNSCache) Get(addr string) (resolved string, ok bool) { + c.rw.RLock() + if c.hostToAddr == nil { + c.rw.RUnlock() + return + } + resolved, ok = c.hostToAddr[addr] + c.rw.RUnlock() + return +} + +func (c *SimpleDNSCache) Add(addr, resolved string) { + c.rw.Lock() + if c.hostToAddr == nil { + c.hostToAddr = make(map[string]string) + } + c.hostToAddr[addr] = resolved + c.rw.Unlock() +} diff --git a/lib/netutil/mocks/conn.go b/lib/netutil/mocks/conn.go new file mode 100644 index 000000000..7f10c79c0 --- /dev/null +++ b/lib/netutil/mocks/conn.go @@ -0,0 +1,142 @@ +// Code generated by mockery v1.0.0 +package netmock + +import "github.com/stretchr/testify/mock" +import "net" + +import "time" + +// Conn is an autogenerated mock type for the Conn type +type Conn struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *Conn) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// LocalAddr provides a mock function with given fields: +func (_m *Conn) LocalAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// Read provides a mock function with given fields: b +func (_m *Conn) Read(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemoteAddr provides a mock function with given fields: +func (_m *Conn) RemoteAddr() net.Addr { + ret := _m.Called() + + var r0 net.Addr + if rf, ok := ret.Get(0).(func() net.Addr); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Addr) + } + } + + return r0 +} + +// SetDeadline provides a mock function with given fields: t +func (_m *Conn) SetDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetReadDeadline provides a mock function with given fields: t +func (_m *Conn) SetReadDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetWriteDeadline provides a mock function with given fields: t +func (_m *Conn) SetWriteDeadline(t time.Time) error { + ret := _m.Called(t) + + var r0 error + if rf, ok := ret.Get(0).(func(time.Time) error); ok { + r0 = rf(t) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Write provides a mock function with given fields: b +func (_m *Conn) Write(b []byte) (int, error) { + ret := _m.Called(b) + + var r0 int + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(b) + } else { + r0 = ret.Get(0).(int) + } + + var r1 error + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(b) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/lib/netutil/mocks/dialer.go b/lib/netutil/mocks/dialer.go new file mode 100644 index 000000000..27e1ea755 --- /dev/null +++ b/lib/netutil/mocks/dialer.go @@ -0,0 +1,34 @@ +// Code generated by mockery v1.0.0 +package netmock + +import "context" +import "github.com/stretchr/testify/mock" +import "net" + +// Dialer is an autogenerated mock type for the Dialer type +type Dialer struct { + mock.Mock +} + +// DialContext provides a mock function with given fields: ctx, _a1, addr +func (_m *Dialer) DialContext(ctx context.Context, _a1 string, addr string) (net.Conn, error) { + ret := _m.Called(ctx, _a1, addr) + + var r0 net.Conn + if rf, ok := ret.Get(0).(func(context.Context, string, string) net.Conn); ok { + r0 = rf(ctx, _a1, addr) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(net.Conn) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, _a1, addr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} diff --git a/lib/netutil/mocks/dns_cache.go b/lib/netutil/mocks/dns_cache.go new file mode 100644 index 000000000..ca9979966 --- /dev/null +++ b/lib/netutil/mocks/dns_cache.go @@ -0,0 +1,35 @@ +// Code generated by mockery v1.0.0 +package netmock + +import "github.com/stretchr/testify/mock" + +// DNSCache is an autogenerated mock type for the DNSCache type +type DNSCache struct { + mock.Mock +} + +// Add provides a mock function with given fields: addr, resolved +func (_m *DNSCache) Add(addr string, resolved string) { + _m.Called(addr, resolved) +} + +// Get provides a mock function with given fields: addr +func (_m *DNSCache) Get(addr string) (string, bool) { + ret := _m.Called(addr) + + var r0 string + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(addr) + } else { + r0 = ret.Get(0).(string) + } + + var r1 bool + if rf, ok := ret.Get(1).(func(string) bool); ok { + r1 = rf(addr) + } else { + r1 = ret.Get(1).(bool) + } + + return r0, r1 +} diff --git a/lib/netutil/netutil_suite_test.go b/lib/netutil/netutil_suite_test.go new file mode 100644 index 000000000..f36a31599 --- /dev/null +++ b/lib/netutil/netutil_suite_test.go @@ -0,0 +1,107 @@ +package netutil + +import ( + "context" + "net" + "strconv" + "testing" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" + + "github.com/pkg/errors" + "github.com/yandex/pandora/lib/netutil/mocks" + "github.com/yandex/pandora/lib/testutil" +) + +func TestNetutil(t *testing.T) { + testutil.RunSuite(t, "Netutil Suite") +} + +var _ = Describe("DNS", func() { + + It("lookup reachable", func() { + listener, err := net.ListenTCP("tcp4", nil) + defer listener.Close() + Expect(err).NotTo(HaveOccurred()) + + port := strconv.Itoa(listener.Addr().(*net.TCPAddr).Port) + addr := "localhost:" + port + expectedResolved := "127.0.0.1:" + port + + resolved, err := LookupReachable(addr) + Expect(err).NotTo(HaveOccurred()) + Expect(resolved).To(Equal(expectedResolved)) + }) + + const ( + addr = "localhost:8888" + resolved = "[::1]:8888" + ) + + It("cache", func() { + cache := &SimpleDNSCache{} + got, ok := cache.Get(addr) + Expect(ok).To(BeFalse()) + Expect(got).To(BeEmpty()) + + cache.Add(addr, resolved) + got, ok = cache.Get(addr) + Expect(ok).To(BeTrue()) + Expect(got).To(Equal(resolved)) + }) + + It("Dialer cache miss", func() { + ctx := context.Background() + mockConn := &netmock.Conn{} + mockConn.On("RemoteAddr").Return(&net.TCPAddr{ + IP: net.IPv6loopback, + Port: 8888, + }) + cache := &netmock.DNSCache{} + cache.On("Get", addr).Return("", false) + cache.On("Add", addr, resolved) + dialer := &netmock.Dialer{} + dialer.On("DialContext", ctx, "tcp", addr).Return(mockConn, nil) + + testee := NewDNSCachingDialer(dialer, cache) + conn, err := testee.DialContext(ctx, "tcp", addr) + Expect(err).NotTo(HaveOccurred()) + Expect(conn).To(Equal(mockConn)) + + testutil.AssertExpectations(mockConn, cache, dialer) + }) + + It("Dialer cache hit", func() { + ctx := context.Background() + mockConn := &netmock.Conn{} + cache := &netmock.DNSCache{} + cache.On("Get", addr).Return(resolved, true) + dialer := &netmock.Dialer{} + dialer.On("DialContext", ctx, "tcp", resolved).Return(mockConn, nil) + + testee := NewDNSCachingDialer(dialer, cache) + conn, err := testee.DialContext(ctx, "tcp", addr) + Expect(err).NotTo(HaveOccurred()) + Expect(conn).To(Equal(mockConn)) + + testutil.AssertExpectations(mockConn, cache, dialer) + }) + + It("Dialer cache miss err", func() { + ctx := context.Background() + expectedErr := errors.New("dial failed") + cache := &netmock.DNSCache{} + cache.On("Get", addr).Return("", false) + dialer := &netmock.Dialer{} + dialer.On("DialContext", ctx, "tcp", addr).Return(nil, expectedErr) + + testee := NewDNSCachingDialer(dialer, cache) + conn, err := testee.DialContext(ctx, "tcp", addr) + Expect(err).To(Equal(expectedErr)) + Expect(conn).To(BeNil()) + + testutil.AssertExpectations(cache, dialer) + }) + +}) diff --git a/lib/testutil/ginkgo.go b/lib/testutil/ginkgo.go index b60339da6..888227f33 100644 --- a/lib/testutil/ginkgo.go +++ b/lib/testutil/ginkgo.go @@ -7,15 +7,24 @@ package testutil import ( "strings" + "testing" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" + "github.com/onsi/gomega/format" "github.com/spf13/viper" "github.com/stretchr/testify/mock" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) +func RunSuite(t *testing.T, description string) { + format.UseStringerRepresentation = true + ReplaceGlobalLogger() + RegisterFailHandler(Fail) + RunSpecs(t, description) +} + func ReplaceGlobalLogger() *zap.Logger { log := NewLogger() zap.ReplaceGlobals(log) diff --git a/script/coverage.sh b/script/coverage.sh index 1a7fdf5c7..a40d4ae37 100755 --- a/script/coverage.sh +++ b/script/coverage.sh @@ -14,7 +14,7 @@ _cd_into_top_level() { } _generate_coverage_files() { - for dir in $(find . -maxdepth 10 -not -path './.git*' -not -path '*/vendor/*' -type d); do + for dir in $(find . -maxdepth 10 -not -path './.git*' -not -path '*/vendor/*' -not -path '*/mocks/*' -type d); do if ls $dir/*.go &>/dev/null ; then go test -covermode=count -coverprofile=$dir/profile.coverprofile $dir || fail=1 fi