Skip to content

Commit

Permalink
Merge pull request #114 from ekristen/fix-sagemaker-domain
Browse files Browse the repository at this point in the history
fix(sagemaker-domain): bug with mock service vs real service
  • Loading branch information
ekristen authored Mar 8, 2024
2 parents 089df08 + fb92627 commit db586ea
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions resources/sagemaker-domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ func init() {
}

type SageMakerDomainLister struct {
svc sagemakeriface.SageMakerAPI
mockSvc sagemakeriface.SageMakerAPI
}

func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resource.Resource, error) {
opts := o.(*nuke.ListerOpts)

// Note: this allows us to override svc in tests with a mock
if l.svc == nil {
l.svc = sagemaker.New(opts.Session)
var svc sagemakeriface.SageMakerAPI
if l.mockSvc != nil {
svc = l.mockSvc
} else {
svc = sagemaker.New(opts.Session)
}

resources := make([]resource.Resource, 0)
Expand All @@ -47,7 +50,7 @@ func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resour
}

for {
resp, err := l.svc.ListDomains(params)
resp, err := svc.ListDomains(params)
if err != nil {
return nil, err
}
Expand All @@ -57,7 +60,7 @@ func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resour
tagParams := &sagemaker.ListTagsInput{
ResourceArn: domain.DomainArn,
}
tagOutput, err := l.svc.ListTags(tagParams)
tagOutput, err := svc.ListTags(tagParams)
if err != nil {
logrus.WithError(err).Errorf("unable to get tags for SageMakerDomain: %s", ptr.ToString(domain.DomainId))
}
Expand All @@ -66,7 +69,7 @@ func (l *SageMakerDomainLister) List(_ context.Context, o interface{}) ([]resour
}

resources = append(resources, &SageMakerDomain{
svc: l.svc,
svc: svc,
domainID: domain.DomainId,
creationTime: domain.CreationTime,
tags: tags,
Expand Down
2 changes: 1 addition & 1 deletion resources/sagemaker-domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestSageMakerDomain_List(t *testing.T) {
mockSageMaker := mock_sagemakeriface.NewMockSageMakerAPI(ctrl)

sagemakerDomainLister := SageMakerDomainLister{
svc: mockSageMaker,
mockSvc: mockSageMaker,
}

sagemakerDomain := SageMakerDomain{
Expand Down

0 comments on commit db586ea

Please sign in to comment.