diff --git a/core/reader/replicate_channel_manager.go b/core/reader/replicate_channel_manager.go index 7a22844..4ab6654 100644 --- a/core/reader/replicate_channel_manager.go +++ b/core/reader/replicate_channel_manager.go @@ -564,7 +564,8 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode func (r *replicateChannelManager) StopReadCollection(ctx context.Context, info *pb.CollectionInfo) error { for _, channel := range info.GetPhysicalChannelNames() { - r.stopReadChannel(channel, info.ID) + handler := r.stopReadChannel(channel, info.ID) + handler.Close() } r.collectionLock.Lock() closeChan, ok := r.replicateCollections[info.ID] @@ -816,17 +817,17 @@ func (r *replicateChannelManager) isDroppedPartition(partition int64) bool { return ok } -func (r *replicateChannelManager) stopReadChannel(pChannelName string, collectionID int64) { +func (r *replicateChannelManager) stopReadChannel(pChannelName string, collectionID int64) *replicateChannelHandler { r.channelLock.RLock() mapKey := r.getChannelMapKey(collectionID, pChannelName) if mapKey == "" { r.channelLock.RUnlock() - return + return nil } channelHandler, ok := r.channelHandlerMap[mapKey] if !ok { r.channelLock.RUnlock() - return + return nil } r.channelLock.RUnlock() channelHandler.RemoveCollection(collectionID) @@ -834,6 +835,11 @@ func (r *replicateChannelManager) stopReadChannel(pChannelName string, collectio // if channelHandler.IsEmpty() { // channelHandler.Close() // } + return channelHandler +} + +func (r *replicateChannelManager) Close(handler *replicateChannelHandler) { + handler.Close() } type replicateChannelHandler struct { @@ -874,7 +880,12 @@ type replicateChannelHandler struct { } func (r *replicateChannelHandler) AddCollection(sourceInfo *model.SourceCollectionInfo, targetInfo *model.TargetCollectionInfo) { - <-r.startReadChan + select { + case <-r.replicateCtx.Done(): + log.Warn("replicate channel handler closed") + return + case <-r.startReadChan: + } r.collectionSourceSeekPosition(sourceInfo.SeekPosition) collectionID := sourceInfo.CollectionID streamChan, closeStreamFunc, err := r.streamCreator.GetStreamChan(r.replicateCtx, sourceInfo.VChannel, sourceInfo.SeekPosition) @@ -1152,6 +1163,13 @@ func (r *replicateChannelHandler) IsEmpty() bool { func (r *replicateChannelHandler) Close() { // r.stream.Close() + r.recordLock.Lock() + defer r.recordLock.Unlock() + for _, closeStreamFunc := range r.closeStreamFuncs { + if closeStreamFunc != nil { + _ = closeStreamFunc.Close() + } + } } func (r *replicateChannelHandler) getTSManagerChannelKey(channelName string) string { diff --git a/server/cdc_impl.go b/server/cdc_impl.go index cbcc994..0dddffc 100644 --- a/server/cdc_impl.go +++ b/server/cdc_impl.go @@ -173,12 +173,11 @@ func (e *MetaCDC) ReloadTask() { for _, taskInfo := range taskInfos { uKey := getTaskUniqueIDFromInfo(taskInfo) - newCollectionNames := lo.Map(taskInfo.CollectionInfos, func(t model.CollectionInfo, _ int) string { - return t.Name - }) + newCollectionNames := GetCollectionNamesFromTaskInfo(taskInfo) e.collectionNames.data[uKey] = append(e.collectionNames.data[uKey], newCollectionNames...) e.collectionNames.excludeData[uKey] = append(e.collectionNames.excludeData[uKey], taskInfo.ExcludeCollections...) e.collectionNames.excludeData[uKey] = lo.Uniq(e.collectionNames.excludeData[uKey]) + e.collectionNames.extraInfos[uKey] = taskInfo.ExtraInfo e.cdcTasks.Lock() e.cdcTasks.data[taskInfo.TaskID] = taskInfo e.cdcTasks.Unlock() @@ -246,6 +245,40 @@ func getCollectionNameFromFull(fullName string) (string, string) { return names[0], names[1] } +func GetCollectionNamesFromTaskInfo(info *meta.TaskInfo) []string { + var newCollectionNames []string + if len(info.CollectionInfos) > 0 { + newCollectionNames = lo.Map(info.CollectionInfos, func(t model.CollectionInfo, _ int) string { + return getFullCollectionName(t.Name, cdcreader.DefaultDatabase) + }) + } + if len(info.DBCollections) > 0 { + for db, infos := range info.DBCollections { + for _, t := range infos { + newCollectionNames = append(newCollectionNames, getFullCollectionName(t.Name, db)) + } + } + } + return newCollectionNames +} + +func GetCollectionNamesFromReq(req *request.CreateRequest) []string { + var newCollectionNames []string + if len(req.CollectionInfos) > 0 { + newCollectionNames = lo.Map(req.CollectionInfos, func(t model.CollectionInfo, _ int) string { + return getFullCollectionName(t.Name, cdcreader.DefaultDatabase) + }) + } + if len(req.DBCollections) > 0 { + for db, infos := range req.DBCollections { + for _, t := range infos { + newCollectionNames = append(newCollectionNames, getFullCollectionName(t.Name, db)) + } + } + } + return newCollectionNames +} + func matchCollectionName(sampleCollection, targetCollection string) (bool, bool) { db1, collection1 := getCollectionNameFromFull(sampleCollection) db2, collection2 := getCollectionNameFromFull(targetCollection) @@ -321,19 +354,7 @@ func (e *MetaCDC) Create(req *request.CreateRequest) (resp *request.CreateRespon return nil, err } uKey := getTaskUniqueIDFromReq(req) - var newCollectionNames []string - if len(req.CollectionInfos) > 0 { - newCollectionNames = lo.Map(req.CollectionInfos, func(t model.CollectionInfo, _ int) string { - return getFullCollectionName(t.Name, cdcreader.DefaultDatabase) - }) - } - if len(req.DBCollections) > 0 { - for db, infos := range req.DBCollections { - for _, t := range infos { - newCollectionNames = append(newCollectionNames, getFullCollectionName(t.Name, db)) - } - } - } + newCollectionNames := GetCollectionNamesFromReq(req) excludeCollectionNames, err := e.checkDuplicateCollection(uKey, newCollectionNames, req.ExtraInfo) if err != nil { @@ -1185,11 +1206,8 @@ func (e *MetaCDC) delete(taskID string) error { if err != nil { return errors.WithMessage(err, "fail to delete the task meta, task_id: "+taskID) } - var uKey string - milvusURI := GetMilvusURI(info.MilvusConnectParam) - kafkaAddress := GetKafkaAddress(info.KafkaConnectParam) - uKey = milvusURI + kafkaAddress - collectionNames := info.CollectionNames() + uKey := getTaskUniqueIDFromInfo(info) + collectionNames := GetCollectionNamesFromTaskInfo(info) e.collectionNames.Lock() e.collectionNames.excludeData[uKey] = lo.Without(e.collectionNames.excludeData[uKey], info.ExcludeCollections...) e.collectionNames.data[uKey] = lo.Without(e.collectionNames.data[uKey], collectionNames...) diff --git a/server/model/request/base.go b/server/model/request/base.go index bb88b2d..9073b62 100644 --- a/server/model/request/base.go +++ b/server/model/request/base.go @@ -47,12 +47,13 @@ type CDCResponse struct { // Task some info can be showed about the task type Task struct { - TaskID string `json:"task_id" mapstructure:"task_id"` - MilvusConnectParam model.MilvusConnectParam `json:"milvus_connect_param" mapstructure:"milvus_connect_param"` - KafkaConnectParam model.KafkaConnectParam `json:"kafka_connect_param" mapstructure:"kafka_connect_param"` - CollectionInfos []model.CollectionInfo `json:"collection_infos" mapstructure:"collection_infos"` - State string `json:"state" mapstructure:"state"` - LastPauseReason string `json:"reason,omitempty" mapstructure:"reason,omitempty"` + TaskID string `json:"task_id" mapstructure:"task_id"` + MilvusConnectParam model.MilvusConnectParam `json:"milvus_connect_param" mapstructure:"milvus_connect_param"` + KafkaConnectParam model.KafkaConnectParam `json:"kafka_connect_param" mapstructure:"kafka_connect_param"` + CollectionInfos []model.CollectionInfo `json:"collection_infos" mapstructure:"collection_infos"` + DBCollections map[string][]model.CollectionInfo `json:"db_collections" mapstructure:"db_collections"` + State string `json:"state" mapstructure:"state"` + LastPauseReason string `json:"reason,omitempty" mapstructure:"reason,omitempty"` } func GetTask(taskInfo *meta.TaskInfo) Task { @@ -66,6 +67,7 @@ func GetTask(taskInfo *meta.TaskInfo) Task { MilvusConnectParam: taskInfo.MilvusConnectParam, KafkaConnectParam: taskInfo.KafkaConnectParam, CollectionInfos: taskInfo.CollectionInfos, + DBCollections: taskInfo.DBCollections, State: taskInfo.State.String(), LastPauseReason: taskInfo.Reason, }