From ed8345a6a8d69f16d7506e57d27684715fc346fb Mon Sep 17 00:00:00 2001 From: Jan-Hendrik Boll Date: Mon, 19 Sep 2022 09:29:11 +0200 Subject: [PATCH 1/2] Refactor aws client Added an interface Client, that abstracts AWS client calls. This indirection helps using gomock to create tests. See ec2_test.go as an example for this. As of this commit only the client usage of ec2.go and parts of route53.go have been updated. The rest of the codebase still requires this refactoring. --- go.mod | 1 + go.sum | 23 +++ pkg/awsclient/awsclient.go | 79 ++++++++++ .../mock/zz_generated.mock_client.go | 138 ++++++++++++++++++ pkg/ec2.go | 32 ++-- pkg/ec2_test.go | 47 ++++++ pkg/route53.go | 8 +- 7 files changed, 312 insertions(+), 16 deletions(-) create mode 100644 pkg/awsclient/awsclient.go create mode 100644 pkg/awsclient/mock/zz_generated.mock_client.go create mode 100644 pkg/ec2_test.go diff --git a/go.mod b/go.mod index 6d6a76a..bd09c73 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/aws/aws-sdk-go v1.44.100 github.com/go-kit/kit v0.9.0 + github.com/golang/mock v1.6.0 github.com/prometheus/client_golang v1.4.1 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.9.1 diff --git a/go.sum b/go.sum index 1c9995f..1b5a4ca 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= +github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -87,27 +89,48 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd h1:O7DYs+zxREGLKzKoMQrtrEacpb0ZVXA5rIwylE2Xchk= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/alecthomas/kingpin.v2 v2.2.6 h1:jMFz6MfLP0/4fUyZle81rXUoxOBFi19VUFKVDOQfozc= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/awsclient/awsclient.go b/pkg/awsclient/awsclient.go new file mode 100644 index 0000000..635578b --- /dev/null +++ b/pkg/awsclient/awsclient.go @@ -0,0 +1,79 @@ +// inspired by https://github.com/openshift/aws-account-operator/blob/master/pkg/awsclient/client.go + +package awsclient + +import ( + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/servicequotas" + "github.com/aws/aws-sdk-go/service/servicequotas/servicequotasiface" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" +) + +//go:generate mockgen -source=./awsclient.go -destination=./mock/zz_generated.mock_client.go -package=mock + +// Client is a wrapper object for actual AWS SDK clients to allow for easier testing. +type Client interface { + //EC2 + DescribeTransitGatewaysWithContext(ctx aws.Context, input *ec2.DescribeTransitGatewaysInput, opts ...request.Option) (*ec2.DescribeTransitGatewaysOutput, error) + + // Service Quota + GetServiceQuota(*servicequotas.GetServiceQuotaInput) (*servicequotas.GetServiceQuotaOutput, error) + GetServiceQuotaWithContext(ctx aws.Context, input *servicequotas.GetServiceQuotaInput, opts ...request.Option) (*servicequotas.GetServiceQuotaOutput, error) + RequestServiceQuotaIncrease(*servicequotas.RequestServiceQuotaIncreaseInput) (*servicequotas.RequestServiceQuotaIncreaseOutput, error) + ListRequestedServiceQuotaChangeHistory(*servicequotas.ListRequestedServiceQuotaChangeHistoryInput) (*servicequotas.ListRequestedServiceQuotaChangeHistoryOutput, error) + ListRequestedServiceQuotaChangeHistoryByQuota(*servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaInput) (*servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaOutput, error) +} + +type awsClient struct { + ec2Client ec2iface.EC2API + serviceQuotasClient servicequotasiface.ServiceQuotasAPI +} + +// NewAwsClientInput input for new aws client +type NewAwsClientInput struct { + AwsCredsSecretIDKey string + AwsCredsSecretAccessKey string + AwsToken string + AwsRegion string + SecretName string + NameSpace string +} + +func (c *awsClient) DescribeTransitGatewaysWithContext(ctx aws.Context, input *ec2.DescribeTransitGatewaysInput, opts ...request.Option) (*ec2.DescribeTransitGatewaysOutput, error) { + return c.ec2Client.DescribeTransitGatewaysWithContext(ctx, input, opts...) +} + +func (c *awsClient) DeleteSubnet(input *ec2.DeleteSubnetInput) (*ec2.DeleteSubnetOutput, error) { + return c.ec2Client.DeleteSubnet(input) +} + +func (c *awsClient) GetServiceQuota(input *servicequotas.GetServiceQuotaInput) (*servicequotas.GetServiceQuotaOutput, error) { + return c.serviceQuotasClient.GetServiceQuota(input) +} + +func (c *awsClient) GetServiceQuotaWithContext(ctx aws.Context, input *servicequotas.GetServiceQuotaInput, opts ...request.Option) (*servicequotas.GetServiceQuotaOutput, error) { + return c.serviceQuotasClient.GetServiceQuotaWithContext(ctx, input, opts...) +} + +func (c *awsClient) RequestServiceQuotaIncrease(input *servicequotas.RequestServiceQuotaIncreaseInput) (*servicequotas.RequestServiceQuotaIncreaseOutput, error) { + return c.serviceQuotasClient.RequestServiceQuotaIncrease(input) +} + +func (c *awsClient) ListRequestedServiceQuotaChangeHistory(input *servicequotas.ListRequestedServiceQuotaChangeHistoryInput) (*servicequotas.ListRequestedServiceQuotaChangeHistoryOutput, error) { + return c.serviceQuotasClient.ListRequestedServiceQuotaChangeHistory(input) +} + +func (c *awsClient) ListRequestedServiceQuotaChangeHistoryByQuota(input *servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaInput) (*servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaOutput, error) { + return c.serviceQuotasClient.ListRequestedServiceQuotaChangeHistoryByQuota(input) +} + +func NewClientFromSession(sess *session.Session) Client { + return &awsClient{ + ec2Client: ec2.New(sess), + serviceQuotasClient: servicequotas.New(sess), + } +} diff --git a/pkg/awsclient/mock/zz_generated.mock_client.go b/pkg/awsclient/mock/zz_generated.mock_client.go new file mode 100644 index 0000000..5366980 --- /dev/null +++ b/pkg/awsclient/mock/zz_generated.mock_client.go @@ -0,0 +1,138 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./awsclient.go + +// Package mock is a generated GoMock package. +package mock + +import ( + reflect "reflect" + + aws "github.com/aws/aws-sdk-go/aws" + request "github.com/aws/aws-sdk-go/aws/request" + ec2 "github.com/aws/aws-sdk-go/service/ec2" + servicequotas "github.com/aws/aws-sdk-go/service/servicequotas" + gomock "github.com/golang/mock/gomock" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// DescribeTransitGatewaysWithContext mocks base method. +func (m *MockClient) DescribeTransitGatewaysWithContext(ctx aws.Context, input *ec2.DescribeTransitGatewaysInput, opts ...request.Option) (*ec2.DescribeTransitGatewaysOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, input} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DescribeTransitGatewaysWithContext", varargs...) + ret0, _ := ret[0].(*ec2.DescribeTransitGatewaysOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DescribeTransitGatewaysWithContext indicates an expected call of DescribeTransitGatewaysWithContext. +func (mr *MockClientMockRecorder) DescribeTransitGatewaysWithContext(ctx, input interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, input}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DescribeTransitGatewaysWithContext", reflect.TypeOf((*MockClient)(nil).DescribeTransitGatewaysWithContext), varargs...) +} + +// GetServiceQuota mocks base method. +func (m *MockClient) GetServiceQuota(arg0 *servicequotas.GetServiceQuotaInput) (*servicequotas.GetServiceQuotaOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetServiceQuota", arg0) + ret0, _ := ret[0].(*servicequotas.GetServiceQuotaOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceQuota indicates an expected call of GetServiceQuota. +func (mr *MockClientMockRecorder) GetServiceQuota(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceQuota", reflect.TypeOf((*MockClient)(nil).GetServiceQuota), arg0) +} + +// GetServiceQuotaWithContext mocks base method. +func (m *MockClient) GetServiceQuotaWithContext(ctx aws.Context, input *servicequotas.GetServiceQuotaInput, opts ...request.Option) (*servicequotas.GetServiceQuotaOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, input} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "GetServiceQuotaWithContext", varargs...) + ret0, _ := ret[0].(*servicequotas.GetServiceQuotaOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetServiceQuotaWithContext indicates an expected call of GetServiceQuotaWithContext. +func (mr *MockClientMockRecorder) GetServiceQuotaWithContext(ctx, input interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, input}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServiceQuotaWithContext", reflect.TypeOf((*MockClient)(nil).GetServiceQuotaWithContext), varargs...) +} + +// ListRequestedServiceQuotaChangeHistory mocks base method. +func (m *MockClient) ListRequestedServiceQuotaChangeHistory(arg0 *servicequotas.ListRequestedServiceQuotaChangeHistoryInput) (*servicequotas.ListRequestedServiceQuotaChangeHistoryOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRequestedServiceQuotaChangeHistory", arg0) + ret0, _ := ret[0].(*servicequotas.ListRequestedServiceQuotaChangeHistoryOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRequestedServiceQuotaChangeHistory indicates an expected call of ListRequestedServiceQuotaChangeHistory. +func (mr *MockClientMockRecorder) ListRequestedServiceQuotaChangeHistory(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRequestedServiceQuotaChangeHistory", reflect.TypeOf((*MockClient)(nil).ListRequestedServiceQuotaChangeHistory), arg0) +} + +// ListRequestedServiceQuotaChangeHistoryByQuota mocks base method. +func (m *MockClient) ListRequestedServiceQuotaChangeHistoryByQuota(arg0 *servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaInput) (*servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListRequestedServiceQuotaChangeHistoryByQuota", arg0) + ret0, _ := ret[0].(*servicequotas.ListRequestedServiceQuotaChangeHistoryByQuotaOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListRequestedServiceQuotaChangeHistoryByQuota indicates an expected call of ListRequestedServiceQuotaChangeHistoryByQuota. +func (mr *MockClientMockRecorder) ListRequestedServiceQuotaChangeHistoryByQuota(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListRequestedServiceQuotaChangeHistoryByQuota", reflect.TypeOf((*MockClient)(nil).ListRequestedServiceQuotaChangeHistoryByQuota), arg0) +} + +// RequestServiceQuotaIncrease mocks base method. +func (m *MockClient) RequestServiceQuotaIncrease(arg0 *servicequotas.RequestServiceQuotaIncreaseInput) (*servicequotas.RequestServiceQuotaIncreaseOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RequestServiceQuotaIncrease", arg0) + ret0, _ := ret[0].(*servicequotas.RequestServiceQuotaIncreaseOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RequestServiceQuotaIncrease indicates an expected call of RequestServiceQuotaIncrease. +func (mr *MockClientMockRecorder) RequestServiceQuotaIncrease(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestServiceQuotaIncrease", reflect.TypeOf((*MockClient)(nil).RequestServiceQuotaIncrease), arg0) +} diff --git a/pkg/ec2.go b/pkg/ec2.go index fc740e6..831d424 100644 --- a/pkg/ec2.go +++ b/pkg/ec2.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/app-sre/aws-resource-exporter/pkg/awsclient" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ec2" @@ -70,17 +71,17 @@ func (e *EC2Exporter) CollectLoop() { func (e *EC2Exporter) collectInRegion(sess *session.Session, logger log.Logger, wg *sync.WaitGroup, ctx context.Context) { defer wg.Done() - ec2Svc := ec2.New(sess) - serviceQuotaSvc := servicequotas.New(sess) - quota, err := getQuotaValueWithContext(serviceQuotaSvc, ec2ServiceCode, transitGatewayPerAccountQuotaCode, ctx) + aws := awsclient.NewClientFromSession(sess) + + quota, err := getQuotaValueWithContext(aws, ec2ServiceCode, transitGatewayPerAccountQuotaCode, ctx) if err != nil { level.Error(logger).Log("msg", "Could not retrieve Transit Gateway quota", "error", err.Error()) AwsExporterMetrics.IncrementErrors() return } - gateways, err := getAllTransitGatewaysWithContext(ec2Svc, ctx) + gateways, err := getAllTransitGatewaysWithContext(aws, ctx) if err != nil { level.Error(logger).Log("msg", "Could not retrieve Transit Gateway quota", "error", err.Error()) AwsExporterMetrics.IncrementErrors() @@ -96,13 +97,23 @@ func (e *EC2Exporter) Describe(ch chan<- *prometheus.Desc) { ch <- TransitGatewaysUsage } -func getAllTransitGatewaysWithContext(client *ec2.EC2, ctx context.Context) ([]*ec2.TransitGateway, error) { - results := []*ec2.TransitGateway{} - describeGatewaysInput := &ec2.DescribeTransitGatewaysInput{ +func createDescribeTransitGatewayInput() *ec2.DescribeTransitGatewaysInput { + return &ec2.DescribeTransitGatewaysInput{ DryRun: aws.Bool(false), MaxResults: aws.Int64(1000), } +} +func createGetServiceQuotaInput(serviceCode, quotaCode string) *servicequotas.GetServiceQuotaInput { + return &servicequotas.GetServiceQuotaInput{ + ServiceCode: aws.String(serviceCode), + QuotaCode: aws.String(quotaCode), + } +} + +func getAllTransitGatewaysWithContext(client awsclient.Client, ctx context.Context) ([]*ec2.TransitGateway, error) { + results := []*ec2.TransitGateway{} + describeGatewaysInput := createDescribeTransitGatewayInput() describeGatewaysOutput, err := client.DescribeTransitGatewaysWithContext(ctx, describeGatewaysInput) if err != nil { @@ -122,11 +133,8 @@ func getAllTransitGatewaysWithContext(client *ec2.EC2, ctx context.Context) ([]* return results, nil } -func getQuotaValueWithContext(client *servicequotas.ServiceQuotas, serviceCode string, quotaCode string, ctx context.Context) (float64, error) { - sqOutput, err := client.GetServiceQuotaWithContext(ctx, &servicequotas.GetServiceQuotaInput{ - QuotaCode: aws.String(quotaCode), - ServiceCode: aws.String(serviceCode), - }) +func getQuotaValueWithContext(client awsclient.Client, serviceCode string, quotaCode string, ctx context.Context) (float64, error) { + sqOutput, err := client.GetServiceQuotaWithContext(ctx, createGetServiceQuotaInput(serviceCode, quotaCode)) if err != nil { return 0, err diff --git a/pkg/ec2_test.go b/pkg/ec2_test.go new file mode 100644 index 0000000..f497a10 --- /dev/null +++ b/pkg/ec2_test.go @@ -0,0 +1,47 @@ +package pkg + +import ( + "context" + "testing" + + "github.com/app-sre/aws-resource-exporter/pkg/awsclient/mock" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/servicequotas" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +func TestGetAllTransitGatewaysWithContext(t *testing.T) { + ctx := context.TODO() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mock.NewMockClient(ctrl) + + mockClient.EXPECT().DescribeTransitGatewaysWithContext(ctx, createDescribeTransitGatewayInput()). + Return(&ec2.DescribeTransitGatewaysOutput{ + TransitGateways: []*ec2.TransitGateway{&ec2.TransitGateway{}}, + }, nil) + + gateways, err := getAllTransitGatewaysWithContext(mockClient, ctx) + assert.Nil(t, err) + assert.Len(t, gateways, 1) +} + +func TestGetQuotaValueWithContext(t *testing.T) { + ctx := context.TODO() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mock.NewMockClient(ctrl) + + mockClient.EXPECT().GetServiceQuotaWithContext(ctx, + createGetServiceQuotaInput(ec2ServiceCode, transitGatewayPerAccountQuotaCode)).Return( + &servicequotas.GetServiceQuotaOutput{Quota: &servicequotas.ServiceQuota{Value: aws.Float64(1337.0)}}, nil, + ) + + quotaValue, err := getQuotaValueWithContext(mockClient, ec2ServiceCode, transitGatewayPerAccountQuotaCode, ctx) + assert.Nil(t, err) + assert.Equal(t, quotaValue, 1337.0) +} diff --git a/pkg/route53.go b/pkg/route53.go index eccad15..8c0749c 100644 --- a/pkg/route53.go +++ b/pkg/route53.go @@ -7,11 +7,11 @@ import ( "sync" "time" + "github.com/app-sre/aws-resource-exporter/pkg/awsclient" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/route53" - "github.com/aws/aws-sdk-go/service/servicequotas" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/prometheus/client_golang/prometheus" @@ -98,7 +98,7 @@ func (e *Route53Exporter) getRecordsPerHostedZoneMetrics(client *route53.Route53 return errs } -func (e *Route53Exporter) getHostedZonesPerAccountMetrics(client *servicequotas.ServiceQuotas, hostedZones []*route53.HostedZone, ctx context.Context) error { +func (e *Route53Exporter) getHostedZonesPerAccountMetrics(client awsclient.Client, hostedZones []*route53.HostedZone, ctx context.Context) error { quota, err := getQuotaValueWithContext(client, route53ServiceCode, hostedZonesQuotaCode, ctx) if err != nil { return err @@ -112,7 +112,7 @@ func (e *Route53Exporter) getHostedZonesPerAccountMetrics(client *servicequotas. // CollectLoop runs indefinitely to collect the route53 metrics in a cache. Metrics are only written into the cache once all have been collected to ensure that we don't have a partial collect. func (e *Route53Exporter) CollectLoop() { route53Svc := route53.New(e.sess) - serviceQuotaSvc := servicequotas.New(e.sess) + awsclient := awsclient.NewClientFromSession(e.sess) for { ctx, ctxCancelFunc := context.WithTimeout(context.Background(), e.timeout) @@ -127,7 +127,7 @@ func (e *Route53Exporter) CollectLoop() { AwsExporterMetrics.IncrementErrors() } - err = e.getHostedZonesPerAccountMetrics(serviceQuotaSvc, hostedZones, ctx) + err = e.getHostedZonesPerAccountMetrics(awsclient, hostedZones, ctx) if err != nil { level.Error(e.logger).Log("msg", "Could not get limits for hosted zone", "error", err.Error()) AwsExporterMetrics.IncrementErrors() From f4a29340f6218e9b31f8706b4c0ec54dffaac65d Mon Sep 17 00:00:00 2001 From: Jan-Hendrik Boll Date: Tue, 20 Sep 2022 13:32:39 +0200 Subject: [PATCH 2/2] Review changes --- pkg/awsclient/awsclient.go | 10 ---------- pkg/ec2.go | 2 +- pkg/ec2_test.go | 4 ++-- pkg/route53.go | 4 ++-- 4 files changed, 5 insertions(+), 15 deletions(-) diff --git a/pkg/awsclient/awsclient.go b/pkg/awsclient/awsclient.go index 635578b..4b86369 100644 --- a/pkg/awsclient/awsclient.go +++ b/pkg/awsclient/awsclient.go @@ -33,16 +33,6 @@ type awsClient struct { serviceQuotasClient servicequotasiface.ServiceQuotasAPI } -// NewAwsClientInput input for new aws client -type NewAwsClientInput struct { - AwsCredsSecretIDKey string - AwsCredsSecretAccessKey string - AwsToken string - AwsRegion string - SecretName string - NameSpace string -} - func (c *awsClient) DescribeTransitGatewaysWithContext(ctx aws.Context, input *ec2.DescribeTransitGatewaysInput, opts ...request.Option) (*ec2.DescribeTransitGatewaysOutput, error) { return c.ec2Client.DescribeTransitGatewaysWithContext(ctx, input, opts...) } diff --git a/pkg/ec2.go b/pkg/ec2.go index 831d424..99cb69c 100644 --- a/pkg/ec2.go +++ b/pkg/ec2.go @@ -120,7 +120,7 @@ func getAllTransitGatewaysWithContext(client awsclient.Client, ctx context.Conte return nil, err } results = append(results, describeGatewaysOutput.TransitGateways...) - + // TODO: replace with aws-go-sdk pagination method for describeGatewaysOutput.NextToken != nil { describeGatewaysInput.SetNextToken(*describeGatewaysOutput.NextToken) describeGatewaysOutput, err := client.DescribeTransitGatewaysWithContext(ctx, describeGatewaysInput) diff --git a/pkg/ec2_test.go b/pkg/ec2_test.go index f497a10..bf57eca 100644 --- a/pkg/ec2_test.go +++ b/pkg/ec2_test.go @@ -38,10 +38,10 @@ func TestGetQuotaValueWithContext(t *testing.T) { mockClient.EXPECT().GetServiceQuotaWithContext(ctx, createGetServiceQuotaInput(ec2ServiceCode, transitGatewayPerAccountQuotaCode)).Return( - &servicequotas.GetServiceQuotaOutput{Quota: &servicequotas.ServiceQuota{Value: aws.Float64(1337.0)}}, nil, + &servicequotas.GetServiceQuotaOutput{Quota: &servicequotas.ServiceQuota{Value: aws.Float64(123.0)}}, nil, ) quotaValue, err := getQuotaValueWithContext(mockClient, ec2ServiceCode, transitGatewayPerAccountQuotaCode, ctx) assert.Nil(t, err) - assert.Equal(t, quotaValue, 1337.0) + assert.Equal(t, quotaValue, 123.0) } diff --git a/pkg/route53.go b/pkg/route53.go index 8c0749c..12f4c02 100644 --- a/pkg/route53.go +++ b/pkg/route53.go @@ -112,7 +112,7 @@ func (e *Route53Exporter) getHostedZonesPerAccountMetrics(client awsclient.Clien // CollectLoop runs indefinitely to collect the route53 metrics in a cache. Metrics are only written into the cache once all have been collected to ensure that we don't have a partial collect. func (e *Route53Exporter) CollectLoop() { route53Svc := route53.New(e.sess) - awsclient := awsclient.NewClientFromSession(e.sess) + client := awsclient.NewClientFromSession(e.sess) for { ctx, ctxCancelFunc := context.WithTimeout(context.Background(), e.timeout) @@ -127,7 +127,7 @@ func (e *Route53Exporter) CollectLoop() { AwsExporterMetrics.IncrementErrors() } - err = e.getHostedZonesPerAccountMetrics(awsclient, hostedZones, ctx) + err = e.getHostedZonesPerAccountMetrics(client, hostedZones, ctx) if err != nil { level.Error(e.logger).Log("msg", "Could not get limits for hosted zone", "error", err.Error()) AwsExporterMetrics.IncrementErrors()