Skip to content

Commit

Permalink
move over to db.SQL from pgx
Browse files Browse the repository at this point in the history
  • Loading branch information
binaek committed Oct 18, 2023
1 parent cf99844 commit 2e463a1
Show file tree
Hide file tree
Showing 21 changed files with 574 additions and 372 deletions.
85 changes: 42 additions & 43 deletions db_client/db_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@ package db_client

import (
"context"
"database/sql"
"fmt"
"github.com/jackc/pgx/v5/pgconn"
"github.com/turbot/pipe-fittings/db_common"
"github.com/turbot/steampipe-plugin-sdk/v5/sperr"
"log"
"strings"
"sync"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/pgconn"
"github.com/spf13/viper"
"github.com/turbot/pipe-fittings/constants"
"github.com/turbot/pipe-fittings/db_client/backend"
"github.com/turbot/pipe-fittings/db_common"
"github.com/turbot/pipe-fittings/serversettings"
"github.com/turbot/pipe-fittings/utils"
"github.com/turbot/steampipe-plugin-sdk/v5/sperr"
"golang.org/x/sync/semaphore"
)

Expand All @@ -24,10 +24,10 @@ type DbClient struct {
connectionString string

// connection userPool for user initiated queries
userPool *pgxpool.Pool
userPool *sql.DB

// connection used to run system/plumbing queries (connection state, server settings)
managementPool *pgxpool.Pool
managementPool *sql.DB

// the settings of the server that this client is connected to
serverSettings *db_common.ServerSettings
Expand Down Expand Up @@ -57,23 +57,22 @@ type DbClient struct {
// (cached to avoid concurrent access error on viper)
showTimingFlag bool
// disable timing - set whilst in process of querying the timing
disableTiming bool
onConnectionCallback DbConnectionCallback
disableTiming bool

// the backend type of the dbclient backend
backend backend.DBClientBackendType

// a reader which can be used to read rows from a pgx.Rows object
rowReader backend.RowReader
}

func NewDbClient(ctx context.Context, connectionString string, onConnectionCallback DbConnectionCallback, opts ...ClientOption) (_ *DbClient, err error) {
func NewDbClient(ctx context.Context, connectionString string, opts ...ClientOption) (_ *DbClient, err error) {
utils.LogTime("db_client.NewDbClient start")
defer utils.LogTime("db_client.NewDbClient end")

wg := &sync.WaitGroup{}
// wrap onConnectionCallback to use wait group
var wrappedOnConnectionCallback DbConnectionCallback
if onConnectionCallback != nil {
wrappedOnConnectionCallback = func(ctx context.Context, conn *pgx.Conn) error {
wg.Add(1)
defer wg.Done()
return onConnectionCallback(ctx, conn)
}
backendType, err := backend.GetBackendFromConnectionString(ctx, connectionString)
if err != nil {
return nil, err
}

client := &DbClient{
Expand All @@ -82,9 +81,9 @@ func NewDbClient(ctx context.Context, connectionString string, onConnectionCallb
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
// store the callback
onConnectionCallback: wrappedOnConnectionCallback,
connectionString: connectionString,
connectionString: connectionString,
backend: backendType,
rowReader: backend.RowReaderFactory(backendType),
}

defer func() {
Expand All @@ -104,19 +103,19 @@ func NewDbClient(ctx context.Context, connectionString string, onConnectionCallb
}

// load up the server settings
if err := client.loadServerSettings(ctx); err != nil {
return nil, err
}
// if err := client.loadServerSettings(ctx); err != nil {
// return nil, err
// }

// set user search path
if err := client.LoadUserSearchPath(ctx); err != nil {
return nil, err
}
// // set user search path
// if err := client.LoadUserSearchPath(ctx); err != nil {
// return nil, err
// }

// populate customSearchPath
if err := client.SetRequiredSessionSearchPath(ctx); err != nil {
return nil, err
}
// // populate customSearchPath
// if err := client.SetRequiredSessionSearchPath(ctx); err != nil {
// return nil, err
// }

return client, nil
}
Expand Down Expand Up @@ -197,11 +196,11 @@ func (c *DbClient) Close(context.Context) error {
// connections backed by distinct plugins and then fanning back out.
func (c *DbClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) {
log.Printf("[INFO] DbClient GetSchemaFromDB")
mgmtConn, err := c.managementPool.Acquire(ctx)
if err != nil {
return nil, err
}
defer mgmtConn.Release()
// mgmtConn, err := c.managementPool.Acquire(ctx)
// if err != nil {
// return nil, err
// }
// defer mgmtConn.Release()

// TODO KAI not needed for powerpipe
return nil, sperr.New("not supported in Powerpipe")
Expand Down Expand Up @@ -251,21 +250,21 @@ func (c *DbClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetada
//return metadata, nil
}

func (c *DbClient) GetSchemaFromDBLegacy(ctx context.Context, conn *pgxpool.Conn) (*db_common.SchemaMetadata, error) {
func (c *DbClient) GetSchemaFromDBLegacy(ctx context.Context, conn *sql.Conn) (*db_common.SchemaMetadata, error) {
// build a query to retrieve these schemas
query := c.buildSchemasQueryLegacy()

// build schema metadata from query result
return db_common.LoadSchemaMetadata(ctx, conn.Conn(), query)
return db_common.LoadSchemaMetadata(ctx, conn, query)
}

// refreshDbClient terminates the current connection and opens up a new connection to the service.
// Unimplemented (sql.DB does not have a mechanism to reset pools) - refreshDbClient terminates the current connection and opens up a new connection to the service.
func (c *DbClient) ResetPools(ctx context.Context) {
log.Println("[TRACE] db_client.ResetPools start")
defer log.Println("[TRACE] db_client.ResetPools end")

c.userPool.Reset()
c.managementPool.Reset()
// c.userPool.Reset()
// c.managementPool.Reset()
}

func (c *DbClient) buildSchemasQuery(schemas ...string) string {
Expand Down
123 changes: 50 additions & 73 deletions db_client/db_client_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package db_client

import (
"context"
"github.com/turbot/pipe-fittings/db_common"
"database/sql"
"time"

"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/turbot/pipe-fittings/db_common"

"github.com/spf13/viper"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/pipe-fittings/constants"
"github.com/turbot/pipe-fittings/constants/runtime"
"github.com/turbot/pipe-fittings/utils"
)

Expand All @@ -19,110 +17,89 @@ const (
MaxConnIdleTime = 1 * time.Minute
)

type DbConnectionCallback func(context.Context, *pgx.Conn) error
func getDriverNameFromConnectionString(connStr string) string {
if isPostgresConnectionString(connStr) {
return "pgx"
} else if isSqliteConnectionString(connStr) {
return "sqlite3"
} else {
return ""
}
}

type DbConnectionCallback func(context.Context, *sql.Conn) error

func (c *DbClient) establishConnectionPool(ctx context.Context, overrides clientConfig) error {
utils.LogTime("db_client.establishConnectionPool start")
defer utils.LogTime("db_client.establishConnectionPool end")

config, err := pgxpool.ParseConfig(c.connectionString)
pool, err := establishConnectionPool(ctx, c.connectionString)
if err != nil {
return err
}

locals := []string{
"127.0.0.1",
"::1",
"localhost",
}

// when connected to a service which is running a plugin compiled with SDK pre-v5, the plugin
// will not have the ability to turn off caching (feature introduced in SDKv5)
//
// the 'isLocalService' is used to set the client end cache to 'false' if caching is turned off in the local service
//
// this is a temporary workaround to make sure
// that we can turn off caching for plugins compiled with SDK pre-V5
// worst case scenario is that we don't switch off the cache for pre-V5 plugins
// refer to: https://github.com/turbot/steampipe/blob/f7f983a552a07e50e526fcadf2ccbfdb7b247cc0/pkg/db/db_client/db_client_session.go#L66
if helpers.StringSliceContains(locals, config.ConnConfig.Host) {
c.isLocalService = true
}

// MinConns should default to 0, but when not set, it actually get very high values (e.g. 80217984)
// this leads to a huge number of connections getting created
// TODO BINAEK dig into this and figure out why this is happening.
// We need to be sure that it is not an issue with service management
config.MinConns = 0
config.MaxConns = int32(db_common.MaxDbConnections())
config.MaxConnLifetime = MaxConnLifeTime
config.MaxConnIdleTime = MaxConnIdleTime
if c.onConnectionCallback != nil {
config.AfterConnect = c.onConnectionCallback
}
// set an app name so that we can track database connections from this Steampipe execution
// this is used to determine whether the database can safely be closed
config.ConnConfig.Config.RuntimeParams = map[string]string{
constants.RuntimeParamsKeyApplicationName: runtime.ClientConnectionAppName,
}
// TODO - how do we apply the AfterConnect hook here?
// the after connect hook used to create and populate the introspection tables

// apply any overrides
// this is used to set the pool size and lifetimes of the connections from up top
overrides.userPoolSettings.apply(config)

// this returns connection pool
dbPool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return err
}
overrides.userPoolSettings.apply(pool)

err = db_common.WaitForPool(
ctx,
dbPool,
pool,
db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second),
)
if err != nil {
return err
}
c.userPool = dbPool
c.userPool = pool

return c.establishManagementConnectionPool(ctx, config, overrides)
return c.establishManagementConnectionPool(ctx, overrides)
}

// establishSystemConnectionPool creates a connection pool to use to execute
// system-initiated queries (loading of connection state etc.)
// unlike establishConnectionPool, which is run first to create the user-query pool
// this doesn't wait for the pool to completely start, as establishConnectionPool will have established and verified a connection with the service
func (c *DbClient) establishManagementConnectionPool(ctx context.Context, config *pgxpool.Config, overrides clientConfig) error {
utils.LogTime("db_client.establishSystemConnectionPool start")
defer utils.LogTime("db_client.establishSystemConnectionPool end")
func (c *DbClient) establishManagementConnectionPool(ctx context.Context, overrides clientConfig) error {
utils.LogTime("db_client.establishManagementConnectionPool start")
defer utils.LogTime("db_client.establishManagementConnectionPool end")

// create a config from the config of the user pool
copiedConfig := createManagementPoolConfig(config, overrides)

// this returns connection pool
dbPool, err := pgxpool.NewWithConfig(context.Background(), copiedConfig)
pool, err := establishConnectionPool(ctx, c.connectionString)
if err != nil {
return err
}
c.managementPool = dbPool
return nil
}

func createManagementPoolConfig(config *pgxpool.Config, overrides clientConfig) *pgxpool.Config {
// create a copy - we will be modifying this
copiedConfig := config.Copy()
// apply any overrides
// this is used to set the pool size and lifetimes of the connections from up top
overrides.managementPoolSettings.apply(pool)

// update the app name of the connection
copiedConfig.ConnConfig.Config.RuntimeParams = map[string]string{
constants.RuntimeParamsKeyApplicationName: runtime.ClientSystemConnectionAppName,
err = db_common.WaitForPool(
ctx,
pool,
db_common.WithRetryInterval(constants.DBConnectionRetryBackoff),
db_common.WithTimeout(time.Duration(viper.GetInt(constants.ArgDatabaseStartTimeout))*time.Second),
)
if err != nil {
return err
}
c.managementPool = pool

// remove the afterConnect hook - we don't need the session data in management connections
copiedConfig.AfterConnect = nil
return nil
}

overrides.managementPoolSettings.apply(copiedConfig)
func establishConnectionPool(ctx context.Context, connectionString string) (*sql.DB, error) {
driverName := getDriverNameFromConnectionString(connectionString)
connectionString = getUseableConnectionString(driverName, connectionString)

return copiedConfig
pool, err := sql.Open(driverName, connectionString)
if err != nil {
return nil, err
}
pool.SetConnMaxIdleTime(MaxConnIdleTime)
pool.SetConnMaxLifetime(MaxConnLifeTime)
pool.SetMaxOpenConns(db_common.MaxDbConnections())
return pool, nil
}
Loading

0 comments on commit 2e463a1

Please sign in to comment.