diff --git a/go.mod b/go.mod index 76fb230..05b8a06 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/go-resty/resty/v2 v2.11.0 github.com/onsi/ginkgo v1.16.5 github.com/onsi/gomega v1.24.1 + github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 github.com/slok/goresilience v0.2.0 github.com/stretchr/testify v1.8.1 golang.org/x/oauth2 v0.2.0 @@ -25,6 +26,7 @@ require ( github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 // indirect github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a // indirect golang.org/x/net v0.17.0 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go.sum b/go.sum index b077ec4..fa43083 100644 --- a/go.sum +++ b/go.sum @@ -52,6 +52,8 @@ github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 h1:PnBWHBf+6L0jO github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro= github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a h1:9a8MnZMP0X2nLJdBg+pBmGgkJlSaKC2KaQmTCk1XDtE= github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= +github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529 h1:18kd+8ZUlt/ARXhljq+14TwAoKa61q6dX8jtwOf6DH8= +github.com/rs/dnscache v0.0.0-20230804202142-fc85eb664529/go.mod h1:qe5TWALJ8/a1Lqznoc5BDHpYX/8HU60Hm2AwRmqzxqA= github.com/slok/goresilience v0.2.0 h1:dagdIiWlhTm7BK/r/LRKz+zvw0SCNk+nHf7obdsbzxQ= github.com/slok/goresilience v0.2.0/go.mod h1:L6IqqHlxWGTrTyq8WwF8kUY8kOIESZAMWr1xkV0zdZA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -94,6 +96,7 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/httpclient.go b/httpclient.go index 9c1e6de..3169571 100644 --- a/httpclient.go +++ b/httpclient.go @@ -2,8 +2,6 @@ package httpclient import ( "context" - "crypto/tls" - "net" "net/http" "net/url" "time" @@ -78,27 +76,6 @@ func (c *HTTPClient) setTransport(transport http.RoundTripper) { c.resty.SetTransport(transport) } -func NewDefaultTransport(transportTimeout time.Duration) http.RoundTripper { - return &Transport{ - RoundTripper: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: transportTimeout, - KeepAlive: 15 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - TLSClientConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - ClientSessionCache: tls.NewLRUClientSessionCache(-1), - }, - }, - } -} - // WithDefaultTransport sets a custom connection timeout to http.Transport. // This timeout limits the time spent establishing a TCP connection. // @@ -144,23 +121,21 @@ func WithOAUTHTransport(conf cc.Config, transportTimeout time.Duration) func(*HT // More information about proxy: http.Transport. func WithDefaultTransportWithProxy(proxyURL *url.URL) func(*HTTPClient) { return func(client *HTTPClient) { - transport := &Transport{ - RoundTripper: &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - DialContext: (&net.Dialer{ - KeepAlive: 5 * time.Minute, - DualStack: true, - }).DialContext, - MaxIdleConns: 10, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - TLSClientConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - ClientSessionCache: tls.NewLRUClientSessionCache(-1), - }, - }, - } + transport := NewDefaultTransport(5 * time.Second) + transport.SetProxy(http.ProxyURL(proxyURL)) + client.setTransport(transport) + } +} +// WithDefaultTransportWithDNSCache sets a cache for DNS lookups. +// The TTL of the cache is defined by DNS Server TTL. +// The keepAliveDuration is the time to keep the connection alive. +// +// More information about DNS cache: https://github.com/rs/dnscache. +func WithDefaultTransportWithDNSCache(keepAliveDuration time.Duration) func(*HTTPClient) { + return func(client *HTTPClient) { + transport := NewDefaultTransport(5 * time.Second) + transport.SetDNSCache(keepAliveDuration, 5*time.Minute) client.setTransport(transport) } } diff --git a/transport.go b/transport.go index 3d3543b..90db90f 100644 --- a/transport.go +++ b/transport.go @@ -2,13 +2,91 @@ package httpclient import ( "context" + "crypto/tls" + "net" "net/http" + "net/url" + "time" + + "github.com/rs/dnscache" ) // Transport accepts a custom RoundTripper and acts as a middleware to facilitate logging and // argument passing to external requests. type Transport struct { RoundTripper http.RoundTripper + http.Transport + Proxy func(*http.Request) (*url.URL, error) + Resolver interface{} +} + +func NewDefaultTransport(transportTimeout time.Duration) *Transport { + return &Transport{ + RoundTripper: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: transportTimeout, + KeepAlive: 5 * time.Minute, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + ClientSessionCache: tls.NewLRUClientSessionCache(-1), + }, + }, + } +} + +func (t *Transport) SetProxy(proxy func(*http.Request) (*url.URL, error)) *Transport { + t.Proxy = proxy + return t +} + +func (t *Transport) SetDNSCache(keepAliveDuration time.Duration, refreshCacheTime time.Duration) *Transport { + + r := &dnscache.Resolver{} + options := dnscache.ResolverRefreshOptions{} + options.ClearUnused = true + options.PersistOnFailure = false + r.RefreshWithOptions(options) + + go func() { + t := time.NewTicker(refreshCacheTime) + defer t.Stop() + for range t.C { + r.Refresh(true) + } + }() + + t.DialContext = func(ctx context.Context, network, addr string) (conn net.Conn, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + ips, err := r.LookupHost(ctx, host) + if err != nil { + return nil, err + } + + dialer := net.Dialer{ + KeepAlive: keepAliveDuration, + } + + for _, ip := range ips { + conn, err = dialer.DialContext(ctx, network, net.JoinHostPort(ip, port)) + if err == nil { + return conn, nil + } + } + + return nil, err + } + return t } // RoundTrip acts as a middleware performing external requests logging and argument passing to diff --git a/transport_test.go b/transport_test.go new file mode 100644 index 0000000..bd04e5c --- /dev/null +++ b/transport_test.go @@ -0,0 +1,71 @@ +package httpclient_test + +import ( + "context" + "net/http" + "net/url" + "testing" + "time" + + "github.com/globocom/httpclient" + "github.com/stretchr/testify/assert" +) + +func TestHTTPClientTransport(t *testing.T) { + t.Run("TestDefault", TestNewDefaultTransport) + t.Run("TestSetProxy", TestSetProxy) + t.Run("TestDNSCacheBehavior", TestDNSCacheBehavior) +} + +func TestNewDefaultTransport(t *testing.T) { + timeout := 5 * time.Second + transport := httpclient.NewDefaultTransport(timeout) + + assert.NotNil(t, transport) + assert.IsType(t, &http.Transport{}, transport.RoundTripper) + + httpTransport := transport.RoundTripper.(*http.Transport) + assert.Equal(t, 100, httpTransport.MaxIdleConns) + assert.Equal(t, 90*time.Second, httpTransport.IdleConnTimeout) + assert.Equal(t, 10*time.Second, httpTransport.TLSHandshakeTimeout) + assert.Equal(t, 1*time.Second, httpTransport.ExpectContinueTimeout) +} + +func TestSetProxy(t *testing.T) { + transport := httpclient.NewDefaultTransport(5 * time.Second) + + proxyFunc := func(req *http.Request) (*url.URL, error) { + return url.Parse("http://example.com") + } + + transport.SetProxy(proxyFunc) + + assert.NotNil(t, transport.Proxy, "Expected Proxy to be non-nil after setting it") + proxyURL, err := transport.Proxy(&http.Request{}) + assert.NoError(t, err, "Expected no error when calling proxy function") + assert.Equal(t, "http://example.com", proxyURL.String(), "Expected Proxy URL to match set value") +} + +func TestDNSCacheBehavior(t *testing.T) { + transport := httpclient.NewDefaultTransport(5 * time.Minute) + keepAliveDuration := 5 * time.Minute + tr := transport.SetDNSCache(keepAliveDuration, 5*time.Minute) + + ctx := context.Background() + + conn, err := tr.DialContext(ctx, "tcp", "example.com:80") + assert.NoError(t, err, "Expected no error dialing example.com on first attempt") + assert.NotNil(t, conn, "Expected a connection object on first attempt") + if conn != nil { + conn.Close() + } + + // cached DNS + conn, err = tr.DialContext(ctx, "tcp", "example.com:80") + assert.NoError(t, err, "Expected no error dialing example.com on second attempt") + assert.NotNil(t, conn, "Expected a connection object on second attempt") + if conn != nil { + conn.Close() + } + +}