diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go index 14844a0c8a..6d3fd146bd 100644 --- a/internal/oauth/oauth.go +++ b/internal/oauth/oauth.go @@ -15,36 +15,17 @@ package oauth import ( - "net" "net/http" - "time" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" + "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/transport" "go.mongodb.org/atlas/auth" ) const ( - timeout = 5 * time.Second - keepAlive = 30 * time.Second - maxIdleConns = 5 - maxIdleConnsPerHost = 4 - idleConnTimeout = 30 * time.Second - expectContinueTimeout = 1 * time.Second - cloudGovServiceURL = "https://cloud.mongodbgov.com/" + cloudGovServiceURL = "https://cloud.mongodbgov.com/" ) -var defaultTransport = &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: timeout, - KeepAlive: keepAlive, - }).DialContext, - MaxIdleConns: maxIdleConns, - MaxIdleConnsPerHost: maxIdleConnsPerHost, - Proxy: http.ProxyFromEnvironment, - IdleConnTimeout: idleConnTimeout, - ExpectContinueTimeout: expectContinueTimeout, -} - type ServiceGetter interface { Service() string OpsManagerURL() string @@ -58,7 +39,7 @@ const ( func FlowWithConfig(c ServiceGetter) (*auth.Config, error) { client := http.DefaultClient - client.Transport = defaultTransport + client.Transport = transport.Default() id := ClientID if c.Service() == config.CloudGovService { id = GovClientID diff --git a/internal/store/store.go b/internal/store/store.go index a4abd060cd..e7cc97aff6 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -20,14 +20,12 @@ import ( "context" "errors" "fmt" - "net" "net/http" "strings" - "time" - "github.com/mongodb-forks/digest" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/config" "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/log" + "github.com/mongodb/mongodb-atlas-cli/atlascli/internal/transport" atlasClustersPinned "go.mongodb.org/atlas-sdk/v20240530005/admin" atlasv2 "go.mongodb.org/atlas-sdk/v20241113001/admin" atlasauth "go.mongodb.org/atlas/auth" @@ -35,14 +33,7 @@ import ( ) const ( - telemetryTimeout = 1 * time.Second - timeout = 5 * time.Second - keepAlive = 30 * time.Second - maxIdleConns = 5 - maxIdleConnsPerHost = 4 - idleConnTimeout = 30 * time.Second - expectContinueTimeout = 1 * time.Second - cloudGovServiceURL = "https://cloud.mongodbgov.com/" + cloudGovServiceURL = "https://cloud.mongodbgov.com/" ) var errUnsupportedService = errors.New("unsupported service") @@ -62,44 +53,13 @@ type Store struct { ctx context.Context } -var defaultTransport = &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: timeout, - KeepAlive: keepAlive, - }).DialContext, - MaxIdleConns: maxIdleConns, - MaxIdleConnsPerHost: maxIdleConnsPerHost, - Proxy: http.ProxyFromEnvironment, - IdleConnTimeout: idleConnTimeout, - ExpectContinueTimeout: expectContinueTimeout, -} - -var telemetryTransport = &http.Transport{ - DialContext: (&net.Dialer{ - Timeout: telemetryTimeout, - KeepAlive: keepAlive, - }).DialContext, - MaxIdleConns: maxIdleConns, - MaxIdleConnsPerHost: maxIdleConnsPerHost, - Proxy: http.ProxyFromEnvironment, - IdleConnTimeout: idleConnTimeout, - ExpectContinueTimeout: expectContinueTimeout, -} - func (s *Store) httpClient(httpTransport http.RoundTripper) (*http.Client, error) { if s.username != "" && s.password != "" { - t := &digest.Transport{ - Username: s.username, - Password: s.password, - } - t.Transport = httpTransport + t := transport.NewDigestTransport(s.username, s.password, httpTransport) return t.Client() } if s.accessToken != nil { - tr := &Transport{ - token: s.accessToken, - base: httpTransport, - } + tr := transport.NewAccessTokenTransport(s.accessToken, httpTransport) return &http.Client{Transport: tr}, nil } @@ -107,22 +67,12 @@ func (s *Store) httpClient(httpTransport http.RoundTripper) (*http.Client, error return &http.Client{Transport: httpTransport}, nil } -type Transport struct { - token *atlasauth.Token - base http.RoundTripper -} - -func (tr *Transport) RoundTrip(req *http.Request) (*http.Response, error) { - tr.token.SetAuthHeader(req) - return tr.base.RoundTrip(req) -} - func (s *Store) transport() *http.Transport { switch { case s.telemetry: - return telemetryTransport + return transport.Telemetry() default: - return defaultTransport + return transport.Default() } } diff --git a/internal/transport/transport.go b/internal/transport/transport.go new file mode 100644 index 0000000000..a245bbac7d --- /dev/null +++ b/internal/transport/transport.go @@ -0,0 +1,81 @@ +// Copyright 2024 MongoDB Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "net" + "net/http" + "time" + + "github.com/mongodb-forks/digest" + atlasauth "go.mongodb.org/atlas/auth" +) + +const ( + telemetryTimeout = 1 * time.Second + timeout = 5 * time.Second + keepAlive = 30 * time.Second + maxIdleConns = 5 + maxIdleConnsPerHost = 4 + idleConnTimeout = 30 * time.Second + expectContinueTimeout = 1 * time.Second +) + +func Default() *http.Transport { + return newTransport(timeout) +} + +func Telemetry() *http.Transport { + return newTransport(telemetryTimeout) +} + +func newTransport(timeout time.Duration) *http.Transport { + return &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: timeout, + KeepAlive: keepAlive, + }).DialContext, + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + Proxy: http.ProxyFromEnvironment, + IdleConnTimeout: idleConnTimeout, + ExpectContinueTimeout: expectContinueTimeout, + } +} + +func NewDigestTransport(username, password string, base http.RoundTripper) *digest.Transport { + return &digest.Transport{ + Username: username, + Password: password, + Transport: base, + } +} + +func NewAccessTokenTransport(token *atlasauth.Token, base http.RoundTripper) http.RoundTripper { + return &tokenTransport{ + token: token, + base: base, + } +} + +type tokenTransport struct { + token *atlasauth.Token + base http.RoundTripper +} + +func (tr *tokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + tr.token.SetAuthHeader(req) + return tr.base.RoundTrip(req) +}