Skip to content

Commit

Permalink
Added transaction scoper
Browse files Browse the repository at this point in the history
  • Loading branch information
UnAfraid committed Oct 20, 2023
1 parent 3692966 commit 095fb05
Show file tree
Hide file tree
Showing 11 changed files with 480 additions and 356 deletions.
10 changes: 6 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/UnAfraid/wg-ui/pkg/config"
"github.com/UnAfraid/wg-ui/pkg/datastore"
"github.com/UnAfraid/wg-ui/pkg/datastore/bbolt"
"github.com/UnAfraid/wg-ui/pkg/dbx"
"github.com/UnAfraid/wg-ui/pkg/manage"
"github.com/UnAfraid/wg-ui/pkg/peer"
"github.com/UnAfraid/wg-ui/pkg/server"
Expand Down Expand Up @@ -92,16 +93,17 @@ func main() {
return
}

transactionScoper := dbx.NewBBoltTransactionScope(db)
subscriptionImpl := subscription.NewInMemorySubscription()

serverRepository := bbolt.NewServerRepository(db)
serverService := server.NewService(serverRepository, subscriptionImpl)
serverService := server.NewService(serverRepository, transactionScoper, subscriptionImpl)

peerRepository := bbolt.NewPeerRepository(db)
peerService := peer.NewService(peerRepository, serverService, subscriptionImpl)
peerService := peer.NewService(peerRepository, transactionScoper, serverService, subscriptionImpl)

