Skip to content

Commit

Permalink
fix: build
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Dec 4, 2024
1 parent cbbe1e9 commit 702e525
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 19 deletions.
27 changes: 18 additions & 9 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,38 @@ func WithReadOnlyReplica(replica *sql.DB) DBOption {
}

type DB struct {
// Must be a pointer so we copy the state, not the state fields.
*noCopyState

queryHooks []QueryHook

fmter schema.Formatter
stats DBStats
}

// noCopyState contains DB fields that must not be copied on clone(),
// for example, it is forbidden to copy atomic.Pointer.
type noCopyState struct {
*sql.DB
dialect schema.Dialect

replicas []*sql.DB
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64

dialect schema.Dialect
queryHooks []QueryHook

fmter schema.Formatter
flags internal.Flag
closed atomic.Bool

stats DBStats
}

func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
noCopyState: &noCopyState{
DB: sqldb,
dialect: dialect,
},
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
Expand Down
6 changes: 2 additions & 4 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,8 @@ func (q *baseQuery) GetTableName() string {
}

for _, wq := range q.with {
if v, ok := wq.query.(Query); ok {
if model := v.GetModel(); model != nil {
return v.GetTableName()
}
if model := wq.query.GetModel(); model != nil {
return wq.query.GetTableName()
}
}

Expand Down
14 changes: 8 additions & 6 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) {
query := internal.String(queryBytes)

ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model)
rows, err := q.conn.QueryContext(ctx, query)
rows, err := q.resolveConn(q).QueryContext(ctx, query)
q.db.afterQuery(ctx, event, nil, err)
return rows, err
}
Expand Down Expand Up @@ -876,7 +876,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)

var num int
err = q.conn.QueryRowContext(ctx, query).Scan(&num)
err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num)

q.db.afterQuery(ctx, event, nil, err)

Expand All @@ -894,13 +894,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in
return int(n), nil
}
}
if _, ok := q.conn.(*DB); ok {
return q.scanAndCountConc(ctx, dest...)
if q.conn == nil {
return q.scanAndCountConcurrently(ctx, dest...)
}
return q.scanAndCountSeq(ctx, dest...)
}

func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) {
func (q *SelectQuery) scanAndCountConcurrently(
ctx context.Context, dest ...interface{},
) (int, error) {
var count int
var wg sync.WaitGroup
var mu sync.Mutex
Expand Down Expand Up @@ -978,7 +980,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) {
ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model)

var exists bool
err = q.conn.QueryRowContext(ctx, query).Scan(&exists)
err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists)

q.db.afterQuery(ctx, event, nil, err)

Expand Down

0 comments on commit 702e525

Please sign in to comment.