From b936061ec6fcb8b33dbf5327c763126fea2607d7 Mon Sep 17 00:00:00 2001 From: Zhen Ye Date: Mon, 10 Feb 2025 10:50:44 +0800 Subject: [PATCH] enhance: keep consistent of memory and meta of broadcaster (#39721) issue: #38399 pr: #39720 Signed-off-by: chyezh --- .../server/broadcaster/broadcast_task.go | 121 +++++++++++------- 1 file changed, 74 insertions(+), 47 deletions(-) diff --git a/internal/streamingcoord/server/broadcaster/broadcast_task.go b/internal/streamingcoord/server/broadcaster/broadcast_task.go index 92c322875ed2c..a73e00f2657d8 100644 --- a/internal/streamingcoord/server/broadcaster/broadcast_task.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_task.go @@ -6,6 +6,7 @@ import ( "github.com/cockroachdb/errors" "go.uber.org/zap" + "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/pkg/log" @@ -18,15 +19,10 @@ import ( func newBroadcastTaskFromProto(proto *streamingpb.BroadcastTask) *broadcastTask { msg := message.NewBroadcastMutableMessageBeforeAppend(proto.Message.Payload, proto.Message.Properties) bh := msg.BroadcastHeader() - ackedCount := 0 - for _, acked := range proto.AckedVchannelBitmap { - ackedCount += int(acked) - } return &broadcastTask{ mu: sync.Mutex{}, header: bh, task: proto, - ackedCount: ackedCount, recoverPersisted: true, // the task is recovered from the recovery info, so it's persisted. } } @@ -43,7 +39,6 @@ func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage) * State: streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_PENDING, AckedVchannelBitmap: make([]byte, len(header.VChannels)), }, - ackedCount: 0, recoverPersisted: false, } } @@ -51,11 +46,9 @@ func newBroadcastTaskFromBroadcastMessage(msg message.BroadcastMutableMessage) * // broadcastTask is the state of the broadcast task. type broadcastTask struct { log.Binder - mu sync.Mutex - header *message.BroadcastHeader - task *streamingpb.BroadcastTask - ackedCount int // the count of the acked vchannels, the idompotenace is promised by task's bitmap. - // always keep same with the positive counter of task's acked_bitmap. + mu sync.Mutex + header *message.BroadcastHeader + task *streamingpb.BroadcastTask recoverPersisted bool // a flag to indicate that the task has been persisted into the recovery info and can be recovered. } @@ -80,10 +73,6 @@ func (b *broadcastTask) PendingBroadcastMessages() []message.MutableMessage { msg := message.NewBroadcastMutableMessageBeforeAppend(b.task.Message.Payload, b.task.Message.Properties) msgs := msg.SplitIntoMutableMessage() - // If there's no vchannel acked, return all the messages directly. - if b.ackedCount == 0 { - return msgs - } // filter out the vchannel that has been acked. pendingMessages := make([]message.MutableMessage, 0, len(msgs)) for i, msg := range msgs { @@ -103,7 +92,7 @@ func (b *broadcastTask) InitializeRecovery(ctx context.Context) error { if b.recoverPersisted { return nil } - if err := b.saveTask(ctx, b.Logger()); err != nil { + if err := b.saveTask(ctx, b.task, b.Logger()); err != nil { return err } b.recoverPersisted = true @@ -115,36 +104,44 @@ func (b *broadcastTask) Ack(ctx context.Context, vchannel string) error { b.mu.Lock() defer b.mu.Unlock() - b.setVChannelAcked(vchannel) - if b.isAllDone() { - // All vchannels are acked, mark the task as done, even if there are still pending messages on working. - // The pending messages is repeated sent operation, can be ignored. - b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE + task, ok := b.copyAndSetVChannelAcked(vchannel) + if !ok { + return nil } + // We should always save the task after acked. // Even if the task mark as done in memory. // Because the task is set as done in memory before save the recovery info. - return b.saveTask(ctx, b.Logger().With(zap.String("ackVChannel", vchannel))) + if err := b.saveTask(ctx, task, b.Logger().With(zap.String("ackVChannel", vchannel))); err != nil { + return err + } + b.task = task + return nil } -// setVChannelAcked sets the vchannel as acked. -func (b *broadcastTask) setVChannelAcked(vchannel string) { - idx, err := b.findIdxOfVChannel(vchannel) +// copyAndSetVChannelAcked copies the task and set the vchannel as acked. +// if the vchannel is already acked, it returns nil and false. +func (b *broadcastTask) copyAndSetVChannelAcked(vchannel string) (*streamingpb.BroadcastTask, bool) { + task := proto.Clone(b.task).(*streamingpb.BroadcastTask) + idx, err := findIdxOfVChannel(vchannel, b.Header().VChannels) if err != nil { panic(err) } - b.task.AckedVchannelBitmap[idx] = 1 - // Check if all vchannels are acked. - ackedCount := 0 - for _, acked := range b.task.AckedVchannelBitmap { - ackedCount += int(acked) + if task.AckedVchannelBitmap[idx] != 0 { + return nil, false + } + task.AckedVchannelBitmap[idx] = 1 + if isAllDone(task) { + // All vchannels are acked, mark the task as done, even if there are still pending messages on working. + // The pending messages is repeated sent operation, can be ignored. + task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE } - b.ackedCount = ackedCount + return task, true } // findIdxOfVChannel finds the index of the vchannel in the broadcast task. -func (b *broadcastTask) findIdxOfVChannel(vchannel string) (int, error) { - for i, channelName := range b.header.VChannels { +func findIdxOfVChannel(vchannel string, vchannels []string) (int, error) { + for i, channelName := range vchannels { if channelName == vchannel { return i, nil } @@ -152,44 +149,74 @@ func (b *broadcastTask) findIdxOfVChannel(vchannel string) (int, error) { return -1, errors.Errorf("unreachable: vchannel is %s not found in the broadcast task", vchannel) } -// isAllDone check if all the vchannels are acked. -func (b *broadcastTask) isAllDone() bool { - return b.ackedCount == len(b.header.VChannels) -} - // BroadcastDone marks the broadcast operation is done. func (b *broadcastTask) BroadcastDone(ctx context.Context) error { b.mu.Lock() defer b.mu.Unlock() - if b.isAllDone() { + task := b.copyAndMarkBroadcastDone() + if err := b.saveTask(ctx, task, b.Logger()); err != nil { + return err + } + b.task = task + return nil +} + +// copyAndMarkBroadcastDone copies the task and mark the broadcast task as done. +func (b *broadcastTask) copyAndMarkBroadcastDone() *streamingpb.BroadcastTask { + task := proto.Clone(b.task).(*streamingpb.BroadcastTask) + if isAllDone(task) { // If all vchannels are acked, mark the task as done. - b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE + task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_DONE } else { // There's no more pending message, mark the task as wait ack. - b.task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK + task.State = streamingpb.BroadcastTaskState_BROADCAST_TASK_STATE_WAIT_ACK } - return b.saveTask(ctx, b.Logger()) + return task } // IsAllAcked returns true if all the vchannels are acked. func (b *broadcastTask) IsAllAcked() bool { b.mu.Lock() defer b.mu.Unlock() - return b.isAllDone() + return isAllDone(b.task) +} + +// isAllDone check if all the vchannels are acked. +func isAllDone(task *streamingpb.BroadcastTask) bool { + for _, acked := range task.AckedVchannelBitmap { + if acked == 0 { + return false + } + } + return true +} + +// ackedCount returns the count of the acked vchannels. +func ackedCount(task *streamingpb.BroadcastTask) int { + count := 0 + for _, acked := range task.AckedVchannelBitmap { + count += int(acked) + } + return count } // IsAcked returns true if any vchannel is acked. func (b *broadcastTask) IsAcked() bool { b.mu.Lock() defer b.mu.Unlock() - return b.ackedCount > 0 + for _, acked := range b.task.AckedVchannelBitmap { + if acked != 0 { + return true + } + } + return false } // saveTask saves the broadcast task recovery info. -func (b *broadcastTask) saveTask(ctx context.Context, logger *log.MLogger) error { - logger = logger.With(zap.String("state", b.task.State.String()), zap.Int("ackedVChannelCount", b.ackedCount)) - if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.header.BroadcastID, b.task); err != nil { +func (b *broadcastTask) saveTask(ctx context.Context, task *streamingpb.BroadcastTask, logger *log.MLogger) error { + logger = logger.With(zap.String("state", task.State.String()), zap.Int("ackedVChannelCount", ackedCount(task))) + if err := resource.Resource().StreamingCatalog().SaveBroadcastTask(ctx, b.header.BroadcastID, task); err != nil { logger.Warn("save broadcast task failed", zap.Error(err)) return err }