userRepository := bbolt.NewUserRepository(db)
userService, err := user.NewService(userRepository, subscriptionImpl, conf.Initial.Email, conf.Initial.Password)
userService, err := user.NewService(userRepository, transactionScoper, subscriptionImpl, conf.Initial.Email, conf.Initial.Password)
if err != nil {
logrus.
WithError(err).
Expand All @@ -120,7 +122,7 @@ func main() {

authService := auth.NewService(jwt.SigningMethodHS256, jwtSecretBytes, jwtSecretBytes, conf.JwtDuration)

manageService := manage.NewService(userService, serverService, peerService, wgService)
manageService := manage.NewService(transactionScoper, userService, serverService, peerService, wgService)

router := api.NewRouter(
conf,
Expand Down
36 changes: 9 additions & 27 deletions pkg/datastore/bbolt/helper.go
Original file line number Diff line number Diff line change
@@ -1,45 +1,27 @@
package bbolt

import (
"context"

"go.etcd.io/bbolt"
)

func dbView[T any](db *bbolt.DB, bucketName string, createBucketIfNotExists bool, callback func(*bbolt.Tx, *bbolt.Bucket) (T, error)) (result T, err error) {
err = db.View(func(tx *bbolt.Tx) error {
var bucket *bbolt.Bucket
if createBucketIfNotExists {
bucket, err = tx.CreateBucketIfNotExists([]byte(bucketName))
if err != nil {
return err
}
} else {
bucket = tx.Bucket([]byte(bucketName))
if bucket == nil {
return nil
}
}
result, err = callback(tx, bucket)
return err
})
return result, err
}
"github.com/UnAfraid/wg-ui/pkg/dbx"
)

func dbUpdate[T any](db *bbolt.DB, bucketName string, createBucketIfNotExists bool, callback func(*bbolt.Tx, *bbolt.Bucket) (T, error)) (result T, err error) {
err = db.Update(func(tx *bbolt.Tx) error {
func dbTx[T any](ctx context.Context, db *bbolt.DB, bucketName string, createBucketIfNotExists bool, callback func(*bbolt.Tx, *bbolt.Bucket) (T, error)) (T, error) {
return dbx.InBBoltTransactionScopeWithResult(ctx, db, func(ctx context.Context, tx *bbolt.Tx) (result T, err error) {
var bucket *bbolt.Bucket
if createBucketIfNotExists {
bucket, err = tx.CreateBucketIfNotExists([]byte(bucketName))
if err != nil {
return err
return result, err
}
} else {
bucket = tx.Bucket([]byte(bucketName))
if bucket == nil {
return nil
return result, nil
}
}
result, err = callback(tx, bucket)
return err
return callback(tx, bucket)
})
return result, err
}
20 changes: 10 additions & 10 deletions pkg/datastore/bbolt/peer_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func NewPeerRepository(db *bbolt.DB) peer.Repository {
}
}

func (r *peerRepository) FindOne(_ context.Context, options *peer.FindOneOptions) (*peer.Peer, error) {
return dbView(r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
func (r *peerRepository) FindOne(ctx context.Context, options *peer.FindOneOptions) (*peer.Peer, error) {
return dbTx(ctx, r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
if idOption := options.IdOption; idOption != nil {
jsonState := bucket.Get([]byte(idOption.Id))
if jsonState == nil {
Expand Down Expand Up @@ -61,8 +61,8 @@ func (r *peerRepository) FindOne(_ context.Context, options *peer.FindOneOptions
})
}

func (r *peerRepository) FindAll(_ context.Context, options *peer.FindOptions) ([]*peer.Peer, error) {
return dbView(r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) ([]*peer.Peer, error) {
func (r *peerRepository) FindAll(ctx context.Context, options *peer.FindOptions) ([]*peer.Peer, error) {
return dbTx(ctx, r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) ([]*peer.Peer, error) {
var peers []*peer.Peer
var peersCount int
var searchList searchindex.SearchList[*peer.Peer]
Expand Down Expand Up @@ -139,8 +139,8 @@ func (r *peerRepository) FindAll(_ context.Context, options *peer.FindOptions) (
})
}

func (r *peerRepository) Create(_ context.Context, p *peer.Peer) (*peer.Peer, error) {
return dbUpdate(r.db, peerBucket, true, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
func (r *peerRepository) Create(ctx context.Context, p *peer.Peer) (*peer.Peer, error) {
return dbTx(ctx, r.db, peerBucket, true, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
id := []byte(p.Id)
if bucket.Get(id) != nil {
return nil, peer.ErrPeerIdAlreadyExists
Expand All @@ -155,8 +155,8 @@ func (r *peerRepository) Create(_ context.Context, p *peer.Peer) (*peer.Peer, er
})
}

func (r *peerRepository) Update(_ context.Context, p *peer.Peer, fieldMask *peer.UpdateFieldMask) (*peer.Peer, error) {
return dbUpdate(r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
func (r *peerRepository) Update(ctx context.Context, p *peer.Peer, fieldMask *peer.UpdateFieldMask) (*peer.Peer, error) {
return dbTx(ctx, r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
id := []byte(p.Id)
jsonState := bucket.Get(id)
if jsonState == nil {
Expand Down Expand Up @@ -219,8 +219,8 @@ func (r *peerRepository) Update(_ context.Context, p *peer.Peer, fieldMask *peer
})
}

func (r *peerRepository) Delete(_ context.Context, peerId string, deleteUserId string) (*peer.Peer, error) {
return dbUpdate(r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
func (r *peerRepository) Delete(ctx context.Context, peerId string, deleteUserId string) (*peer.Peer, error) {
return dbTx(ctx, r.db, peerBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*peer.Peer, error) {
id := []byte(peerId)
jsonState := bucket.Get(id)
if jsonState == nil {
Expand Down
20 changes: 10 additions & 10 deletions pkg/datastore/bbolt/server_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ func NewServerRepository(db *bbolt.DB) server.Repository {
}
}

func (r *serverRepository) FindOne(_ context.Context, options *server.FindOneOptions) (*server.Server, error) {
return dbView(r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
func (r *serverRepository) FindOne(ctx context.Context, options *server.FindOneOptions) (*server.Server, error) {
return dbTx(ctx, r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
if idOption := options.IdOption; idOption != nil {
jsonState := bucket.Get([]byte(idOption.Id))
if jsonState == nil {
Expand Down Expand Up @@ -60,8 +60,8 @@ func (r *serverRepository) FindOne(_ context.Context, options *server.FindOneOpt
})
}

func (r *serverRepository) FindAll(_ context.Context, options *server.FindOptions) ([]*server.Server, error) {
return dbView(r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) ([]*server.Server, error) {
func (r *serverRepository) FindAll(ctx context.Context, options *server.FindOptions) ([]*server.Server, error) {
return dbTx(ctx, r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) ([]*server.Server, error) {
var servers []*server.Server
var serversCount int
var searchList searchindex.SearchList[*server.Server]
Expand Down Expand Up @@ -138,8 +138,8 @@ func (r *serverRepository) FindAll(_ context.Context, options *server.FindOption
})
}

func (r *serverRepository) Create(_ context.Context, s *server.Server) (*server.Server, error) {
return dbUpdate(r.db, serverBucket, true, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
func (r *serverRepository) Create(ctx context.Context, s *server.Server) (*server.Server, error) {
return dbTx(ctx, r.db, serverBucket, true, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
id := []byte(s.Id)
if bucket.Get(id) != nil {
return nil, server.ErrServerIdAlreadyExists
Expand All @@ -154,8 +154,8 @@ func (r *serverRepository) Create(_ context.Context, s *server.Server) (*server.
})
}

func (r *serverRepository) Update(_ context.Context, s *server.Server, fieldMask *server.UpdateFieldMask) (*server.Server, error) {
return dbUpdate(r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
func (r *serverRepository) Update(ctx context.Context, s *server.Server, fieldMask *server.UpdateFieldMask) (*server.Server, error) {
return dbTx(ctx, r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
id := []byte(s.Id)
jsonState := bucket.Get(id)
if jsonState == nil {
Expand Down Expand Up @@ -230,8 +230,8 @@ func (r *serverRepository) Update(_ context.Context, s *server.Server, fieldMask
})
}

func (r *serverRepository) Delete(_ context.Context, serverId string, deleteUserId string) (*server.Server, error) {
return dbUpdate(r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
func (r *serverRepository) Delete(ctx context.Context, serverId string, deleteUserId string) (*server.Server, error) {
return dbTx(ctx, r.db, serverBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*server.Server, error) {
id := []byte(serverId)
jsonState := bucket.Get(id)
if jsonState == nil {
Expand Down
20 changes: 10 additions & 10 deletions pkg/datastore/bbolt/user_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func NewUserRepository(db *bbolt.DB) user.Repository {
}
}

func (r *userRepository) FindOne(_ context.Context, options *user.FindOneOptions) (*user.User, error) {
return dbView(r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
func (r *userRepository) FindOne(ctx context.Context, options *user.FindOneOptions) (*user.User, error) {
return dbTx(ctx, r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
if idOption := options.IdOption; idOption != nil {
jsonState := bucket.Get([]byte(idOption.Id))
if jsonState == nil {
Expand Down Expand Up @@ -63,8 +63,8 @@ func (r *userRepository) FindOne(_ context.Context, options *user.FindOneOptions
})
}

func (r *userRepository) FindAll(_ context.Context, options *user.FindOptions) ([]*user.User, error) {
return dbView(r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) ([]*user.User, error) {
func (r *userRepository) FindAll(ctx context.Context, options *user.FindOptions) ([]*user.User, error) {
return dbTx(ctx, r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) ([]*user.User, error) {
var users []*user.User
var usersCount int
var searchList searchindex.SearchList[*user.User]
Expand Down Expand Up @@ -112,8 +112,8 @@ func (r *userRepository) FindAll(_ context.Context, options *user.FindOptions) (
})
}

func (r *userRepository) Create(_ context.Context, u *user.User) (*user.User, error) {
return dbUpdate(r.db, userBucket, true, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
func (r *userRepository) Create(ctx context.Context, u *user.User) (*user.User, error) {
return dbTx(ctx, r.db, userBucket, true, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
id := []byte(u.Id)
if bucket.Get(id) != nil {
return nil, user.ErrUserIdAlreadyExists
Expand All @@ -128,8 +128,8 @@ func (r *userRepository) Create(_ context.Context, u *user.User) (*user.User, er
})
}

func (r *userRepository) Update(_ context.Context, u *user.User, fieldMask *user.UpdateFieldMask) (*user.User, error) {
return dbUpdate(r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
func (r *userRepository) Update(ctx context.Context, u *user.User, fieldMask *user.UpdateFieldMask) (*user.User, error) {
return dbTx(ctx, r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
id := []byte(u.Id)
jsonState := bucket.Get(id)
if jsonState == nil {
Expand Down Expand Up @@ -160,8 +160,8 @@ func (r *userRepository) Update(_ context.Context, u *user.User, fieldMask *user
})
}

func (r *userRepository) Delete(_ context.Context, userId string) (*user.User, error) {
return dbUpdate(r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
func (r *userRepository) Delete(ctx context.Context, userId string) (*user.User, error) {
return dbTx(ctx, r.db, userBucket, false, func(tx *bbolt.Tx, bucket *bbolt.Bucket) (*user.User, error) {
id := []byte(userId)
jsonState := bucket.Get(id)
if jsonState == nil {
Expand Down
84 changes: 84 additions & 0 deletions pkg/dbx/bbolt_transaction_scoper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package dbx

import (
"context"
"errors"

"go.etcd.io/bbolt"
)

type contextKey struct{ name string }

var bboltTxKey = contextKey{name: "bboltTxKey"}

type bboltTransactionScope struct {
db *bbolt.DB
}

func NewBBoltTransactionScope(db *bbolt.DB) TransactionScoper {
return &bboltTransactionScope{
db: db,
}
}

func (txScope *bboltTransactionScope) InTransactionScope(ctx context.Context, transactionScope func(ctx context.Context) error) (err error) {
return InBBoltTransactionScope(ctx, txScope.db, func(ctx context.Context, tx *bbolt.Tx) error {
return transactionScope(ctx)
})
}

func InBBoltTransactionScope(ctx context.Context, db *bbolt.DB, transactionScope func(ctx context.Context, tx *bbolt.Tx) error) (retErr error) {
tx, transactionCloser, err := useOrStartBBoltTransaction(ctx, db)
if err != nil {
return err
}

defer func() {
retErr = transactionCloser(retErr)
}()

return transactionScope(context.WithValue(ctx, bboltTxKey, tx), tx)
}

func InBBoltTransactionScopeWithResult[T any](ctx context.Context, db *bbolt.DB, transactionScope func(ctx context.Context, tx *bbolt.Tx) (T, error)) (result T, err error) {
tx, transactionCloser, err := useOrStartBBoltTransaction(ctx, db)
if err != nil {
return result, err
}

defer func() {
err = transactionCloser(err)
}()

return transactionScope(context.WithValue(ctx, bboltTxKey, tx), tx)
}

func useOrStartBBoltTransaction(ctx context.Context, db *bbolt.DB) (*bbolt.Tx, func(err error) error, error) {
tx, ok := ctx.Value(bboltTxKey).(*bbolt.Tx)
if !ok {
tx, err := db.Begin(true)
if err != nil {
return nil, nil, err
}

transactionScope := func(err error) error {
if err != nil {
if txErr := tx.Rollback(); txErr != nil {
err = errors.Join(err, txErr)
}
} else {
if txErr := tx.Commit(); txErr != nil {
err = txErr
}
}
return err
}

return tx, transactionScope, nil
} else {
transactionCloser := func(err error) error {
return err
}
return tx, transactionCloser, nil
}
}
17 changes: 17 additions & 0 deletions pkg/dbx/transaction_scoper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package dbx

import (
"context"
)

type TransactionScoper interface {
InTransactionScope(ctx context.Context, transactionScope func(ctx context.Context) error) error
}

func InTransactionScopeWithResult[T any](ctx context.Context, transactionScoper TransactionScoper, transactionScope func(ctx context.Context) (T, error)) (result T, err error) {
err = transactionScoper.InTransactionScope(ctx, func(ctx context.Context) error {
result, err = transactionScope(ctx)
return err
})
return result, err
}
Loading

0 comments on commit 095fb05

Please sign in to comment.