diff --git a/chotki.go b/chotki.go index 0f6a64f..4a58a8a 100644 --- a/chotki.go +++ b/chotki.go @@ -548,10 +548,7 @@ func (cho *Chotki) Drain(ctx context.Context, recs protocol.Records) (err error) case 'B': // bye dear cho.log.InfoCtx(ctx, "received session end", "id", id.String()) cho.syncs.Delete(id) - case 'A': - cho.log.InfoCtx(ctx, "received ping") - case 'Z': - cho.log.InfoCtx(ctx, "received pong") + case 'P': // ping noop default: return fmt.Errorf("unsupported packet type %c", lit) } diff --git a/chotki_test.go b/chotki_test.go index 0aa36f8..7a01687 100644 --- a/chotki_test.go +++ b/chotki_test.go @@ -182,8 +182,6 @@ func TestChotki_SyncLivePingsOk(t *testing.T) { assert.Equal(t, SendLive, synca.GetFeedState()) assert.Equal(t, SendLive, syncb.GetFeedState()) - assert.Equal(t, SendLive, synca.GetDrainState()) - assert.Equal(t, SendLive, syncb.GetDrainState()) cancel() // wait until everything stopped time.Sleep(time.Millisecond * 100) diff --git a/sync.go b/sync.go index 8af6c2f..c5e7ac2 100644 --- a/sync.go +++ b/sync.go @@ -53,6 +53,9 @@ func (m *SyncMode) Unzip(raw []byte) error { return nil } +const PingVal = "ping" +const PongVal = "pong" + type SyncState int const ( @@ -203,7 +206,7 @@ func (sync *Syncer) Feed(ctx context.Context) (recs protocol.Records, err error) } case SendPing: recs = protocol.Records{ - protocol.Record('A', rdx.Stlv("ping")), + protocol.Record('P', rdx.Stlv(PingVal)), } sync.SetFeedState(ctx, SendLive) sync.pingStage.Store(int32(Inactive)) @@ -214,7 +217,7 @@ func (sync *Syncer) Feed(ctx context.Context) (recs protocol.Records, err error) }) case SendPong: recs = protocol.Records{ - protocol.Record('Z', rdx.Stlv("pong")), + protocol.Record('P', rdx.Stlv(PongVal)), } sync.pingStage.Store(int32(Inactive)) sync.SetFeedState(ctx, SendLive) @@ -429,11 +432,29 @@ func (sync *Syncer) resetPingTimer() { }) } +func (sync *Syncer) processPings(recs protocol.Records) { + for _, rec := range recs { + if protocol.Lit(rec) == 'P' { + body, _ := protocol.Take('P', rec) + switch rdx.Snative(body) { + case PingVal: + sync.log.Info("ping received", sync.withDefaultArgs()...) + // go to pong state next time + sync.pingStage.Store(int32(Pong)) + case PongVal: + sync.log.Info("pong received", sync.withDefaultArgs()...) + } + } + } +} + func (sync *Syncer) Drain(ctx context.Context, recs protocol.Records) (err error) { if len(recs) == 0 { return nil } + sync.processPings(recs) + switch sync.drainState { case SendHandshake: if len(recs) == 0 { @@ -462,9 +483,6 @@ func (sync *Syncer) Drain(ctx context.Context, recs protocol.Records) (err error } else { sync.SetDrainState(ctx, SendLive) } - if lit == 'A' { - sync.pingStage.Store(int32(Pong)) - } } if sync.Mode&SyncLive != 0 { sync.resetPingTimer() @@ -480,9 +498,6 @@ func (sync *Syncer) Drain(ctx context.Context, recs protocol.Records) (err error if lit == 'B' { sync.SetDrainState(ctx, SendNone) } - if lit == 'A' { - sync.pingStage.Store(int32(Pong)) - } err = sync.Host.Drain(sync.log.WithDefaultArgs(ctx, sync.withDefaultArgs()...), recs) if err == nil { sync.Host.Broadcast(sync.log.WithDefaultArgs(ctx, sync.withDefaultArgs()...), recs, sync.Name)