Skip to content

Commit

Permalink
add connection keepalive (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
vadv authored Sep 26, 2019
1 parent a8e5e77 commit df99bbc
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 4 deletions.
40 changes: 40 additions & 0 deletions gatherer/internal/connection/driver.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package connection

import (
"database/sql"
"database/sql/driver"
"net"
"time"

"github.com/lib/pq"
)

// KeepAliveDuration is the duration between keepalives for all new Postgres
// connections.
var KeepAliveDuration = 5 * time.Second

func init() {
sql.Register("gatherer-pq", &enhancedDriver{})
}

// enhancedDriver is a wrapper over lib/pq to mimic jackc/pgx's keepalive
// policy. This avoids an issue where the NAT kills an "idle" connection while
// it is waiting on a long-running query.
type enhancedDriver struct{}

// Open returns a new SQL driver connection with our custom settings.
func (d *enhancedDriver) Open(name string) (driver.Conn, error) {
return pq.DialOpen(&dialer{}, name)
}

type dialer struct{}

func (d dialer) Dial(ntw, addr string) (net.Conn, error) {
customDialer := net.Dialer{KeepAlive: KeepAliveDuration}
return customDialer.Dial(ntw, addr)
}

func (d dialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
customDialer := net.Dialer{Timeout: timeout, KeepAlive: KeepAliveDuration}
return customDialer.Dial(ntw, addr)
}
19 changes: 15 additions & 4 deletions gatherer/internal/connection/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@ package connection
import (
"database/sql"
"sync"
"sync/atomic"
)

var (
connectionPool *connPool
poolDatabasesOpen *int32 // for testing
maxOpenConns uint
)

func init() {
maxOpenConns = 5
connectionPool = &connPool{
mutex: sync.Mutex{},
pool: make(map[string]*sql.DB),
}
maxOpenConns = uint(5)
)
zero := int32(0)
poolDatabasesOpen = &zero
}

// SetMaxOpenConns set max open connections
func SetMaxOpenConns(i uint) {
Expand All @@ -20,6 +29,7 @@ func SetMaxOpenConns(i uint) {
defer connectionPool.mutex.Unlock()
for _, db := range connectionPool.pool {
db.SetMaxOpenConns(int(maxOpenConns))
db.SetMaxIdleConns(int(maxOpenConns))
}
}

Expand All @@ -29,11 +39,12 @@ type connPool struct {
}

func newPostgresConnection(connectionString string) (*sql.DB, error) {
db, err := sql.Open(`postgres`, connectionString)
atomic.AddInt32(poolDatabasesOpen, 1)
db, err := sql.Open(`gatherer-pq`, connectionString)
if err != nil {
return nil, err
}
db.SetMaxIdleConns(1)
db.SetMaxIdleConns(int(maxOpenConns))
db.SetMaxOpenConns(int(maxOpenConns))
return db, err
}
Expand Down
82 changes: 82 additions & 0 deletions gatherer/internal/connection/pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package connection

import (
"database/sql"
"sync"
"testing"

lua "github.com/yuin/gopher-lua"
)

const doAvailableConnections = `
connection:query('select 1')
for _, conn in pairs(connection:available_connections()) do
conn:query('select 1')
end
`

func TestNew(t *testing.T) {
wait := sync.WaitGroup{}
count := 100
wait.Add(count)
countOfDatabases := getCountOfDatabases(t)
SetMaxOpenConns(1)
for i := 0; i < count; i++ {
go func() {
defer wait.Done()
state := lua.NewState()
Preload(state)
params := make(map[string]string)
params[`fallback_application_name`] = `test`
params[`connect_timeout`] = `5`
New(state, `connection`,
"/tmp", "gatherer", "gatherer", "", 5432, params)
if err := state.DoString(doAvailableConnections); err != nil {
t.Fatalf("do: %s\n", err.Error())
}
}()
}
wait.Wait()
if *poolDatabasesOpen != int32(countOfDatabases) {
t.Fatalf("open: %d count: %d\n", *poolDatabasesOpen, countOfDatabases)
}
if len(connectionPool.pool) != countOfDatabases {
t.Fatalf("pool: %#v\n", connectionPool.pool)
}
if connections := getCountOfApplicationNameTest(t); connections != countOfDatabases {
t.Fatalf("databases: %d connections: %d\n", countOfDatabases, connections)
}
}

func getCountOfDatabases(t *testing.T) int {
db, err := sql.Open(`postgres`, `host=/tmp dbname=gatherer user=gatherer port=5432`)
if err != nil {
t.Fatalf("open: %s\n", err.Error())
}
row := db.QueryRow(`select
count(d.datname)
from
pg_catalog.pg_database d
where has_database_privilege(d.datname, 'connect') and not d.datistemplate
`)
defer db.Close()
var result int
if errScan := row.Scan(&result); errScan != nil {
t.Fatalf("scan: %s\n", errScan.Error())
}
return result
}

func getCountOfApplicationNameTest(t *testing.T) int {
db, err := sql.Open(`postgres`, `host=/tmp dbname=gatherer user=gatherer port=5432`)
if err != nil {
t.Fatalf("open: %s\n", err.Error())
}
row := db.QueryRow(`select count(*) from pg_stat_activity where application_name = 'test'`)
defer db.Close()
var result int
if errScan := row.Scan(&result); errScan != nil {
t.Fatalf("scan: %s\n", errScan.Error())
}
return result
}

0 comments on commit df99bbc

Please sign in to comment.