Skip to content

Commit

Permalink
hooks: replace FooHook with BeforeFoo and AfterFoo
Browse files Browse the repository at this point in the history
  • Loading branch information
yansal committed Oct 6, 2020
1 parent 455c6fb commit 4fb38c1
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 495 deletions.
331 changes: 90 additions & 241 deletions hooks/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,308 +3,157 @@ package hooks
import (
"context"
"database/sql/driver"
"time"
)

// Wrap returns a new Connector wrapping c.
func Wrap(c driver.Connector) *Connector { return &Connector{wrapped: c} }

// A Connector wraps an existing connector.
type Connector struct {
ExecHook func(ctx context.Context, info ExecInfo)
QueryHook func(ctx context.Context, info QueryInfo)

BeginHook func(ctx context.Context, info BeginInfo)
CommitHook func(info CommitInfo)
RollbackHook func(info RollbackInfo)

// TODO: add ConnectHook, PrepareHook, etc.
BeforeConnect func(ctx context.Context) context.Context
AfterConnect func(ctx context.Context, conn driver.Conn, err error)
BeforeExec func(ctx context.Context, query string, args []driver.NamedValue) context.Context
AfterExec func(ctx context.Context, result driver.Result, err error)
BeforeQuery func(ctx context.Context, query string, args []driver.NamedValue) context.Context
AfterQuery func(ctx context.Context, rows driver.Rows, err error)
BeforeBegin func(ctx context.Context, opts driver.TxOptions) context.Context
AfterBegin func(ctx context.Context, tx driver.Tx, err error)
BeforeCommit func() context.Context
AfterCommit func(ctx context.Context, err error)
BeforeRollback func() context.Context
AfterRollback func(ctx context.Context, err error)

wrapped driver.Connector
}

// ExecInfo is the argument of ExecHook and contains information about the executed query.
type ExecInfo struct {
Query string
Args []driver.Value
Duration time.Duration
Err error
}

// QueryInfo is the argument of QueryHook and contains information about the executed query.
type QueryInfo struct {
Query string
Args []driver.Value
Duration time.Duration
Err error
}

// BeginInfo is the argument of BeginHook.
type BeginInfo struct {
Duration time.Duration
Err error
}

// CommitInfo is the argument of CommitHook.
type CommitInfo struct {
Duration time.Duration
Err error
}

// RollbackInfo is the argument of RollbackHook.
type RollbackInfo struct {
Duration time.Duration
Err error
}

// Connect implements database/sql/driver.Connector.
func (connector *Connector) Connect(ctx context.Context) (driver.Conn, error) {
if connector.BeforeConnect != nil {
ctx = connector.BeforeConnect(ctx)
}
c, err := connector.wrapped.Connect(ctx)
if err != nil {
return nil, err
if connector.AfterConnect != nil {
connector.AfterConnect(ctx, c, err)
}
return &conn{
wrapped: c,
execHook: connector.ExecHook,
queryHook: connector.QueryHook,
beginHook: connector.BeginHook,
commitHook: connector.CommitHook,
rollbackHook: connector.RollbackHook,
}, nil
wrapped: c,
BeforeExec: connector.BeforeExec,
AfterExec: connector.AfterExec,
BeforeQuery: connector.BeforeQuery,
AfterQuery: connector.AfterQuery,
BeforeBegin: connector.BeforeBegin,
AfterBegin: connector.AfterBegin,
BeforeCommit: connector.BeforeCommit,
AfterCommit: connector.AfterCommit,
BeforeRollback: connector.BeforeRollback,
AfterRollback: connector.AfterRollback,
}, err
}

// Driver implements database/sql/driver.Connector.
func (connector *Connector) Driver() driver.Driver { return connector.wrapped.Driver() }

