Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WAL message heartbeat #4

Merged
merged 3 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/changeset/changeset.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ const (
OperationUpdate Operation = "UPDATE"
OperationDelete Operation = "DELETE"
OperationTruncate Operation = "TRUNCATE"

// OperationHeartbeat represents the changeset generated for heartbeats when we
// send messages to increase the WAL LSN. This is used for updating watermarks only,
// and should not process events.
OperationHeartbeat Operation = "HEARTBEAT"
)

// WatermarkCommitter is an interface that commits a given watermark to backing datastores.
Expand Down
2 changes: 2 additions & 0 deletions pkg/consts/pgconsts/pgconsts.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ const (
Username = "inngest"
SlotName = "inngest_cdc"
PublicationName = "inngest"

MessagesVersion = 14
)
20 changes: 17 additions & 3 deletions pkg/decoder/pg_logical_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,24 @@ import (
"github.com/jackc/pgx/v5/pgtype"
)

func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder {
func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger, messages bool) Decoder {
return v1LogicalDecoder{
log: log,
schema: s,
messages: messages,
relations: make(map[uint32]*pglogrepl.RelationMessage),
}
}

type v1LogicalDecoder struct {
log *slog.Logger

messages bool
schema *schema.PGXSchemaLoader
relations map[uint32]*pglogrepl.RelationMessage
}

func (v1LogicalDecoder) ReplicationPluginArgs() []string {
func (v v1LogicalDecoder) ReplicationPluginArgs() []string {
// https://www.postgresql.org/docs/current/protocol-logical-replication.html#PROTOCOL-LOGICAL-REPLICATION-PARAMS
//
// "Proto_version '2'" with "streaming 'true' streams transactions as they're progressing.
Expand All @@ -37,10 +39,17 @@ func (v1LogicalDecoder) ReplicationPluginArgs() []string {
//
// Version 1 only sends DML entries when the transaction commits, ensuring that any event
// generated by Inngest is for a committed transaction.
if v.messages {
return []string{
"proto_version '1'",
fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName),
"messages 'true'", // Doesn't work for <= v13
}
}

return []string{
"proto_version '1'",
fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName),
// "messages 'true'", // Doesn't work for v12 and v13.
}
}

