diff --git a/hooks/hooks.go b/hooks/hooks.go index e135b20..4e2796f 100644 --- a/hooks/hooks.go +++ b/hooks/hooks.go @@ -18,9 +18,9 @@ type Connector struct { AfterQuery func(ctx context.Context, rows driver.Rows, err error) BeforeBegin func(ctx context.Context, opts driver.TxOptions) context.Context AfterBegin func(ctx context.Context, tx driver.Tx, err error) - BeforeCommit func() context.Context + BeforeCommit func(ctx context.Context) context.Context AfterCommit func(ctx context.Context, err error) - BeforeRollback func() context.Context + BeforeRollback func(ctx context.Context) context.Context AfterRollback func(ctx context.Context, err error) wrapped driver.Connector @@ -62,9 +62,9 @@ type conn struct { AfterQuery func(ctx context.Context, rows driver.Rows, err error) BeforeBegin func(ctx context.Context, opts driver.TxOptions) context.Context AfterBegin func(ctx context.Context, tx driver.Tx, err error) - BeforeCommit func() context.Context + BeforeCommit func(ctx context.Context) context.Context AfterCommit func(ctx context.Context, err error) - BeforeRollback func() context.Context + BeforeRollback func(ctx context.Context) context.Context AfterRollback func(ctx context.Context, err error) } @@ -118,6 +118,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e } return &tx{ wrapped: t, + ctx: ctx, BeforeCommit: c.BeforeCommit, AfterCommit: c.AfterCommit, BeforeRollback: c.BeforeRollback, @@ -127,17 +128,18 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e type tx struct { wrapped driver.Tx + ctx context.Context - BeforeCommit func() context.Context + BeforeCommit func(ctx context.Context) context.Context AfterCommit func(ctx context.Context, err error) - BeforeRollback func() context.Context + BeforeRollback func(ctx context.Context) context.Context AfterRollback func(ctx context.Context, err error) } func (tx *tx) Commit() error { - ctx := context.Background() + ctx := tx.ctx if tx.BeforeCommit != nil { - ctx = tx.BeforeCommit() + ctx = tx.BeforeCommit(ctx) } err := tx.wrapped.Commit() if tx.AfterCommit != nil { @@ -147,9 +149,9 @@ func (tx *tx) Commit() error { } func (tx *tx) Rollback() error { - ctx := context.Background() + ctx := tx.ctx if tx.BeforeRollback != nil { - ctx = tx.BeforeRollback() + ctx = tx.BeforeRollback(ctx) } err := tx.wrapped.Rollback() if tx.AfterRollback != nil {