Skip to content

Commit

Permalink
hooks: pass tx.ctx to Commit and Rollback hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
yansal committed Oct 8, 2020
1 parent 4fb38c1 commit 52617c6
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions hooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ type Connector struct {
AfterQuery func(ctx context.Context, rows driver.Rows, err error)
BeforeBegin func(ctx context.Context, opts driver.TxOptions) context.Context
AfterBegin func(ctx context.Context, tx driver.Tx, err error)
BeforeCommit func() context.Context
BeforeCommit func(ctx context.Context) context.Context
AfterCommit func(ctx context.Context, err error)
BeforeRollback func() context.Context
BeforeRollback func(ctx context.Context) context.Context
AfterRollback func(ctx context.Context, err error)

wrapped driver.Connector
Expand Down Expand Up @@ -62,9 +62,9 @@ type conn struct {
AfterQuery func(ctx context.Context, rows driver.Rows, err error)
BeforeBegin func(ctx context.Context, opts driver.TxOptions) context.Context
AfterBegin func(ctx context.Context, tx driver.Tx, err error)
BeforeCommit func() context.Context
BeforeCommit func(ctx context.Context) context.Context
AfterCommit func(ctx context.Context, err error)
BeforeRollback func() context.Context
BeforeRollback func(ctx context.Context) context.Context
AfterRollback func(ctx context.Context, err error)
}

Expand Down Expand Up @@ -118,6 +118,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
}
return &tx{
wrapped: t,
ctx: ctx,
BeforeCommit: c.BeforeCommit,
AfterCommit: c.AfterCommit,
BeforeRollback: c.BeforeRollback,
Expand All @@ -127,17 +128,18 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e

type tx struct {
wrapped driver.Tx
ctx context.Context

BeforeCommit func() context.Context
BeforeCommit func(ctx context.Context) context.Context
AfterCommit func(ctx context.Context, err error)
BeforeRollback func() context.Context
BeforeRollback func(ctx context.Context) context.Context
AfterRollback func(ctx context.Context, err error)
}

func (tx *tx) Commit() error {
ctx := context.Background()
ctx := tx.ctx
if tx.BeforeCommit != nil {
ctx = tx.BeforeCommit()
ctx = tx.BeforeCommit(ctx)
}
err := tx.wrapped.Commit()
if tx.AfterCommit != nil {
Expand All @@ -147,9 +149,9 @@ func (tx *tx) Commit() error {
}

func (tx *tx) Rollback() error {
ctx := context.Background()
ctx := tx.ctx
if tx.BeforeRollback != nil {
ctx = tx.BeforeRollback()
ctx = tx.BeforeRollback(ctx)
}
err := tx.wrapped.Rollback()
if tx.AfterRollback != nil {
Expand Down

0 comments on commit 52617c6

Please sign in to comment.