From fb9262709086f4bbc504c8d09136c1693c93de31 Mon Sep 17 00:00:00 2001 From: Erik Kristensen Date: Fri, 8 Mar 2024 11:43:04 -0700 Subject: [PATCH] fix(sagemaker-domain): bug with mock service vs real service --- resources/sagemaker-domain.go | 15 +++++++++------ resources/sagemaker-domain_test.go | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/resources/sagemaker-domain.go b/resources/sagemaker-domain.go index 526ef82b..e4d20edd 100644 --- a/resources/sagemaker-domain.go +++ b/resources/sagemaker-domain.go @@ -29,15 +29,18 @@ func init() { } type SageMakerDomainLister struct { - svc sagemakeriface.SageMakerAPI + mockSvc sagemakeriface.SageMakerAPI } func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resource.Resource, error) { opts := o.(*nuke.ListerOpts) // Note: this allows us to override svc in tests with a mock - if l.svc == nil { - l.svc = sagemaker.New(opts.Session) + var svc sagemakeriface.SageMakerAPI + if l.mockSvc != nil { + svc = l.mockSvc + } else { + svc = sagemaker.New(opts.Session) } resources := make([]resource.Resource, 0) @@ -47,7 +50,7 @@ func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resour } for { - resp, err := l.svc.ListDomains(params) + resp, err := svc.ListDomains(params) if err != nil { return nil, err } @@ -57,7 +60,7 @@ func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resour tagParams := &sagemaker.ListTagsInput{ ResourceArn: domain.DomainArn, } - tagOutput, err := l.svc.ListTags(tagParams) + tagOutput, err := svc.ListTags(tagParams) if err != nil { logrus.WithError(err).Errorf("unable to get tags for SageMakerDomain: %s", ptr.ToString(domain.DomainId)) } @@ -66,7 +69,7 @@ func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resour } resources = append(resources, &SageMakerDomain{ - svc: l.svc, + svc: svc, domainID: domain.DomainId, creationTime: domain.CreationTime, tags: tags, diff --git a/resources/sagemaker-domain_test.go b/resources/sagemaker-domain_test.go index 6a71f20a..90dea84b 100644 --- a/resources/sagemaker-domain_test.go +++ b/resources/sagemaker-domain_test.go @@ -26,7 +26,7 @@ func TestSageMakerDomain_List(t *testing.T) { mockSageMaker := mock_sagemakeriface.NewMockSageMakerAPI(ctrl) sagemakerDomainLister := SageMakerDomainLister{ - svc: mockSageMaker, + mockSvc: mockSageMaker, } sagemakerDomain := SageMakerDomain{