Expand All @@ -49,6 +58,11 @@ func (v v1LogicalDecoder) Decode(in []byte, cs *changeset.Changeset) (bool, erro
msgType := pglogrepl.MessageType(in[0])

switch msgType {
case pglogrepl.MessageTypeMessage:
// This is a heartbeat (or another WAL message). Do nothing but record
// the heartbeat and updated watermark.
cs.Operation = changeset.OperationHeartbeat
return true, nil
case pglogrepl.MessageTypeRelation:
// MessageTypeRelation describes the OIDs for any relation before DML messages are sent. From the docs:
//
Expand Down
103 changes: 85 additions & 18 deletions pkg/replicator/pgreplicator/pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log/slog"
"os"
"strings"
"sync"
"sync/atomic"
"time"

Expand All @@ -24,8 +25,9 @@ import (
)

var (
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5
ReadTimeout = time.Second * 5
CommitInterval = time.Second * 5
DefaultHeartbeatTime = time.Minute
)

// PostgresReplicator is a Replicator with added postgres functionality.
Expand Down Expand Up @@ -61,6 +63,12 @@ type Opts struct {

// New returns a new postgres replicator for a single postgres database.
func New(ctx context.Context, opts Opts) (PostgresReplicator, error) {
if opts.Log == nil {
opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}

cfg := opts.Config

// Ensure that we add "replication": "database" as a to the replication
Expand All @@ -84,24 +92,28 @@ func New(ctx context.Context, opts Opts) (PostgresReplicator, error) {
return nil, fmt.Errorf("error connecting to postgres host for schemas: %w", err)
}

// Query for current postgres version.
var version int
row := pgxc.QueryRow(ctx, "SELECT current_setting('server_version_num')::int / 10000;")
if err := row.Scan(&version); err != nil {
opts.Log.Warn("error querying for postgres version", "error", err)
}

sl := schema.NewPGXSchemaLoader(pgxc)
// Refresh all schemas to begin with
if err := sl.Refresh(); err != nil {
return nil, err
}

if opts.Log == nil {
opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{
Level: slog.LevelInfo,
}))
}

return &pg{
opts: opts,
conn: replConn,
queryConn: pgxc,
decoder: decoder.NewV1LogicalDecoder(sl, opts.Log),
log: opts.Log,
opts: opts,
conn: replConn,
queryConn: pgxc,
queryLock: &sync.Mutex{},
decoder: decoder.NewV1LogicalDecoder(sl, opts.Log, version >= pgconsts.MessagesVersion),
log: opts.Log,
version: version,
heartbeatTime: DefaultHeartbeatTime,
}, nil
}

Expand All @@ -111,8 +123,14 @@ type pg struct {
// conn is the WAL streaming connection. Once replication starts, this
// conn cannot be used for any queries.
conn *pgx.Conn

// queryCon is a conn for querying data.
queryConn *pgx.Conn

// queryLock is used to lock pgx.Conn, as it's a single connection which cannot be used
// in parallel.
queryLock *sync.Mutex

// decoder decodes the binary WAL log
decoder decoder.Decoder
// nextReportTime records the time in which we must next report the current
Expand All @@ -125,6 +143,9 @@ type pg struct {
// log is a stdlib logger for reporting debug and warn logs.
log *slog.Logger

version int
heartbeatTime time.Duration

stopped int32
}

Expand All @@ -140,14 +161,20 @@ func (p *pg) Close(ctx context.Context) error {
}

func (p *pg) ReplicationSlot(ctx context.Context) (ReplicationSlot, error) {

mode, err := p.walMode(ctx)
if err != nil {
return ReplicationSlot{}, err
}

if mode != "logical" {
return ReplicationSlot{}, ErrLogicalReplicationNotSetUp
}

// Lock when querying repl slot data.
p.queryLock.Lock()
defer p.queryLock.Unlock()

return ReplicationSlotData(ctx, p.queryConn)
}

Expand Down Expand Up @@ -218,6 +245,29 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error {
// the DML.
unwrapper := &txnUnwrapper{cc: cc}

go func() {
if p.version < pgconsts.MessagesVersion {
// doesn't support wal messages; ignore.
return
}

t := time.NewTicker(p.heartbeatTime)
for range t.C {
if ctx.Err() != nil {
return
}

// Send a hearbeat every minute
p.queryLock.Lock()
_, err := p.queryConn.Exec(ctx, "SELECT pg_logical_emit_message(false, 'heartbeat', now()::varchar);")
p.queryLock.Unlock()

if err != nil {
p.log.Warn("unable to emit heartbeat", "error", err, "host", p.opts.Config.Host)
}
}
}()

for {
if ctx.Err() != nil || atomic.LoadInt32(&p.stopped) == 1 {
// Always call Close automatically.
Expand All @@ -233,6 +283,14 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error {
continue
}

if changes.Operation == changeset.OperationHeartbeat {
p.Commit(changes.Watermark)
if err := p.forceNextReport(ctx); err != nil {
p.log.Warn("unable to report lsn on heartbeat", "error", err, "host", p.opts.Config.Host)
}
continue
}

unwrapper.Process(changes)
}
}
Expand All @@ -259,7 +317,9 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {

if err != nil {
if pgconn.Timeout(err) {
p.forceNextReport()
if err := p.forceNextReport(ctx); err != nil {
p.log.Warn("unable to report lsn on timeout", "error", err, "host", p.opts.Config.Host)
}
// We return nil as we want to keep iterating.
return nil, nil
}
Expand Down Expand Up @@ -291,7 +351,9 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {
return nil, fmt.Errorf("error parsing replication keepalive: %w", err)
}
if pkm.ReplyRequested {
p.forceNextReport()
if err := p.forceNextReport(ctx); err != nil {
p.log.Warn("unable to report lsn on request", "error", err, "host", p.opts.Config.Host)
}
}
return nil, nil
case pglogrepl.XLogDataByteID:
Expand All @@ -316,6 +378,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) {
if err != nil {
return nil, fmt.Errorf("error decoding xlog data: %w", err)
}

if !ok {
return nil, nil
}
Expand Down Expand Up @@ -348,10 +411,11 @@ func (p *pg) committedWatermark() (wm changeset.Watermark) {
}
}

func (p *pg) forceNextReport() {
func (p *pg) forceNextReport(ctx context.Context) error {
// Updating the next report time to a zero time always reports the LSN,
// as time.Now() is always after the empty time.
p.nextReportTime = time.Time{}
return p.report(ctx, true)
}

// report reports the current replication slot's LSN progress to the server. We can optionally
Expand Down Expand Up @@ -384,6 +448,9 @@ func (p *pg) LSN() (lsn pglogrepl.LSN) {
}

func (p *pg) walMode(ctx context.Context) (string, error) {
p.queryLock.Lock()
defer p.queryLock.Unlock()

var mode string
row := p.queryConn.QueryRow(ctx, "SHOW wal_level")
err := row.Scan(&mode)
Expand All @@ -408,8 +475,8 @@ func ReplicationSlotData(ctx context.Context, conn *pgx.Conn) (ReplicationSlot,
row := conn.QueryRow(
ctx,
fmt.Sprintf(`SELECT
active, restart_lsn, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = '%s';`,
active, restart_lsn, confirmed_flush_lsn
FROM pg_replication_slots WHERE slot_name = '%s';`,
pgconsts.SlotName,
),
)
Expand Down
47 changes: 47 additions & 0 deletions pkg/replicator/pgreplicator/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,53 @@ func TestInsert(t *testing.T) {
}
}

func TestLogicalEmitHeartbeat(t *testing.T) {
t.Parallel()
versions := []int{14, 15, 16}

for _, v1 := range versions {
v := v1 // loop capture
t.Run(fmt.Sprintf("EmitHeartbeat - Postgres %d", v), func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())

c, conn := test.StartPG(t, ctx, test.StartPGOpts{Version: v})
opts := Opts{Config: conn}
repl, err := New(ctx, opts)

// heartbeat fast in tests.
r := repl.(*pg)
r.heartbeatTime = 250 * time.Millisecond
require.NoError(t, err)

cb := eventwriter.NewCallbackWriter(ctx, 1, time.Millisecond, func(batch []*changeset.Changeset) error {
return nil
})
csChan := cb.Listen(ctx, r)

go func() {
err := r.Pull(ctx, csChan)
require.NoError(t, err)
}()

slotA, err := r.ReplicationSlot(ctx)
require.NoError(t, err)

<-time.After(1100 * time.Millisecond)

slotB, err := r.ReplicationSlot(ctx)
require.NoError(t, err)

require.NotEqual(t, slotA.ConfirmedFlushLSN, slotB.ConfirmedFlushLSN)
require.True(t, int(slotB.ConfirmedFlushLSN) > int(slotA.ConfirmedFlushLSN))

cancel()
_ = c.Stop(ctx, nil)
})
}
}

func TestUpdateMany_ReplicaIdentityFull(t *testing.T) {
t.Parallel()
versions := []int{12, 13, 14, 15, 16}
Expand Down
6 changes: 6 additions & 0 deletions pkg/replicator/pgreplicator/txn_unwrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ func (t *txnUnwrapper) Process(cs *changeset.Changeset) {
}

switch cs.Operation {
case changeset.OperationHeartbeat:
// The unwrapper should never receive heartbeats as the replicator should
// handle them and short circuit. However, always transmit them immediately
// for safety in code in case someone changes something in the future.
t.cc <- cs
return
case changeset.OperationBegin:
t.begin = cs
case changeset.OperationCommit:
Expand Down
Loading