type conn struct {
wrapped driver.Conn
execHook func(ctx context.Context, info ExecInfo)
queryHook func(ctx context.Context, info QueryInfo)
beginHook func(ctx context.Context, info BeginInfo)
commitHook func(info CommitInfo)
rollbackHook func(info RollbackInfo)
wrapped driver.Conn

BeforeExec func(ctx context.Context, query string, args []driver.NamedValue) context.Context
AfterExec func(ctx context.Context, result driver.Result, err error)
BeforeQuery func(ctx context.Context, query string, args []driver.NamedValue) context.Context
AfterQuery func(ctx context.Context, rows driver.Rows, err error)
BeforeBegin func(ctx context.Context, opts driver.TxOptions) context.Context
AfterBegin func(ctx context.Context, tx driver.Tx, err error)
BeforeCommit func() context.Context
AfterCommit func(ctx context.Context, err error)
BeforeRollback func() context.Context
AfterRollback func(ctx context.Context, err error)
}

func (c *conn) Begin() (driver.Tx, error) {
start := time.Now()
t, err := c.wrapped.Begin()
if c.beginHook != nil {
c.beginHook(context.Background(), BeginInfo{
Duration: time.Since(start),
Err: err,
})
}
return &tx{
wrapped: t,
commitHook: c.commitHook,
rollbackHook: c.rollbackHook,
}, nil
return c.wrapped.Begin()
}

func (c *conn) Close() error {
return c.wrapped.Close()
}

func (c *conn) Prepare(query string) (driver.Stmt, error) {
s, err := c.wrapped.Prepare(query)
if err != nil {
return nil, err
}
return &stmt{
wrapped: s,
query: query,
execHook: c.execHook,
queryHook: c.queryHook,
}, nil
return c.wrapped.Prepare(query)
}

var (
_ driver.Execer = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.Queryer = &conn{}
_ driver.QueryerContext = &conn{}

// _ driver.ConnBeginTx = &conn{}
// _ driver.ConnPrepareContext = &conn{}
// _ driver.NamedValueChecker = &conn{}
// _ driver.Pinger = &conn{}
// _ driver.SessionResetter = &conn{}
_ driver.ConnBeginTx = &conn{}
)

