Skip to content

Commit

Permalink
Add auto setup pg package
Browse files Browse the repository at this point in the history
tonyhb committed Sep 16, 2024
1 parent e7ee5e7 commit dbbd684
Showing 8 changed files with 317 additions and 49 deletions.
49 changes: 12 additions & 37 deletions internal/test/pg_init.go
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@ import (
"strings"
"testing"

"github.com/inngest/dbcap/pkg/replicator/pg/pgsetup"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/stretchr/testify/require"
@@ -50,16 +51,17 @@ func StartPG(t *testing.T, ctx context.Context, opts StartPGOpts) (tc.Container,
require.NoError(t, err)
}

if !opts.DisableCreateRoles {
// Create the replication slot.
err := prepareRoles(ctx, conn)
require.NoError(t, err)
}
if !opts.DisableCreateSlot {
// Create the replication slot.
err := createReplicationSlot(ctx, conn)
require.NoError(t, err)
}
connCfg, err := pgx.ParseConfig(connString(t, c))
require.NoError(t, err, "Failed to parse config")

sr, err := pgsetup.Setup(ctx, pgsetup.SetupOpts{
AdminConfig: *connCfg,
Password: "password",
DisableCreateUser: opts.DisableCreateRoles,
DisableCreateRoles: opts.DisableCreateRoles,
DisableCreateSlot: opts.DisableCreateSlot,
})
require.NoError(t, err, "Setup results: %#v", sr.Results())

err = createTables(ctx, conn)
require.NoError(t, err)
@@ -98,33 +100,6 @@ func connOpts(t *testing.T, c tc.Container) pgx.ConnConfig {
return *cfg
}

func prepareRoles(ctx context.Context, c *pgconn.PgConn) error {
stmt := `
CREATE USER inngest WITH REPLICATION PASSWORD 'password';
GRANT USAGE ON SCHEMA public TO inngest;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO inngest;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO inngest;
CREATE PUBLICATION inngest FOR ALL TABLES;
`
res := c.Exec(ctx, stmt)
if err := res.Close(); err != nil {
return err
}
return nil
}

func createReplicationSlot(ctx context.Context, c *pgconn.PgConn) error {
stmt := `
-- pgoutput logical repl plugin
SELECT pg_create_logical_replication_slot('inngest_cdc', 'pgoutput');
`
res := c.Exec(ctx, stmt)
if err := res.Close(); err != nil {
return err
}
return nil
}

func createTables(ctx context.Context, c *pgconn.PgConn) error {
stmt := `
CREATE EXTENSION IF NOT EXISTS "pgcrypto" WITH SCHEMA public;
1 change: 1 addition & 0 deletions pkg/consts/pgconsts/pgconsts.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package pgconsts

const (
Username = "inngest"
SlotName = "inngest_cdc"
PublicationName = "inngest"
)
17 changes: 11 additions & 6 deletions pkg/replicator/pg.go → pkg/replicator/pg/pg.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package replicator
package pg

import (
"context"
@@ -14,6 +14,7 @@ import (
"github.com/inngest/dbcap/pkg/changeset"
"github.com/inngest/dbcap/pkg/consts/pgconsts"
"github.com/inngest/dbcap/pkg/decoder"
"github.com/inngest/dbcap/pkg/replicator"
"github.com/inngest/dbcap/pkg/schema"
"github.com/jackc/pglogrepl"
"github.com/jackc/pgx/v5"
@@ -25,6 +26,10 @@ var (
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5

ErrInvalidCredentials = fmt.Errorf("TODO")

ErrConnectionTimeout = fmt.Errorf("TODO")

ErrLogicalReplicationNotSetUp = fmt.Errorf("ERR_PG_001: Your database does not have logical replication configured. You must set the WAL level to 'logical' to stream events.")

ErrReplicationSlotNotFound = fmt.Errorf("ERR_PG_002: The replication slot 'inngest_cdc' doesn't exist in your database. Please create the logical replication slot to stream events.")
@@ -34,7 +39,7 @@ var (

// PostgresReplicator is a Replicator with added postgres functionality.
type PostgresReplicator interface {
Replicator
replicator.Replicator

// ReplicationSlot returns the replication slot data or an error.
//
@@ -53,11 +58,11 @@ type PostgresOpts struct {
Config pgx.ConnConfig
// WatermarkSaver saves the current watermark to local storage. This should be paired with a
// WatermarkLoader to load offsets when the replicator restarts.
WatermarkSaver WatermarkSaver
WatermarkSaver replicator.WatermarkSaver
// WatermarkLoader, if specified, loads watermarks for the given connection to start replication
// from a given offset. If this isn't specified, replication will start from the latest point in
// the Postgres server's WAL.
WatermarkLoader WatermarkLoader
WatermarkLoader replicator.WatermarkLoader
// Log, if specified, is the stdlib logger used to log debug and warning messages during
// replication.
Log *slog.Logger
@@ -140,9 +145,9 @@ func (p *pg) Close(ctx context.Context) error {
return nil
}

func (p *pg) TestConnection(ctx context.Context) error {
func (p *pg) TestConnection(ctx context.Context) (replicator.ConnectionResult, error) {
_, err := p.ReplicationSlot(ctx)
return err
return nil, err
}

func (p *pg) ReplicationSlot(ctx context.Context) (ReplicationSlot, error) {
2 changes: 1 addition & 1 deletion pkg/replicator/pg_test.go → pkg/replicator/pg/pg_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package replicator
package pg

import (
"context"
271 changes: 271 additions & 0 deletions pkg/replicator/pg/pgsetup/pgsetup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
package pgsetup

import (
"context"
"database/sql"
"errors"
"fmt"

"github.com/inngest/dbcap/pkg/consts/pgconsts"
"github.com/inngest/dbcap/pkg/replicator"
"github.com/jackc/pgx/v5"
)

var (
ErrLogicalReplicationNotSetUp = fmt.Errorf("ERR_PG_001: Your database does not have logical replication configured. You must set the WAL level to 'logical' to stream events.")
ErrReplicationSlotNotFound = fmt.Errorf("ERR_PG_002: The replication slot 'inngest_cdc' doesn't exist in your database. Please create the logical replication slot to stream events.")
ErrReplicationAlreadyRunning = fmt.Errorf("ERR_PG_901: Replication is already streaming events")
)

type TestConnResult struct {
LogicalReplication replicator.ConnectionStepResult
UserCreated replicator.ConnectionStepResult
RolesGranted replicator.ConnectionStepResult
SlotCreated replicator.ConnectionStepResult
PublicationCreated replicator.ConnectionStepResult
}

func (c TestConnResult) Steps() []string {
return []string{
"logical_replication_enabled",
"user_created",
"roles_granted",
"replication_slot_created",
"publication_created",
}
}

func (c TestConnResult) Results() map[string]replicator.ConnectionStepResult {
return map[string]replicator.ConnectionStepResult{
"logical_replication_enabled": c.LogicalReplication,
"user_created": c.UserCreated,
"roles_granted": c.RolesGranted,
"replication_slot_created": c.SlotCreated,
"publication_created": c.PublicationCreated,
}
}

type SetupOpts struct {
AdminConfig pgx.ConnConfig
// Password represents the password for the replication user.
Password string

DisableCreateUser bool
DisableCreateRoles bool
DisableCreateSlot bool
DisableCreatePublication bool
}

func Setup(ctx context.Context, opts SetupOpts) (replicator.ConnectionResult, error) {
conn, err := pgx.ConnectConfig(ctx, &opts.AdminConfig)
if err != nil {
return TestConnResult{}, err
}

setup := setup{
opts: opts,
c: conn,
}
return setup.Setup(ctx)
}

func Check(ctx context.Context, opts SetupOpts) (replicator.ConnectionResult, error) {
conn, err := pgx.ConnectConfig(ctx, &opts.AdminConfig)
if err != nil {
return TestConnResult{}, err
}

setup := setup{
opts: opts,
c: conn,
}
return setup.Check(ctx)
}

type setup struct {
opts SetupOpts
c *pgx.Conn

res TestConnResult
}

func (s *setup) Check(ctx context.Context) (replicator.ConnectionResult, error) {
chain := []func(ctx context.Context) error{
s.checkWAL,
s.checkUser,
s.checkRoles,
s.checkReplicationSlot,
s.checkPublication,
}
for _, f := range chain {
if err := f(ctx); err != nil {
// Short circuit and return the connection result and first error.
return s.res, err
}
}
return s.res, nil
}

func (s *setup) Setup(ctx context.Context) (replicator.ConnectionResult, error) {
chain := []func(ctx context.Context) error{}

if !s.opts.DisableCreateUser {
chain = append(chain, s.createUser)
}
if !s.opts.DisableCreateRoles {
chain = append(chain, s.createRoles)
}
if !s.opts.DisableCreateSlot {
chain = append(chain, s.createReplicationSlot)
}
if !s.opts.DisableCreatePublication {
chain = append(chain, s.createPublication)
}
for _, f := range chain {
if err := f(ctx); err != nil {
// Short circuit and return the connection result and first error.
return s.res, err
}
}
return s.res, nil
}

func (s *setup) checkWAL(ctx context.Context) error {
var mode string
row := s.c.QueryRow(ctx, "SHOW wal_level")
err := row.Scan(&mode)
if err != nil {
s.res.LogicalReplication.Error = fmt.Errorf("Error checking WAL mode: %w", err)
return s.res.LogicalReplication.Error
}
if mode != "logical" {
s.res.LogicalReplication.Error = ErrLogicalReplicationNotSetUp
return s.res.LogicalReplication.Error
}
s.res.LogicalReplication.Complete = true
return nil
}

// checkUser checks if the UserCreated step is complete.
func (s *setup) checkUser(ctx context.Context) error {
row := s.c.QueryRow(ctx,
"SELECT 1 FROM pg_roles WHERE rolname = $1",
pgconsts.Username,
)
var i int
err := row.Scan(&i)

if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
// Add the error to the TestConnResult.
s.res.UserCreated.Error = fmt.Errorf("User '%s' does not exist", pgconsts.Username)
return s.res.UserCreated.Error
}

s.res.UserCreated.Complete = true
return nil
}

func (s *setup) createUser(ctx context.Context) error {
if err := s.checkUser(ctx); err == nil {
// The user already exists; don't need to add.
return nil
}

stmt := fmt.Sprintf(`
CREATE USER %s WITH REPLICATION PASSWORD '%s';
`, pgconsts.Username, s.opts.Password)
_, err := s.c.Exec(ctx, stmt)
if err != nil {
s.res.UserCreated.Error = fmt.Errorf("Error creating user '%s': %w", pgconsts.Username, err)
return s.res.UserCreated.Error
}
return nil
}

// checkRoles checks if the Inngest user has necessary roles
func (s *setup) checkRoles(ctx context.Context) error {
// Check roles is a stub implementation and will always execute.
return nil
}

func (s *setup) createRoles(ctx context.Context) error {
stmt := fmt.Sprintf(`
GRANT USAGE ON SCHEMA public TO %s;
GRANT SELECT ON ALL TABLES IN SCHEMA public TO %s;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT SELECT ON TABLES TO %s;
`, pgconsts.Username, pgconsts.Username, pgconsts.Username)
_, err := s.c.Exec(ctx, stmt)
if err != nil {
s.res.RolesGranted.Error = fmt.Errorf("Error granting roles for user '%s': %w", pgconsts.Username, err)
return s.res.RolesGranted.Error
}
s.res.RolesGranted.Complete = true
return nil
}

func (s *setup) checkReplicationSlot(ctx context.Context) error {
row := s.c.QueryRow(ctx,
"SELECT 1 FROM pg_replication_slots WHERE slot_name = $1",
pgconsts.SlotName,
)
var i int
err := row.Scan(&i)

if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
s.res.SlotCreated.Error = ErrReplicationSlotNotFound
return s.res.SlotCreated.Error
}

s.res.SlotCreated.Complete = true
return nil
}

func (s *setup) createReplicationSlot(ctx context.Context) error {
if err := s.checkReplicationSlot(ctx); err == nil {
return nil
}

stmt := `
-- pgoutput logical repl plugin
SELECT pg_create_logical_replication_slot('inngest_cdc', 'pgoutput');
`
_, err := s.c.Exec(ctx, stmt)
if err != nil {
s.res.SlotCreated.Error = fmt.Errorf("Error creating replication slot '%s': %w", pgconsts.SlotName, err)
return s.res.SlotCreated.Error
}
s.res.SlotCreated.Complete = true
return nil
}

func (s *setup) checkPublication(ctx context.Context) error {
row := s.c.QueryRow(ctx,
"SELECT 1 FROM pg_publication WHERE pubname = $1",
pgconsts.PublicationName,
)
var i int
err := row.Scan(&i)

if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
s.res.PublicationCreated.Error = fmt.Errorf("The publication '%s' doesn't exist in your database", pgconsts.PublicationName)
return s.res.PublicationCreated.Error
}

s.res.PublicationCreated.Complete = true
return nil
}

func (s *setup) createPublication(ctx context.Context) error {
if err := s.checkPublication(ctx); err == nil {
return nil
}

stmt := fmt.Sprintf(`CREATE PUBLICATION %s FOR ALL TABLES;`, pgconsts.PublicationName)
_, err := s.c.Exec(ctx, stmt)
if err != nil {
s.res.PublicationCreated.Error = fmt.Errorf("Error creating publication '%s': %w", pgconsts.PublicationName, err)
return s.res.PublicationCreated.Error
}
s.res.PublicationCreated.Complete = true
return nil
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package replicator
package pg

import (
"sync/atomic"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package replicator
package pg

import (
"testing"
22 changes: 19 additions & 3 deletions pkg/replicator/replicator.go
Original file line number Diff line number Diff line change
@@ -33,9 +33,25 @@ type Replicator interface {
// to cancelling the context passed into Pull.
Stop()

// TestConnection tests the replicator connection, returning an error if the
// connection or DB configuration is invalid.
TestConnection(ctx context.Context) error
// TestConnection tests the replicator connection, returning connection information
// and any errors with the setup.
TestConnection(ctx context.Context) (ConnectionResult, error)

changeset.WatermarkCommitter
}

type ConnectionResult interface {
// Steps indicates the sequential steps necessary to set up a replicator. For example,
// the Postgres replicator may return {"credentials", "user_created", "replication_slot_created"...}
// and so on for each step required to connect.
Steps() []string

// Results contains a map for each Step string listed in Steps[], representing whether each
// step has been complete and any error message for each step.
Results() map[string]ConnectionStepResult
}

type ConnectionStepResult struct {
Error error `json:"error"`
Complete bool `json:"complete"`
}

0 comments on commit dbbd684

Please sign in to comment.