Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow retrying a connection on startup #695

Merged
merged 10 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 43 additions & 29 deletions pkg/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"
"os/exec"
"os/signal"
"regexp"
"strings"
"syscall"
"time"
Expand Down Expand Up @@ -35,9 +34,6 @@ SECURITY WARNING: You are running Pgweb in read-only mode.
This mode is designed for environments where users could potentially delete or change data.
For proper read-only access please follow PostgreSQL role management documentation.
--------------------------------------------------------------------------------`

regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
regexErrAuthFailed = regexp.MustCompile(`authentication failed`)
)

func init() {
Expand Down Expand Up @@ -77,35 +73,20 @@ func initClient() {
}

if command.Opts.Debug {
fmt.Println("Server connection string:", cl.ConnectionString)
fmt.Println("Opening database connection using string:", cl.ConnectionString)
}

fmt.Println("Connecting to server...")
if err := cl.Test(); err != nil {
msg := err.Error()

// Check if we're trying to connect to the default database.
if command.Opts.DbName == "" && command.Opts.URL == "" {
// If database does not exist, allow user to connect from the UI.
if strings.Contains(msg, "database") && strings.Contains(msg, "does not exist") {
fmt.Println("Error:", msg)
return
}

// Do not bail if local server is not running.
if regexErrConnectionRefused.MatchString(msg) {
fmt.Println("Error:", msg)
return
}
retryCount := command.Opts.RetryCount
retryDelay := time.Second * time.Duration(command.Opts.RetryDelay)

// Do not bail if local auth is invalid
if regexErrAuthFailed.MatchString(msg) {
fmt.Println("Error:", msg)
return
}
fmt.Println("Connecting to server...")
abort, err := testClient(cl, int(retryCount), retryDelay)
if err != nil {
if abort {
exitWithMessage(err.Error())
} else {
return
}

exitWithMessage(msg)
}

if !command.Opts.Sessions {
Expand Down Expand Up @@ -280,6 +261,39 @@ func openPage() {
}
}

// testClient attempts to establish a database connection until it succeeds or
// give up after certain number of retries. Retries only available when database
// name or a connection string is provided.
func testClient(cl *client.Client, retryCount int, retryDelay time.Duration) (abort bool, err error) {
usingDefaultDB := command.Opts.DbName == "" && command.Opts.URL == ""

for {
err = cl.Test()
if err == nil {
return false, nil
}

// Continue normal start up if can't connect locally without database details.
if usingDefaultDB {
if errors.Is(err, client.ErrConnectionRefused) ||
errors.Is(err, client.ErrAuthFailed) ||
errors.Is(err, client.ErrDatabaseNotExist) {
return false, err
}
}

// Only retry if can't establish connection to the server.
if errors.Is(err, client.ErrConnectionRefused) && retryCount > 0 {
fmt.Printf("Connection error: %v, retrying in %v (%d remaining)\n", err, retryDelay, retryCount)
retryCount--
<-time.After(retryDelay)
continue
}

return true, err
}
}

func Run() {
initOptions()
initClient()
Expand Down
36 changes: 35 additions & 1 deletion pkg/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
neturl "net/url"
"reflect"
"regexp"
"strings"
"time"

Expand All @@ -21,6 +22,18 @@ import (
"github.com/sosedoff/pgweb/pkg/statements"
)

var (
regexErrAuthFailed = regexp.MustCompile(`(authentication failed|role "(.*)" does not exist)`)
regexErrConnectionRefused = regexp.MustCompile(`(connection|actively) refused`)
regexErrDatabaseNotExist = regexp.MustCompile(`database "(.*)" does not exist`)
)

var (
ErrAuthFailed = errors.New("authentication failed")
ErrConnectionRefused = errors.New("connection refused")
ErrDatabaseNotExist = errors.New("database does not exist")
)

type Client struct {
db *sqlx.DB
tunnel *Tunnel
Expand Down Expand Up @@ -179,7 +192,28 @@ func (client *Client) setServerVersion() {
}

func (client *Client) Test() error {
return client.db.Ping()
// NOTE: This is a different timeout defined in CLI OpenTimeout
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

err := client.db.PingContext(ctx)
if err == nil {
return nil
}

errMsg := err.Error()

if regexErrConnectionRefused.MatchString(errMsg) {
return ErrConnectionRefused
}
if regexErrAuthFailed.MatchString(errMsg) {
return ErrAuthFailed
}
if regexErrDatabaseNotExist.MatchString(errMsg) {
return ErrDatabaseNotExist
}

return err
}

func (client *Client) TestWithTimeout(timeout time.Duration) (result error) {
Expand Down
42 changes: 41 additions & 1 deletion pkg/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var (
Expand Down Expand Up @@ -199,7 +200,46 @@ func testClientIdleTime(t *testing.T) {
}

func testTest(t *testing.T) {
assert.NoError(t, testClient.Test())
examples := []struct {
name string
input string
err error
}{
{
name: "success",
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase),
err: nil,
},
{
name: "connection refused",
input: "postgresql://localhost:5433/dbname",
err: ErrConnectionRefused,
},
{
name: "invalid user",
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", "foo", serverPassword, serverHost, serverPort, serverDatabase),
err: ErrAuthFailed,
},
{
name: "invalid password",
input: fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", serverUser, "foo", serverHost, serverPort, serverDatabase),
err: ErrAuthFailed,
},
{
name: "invalid database",
input: fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, "foo"),
err: ErrDatabaseNotExist,
},
}

for _, ex := range examples {
t.Run(ex.name, func(t *testing.T) {
conn, err := NewFromUrl(ex.input, nil)
require.NoError(t, err)

require.Equal(t, ex.err, conn.Test())
})
}
}

func testInfo(t *testing.T) {
Expand Down
4 changes: 3 additions & 1 deletion pkg/command/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ type Options struct {
SSLRootCert string `long:"ssl-rootcert" description:"SSL certificate authority file"`
SSLCert string `long:"ssl-cert" description:"SSL client certificate file"`
SSLKey string `long:"ssl-key" description:"SSL client certificate key file"`
OpenTimeout int `long:"open-timeout" description:" Maximum wait for connection, in seconds" default:"30"`
OpenTimeout int `long:"open-timeout" description:"Maximum wait time for connection, in seconds" default:"30"`
RetryDelay uint `long:"open-retry-delay" description:"Number of seconds to wait before retrying the connection" default:"3"`
RetryCount uint `long:"open-retry" description:"Number of times to retry establishing connection" default:"0"`
HTTPHost string `long:"bind" description:"HTTP server host" default:"localhost"`
HTTPPort uint `long:"listen" description:"HTTP server listen port" default:"8081"`
AuthUser string `long:"auth-user" description:"HTTP basic auth user"`
Expand Down