func (c *conn) Exec(query string, args []driver.Value) (driver.Result, error) {
start := time.Now()
res, err := c.wrapped.(driver.Execer).Exec(query, args)
if c.execHook != nil {
c.execHook(context.Background(), ExecInfo{
Query: query,
Args: args,
Duration: time.Since(start),
Err: err,
})
}
return res, err
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
start := time.Now()
res, err := c.wrapped.(driver.ExecerContext).ExecContext(ctx, query, args)
if c.execHook != nil {
values := make([]driver.Value, 0, len(args))
for i := range args {
values = append(values, args[i].Value)
}
c.execHook(ctx, ExecInfo{
Query: query,
Args: values,
Duration: time.Since(start),
Err: err,
})
if c.BeforeExec != nil {
ctx = c.BeforeExec(ctx, query, args)
}
return res, err
}

func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
start := time.Now()
res, err := c.wrapped.(driver.Queryer).Query(query, args)
if c.queryHook != nil {
c.queryHook(context.Background(), QueryInfo{
Query: query,
Args: args,
Duration: time.Since(start),
Err: err,
})
result, err := c.wrapped.(driver.ExecerContext).ExecContext(ctx, query, args)
if c.AfterExec != nil {
c.AfterExec(ctx, result, err)
}
return res, err
return result, err
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
start := time.Now()
if c.BeforeQuery != nil {
ctx = c.BeforeQuery(ctx, query, args)
}
rows, err := c.wrapped.(driver.QueryerContext).QueryContext(ctx, query, args)
if c.queryHook != nil {
values := make([]driver.Value, 0, len(args))
for i := range args {
values = append(values, args[i].Value)
}
c.queryHook(ctx, QueryInfo{
Query: query,
Args: values,
Duration: time.Since(start),
Err: err,
})
if c.AfterQuery != nil {
c.AfterQuery(ctx, rows, err)
}
return rows, err
}

func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if c.BeforeBegin != nil {
ctx = c.BeforeBegin(ctx, opts)
}
t, err := c.wrapped.(driver.ConnBeginTx).BeginTx(ctx, opts)
if c.AfterBegin != nil {
c.AfterBegin(ctx, t, err)
}
return &tx{
wrapped: t,
BeforeCommit: c.BeforeCommit,
AfterCommit: c.AfterCommit,
BeforeRollback: c.BeforeRollback,
AfterRollback: c.AfterRollback,
}, err
}

type tx struct {
wrapped driver.Tx
commitHook func(info CommitInfo)
rollbackHook func(info RollbackInfo)
wrapped driver.Tx

BeforeCommit func() context.Context
AfterCommit func(ctx context.Context, err error)
BeforeRollback func() context.Context
AfterRollback func(ctx context.Context, err error)
}

func (tx *tx) Commit() error {
start := time.Now()
ctx := context.Background()
if tx.BeforeCommit != nil {
ctx = tx.BeforeCommit()
}
err := tx.wrapped.Commit()
if tx.commitHook != nil {
tx.commitHook(CommitInfo{Duration: time.Since(start), Err: err})
if tx.AfterCommit != nil {
tx.AfterCommit(ctx, err)
}
return err
}

func (tx *tx) Rollback() error {
start := time.Now()
ctx := context.Background()
if tx.BeforeRollback != nil {
ctx = tx.BeforeRollback()
}
err := tx.wrapped.Rollback()
if tx.rollbackHook != nil {
tx.rollbackHook(RollbackInfo{Duration: time.Since(start), Err: err})
if tx.AfterRollback != nil {
tx.AfterRollback(ctx, err)
}
return err
}

type stmt struct {
wrapped driver.Stmt
query string
execHook func(ctx context.Context, info ExecInfo)
queryHook func(ctx context.Context, info QueryInfo)
}

func (s *stmt) Close() error { return s.wrapped.Close() }
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
start := time.Now()
res, err := s.wrapped.Exec(args)
if s.execHook != nil {
s.execHook(context.Background(), ExecInfo{
Query: s.query,
Args: args,
Duration: time.Since(start),
Err: err,
})
}
return res, err
}
func (s *stmt) NumInput() int { return s.wrapped.NumInput() }
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
start := time.Now()
res, err := s.wrapped.Query(args)
if s.queryHook != nil {
s.queryHook(context.Background(), QueryInfo{
Query: s.query,
Args: args,
Duration: time.Since(start),
Err: err,
})
}
return res, err
}

var (
_ driver.StmtExecContext = &stmt{}
_ driver.StmtQueryContext = &stmt{}

// _ driver.ColumnConverter = &stmt{}
// _ driver.NamedValueChecker = &stmt{}
)

func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
start := time.Now()
res, err := s.wrapped.(driver.StmtExecContext).ExecContext(ctx, args)
if s.execHook != nil {
values := make([]driver.Value, 0, len(args))
for i := range args {
values = append(values, args[i].Value)
}
s.execHook(ctx, ExecInfo{
Query: s.query,
Args: values,
Duration: time.Since(start),
Err: err,
})
}
return res, err
}

func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
start := time.Now()
var (
rows driver.Rows
err error
)

values := make([]driver.Value, 0, len(args))
for i := range args {
values = append(values, args[i].Value)
}
sqc, ok := s.wrapped.(driver.StmtQueryContext)
if ok {
rows, err = sqc.QueryContext(ctx, args)
} else {
rows, err = s.wrapped.Query(values)
}
if s.queryHook != nil {
s.queryHook(ctx, QueryInfo{
Query: s.query,
Args: values,
Duration: time.Since(start),
Err: err,
})
}
return rows, err
}
Loading

0 comments on commit 4fb38c1

Please sign in to comment.