diff --git a/pkg/auth/api/api.go b/pkg/auth/api/api.go index f730c913b1..d5d4f69635 100644 --- a/pkg/auth/api/api.go +++ b/pkg/auth/api/api.go @@ -90,6 +90,9 @@ type authService struct { signer token.Signer config *auth.OAuthConfig mysqlClient mysql.Client + organizationStorage envstotage.OrganizationStorage + projectStorage envstotage.ProjectStorage + environmentStorage envstotage.EnvironmentStorage accountClient accountclient.Client verifier token.Verifier googleAuthenticator auth.Authenticator @@ -113,13 +116,16 @@ func NewAuthService( } logger := options.logger.Named("api") service := &authService{ - issuer: issuer, - audience: audience, - signer: signer, - config: config, - mysqlClient: mysqlClient, - accountClient: accountClient, - verifier: verifier, + issuer: issuer, + audience: audience, + signer: signer, + config: config, + mysqlClient: mysqlClient, + organizationStorage: envstotage.NewOrganizationStorage(mysqlClient), + environmentStorage: envstotage.NewEnvironmentStorage(mysqlClient), + projectStorage: envstotage.NewProjectStorage(mysqlClient), + accountClient: accountClient, + verifier: verifier, googleAuthenticator: google.NewAuthenticator( &config.GoogleConfig, logger, ), @@ -800,19 +806,14 @@ func (s *authService) PrepareDemoUser() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error("Create mysql tx error", zap.Error(err)) - return - } now := time.Now() - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { + var err error + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { // Create a demo organization if not exists - organizationStorage := envstotage.NewOrganizationStorage(tx) - _, err = organizationStorage.GetOrganization(ctx, config.OrganizationId) + _, err = s.organizationStorage.GetOrganization(contextWithTx, config.OrganizationId) if err != nil { if errors.Is(err, envstotage.ErrOrganizationNotFound) { - err = organizationStorage.CreateOrganization(ctx, &envdomain.Organization{ + err = s.organizationStorage.CreateOrganization(contextWithTx, &envdomain.Organization{ Organization: &envproto.Organization{ Id: config.OrganizationId, Name: "Demo organization", @@ -833,11 +834,10 @@ func (s *authService) PrepareDemoUser() { } } // Create a demo project if not exists - projectStorage := envstotage.NewProjectStorage(tx) - _, err = projectStorage.GetProject(ctx, config.ProjectId) + _, err = s.projectStorage.GetProject(contextWithTx, config.ProjectId) if err != nil { if errors.Is(err, envstotage.ErrProjectNotFound) { - err = projectStorage.CreateProject(ctx, &envdomain.Project{ + err = s.projectStorage.CreateProject(contextWithTx, &envdomain.Project{ Project: &envproto.Project{ Id: config.ProjectId, Description: "This project is for demo users", @@ -857,11 +857,10 @@ func (s *authService) PrepareDemoUser() { } } // Create a demo environment if not exists - environmentStorage := envstotage.NewEnvironmentStorage(tx) - _, err = environmentStorage.GetEnvironmentV2(ctx, config.EnvironmentId) + _, err = s.environmentStorage.GetEnvironmentV2(contextWithTx, config.EnvironmentId) if err != nil { if errors.Is(err, envstotage.ErrEnvironmentNotFound) { - err = environmentStorage.CreateEnvironmentV2(ctx, &envdomain.EnvironmentV2{ + err = s.environmentStorage.CreateEnvironmentV2(contextWithTx, &envdomain.EnvironmentV2{ EnvironmentV2: &envproto.EnvironmentV2{ Id: config.EnvironmentId, Name: "Demo", diff --git a/pkg/environment/api/api.go b/pkg/environment/api/api.go index 677577e454..364d9635b2 100644 --- a/pkg/environment/api/api.go +++ b/pkg/environment/api/api.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc/status" accountclient "github.com/bucketeer-io/bucketeer/pkg/account/client" + v2 "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2" "github.com/bucketeer-io/bucketeer/pkg/locale" "github.com/bucketeer-io/bucketeer/pkg/log" "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher" @@ -47,11 +48,14 @@ func WithLogger(l *zap.Logger) Option { } type EnvironmentService struct { - accountClient accountclient.Client - mysqlClient mysql.Client - publisher publisher.Publisher - opts *options - logger *zap.Logger + accountClient accountclient.Client + mysqlClient mysql.Client + projectStorage v2.ProjectStorage + orgStorage v2.OrganizationStorage + environmentStorage v2.EnvironmentStorage + publisher publisher.Publisher + opts *options + logger *zap.Logger } func NewEnvironmentService( @@ -67,11 +71,14 @@ func NewEnvironmentService( opt(dopts) } return &EnvironmentService{ - accountClient: ac, - mysqlClient: mysqlClient, - publisher: publisher, - opts: dopts, - logger: dopts.logger.Named("api"), + accountClient: ac, + mysqlClient: mysqlClient, + projectStorage: v2.NewProjectStorage(mysqlClient), + orgStorage: v2.NewOrganizationStorage(mysqlClient), + environmentStorage: v2.NewEnvironmentStorage(mysqlClient), + publisher: publisher, + opts: dopts, + logger: dopts.logger.Named("api"), } } diff --git a/pkg/environment/api/api_test.go b/pkg/environment/api/api_test.go index 450cff9b60..c25e6ea261 100644 --- a/pkg/environment/api/api_test.go +++ b/pkg/environment/api/api_test.go @@ -25,6 +25,7 @@ import ( "go.uber.org/zap" acmock "github.com/bucketeer-io/bucketeer/pkg/account/client/mock" + storagemock "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/log" publishermock "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher/mock" "github.com/bucketeer-io/bucketeer/pkg/rpc" @@ -77,9 +78,12 @@ func newEnvironmentService(t *testing.T, mockController *gomock.Controller, s st logger, err := log.NewLogger() require.NoError(t, err) return &EnvironmentService{ - accountClient: acmock.NewMockClient(mockController), - mysqlClient: mysqlmock.NewMockClient(mockController), - publisher: publishermock.NewMockPublisher(mockController), - logger: logger.Named("api"), + accountClient: acmock.NewMockClient(mockController), + mysqlClient: mysqlmock.NewMockClient(mockController), + orgStorage: storagemock.NewMockOrganizationStorage(mockController), + projectStorage: storagemock.NewMockProjectStorage(mockController), + environmentStorage: storagemock.NewMockEnvironmentStorage(mockController), + publisher: publishermock.NewMockPublisher(mockController), + logger: logger.Named("api"), } } diff --git a/pkg/environment/api/environment_v2.go b/pkg/environment/api/environment_v2.go index f2b2b86c49..61359e8802 100644 --- a/pkg/environment/api/environment_v2.go +++ b/pkg/environment/api/environment_v2.go @@ -53,8 +53,7 @@ func (s *EnvironmentService) GetEnvironmentV2( if err := validateGetEnvironmentV2Request(req, localizer); err != nil { return nil, err } - environmentStorage := v2es.NewEnvironmentStorage(s.mysqlClient) - environment, err := environmentStorage.GetEnvironmentV2(ctx, req.Id) + environment, err := s.environmentStorage.GetEnvironmentV2(ctx, req.Id) if err != nil { if err == v2es.ErrEnvironmentNotFound { dt, err := statusEnvironmentNotFound.WithDetails(&errdetails.LocalizedMessage{ @@ -149,8 +148,7 @@ func (s *EnvironmentService) ListEnvironmentsV2( } return nil, dt.Err() } - environmentStorage := v2es.NewEnvironmentStorage(s.mysqlClient) - environments, nextCursor, totalCount, err := environmentStorage.ListEnvironmentsV2( + environments, nextCursor, totalCount, err := s.environmentStorage.ListEnvironmentsV2( ctx, whereParts, orders, @@ -287,7 +285,6 @@ func (s *EnvironmentService) createEnvironmentV2NoCommand( } err = s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, tx mysql.Transaction) error { - environmentStorage := v2es.NewEnvironmentStorage(s.mysqlClient) e, err := domainevent.NewAdminEvent( editor, eventproto.Event_ENVIRONMENT, @@ -313,7 +310,7 @@ func (s *EnvironmentService) createEnvironmentV2NoCommand( if err := s.publisher.Publish(ctx, e); err != nil { return err } - return environmentStorage.CreateEnvironmentV2(ctxWithTx, newEnvironment) + return s.environmentStorage.CreateEnvironmentV2(ctxWithTx, newEnvironment) }) if err != nil { if errors.Is(err, v2es.ErrEnvironmentAlreadyExists) { @@ -470,25 +467,7 @@ func (s *EnvironmentService) createEnvironmentV2( editor *eventproto.Editor, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - environmentStorage := v2es.NewEnvironmentStorage(tx) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { handler, err := command.NewEnvironmentV2CommandHandler(editor, environment, s.publisher) if err != nil { return err @@ -496,7 +475,7 @@ func (s *EnvironmentService) createEnvironmentV2( if err := handler.Handle(ctx, cmd); err != nil { return err } - return environmentStorage.CreateEnvironmentV2(ctx, environment) + return s.environmentStorage.CreateEnvironmentV2(contextWithTx, environment) }) if err != nil { if errors.Is(err, v2es.ErrEnvironmentAlreadyExists) { @@ -560,8 +539,7 @@ func (s *EnvironmentService) updateEnvironmentV2NoCommand( } err := s.mysqlClient.RunInTransactionV2(ctx, func(ctxWithTx context.Context, tx mysql.Transaction) error { - environmentStorage := v2es.NewEnvironmentStorage(s.mysqlClient) - environment, err := environmentStorage.GetEnvironmentV2(ctxWithTx, req.Id) + environment, err := s.environmentStorage.GetEnvironmentV2(ctxWithTx, req.Id) if err != nil { return err } @@ -589,7 +567,7 @@ func (s *EnvironmentService) updateEnvironmentV2NoCommand( if err := s.publisher.Publish(ctx, event); err != nil { return err } - return environmentStorage.UpdateEnvironmentV2(ctxWithTx, updated) + return s.environmentStorage.UpdateEnvironmentV2(ctxWithTx, updated) }) if err != nil { if errors.Is(err, v2es.ErrEnvironmentNotFound) || errors.Is(err, v2es.ErrEnvironmentUnexpectedAffectedRows) { @@ -625,26 +603,8 @@ func (s *EnvironmentService) updateEnvironmentV2( editor *eventproto.Editor, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - environmentStorage := v2es.NewEnvironmentStorage(tx) - environment, err := environmentStorage.GetEnvironmentV2(ctx, envId) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + environment, err := s.environmentStorage.GetEnvironmentV2(contextWithTx, envId) if err != nil { return err } @@ -657,7 +617,7 @@ func (s *EnvironmentService) updateEnvironmentV2( return err } } - return environmentStorage.UpdateEnvironmentV2(ctx, environment) + return s.environmentStorage.UpdateEnvironmentV2(contextWithTx, environment) }) if err != nil { if errors.Is(err, v2es.ErrEnvironmentNotFound) || errors.Is(err, v2es.ErrEnvironmentUnexpectedAffectedRows) { @@ -772,26 +732,8 @@ func (s *EnvironmentService) ArchiveEnvironmentV2( if err := validateArchiveEnvironmentV2Request(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - environmentStorage := v2es.NewEnvironmentStorage(tx) - environment, err := environmentStorage.GetEnvironmentV2(ctx, req.Id) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + environment, err := s.environmentStorage.GetEnvironmentV2(contextWithTx, req.Id) if err != nil { return err } @@ -802,7 +744,7 @@ func (s *EnvironmentService) ArchiveEnvironmentV2( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return environmentStorage.UpdateEnvironmentV2(ctx, environment) + return s.environmentStorage.UpdateEnvironmentV2(contextWithTx, environment) }) if err != nil { if err == v2es.ErrEnvironmentNotFound || err == v2es.ErrEnvironmentUnexpectedAffectedRows { @@ -861,26 +803,8 @@ func (s *EnvironmentService) UnarchiveEnvironmentV2( if err := validateUnarchiveEnvironmentV2Request(req, localizer); err != nil { return nil, err } - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - environmentStorage := v2es.NewEnvironmentStorage(tx) - environment, err := environmentStorage.GetEnvironmentV2(ctx, req.Id) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + environment, err := s.environmentStorage.GetEnvironmentV2(contextWithTx, req.Id) if err != nil { return err } @@ -891,7 +815,7 @@ func (s *EnvironmentService) UnarchiveEnvironmentV2( if err := handler.Handle(ctx, req.Command); err != nil { return err } - return environmentStorage.UpdateEnvironmentV2(ctx, environment) + return s.environmentStorage.UpdateEnvironmentV2(contextWithTx, environment) }) if err != nil { if err == v2es.ErrEnvironmentNotFound || err == v2es.ErrEnvironmentUnexpectedAffectedRows { diff --git a/pkg/environment/api/environment_v2_test.go b/pkg/environment/api/environment_v2_test.go index fe58fc47a3..8b962d78b9 100644 --- a/pkg/environment/api/environment_v2_test.go +++ b/pkg/environment/api/environment_v2_test.go @@ -15,6 +15,7 @@ package api import ( + "context" "errors" "strings" "testing" @@ -29,7 +30,9 @@ import ( "github.com/bucketeer-io/bucketeer/pkg/environment/domain" v2es "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2" + storagemock "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/locale" + publishermock "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher/mock" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" mysqlmock "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql/mock" proto "github.com/bucketeer-io/bucketeer/proto/environment" @@ -63,11 +66,9 @@ func TestGetEnvironmentV2(t *testing.T) { { desc: "err: ErrEnvironmentNotFound", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().GetEnvironmentV2( + gomock.Any(), gomock.Any(), + ).Return(nil, v2es.ErrEnvironmentNotFound) }, id: "id-0", expectedErr: createError(statusEnvironmentNotFound, localizer.MustLocalize(locale.NotFoundError)), @@ -75,11 +76,9 @@ func TestGetEnvironmentV2(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().GetEnvironmentV2( + gomock.Any(), gomock.Any(), + ).Return(nil, errors.New("error")) }, id: "id-1", expectedErr: createError(statusInternal, localizer.MustLocalize(locale.InternalServerError)), @@ -87,11 +86,9 @@ func TestGetEnvironmentV2(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().GetEnvironmentV2( + gomock.Any(), gomock.Any(), + ).Return(&domain.EnvironmentV2{}, nil) }, id: "id-3", expectedErr: nil, @@ -152,9 +149,9 @@ func TestListEnvironmentsV2(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil, errors.New("error")) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().ListEnvironmentsV2( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, 0, int64(0), errors.New("error")) }, input: &proto.ListEnvironmentsV2Request{}, expected: nil, @@ -163,18 +160,9 @@ func TestListEnvironmentsV2(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().ListEnvironmentsV2( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.EnvironmentV2{}, 0, int64(0), nil) }, input: &proto.ListEnvironmentsV2Request{PageSize: 2, Cursor: ""}, expected: &proto.ListEnvironmentsV2Response{Environments: []*proto.EnvironmentV2{}, Cursor: "0"}, @@ -334,11 +322,9 @@ func TestCreateEnvironmentV2(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(nil, v2es.ErrProjectNotFound) }, req: &proto.CreateEnvironmentV2Request{ Command: &proto.CreateEnvironmentV2Command{Name: "name", UrlCode: "url-code", ProjectId: "project-id"}, @@ -348,14 +334,13 @@ func TestCreateEnvironmentV2(t *testing.T) { { desc: "err: ErrEnvironmentAlreadyExists", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrEnvironmentAlreadyExists) }, req: &proto.CreateEnvironmentV2Request{ @@ -366,11 +351,14 @@ func TestCreateEnvironmentV2(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Return(errors.New("error")) }, req: &proto.CreateEnvironmentV2Request{ Command: &proto.CreateEnvironmentV2Command{Name: "name", UrlCode: "url-code", ProjectId: "project-id"}, @@ -380,14 +368,19 @@ func TestCreateEnvironmentV2(t *testing.T) { { desc: "success: require comment is true", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().CreateEnvironmentV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateEnvironmentV2Request{ @@ -404,14 +397,19 @@ func TestCreateEnvironmentV2(t *testing.T) { { desc: "success: require comment is false", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().CreateEnvironmentV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateEnvironmentV2Request{ @@ -595,11 +593,9 @@ func TestCreateEnvironmentV2NoCommand(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(nil, v2es.ErrProjectNotFound) }, req: &proto.CreateEnvironmentV2Request{ Name: "name", @@ -611,11 +607,11 @@ func TestCreateEnvironmentV2NoCommand(t *testing.T) { { desc: "err: ErrEnvironmentAlreadyExists", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), ).Return(v2es.ErrEnvironmentAlreadyExists) @@ -630,11 +626,9 @@ func TestCreateEnvironmentV2NoCommand(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(nil, errors.New("error")) }, req: &proto.CreateEnvironmentV2Request{ Name: "name", @@ -646,13 +640,19 @@ func TestCreateEnvironmentV2NoCommand(t *testing.T) { { desc: "success: require comment is true", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().CreateEnvironmentV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateEnvironmentV2Request{ @@ -667,13 +667,19 @@ func TestCreateEnvironmentV2NoCommand(t *testing.T) { { desc: "success: require comment is false", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "project-id"}, + }, nil) s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().CreateEnvironmentV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateEnvironmentV2Request{ @@ -763,9 +769,8 @@ func TestUpdateEnvironmentV2(t *testing.T) { { desc: "err: ErrEnvironmentNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrEnvironmentNotFound) }, req: &proto.UpdateEnvironmentV2Request{ @@ -777,9 +782,8 @@ func TestUpdateEnvironmentV2(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.UpdateEnvironmentV2Request{ @@ -791,9 +795,19 @@ func TestUpdateEnvironmentV2(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().GetEnvironmentV2( + gomock.Any(), gomock.Any(), + ).Return(&domain.EnvironmentV2{ + EnvironmentV2: &proto.EnvironmentV2{}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().UpdateEnvironmentV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.UpdateEnvironmentV2Request{ @@ -950,9 +964,8 @@ func TestArchiveEnvironmentV2(t *testing.T) { { desc: "err: ErrEnvironmentNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrEnvironmentNotFound) }, req: &proto.ArchiveEnvironmentV2Request{Id: "id01", Command: &proto.ArchiveEnvironmentV2Command{}}, @@ -964,9 +977,8 @@ func TestArchiveEnvironmentV2(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.ArchiveEnvironmentV2Request{Id: "id02", Command: &proto.ArchiveEnvironmentV2Command{}}, @@ -975,9 +987,19 @@ func TestArchiveEnvironmentV2(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().GetEnvironmentV2( + gomock.Any(), gomock.Any(), + ).Return(&domain.EnvironmentV2{ + EnvironmentV2: &proto.EnvironmentV2{}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().UpdateEnvironmentV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.ArchiveEnvironmentV2Request{Id: "id01", Command: &proto.ArchiveEnvironmentV2Command{}}, @@ -1024,9 +1046,8 @@ func TestUnarchiveEnvironmentV2(t *testing.T) { { desc: "err: ErrEnvironmentNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrEnvironmentNotFound) }, req: &proto.UnarchiveEnvironmentV2Request{Id: "id01", Command: &proto.UnarchiveEnvironmentV2Command{}}, @@ -1038,9 +1059,8 @@ func TestUnarchiveEnvironmentV2(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.UnarchiveEnvironmentV2Request{Id: "id02", Command: &proto.UnarchiveEnvironmentV2Command{}}, @@ -1049,9 +1069,8 @@ func TestUnarchiveEnvironmentV2(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.UnarchiveEnvironmentV2Request{Id: "id01", Command: &proto.UnarchiveEnvironmentV2Command{}}, diff --git a/pkg/environment/api/organization.go b/pkg/environment/api/organization.go index 4033e05b1d..e7854eb1b2 100644 --- a/pkg/environment/api/organization.go +++ b/pkg/environment/api/organization.go @@ -92,8 +92,7 @@ func (s *EnvironmentService) getOrganization( id string, localizer locale.Localizer, ) (*domain.Organization, error) { - orgStorage := v2es.NewOrganizationStorage(s.mysqlClient) - org, err := orgStorage.GetOrganization(ctx, id) + org, err := s.orgStorage.GetOrganization(ctx, id) if err != nil { if errors.Is(err, v2es.ErrOrganizationNotFound) { dt, err := statusOrganizationNotFound.WithDetails(&errdetails.LocalizedMessage{ @@ -168,8 +167,7 @@ func (s *EnvironmentService) ListOrganizations( } return nil, dt.Err() } - orgStorage := v2es.NewOrganizationStorage(s.mysqlClient) - organizations, nextCursor, totalCount, err := orgStorage.ListOrganizations(ctx, whereParts, orders, limit, offset) + organizations, nextCursor, totalCount, err := s.orgStorage.ListOrganizations(ctx, whereParts, orders, limit, offset) if err != nil { s.logger.Error( "failed to list organizations", @@ -371,29 +369,10 @@ func (s *EnvironmentService) createOrganizationNoCommand( return nil, statusInternal.Err() } var envRoles []*accountproto.AccountV2_EnvironmentRole - // Begin the SQL transaction - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return nil, statusInternal.Err() - } - return nil, dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - orgStorage := v2es.NewOrganizationStorage(tx) + err = s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { // Check if there is already a system admin organization if organization.Organization.SystemAdmin { - org, err := orgStorage.GetSystemAdminOrganization(ctx) + org, err := s.orgStorage.GetSystemAdminOrganization(contextWithTx) if err != nil { return err } @@ -401,13 +380,12 @@ func (s *EnvironmentService) createOrganizationNoCommand( return v2es.ErrOrganizationAlreadyExists } } - if err := orgStorage.CreateOrganization(ctx, organization); err != nil { + if err := s.orgStorage.CreateOrganization(contextWithTx, organization); err != nil { return err } // Create a default project project, err := s.createDefaultProject( - ctx, - tx, + contextWithTx, organization.Id, organization.OwnerEmail, ) @@ -422,8 +400,7 @@ func (s *EnvironmentService) createOrganizationNoCommand( } // Create default environments envRoles, err = s.createDefaultEnvironments( - ctx, - tx, + contextWithTx, organization.Id, project, ) @@ -574,27 +551,9 @@ func (s *EnvironmentService) createOrganization( editor *eventproto.Editor, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - orgStorage := v2es.NewOrganizationStorage(tx) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { if organization.Organization.SystemAdmin { - org, err := orgStorage.GetSystemAdminOrganization(ctx) + org, err := s.orgStorage.GetSystemAdminOrganization(contextWithTx) if err != nil { return err } @@ -609,7 +568,7 @@ func (s *EnvironmentService) createOrganization( if err := handler.Handle(ctx, cmd); err != nil { return err } - return orgStorage.CreateOrganization(ctx, organization) + return s.orgStorage.CreateOrganization(contextWithTx, organization) }) if err != nil { if errors.Is(err, v2es.ErrOrganizationAlreadyExists) { @@ -646,7 +605,6 @@ func (s *EnvironmentService) createOrganization( // To create it we need the project, so we can also create the environment. func (s *EnvironmentService) createDefaultProject( ctx context.Context, - tx mysql.Transaction, organizationID, email string, ) (*domain.Project, error) { project, err := domain.NewProject( @@ -660,8 +618,7 @@ func (s *EnvironmentService) createDefaultProject( if err != nil { return nil, err } - projectStorage := v2es.NewProjectStorage(tx) - if err := projectStorage.CreateProject(ctx, project); err != nil { + if err := s.projectStorage.CreateProject(ctx, project); err != nil { return nil, err } return project, nil @@ -673,7 +630,6 @@ func (s *EnvironmentService) createDefaultProject( // and to create it we need the organization and environment roles. func (s *EnvironmentService) createDefaultEnvironments( ctx context.Context, - tx mysql.Transaction, organizationID string, project *domain.Project, ) ([]*accountproto.AccountV2_EnvironmentRole, error) { @@ -696,8 +652,7 @@ func (s *EnvironmentService) createDefaultEnvironments( if err != nil { return nil, err } - environmentStorage := v2es.NewEnvironmentStorage(tx) - if err := environmentStorage.CreateEnvironmentV2(ctx, env); err != nil { + if err := s.environmentStorage.CreateEnvironmentV2(ctx, env); err != nil { return nil, err } envRoles = append(envRoles, &accountproto.AccountV2_EnvironmentRole{ @@ -1000,28 +955,10 @@ func (s *EnvironmentService) updateOrganization( localizer locale.Localizer, commands ...command.Command, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } var prevOwnerEmail string var newOwnerEmail string - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - orgStorage := v2es.NewOrganizationStorage(tx) - organization, err := orgStorage.GetOrganization(ctx, id) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + organization, err := s.orgStorage.GetOrganization(contextWithTx, id) if err != nil { return err } @@ -1039,7 +976,7 @@ func (s *EnvironmentService) updateOrganization( if prevOwnerEmail != organization.OwnerEmail { newOwnerEmail = organization.OwnerEmail } - return orgStorage.UpdateOrganization(ctx, organization) + return s.orgStorage.UpdateOrganization(contextWithTx, organization) }) if err != nil { return s.reportUpdateOrganizationError(ctx, err, localizer) diff --git a/pkg/environment/api/organization_test.go b/pkg/environment/api/organization_test.go index 5efb766649..b0cf6b1440 100644 --- a/pkg/environment/api/organization_test.go +++ b/pkg/environment/api/organization_test.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" "strings" "testing" @@ -17,6 +18,7 @@ import ( "github.com/bucketeer-io/bucketeer/pkg/environment/domain" v2es "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2" + storagemock "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/locale" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" mysqlmock "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql/mock" @@ -57,11 +59,9 @@ func TestGetOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(nil, v2es.ErrOrganizationNotFound) }, id: "err-id-0", expectedErr: createError(statusOrganizationNotFound, localizer.MustLocalize(locale.NotFoundError)), @@ -69,11 +69,9 @@ func TestGetOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(nil, errors.New("error")) }, id: "err-id-1", expectedErr: createError(statusInternal, localizer.MustLocalize(locale.InternalServerError)), @@ -81,11 +79,11 @@ func TestGetOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "success-id-0"}, + }, nil) }, id: "success-id-0", expectedErr: nil, @@ -143,9 +141,9 @@ func TestListOrganizationsMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil, errors.New("error")) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().ListOrganizations( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, 0, int64(0), errors.New("error")) }, input: &proto.ListOrganizationsRequest{}, expected: nil, @@ -154,18 +152,9 @@ func TestListOrganizationsMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().ListOrganizations( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Organization{}, 0, int64(0), nil) }, input: &proto.ListOrganizationsRequest{PageSize: 2, Cursor: ""}, expected: &proto.ListOrganizationsResponse{Organizations: []*proto.Organization{}, Cursor: "0", TotalCount: 0}, @@ -280,9 +269,8 @@ func TestCreateOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationAlreadyExists: duplicate id", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationAlreadyExists) }, req: &proto.CreateOrganizationRequest{ @@ -293,9 +281,8 @@ func TestCreateOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.CreateOrganizationRequest{ @@ -306,9 +293,14 @@ func TestCreateOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().CreateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateOrganizationRequest{ @@ -326,9 +318,14 @@ func TestCreateOrganizationMySQL(t *testing.T) { { desc: "success trial", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().CreateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateOrganizationRequest{ @@ -422,9 +419,8 @@ func TestUpdateOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationNotFound) }, req: &proto.UpdateOrganizationRequest{ @@ -436,9 +432,8 @@ func TestUpdateOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.UpdateOrganizationRequest{ @@ -450,9 +445,19 @@ func TestUpdateOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "success-id-0"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().UpdateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.UpdateOrganizationRequest{ @@ -625,9 +630,8 @@ func TestEnableOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationNotFound) }, req: &proto.EnableOrganizationRequest{ @@ -639,9 +643,8 @@ func TestEnableOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.EnableOrganizationRequest{ @@ -653,10 +656,20 @@ func TestEnableOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "id-1"}, + }, nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().UpdateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) }, req: &proto.EnableOrganizationRequest{ Id: "id-1", @@ -721,9 +734,8 @@ func TestDisableOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationNotFound) }, req: &proto.DisableOrganizationRequest{ @@ -735,9 +747,8 @@ func TestDisableOrganizationMySQL(t *testing.T) { { desc: "err: ErrCannotUpdateSystemAdmin", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(domain.ErrCannotDisableSystemAdmin) }, req: &proto.DisableOrganizationRequest{ @@ -749,9 +760,8 @@ func TestDisableOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.DisableOrganizationRequest{ @@ -763,10 +773,20 @@ func TestDisableOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "id-1"}, + }, nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().UpdateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) }, req: &proto.DisableOrganizationRequest{ Id: "id-1", @@ -831,9 +851,8 @@ func TestArchiveOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationNotFound) }, req: &proto.ArchiveOrganizationRequest{ @@ -845,9 +864,8 @@ func TestArchiveOrganizationMySQL(t *testing.T) { { desc: "err: ErrCannotUpdateSystemAdmin", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(domain.ErrCannotArchiveSystemAdmin) }, req: &proto.ArchiveOrganizationRequest{ @@ -859,9 +877,8 @@ func TestArchiveOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.ArchiveOrganizationRequest{ @@ -873,9 +890,19 @@ func TestArchiveOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "id-1"}, + }, nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().UpdateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.ArchiveOrganizationRequest{ @@ -941,9 +968,8 @@ func TestUnarchiveOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationNotFound) }, req: &proto.UnarchiveOrganizationRequest{ @@ -955,9 +981,8 @@ func TestUnarchiveOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.UnarchiveOrganizationRequest{ @@ -969,9 +994,19 @@ func TestUnarchiveOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "id-1"}, + }, nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().UpdateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.UnarchiveOrganizationRequest{ @@ -1037,9 +1072,8 @@ func TestConvertTrialOrganizationMySQL(t *testing.T) { { desc: "err: ErrOrganizationNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrOrganizationNotFound) }, req: &proto.ConvertTrialOrganizationRequest{ @@ -1051,9 +1085,8 @@ func TestConvertTrialOrganizationMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.ConvertTrialOrganizationRequest{ @@ -1065,10 +1098,20 @@ func TestConvertTrialOrganizationMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().GetOrganization( + gomock.Any(), gomock.Any(), + ).Return(&domain.Organization{ + Organization: &proto.Organization{Id: "id-1"}, + }, nil) + s.orgStorage.(*storagemock.MockOrganizationStorage).EXPECT().UpdateOrganization( + gomock.Any(), gomock.Any(), ).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) }, req: &proto.ConvertTrialOrganizationRequest{ Id: "id-1", diff --git a/pkg/environment/api/project.go b/pkg/environment/api/project.go index 2bb77ccf54..1279e2ed60 100644 --- a/pkg/environment/api/project.go +++ b/pkg/environment/api/project.go @@ -93,8 +93,7 @@ func (s *EnvironmentService) getProject( id string, localizer locale.Localizer, ) (*domain.Project, error) { - projectStorage := v2es.NewProjectStorage(s.mysqlClient) - project, err := projectStorage.GetProject(ctx, id) + project, err := s.projectStorage.GetProject(ctx, id) if err != nil { if err == v2es.ErrProjectNotFound { dt, err := statusProjectNotFound.WithDetails(&errdetails.LocalizedMessage{ @@ -168,8 +167,7 @@ func (s *EnvironmentService) ListProjects( } return nil, dt.Err() } - projectStorage := v2es.NewProjectStorage(s.mysqlClient) - projects, nextCursor, totalCount, err := projectStorage.ListProjects( + projects, nextCursor, totalCount, err := s.projectStorage.ListProjects( ctx, whereParts, orders, @@ -502,25 +500,7 @@ func (s *EnvironmentService) createProject( editor *eventproto.Editor, localizer locale.Localizer, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - projectStorage := v2es.NewProjectStorage(tx) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { handler, err := command.NewProjectCommandHandler(editor, project, s.publisher) if err != nil { return err @@ -528,7 +508,7 @@ func (s *EnvironmentService) createProject( if err := handler.Handle(ctx, cmd); err != nil { return err } - return projectStorage.CreateProject(ctx, project) + return s.projectStorage.CreateProject(contextWithTx, project) }) if err != nil { if err == v2es.ErrProjectAlreadyExists { @@ -721,8 +701,7 @@ func (s *EnvironmentService) getTrialProjectByEmail( email string, localizer locale.Localizer, ) (*environmentproto.Project, error) { - projectStorage := v2es.NewProjectStorage(s.mysqlClient) - project, err := projectStorage.GetTrialProjectByEmail(ctx, email, false, true) + project, err := s.projectStorage.GetTrialProjectByEmail(ctx, email, false, true) if err != nil { if err == v2es.ErrProjectNotFound { dt, err := statusProjectNotFound.WithDetails(&errdetails.LocalizedMessage{ @@ -1036,26 +1015,8 @@ func (s *EnvironmentService) updateProject( localizer locale.Localizer, commands ...command.Command, ) error { - tx, err := s.mysqlClient.BeginTx(ctx) - if err != nil { - s.logger.Error( - "Failed to begin transaction", - log.FieldsFromImcomingContext(ctx).AddFields( - zap.Error(err), - )..., - ) - dt, err := statusInternal.WithDetails(&errdetails.LocalizedMessage{ - Locale: localizer.GetLocale(), - Message: localizer.MustLocalize(locale.InternalServerError), - }) - if err != nil { - return statusInternal.Err() - } - return dt.Err() - } - err = s.mysqlClient.RunInTransaction(ctx, tx, func() error { - projectStorage := v2es.NewProjectStorage(tx) - project, err := projectStorage.GetProject(ctx, id) + err := s.mysqlClient.RunInTransactionV2(ctx, func(contextWithTx context.Context, _ mysql.Transaction) error { + project, err := s.projectStorage.GetProject(contextWithTx, id) if err != nil { return err } @@ -1068,7 +1029,7 @@ func (s *EnvironmentService) updateProject( return err } } - return projectStorage.UpdateProject(ctx, project) + return s.projectStorage.UpdateProject(contextWithTx, project) }) if err != nil { if err == v2es.ErrProjectNotFound || err == v2es.ErrProjectUnexpectedAffectedRows { @@ -1289,8 +1250,7 @@ func (s *EnvironmentService) ListProjectsV2( } return nil, dt.Err() } - projectStorage := v2es.NewProjectStorage(s.mysqlClient) - projects, nextCursor, totalCount, err := projectStorage.ListProjects( + projects, nextCursor, totalCount, err := s.projectStorage.ListProjects( ctx, whereParts, orders, diff --git a/pkg/environment/api/project_test.go b/pkg/environment/api/project_test.go index 97b2e50fe9..651187884a 100644 --- a/pkg/environment/api/project_test.go +++ b/pkg/environment/api/project_test.go @@ -32,7 +32,9 @@ import ( acmock "github.com/bucketeer-io/bucketeer/pkg/account/client/mock" "github.com/bucketeer-io/bucketeer/pkg/environment/domain" v2es "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2" + storagemock "github.com/bucketeer-io/bucketeer/pkg/environment/storage/v2/mock" "github.com/bucketeer-io/bucketeer/pkg/locale" + publishermock "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher/mock" pubmock "github.com/bucketeer-io/bucketeer/pkg/pubsub/publisher/mock" "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql" mysqlmock "github.com/bucketeer-io/bucketeer/pkg/storage/v2/mysql/mock" @@ -76,11 +78,9 @@ func TestGetProjectMySQL(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(nil, v2es.ErrProjectNotFound) }, id: "err-id-0", expectedErr: createError(statusProjectNotFound, localizer.MustLocalize(locale.NotFoundError)), @@ -88,11 +88,9 @@ func TestGetProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(nil, errors.New("error")) }, id: "err-id-1", expectedErr: createError(statusInternal, localizer.MustLocalize(locale.InternalServerError)), @@ -100,11 +98,9 @@ func TestGetProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{}, nil) }, id: "success-id-0", expectedErr: nil, @@ -161,9 +157,9 @@ func TestListProjectsMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil, errors.New("error")) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().ListProjects( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, 0, int64(0), errors.New("error")) }, input: &proto.ListProjectsRequest{}, expected: nil, @@ -172,18 +168,9 @@ func TestListProjectsMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().ListProjects( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*proto.Project{}, 0, int64(0), nil) }, input: &proto.ListProjectsRequest{PageSize: 2, Cursor: "", OrganizationIds: []string{"org-1", "org-2"}}, expected: &proto.ListProjectsResponse{Projects: []*proto.Project{}, Cursor: "0", TotalCount: 0}, @@ -282,9 +269,8 @@ func TestCreateProjectMySQL(t *testing.T) { { desc: "err: ErrProjectAlreadyExists: duplicate id", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrProjectAlreadyExists) }, req: &proto.CreateProjectRequest{ @@ -295,9 +281,8 @@ func TestCreateProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.CreateProjectRequest{ @@ -308,9 +293,16 @@ func TestCreateProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish( + gomock.Any(), gomock.Any(), + ).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().CreateProject( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.CreateProjectRequest{ @@ -700,11 +692,11 @@ func TestCreateTrialProjectMySQL(t *testing.T) { { desc: "err: ErrProjectAlreadyExists: trial exists", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetTrialProjectByEmail( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "id-0"}, + }, nil) }, req: &proto.CreateTrialProjectRequest{ Command: &proto.CreateTrialProjectCommand{Name: "id-0", Email: "test@example.com"}, @@ -714,19 +706,11 @@ func TestCreateTrialProjectMySQL(t *testing.T) { { desc: "err: ErrProjectAlreadyExists: duplicated id", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(v2es.ErrProjectAlreadyExists) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetTrialProjectByEmail( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "id-0"}, + }, nil) }, req: &proto.CreateTrialProjectRequest{ Command: &proto.CreateTrialProjectCommand{Name: "id-0", Email: "test@example.com"}, @@ -736,11 +720,9 @@ func TestCreateTrialProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(errors.New("error")) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetTrialProjectByEmail( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, errors.New("error")) }, req: &proto.CreateTrialProjectRequest{ Command: &proto.CreateTrialProjectCommand{Name: "id-1", Email: "test@example.com"}, @@ -750,17 +732,39 @@ func TestCreateTrialProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).Return(mysql.ErrNoRows) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil).Times(4) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil).Times(4) - s.accountClient.(*acmock.MockClient).EXPECT().CreateAccountV2(gomock.Any(), gomock.Any()).Return( - &accountproto.CreateAccountV2Response{}, nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetTrialProjectByEmail( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: nil, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Return(nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish( + gomock.Any(), gomock.Any(), + ).Return(nil).AnyTimes() + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().CreateProject( + gomock.Any(), gomock.Any(), + ).Return(nil) + + // CreateEnvironmentV2 is called for two purposes: development environment and production environment. + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil).Times(2) + s.environmentStorage.(*storagemock.MockEnvironmentStorage).EXPECT().CreateEnvironmentV2( + gomock.Any(), gomock.Any(), + ).Return(nil).Times(2) + + s.accountClient.(*acmock.MockClient).EXPECT().CreateAccountV2( + gomock.Any(), gomock.Any(), + ).Return(&accountproto.CreateAccountV2Response{}, nil) }, req: &proto.CreateTrialProjectRequest{ Command: &proto.CreateTrialProjectCommand{Name: "Project Name_001", Email: "test@example.com"}, @@ -987,9 +991,8 @@ func TestUpdateProjectMySQL(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrProjectNotFound) }, req: &proto.UpdateProjectRequest{ @@ -1001,9 +1004,8 @@ func TestUpdateProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.UpdateProjectRequest{ @@ -1015,9 +1017,21 @@ func TestUpdateProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "id-1", Description: "old desc"}, + }, nil) + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish( + gomock.Any(), gomock.Any(), + ).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().UpdateProject( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.UpdateProjectRequest{ @@ -1083,9 +1097,8 @@ func TestEnableProjectMySQL(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrProjectNotFound) }, req: &proto.EnableProjectRequest{ @@ -1097,9 +1110,8 @@ func TestEnableProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.EnableProjectRequest{ @@ -1111,9 +1123,19 @@ func TestEnableProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "id-1"}, + }, nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().UpdateProject( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.EnableProjectRequest{ @@ -1179,9 +1201,8 @@ func TestDisableProjectMySQL(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrProjectNotFound) }, req: &proto.DisableProjectRequest{ @@ -1193,9 +1214,8 @@ func TestDisableProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.DisableProjectRequest{ @@ -1207,9 +1227,19 @@ func TestDisableProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "id-1"}, + }, nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().UpdateProject( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.DisableProjectRequest{ @@ -1275,9 +1305,8 @@ func TestConvertTrialProjectMySQL(t *testing.T) { { desc: "err: ErrProjectNotFound", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(v2es.ErrProjectNotFound) }, req: &proto.ConvertTrialProjectRequest{ @@ -1289,9 +1318,8 @@ func TestConvertTrialProjectMySQL(t *testing.T) { { desc: "err: ErrInternal", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), ).Return(errors.New("error")) }, req: &proto.ConvertTrialProjectRequest{ @@ -1303,9 +1331,19 @@ func TestConvertTrialProjectMySQL(t *testing.T) { { desc: "success", setup: func(s *EnvironmentService) { - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().BeginTx(gomock.Any()).Return(nil, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransaction( - gomock.Any(), gomock.Any(), gomock.Any(), + s.mysqlClient.(*mysqlmock.MockClient).EXPECT().RunInTransactionV2( + gomock.Any(), gomock.Any(), + ).Do(func(ctx context.Context, fn func(ctx context.Context, tx mysql.Transaction) error) { + _ = fn(ctx, nil) + }).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().GetProject( + gomock.Any(), gomock.Any(), + ).Return(&domain.Project{ + Project: &proto.Project{Id: "id-1"}, + }, nil) + s.publisher.(*publishermock.MockPublisher).EXPECT().Publish(gomock.Any(), gomock.Any()).Return(nil) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().UpdateProject( + gomock.Any(), gomock.Any(), ).Return(nil) }, req: &proto.ConvertTrialProjectRequest{ @@ -1420,24 +1458,9 @@ func TestListProjectsV2(t *testing.T) { OrganizationRole: accountproto.AccountV2_Role_Organization_MEMBER, }, }, nil) - rows := mysqlmock.NewMockRows(mockController) - rows.EXPECT().Close().Return(nil) - rows.EXPECT().Next().Return(true) - rows.EXPECT().Scan(gomock.Any()).Return(nil) - rows.EXPECT().Next().Return(false) - rows.EXPECT().Err().Return(nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(rows, nil) - row := mysqlmock.NewMockRow(mockController) - row.EXPECT().Scan(gomock.Any()).DoAndReturn(func(dest ...interface{}) error { - // Mock the TotalCount - *dest[0].(*int64) = 1 - return nil - }) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryRowContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(row) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().ListProjects( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return([]*environmentproto.Project{{}}, 1, int64(1), nil) }, input: &environmentproto.ListProjectsV2Request{ PageSize: 10, @@ -1502,9 +1525,9 @@ func TestListProjectsV2(t *testing.T) { OrganizationRole: accountproto.AccountV2_Role_Organization_MEMBER, }, }, nil) - s.mysqlClient.(*mysqlmock.MockClient).EXPECT().QueryContext( - gomock.Any(), gomock.Any(), gomock.Any(), - ).Return(nil, errors.New("internal error")) + s.projectStorage.(*storagemock.MockProjectStorage).EXPECT().ListProjects( + gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + ).Return(nil, 0, int64(0), errors.New("internal error")) }, input: &environmentproto.ListProjectsV2Request{ PageSize: 10, diff --git a/pkg/storage/v2/mysql/client.go b/pkg/storage/v2/mysql/client.go index 8b7e73f734..31481c9da8 100644 --- a/pkg/storage/v2/mysql/client.go +++ b/pkg/storage/v2/mysql/client.go @@ -111,6 +111,7 @@ type Client interface { // Transaction is passed because it is required for storage that does not support storage architecture refactoring, // but we plan to remove it once the refactoring is complete. RunInTransactionV2(ctx context.Context, f func(ctx context.Context, tx Transaction) error) error + // Deprecated Qe(ctx context.Context) QueryExecer } @@ -166,6 +167,12 @@ func (c *client) Close() error { func (c *client) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { var err error defer record()(operationExec, &err) + + tx, ok := ctx.Value(transactionKey).(Transaction) + if ok { + return tx.ExecContext(ctx, query, args...) + } + sret, err := c.db.ExecContext(ctx, query, args...) err = convertMySQLError(err) return &result{sret}, err @@ -174,6 +181,12 @@ func (c *client) ExecContext(ctx context.Context, query string, args ...interfac func (c *client) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { var err error defer record()(operationQuery, &err) + + tx, ok := ctx.Value(transactionKey).(Transaction) + if ok { + return tx.QueryContext(ctx, query, args...) + } + srows, err := c.db.QueryContext(ctx, query, args...) return &rows{srows}, err } @@ -181,6 +194,12 @@ func (c *client) QueryContext(ctx context.Context, query string, args ...interfa func (c *client) QueryRowContext(ctx context.Context, query string, args ...interface{}) Row { var err error defer record()(operationQueryRow, &err) + + tx, ok := ctx.Value(transactionKey).(Transaction) + if ok { + return tx.QueryRowContext(ctx, query, args...) + } + r := &row{c.db.QueryRowContext(ctx, query, args...)} err = r.Err() return r @@ -228,6 +247,7 @@ func (c *client) RunInTransactionV2( return err } +// Deprecated func (c *client) Qe(ctx context.Context) QueryExecer { tx, ok := ctx.Value(transactionKey).(Transaction) if ok {