diff --git a/pkg/changeset/changeset.go b/pkg/changeset/changeset.go index 833f2f4..33caf1c 100644 --- a/pkg/changeset/changeset.go +++ b/pkg/changeset/changeset.go @@ -16,6 +16,11 @@ const ( OperationUpdate Operation = "UPDATE" OperationDelete Operation = "DELETE" OperationTruncate Operation = "TRUNCATE" + + // OperationHeartbeat represents the changeset generated for heartbeats when we + // send messages to increase the WAL LSN. This is used for updating watermarks only, + // and should not process events. + OperationHeartbeat Operation = "HEARTBEAT" ) // WatermarkCommitter is an interface that commits a given watermark to backing datastores. diff --git a/pkg/consts/pgconsts/pgconsts.go b/pkg/consts/pgconsts/pgconsts.go index 14e0ebd..7d0def3 100644 --- a/pkg/consts/pgconsts/pgconsts.go +++ b/pkg/consts/pgconsts/pgconsts.go @@ -4,4 +4,6 @@ const ( Username = "inngest" SlotName = "inngest_cdc" PublicationName = "inngest" + + MessagesVersion = 14 ) diff --git a/pkg/decoder/pg_logical_v1.go b/pkg/decoder/pg_logical_v1.go index b978776..7224b84 100644 --- a/pkg/decoder/pg_logical_v1.go +++ b/pkg/decoder/pg_logical_v1.go @@ -12,10 +12,11 @@ import ( "github.com/jackc/pgx/v5/pgtype" ) -func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder { +func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger, messages bool) Decoder { return v1LogicalDecoder{ log: log, schema: s, + messages: messages, relations: make(map[uint32]*pglogrepl.RelationMessage), } } @@ -23,11 +24,12 @@ func NewV1LogicalDecoder(s *schema.PGXSchemaLoader, log *slog.Logger) Decoder { type v1LogicalDecoder struct { log *slog.Logger + messages bool schema *schema.PGXSchemaLoader relations map[uint32]*pglogrepl.RelationMessage } -func (v1LogicalDecoder) ReplicationPluginArgs() []string { +func (v v1LogicalDecoder) ReplicationPluginArgs() []string { // https://www.postgresql.org/docs/current/protocol-logical-replication.html#PROTOCOL-LOGICAL-REPLICATION-PARAMS // // "Proto_version '2'" with "streaming 'true' streams transactions as they're progressing. @@ -37,10 +39,17 @@ func (v1LogicalDecoder) ReplicationPluginArgs() []string { // // Version 1 only sends DML entries when the transaction commits, ensuring that any event // generated by Inngest is for a committed transaction. + if v.messages { + return []string{ + "proto_version '1'", + fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName), + "messages 'true'", // Doesn't work for <= v13 + } + } + return []string{ "proto_version '1'", fmt.Sprintf("publication_names '%s'", pgconsts.PublicationName), - // "messages 'true'", // Doesn't work for v12 and v13. } } @@ -49,6 +58,11 @@ func (v v1LogicalDecoder) Decode(in []byte, cs *changeset.Changeset) (bool, erro msgType := pglogrepl.MessageType(in[0]) switch msgType { + case pglogrepl.MessageTypeMessage: + // This is a heartbeat (or another WAL message). Do nothing but record + // the heartbeat and updated watermark. + cs.Operation = changeset.OperationHeartbeat + return true, nil case pglogrepl.MessageTypeRelation: // MessageTypeRelation describes the OIDs for any relation before DML messages are sent. From the docs: // diff --git a/pkg/replicator/pgreplicator/pg.go b/pkg/replicator/pgreplicator/pg.go index 4735a27..d8d13a7 100644 --- a/pkg/replicator/pgreplicator/pg.go +++ b/pkg/replicator/pgreplicator/pg.go @@ -9,6 +9,7 @@ import ( "log/slog" "os" "strings" + "sync" "sync/atomic" "time" @@ -24,8 +25,9 @@ import ( ) var ( - ReadTimeout = time.Second * 5 - CommitInterval = time.Second * 5 + ReadTimeout = time.Second * 5 + CommitInterval = time.Second * 5 + DefaultHeartbeatTime = time.Minute ) // PostgresReplicator is a Replicator with added postgres functionality. @@ -61,6 +63,12 @@ type Opts struct { // New returns a new postgres replicator for a single postgres database. func New(ctx context.Context, opts Opts) (PostgresReplicator, error) { + if opts.Log == nil { + opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + } + cfg := opts.Config // Ensure that we add "replication": "database" as a to the replication @@ -84,24 +92,28 @@ func New(ctx context.Context, opts Opts) (PostgresReplicator, error) { return nil, fmt.Errorf("error connecting to postgres host for schemas: %w", err) } + // Query for current postgres version. + var version int + row := pgxc.QueryRow(ctx, "SELECT current_setting('server_version_num')::int / 10000;") + if err := row.Scan(&version); err != nil { + opts.Log.Warn("error querying for postgres version", "error", err) + } + sl := schema.NewPGXSchemaLoader(pgxc) // Refresh all schemas to begin with if err := sl.Refresh(); err != nil { return nil, err } - if opts.Log == nil { - opts.Log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ - Level: slog.LevelInfo, - })) - } - return &pg{ - opts: opts, - conn: replConn, - queryConn: pgxc, - decoder: decoder.NewV1LogicalDecoder(sl, opts.Log), - log: opts.Log, + opts: opts, + conn: replConn, + queryConn: pgxc, + queryLock: &sync.Mutex{}, + decoder: decoder.NewV1LogicalDecoder(sl, opts.Log, version >= pgconsts.MessagesVersion), + log: opts.Log, + version: version, + heartbeatTime: DefaultHeartbeatTime, }, nil } @@ -111,8 +123,14 @@ type pg struct { // conn is the WAL streaming connection. Once replication starts, this // conn cannot be used for any queries. conn *pgx.Conn + // queryCon is a conn for querying data. queryConn *pgx.Conn + + // queryLock is used to lock pgx.Conn, as it's a single connection which cannot be used + // in parallel. + queryLock *sync.Mutex + // decoder decodes the binary WAL log decoder decoder.Decoder // nextReportTime records the time in which we must next report the current @@ -125,6 +143,9 @@ type pg struct { // log is a stdlib logger for reporting debug and warn logs. log *slog.Logger + version int + heartbeatTime time.Duration + stopped int32 } @@ -140,14 +161,20 @@ func (p *pg) Close(ctx context.Context) error { } func (p *pg) ReplicationSlot(ctx context.Context) (ReplicationSlot, error) { + mode, err := p.walMode(ctx) if err != nil { return ReplicationSlot{}, err } + if mode != "logical" { return ReplicationSlot{}, ErrLogicalReplicationNotSetUp } + // Lock when querying repl slot data. + p.queryLock.Lock() + defer p.queryLock.Unlock() + return ReplicationSlotData(ctx, p.queryConn) } @@ -218,6 +245,29 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error { // the DML. unwrapper := &txnUnwrapper{cc: cc} + go func() { + if p.version < pgconsts.MessagesVersion { + // doesn't support wal messages; ignore. + return + } + + t := time.NewTicker(p.heartbeatTime) + for range t.C { + if ctx.Err() != nil { + return + } + + // Send a hearbeat every minute + p.queryLock.Lock() + _, err := p.queryConn.Exec(ctx, "SELECT pg_logical_emit_message(false, 'heartbeat', now()::varchar);") + p.queryLock.Unlock() + + if err != nil { + p.log.Warn("unable to emit heartbeat", "error", err, "host", p.opts.Config.Host) + } + } + }() + for { if ctx.Err() != nil || atomic.LoadInt32(&p.stopped) == 1 { // Always call Close automatically. @@ -233,6 +283,14 @@ func (p *pg) Pull(ctx context.Context, cc chan *changeset.Changeset) error { continue } + if changes.Operation == changeset.OperationHeartbeat { + p.Commit(changes.Watermark) + if err := p.forceNextReport(ctx); err != nil { + p.log.Warn("unable to report lsn on heartbeat", "error", err, "host", p.opts.Config.Host) + } + continue + } + unwrapper.Process(changes) } } @@ -259,7 +317,9 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) { if err != nil { if pgconn.Timeout(err) { - p.forceNextReport() + if err := p.forceNextReport(ctx); err != nil { + p.log.Warn("unable to report lsn on timeout", "error", err, "host", p.opts.Config.Host) + } // We return nil as we want to keep iterating. return nil, nil } @@ -291,7 +351,9 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) { return nil, fmt.Errorf("error parsing replication keepalive: %w", err) } if pkm.ReplyRequested { - p.forceNextReport() + if err := p.forceNextReport(ctx); err != nil { + p.log.Warn("unable to report lsn on request", "error", err, "host", p.opts.Config.Host) + } } return nil, nil case pglogrepl.XLogDataByteID: @@ -316,6 +378,7 @@ func (p *pg) fetch(ctx context.Context) (*changeset.Changeset, error) { if err != nil { return nil, fmt.Errorf("error decoding xlog data: %w", err) } + if !ok { return nil, nil } @@ -348,10 +411,11 @@ func (p *pg) committedWatermark() (wm changeset.Watermark) { } } -func (p *pg) forceNextReport() { +func (p *pg) forceNextReport(ctx context.Context) error { // Updating the next report time to a zero time always reports the LSN, // as time.Now() is always after the empty time. p.nextReportTime = time.Time{} + return p.report(ctx, true) } // report reports the current replication slot's LSN progress to the server. We can optionally @@ -384,6 +448,9 @@ func (p *pg) LSN() (lsn pglogrepl.LSN) { } func (p *pg) walMode(ctx context.Context) (string, error) { + p.queryLock.Lock() + defer p.queryLock.Unlock() + var mode string row := p.queryConn.QueryRow(ctx, "SHOW wal_level") err := row.Scan(&mode) @@ -408,8 +475,8 @@ func ReplicationSlotData(ctx context.Context, conn *pgx.Conn) (ReplicationSlot, row := conn.QueryRow( ctx, fmt.Sprintf(`SELECT - active, restart_lsn, confirmed_flush_lsn - FROM pg_replication_slots WHERE slot_name = '%s';`, + active, restart_lsn, confirmed_flush_lsn + FROM pg_replication_slots WHERE slot_name = '%s';`, pgconsts.SlotName, ), ) diff --git a/pkg/replicator/pgreplicator/pg_test.go b/pkg/replicator/pgreplicator/pg_test.go index 900d071..60004bc 100644 --- a/pkg/replicator/pgreplicator/pg_test.go +++ b/pkg/replicator/pgreplicator/pg_test.go @@ -214,6 +214,53 @@ func TestInsert(t *testing.T) { } } +func TestLogicalEmitHeartbeat(t *testing.T) { + t.Parallel() + versions := []int{14, 15, 16} + + for _, v1 := range versions { + v := v1 // loop capture + t.Run(fmt.Sprintf("EmitHeartbeat - Postgres %d", v), func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + + c, conn := test.StartPG(t, ctx, test.StartPGOpts{Version: v}) + opts := Opts{Config: conn} + repl, err := New(ctx, opts) + + // heartbeat fast in tests. + r := repl.(*pg) + r.heartbeatTime = 250 * time.Millisecond + require.NoError(t, err) + + cb := eventwriter.NewCallbackWriter(ctx, 1, time.Millisecond, func(batch []*changeset.Changeset) error { + return nil + }) + csChan := cb.Listen(ctx, r) + + go func() { + err := r.Pull(ctx, csChan) + require.NoError(t, err) + }() + + slotA, err := r.ReplicationSlot(ctx) + require.NoError(t, err) + + <-time.After(1100 * time.Millisecond) + + slotB, err := r.ReplicationSlot(ctx) + require.NoError(t, err) + + require.NotEqual(t, slotA.ConfirmedFlushLSN, slotB.ConfirmedFlushLSN) + require.True(t, int(slotB.ConfirmedFlushLSN) > int(slotA.ConfirmedFlushLSN)) + + cancel() + _ = c.Stop(ctx, nil) + }) + } +} + func TestUpdateMany_ReplicaIdentityFull(t *testing.T) { t.Parallel() versions := []int{12, 13, 14, 15, 16} diff --git a/pkg/replicator/pgreplicator/txn_unwrapper.go b/pkg/replicator/pgreplicator/txn_unwrapper.go index 60062f5..1906d2d 100644 --- a/pkg/replicator/pgreplicator/txn_unwrapper.go +++ b/pkg/replicator/pgreplicator/txn_unwrapper.go @@ -30,6 +30,12 @@ func (t *txnUnwrapper) Process(cs *changeset.Changeset) { } switch cs.Operation { + case changeset.OperationHeartbeat: + // The unwrapper should never receive heartbeats as the replicator should + // handle them and short circuit. However, always transmit them immediately + // for safety in code in case someone changes something in the future. + t.cc <- cs + return case changeset.OperationBegin: t.begin = cs case changeset.OperationCommit: