diff --git a/Makefile b/Makefile index adcb4116..14ef1b03 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ static-check: @echo "Running go-lint check:" @(env bash $(PWD)/scripts/run_go_lint.sh) -CORE_API := DataHandler MessageManager MetaOp Reader ChannelManager TargetAPI Writer FactoryCreator +CORE_API := DataHandler MessageManager MetaOp Reader ChannelManager TargetAPI Writer FactoryCreator ReplicateStore ReplicateMeta SERVER_API := MetaStore MetaStoreFactory CDCService generate-mockery: diff --git a/core/api/replicate_manager.go b/core/api/replicate_manager.go index 2296d0a3..e67bb55c 100644 --- a/core/api/replicate_manager.go +++ b/core/api/replicate_manager.go @@ -59,6 +59,7 @@ type ReplicateAPIEvent struct { ReplicateInfo *commonpb.ReplicateInfo ReplicateParam ReplicateParam TaskID string + MsgID string Error error } diff --git a/core/api/replicate_meta.go b/core/api/replicate_meta.go new file mode 100644 index 00000000..e4c07fda --- /dev/null +++ b/core/api/replicate_meta.go @@ -0,0 +1,35 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package api + +import "context" + +type ReplicateStore interface { + Get(ctx context.Context, key string, withPrefix bool) ([]MetaMsg, error) + Put(ctx context.Context, key string, value MetaMsg) error + Remove(ctx context.Context, key string) error +} + +type ReplicateMeta interface { + UpdateTaskDropCollectionMsg(ctx context.Context, msg TaskDropCollectionMsg) (bool, error) + GetTaskDropCollectionMsg(ctx context.Context, taskID string, msgID string) ([]TaskDropCollectionMsg, error) + UpdateTaskDropPartitionMsg(ctx context.Context, msg TaskDropPartitionMsg) (bool, error) + GetTaskDropPartitionMsg(ctx context.Context, taskID string, msgID string) ([]TaskDropPartitionMsg, error) + RemoveTaskMsg(ctx context.Context, taskID string, msgID string) error +} diff --git a/core/api/task_msg.go b/core/api/task_msg.go new file mode 100644 index 00000000..d7f3c416 --- /dev/null +++ b/core/api/task_msg.go @@ -0,0 +1,145 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package api + +import ( + "encoding/json" + "fmt" + "sort" + + "github.com/cockroachdb/errors" + "github.com/mitchellh/mapstructure" +) + +type MetaMsgType int + +const ( + DropCollectionMetaMsgType MetaMsgType = iota + 1 + DropPartitionMetaMsgType +) + +type BaseTaskMsg struct { + TaskID string `json:"task_id"` + MsgID string `json:"msg_id"` + TargetChannels []string `json:"target_channels"` + ReadyChannels []string `json:"ready_channels"` +} + +func (msg BaseTaskMsg) IsReady() bool { + if len(msg.TargetChannels) != len(msg.ReadyChannels) { + return false + } + sort.Strings(msg.TargetChannels) + sort.Strings(msg.ReadyChannels) + for i := range msg.TargetChannels { + if msg.TargetChannels[i] != msg.ReadyChannels[i] { + return false + } + } + return true +} + +type MetaMsg struct { + Base BaseTaskMsg `json:"base"` + Type MetaMsgType `json:"type"` + Data map[string]interface{} `json:"data"` +} + +func (msg MetaMsg) ToJSON() (string, error) { + bs, err := json.Marshal(msg) + if err != nil { + return "", err + } + return string(bs), nil +} + +type TaskDropCollectionMsg struct { + Base BaseTaskMsg `mapstructure:"-"` + DatabaseName string `mapstructure:"database_name"` + CollectionName string `mapstructure:"collection_name"` + DropTS uint64 `mapstructure:"drop_ts"` +} + +func (msg TaskDropCollectionMsg) ConvertToMetaMsg() (MetaMsg, error) { + var m map[string]interface{} + err := mapstructure.Decode(msg, &m) + if err != nil { + return MetaMsg{}, err + } + return MetaMsg{ + Base: msg.Base, + Type: DropCollectionMetaMsgType, + Data: m, + }, nil +} + +func GetTaskDropCollectionMsg(msg MetaMsg) (TaskDropCollectionMsg, error) { + if msg.Type != DropCollectionMetaMsgType { + return TaskDropCollectionMsg{}, errors.Newf("type %d is not DropCollectionMetaMsg", msg.Type) + } + var m TaskDropCollectionMsg + err := mapstructure.Decode(msg.Data, &m) + if err != nil { + return TaskDropCollectionMsg{}, err + } + m.Base = msg.Base + return m, nil +} + +func GetDropCollectionMsgID(collectionID int64) string { + return fmt.Sprintf("drop-collection-%d", collectionID) +} + +type TaskDropPartitionMsg struct { + Base BaseTaskMsg `mapstructure:"-"` + DatabaseName string `mapstructure:"database_name"` + CollectionName string `mapstructure:"collection_name"` + PartitionName string `mapstructure:"partition_name"` + DropTS uint64 `mapstructure:"drop_ts"` +} + +func (msg TaskDropPartitionMsg) ConvertToMetaMsg() (MetaMsg, error) { + var m map[string]interface{} + err := mapstructure.Decode(msg, &m) + if err != nil { + return MetaMsg{}, err + } + return MetaMsg{ + Base: msg.Base, + Type: DropPartitionMetaMsgType, + Data: m, + }, nil +} + +func GetTaskDropPartitionMsg(msg MetaMsg) (TaskDropPartitionMsg, error) { + if msg.Type != DropPartitionMetaMsgType { + return TaskDropPartitionMsg{}, errors.Newf("type %d is not DropPartitionMetaMsg", msg.Type) + } + var m TaskDropPartitionMsg + err := mapstructure.Decode(msg.Data, &m) + if err != nil { + return TaskDropPartitionMsg{}, err + } + m.Base = msg.Base + return m, nil +} + +func GetDropPartitionMsgID(collectionID int64, partitionID int64) string { + return fmt.Sprintf("drop-partition-%d-%d", collectionID, partitionID) +} diff --git a/core/api/task_msg_test.go b/core/api/task_msg_test.go new file mode 100644 index 00000000..dc721b86 --- /dev/null +++ b/core/api/task_msg_test.go @@ -0,0 +1,94 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTaskDropCollectionMsg(t *testing.T) { + dropCollectionMsg := TaskDropCollectionMsg{ + Base: BaseTaskMsg{ + TaskID: "1001", + MsgID: "1002", + TargetChannels: []string{ + "c1", "c2", + }, + ReadyChannels: []string{ + "c1", + }, + }, + CollectionName: "test_collection", + DatabaseName: "test_db", + } + + metaMsg, err := dropCollectionMsg.ConvertToMetaMsg() + if err != nil { + t.Errorf("TaskDropCollectionMsg.ConvertToMetaMsg() failed: %v", err) + } + assert.Equal(t, DropCollectionMetaMsgType, metaMsg.Type) + assert.Equal(t, "1001", metaMsg.Base.TaskID) + assert.Equal(t, "1002", metaMsg.Base.MsgID) + assert.Equal(t, []string{"c1", "c2"}, metaMsg.Base.TargetChannels) + assert.Equal(t, []string{"c1"}, metaMsg.Base.ReadyChannels) + assert.Equal(t, "test_collection", metaMsg.Data["collection_name"]) + assert.Equal(t, "test_db", metaMsg.Data["database_name"]) + assert.Equal(t, 3, len(metaMsg.Data)) + + // test convert to TaskDropCollectionMsg + taskMsg, err := GetTaskDropCollectionMsg(metaMsg) + if err != nil { + t.Errorf("GetTaskDropCollectionMsg() failed: %v", err) + } + assert.Equal(t, "1001", taskMsg.Base.TaskID) + assert.Equal(t, "1002", taskMsg.Base.MsgID) + assert.Equal(t, []string{"c1", "c2"}, taskMsg.Base.TargetChannels) + assert.Equal(t, []string{"c1"}, taskMsg.Base.ReadyChannels) + assert.Equal(t, "test_collection", taskMsg.CollectionName) + assert.Equal(t, "test_db", taskMsg.DatabaseName) +} + +func TestMetaMsgToJson(t *testing.T) { + metaMsg := MetaMsg{ + Base: BaseTaskMsg{ + TaskID: "1001", + MsgID: "1002", + TargetChannels: []string{ + "c1", "c2", + }, + ReadyChannels: []string{ + "c1", + }, + }, + Type: DropCollectionMetaMsgType, + Data: map[string]interface{}{ + "collection_name": "test_collection", + "database_name": "test_db", + }, + } + + jsonStr, err := metaMsg.ToJSON() + if err != nil { + t.Errorf("MetaMsg.ToJSON() failed: %v", err) + } + assert.NotEmpty(t, jsonStr) + t.Logf("jsonStr = %s", jsonStr) +} diff --git a/core/api/writer.go b/core/api/writer.go index e0f70f2b..efb84930 100644 --- a/core/api/writer.go +++ b/core/api/writer.go @@ -30,6 +30,7 @@ type Writer interface { HandleReplicateAPIEvent(ctx context.Context, apiEvent *ReplicateAPIEvent) error HandleReplicateMessage(ctx context.Context, channelName string, msgPack *msgstream.MsgPack) ([]byte, []byte, error) HandleOpMessagePack(ctx context.Context, msgPack *msgstream.MsgPack) ([]byte, error) + RecoveryMetaMsg(ctx context.Context, taskID string) error } type DefaultWriter struct{} @@ -50,3 +51,8 @@ func (d *DefaultWriter) HandleOpMessagePack(ctx context.Context, msgPack *msgstr log.Warn("HandleOpMessagePack is not implemented, please check it") return nil, nil } + +func (d *DefaultWriter) RecoveryMetaMsg(ctx context.Context, taskID string) error { + log.Warn("RecoveryMetaMsg is not implemented, please check it") + return nil +} diff --git a/core/go.mod b/core/go.mod index e8f44d94..b8266fe7 100644 --- a/core/go.mod +++ b/core/go.mod @@ -71,6 +71,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect github.com/minio/highwayhash v1.0.2 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/mtibben/percent v0.2.1 // indirect diff --git a/core/go.sum b/core/go.sum index a0cbfafe..6001cf6d 100644 --- a/core/go.sum +++ b/core/go.sum @@ -493,6 +493,8 @@ github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0Qu github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= diff --git a/core/meta/etcd_store.go b/core/meta/etcd_store.go new file mode 100644 index 00000000..31fb8a17 --- /dev/null +++ b/core/meta/etcd_store.go @@ -0,0 +1,91 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package meta + +import ( + "context" + "encoding/json" + "time" + + clientv3 "go.etcd.io/etcd/client/v3" + + "github.com/zilliztech/milvus-cdc/core/api" +) + +type EtcdReplicateStore struct { + client *clientv3.Client + rootPath string +} + +func NewEtcdReplicateStore(endpoints []string, rootPath string) (*EtcdReplicateStore, error) { + cfg := clientv3.Config{ + Endpoints: endpoints, + DialTimeout: 5 * time.Second, + } + client, err := clientv3.New(cfg) + if err != nil { + return nil, err + } + return &EtcdReplicateStore{ + client: client, + rootPath: rootPath, + }, nil +} + +func (s *EtcdReplicateStore) Get(ctx context.Context, key string, withPrefix bool) ([]api.MetaMsg, error) { + var result []api.MetaMsg + + var resp *clientv3.GetResponse + var err error + + if withPrefix { + resp, err = s.client.Get(ctx, s.rootPath+"/"+key, clientv3.WithPrefix()) + } else { + resp, err = s.client.Get(ctx, s.rootPath+"/"+key) + } + + if err != nil { + return nil, err + } + + for _, kv := range resp.Kvs { + var msg api.MetaMsg + if err := json.Unmarshal(kv.Value, &msg); err != nil { + return nil, err + } + result = append(result, msg) + } + + return result, nil +} + +func (s *EtcdReplicateStore) Put(ctx context.Context, key string, value api.MetaMsg) error { + data, err := json.Marshal(value) + if err != nil { + return err + } + + _, err = s.client.Put(ctx, s.rootPath+"/"+key, string(data)) + return err +} + +func (s *EtcdReplicateStore) Remove(ctx context.Context, key string) error { + _, err := s.client.Delete(ctx, s.rootPath+"/"+key) + return err +} diff --git a/core/meta/meta.go b/core/meta/meta.go new file mode 100644 index 00000000..2414c705 --- /dev/null +++ b/core/meta/meta.go @@ -0,0 +1,252 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package meta + +import ( + "context" + "strings" + "sync" + + "github.com/cockroachdb/errors" + "github.com/samber/lo" + + "github.com/zilliztech/milvus-cdc/core/api" +) + +const KeyPrefix = "task_msg" + +type ReplicateMeteImpl struct { + store api.ReplicateStore + metaLock sync.RWMutex + dropCollectionMsgs map[string]map[string]api.TaskDropCollectionMsg + dropPartitionMsgs map[string]map[string]api.TaskDropPartitionMsg +} + +var _ api.ReplicateMeta = (*ReplicateMeteImpl)(nil) + +func NewReplicateMetaImpl(store api.ReplicateStore) (*ReplicateMeteImpl, error) { + impl := &ReplicateMeteImpl{ + store: store, + dropCollectionMsgs: make(map[string]map[string]api.TaskDropCollectionMsg), + dropPartitionMsgs: make(map[string]map[string]api.TaskDropPartitionMsg), + } + err := impl.Reload() + if err != nil { + return nil, err + } + return impl, nil +} + +func (r *ReplicateMeteImpl) Reload() error { + metaMsgs, err := r.store.Get(context.Background(), "", true) + if err != nil { + return err + } + for _, msg := range metaMsgs { + switch msg.Type { + case api.DropCollectionMetaMsgType: + taskDropCollectionMsg, err := api.GetTaskDropCollectionMsg(msg) + if err != nil { + return err + } + if _, ok := r.dropCollectionMsgs[taskDropCollectionMsg.Base.TaskID]; !ok { + r.dropCollectionMsgs[taskDropCollectionMsg.Base.TaskID] = make(map[string]api.TaskDropCollectionMsg) + } + r.dropCollectionMsgs[taskDropCollectionMsg.Base.TaskID][taskDropCollectionMsg.Base.MsgID] = taskDropCollectionMsg + case api.DropPartitionMetaMsgType: + taskDropPartitionMsg, err := api.GetTaskDropPartitionMsg(msg) + if err != nil { + return err + } + if _, ok := r.dropPartitionMsgs[taskDropPartitionMsg.Base.TaskID]; !ok { + r.dropPartitionMsgs[taskDropPartitionMsg.Base.TaskID] = make(map[string]api.TaskDropPartitionMsg) + } + r.dropPartitionMsgs[taskDropPartitionMsg.Base.TaskID][taskDropPartitionMsg.Base.MsgID] = taskDropPartitionMsg + } + } + + return nil +} + +// UpdateTaskDropCollectionMsg bool: true if the msg is ready to be consumed +func (r *ReplicateMeteImpl) UpdateTaskDropCollectionMsg(ctx context.Context, msg api.TaskDropCollectionMsg) (bool, error) { + r.metaLock.Lock() + defer r.metaLock.Unlock() + taskMsgs, ok := r.dropCollectionMsgs[msg.Base.TaskID] + if !ok { + taskMsgs = map[string]api.TaskDropCollectionMsg{ + msg.Base.MsgID: msg, + } + r.dropCollectionMsgs[msg.Base.TaskID] = taskMsgs + metaMsg, err := msg.ConvertToMetaMsg() + if err != nil { + return false, err + } + err = r.store.Put(ctx, GetMetaKey(msg.Base.TaskID, msg.Base.MsgID), metaMsg) + if err != nil { + return false, err + } + return msg.Base.IsReady(), nil + } + var taskMsg api.TaskDropCollectionMsg + if taskMsg, ok = taskMsgs[msg.Base.MsgID]; !ok { + taskMsgs[msg.Base.MsgID] = msg + metaMsg, err := msg.ConvertToMetaMsg() + if err != nil { + return false, err + } + err = r.store.Put(ctx, GetMetaKey(msg.Base.TaskID, msg.Base.MsgID), metaMsg) + if err != nil { + return false, err + } + return msg.Base.IsReady(), nil + } + taskMsg.Base.ReadyChannels = lo.Union[string](taskMsg.Base.ReadyChannels, msg.Base.ReadyChannels) + metaMsg, err := taskMsg.ConvertToMetaMsg() + if err != nil { + return false, err + } + err = r.store.Put(ctx, GetMetaKey(msg.Base.TaskID, msg.Base.MsgID), metaMsg) + if err != nil { + return false, err + } + return taskMsg.Base.IsReady(), nil +} + +func (r *ReplicateMeteImpl) GetTaskDropCollectionMsg(ctx context.Context, taskID string, msgID string) ([]api.TaskDropCollectionMsg, error) { + if taskID == "" { + return nil, errors.New("taskID is empty") + } + r.metaLock.RLock() + defer r.metaLock.RUnlock() + if msgID == "" { + taskMsgs, ok := r.dropCollectionMsgs[taskID] + if !ok { + return nil, errors.Errorf("taskID %s not found", taskID) + } + result := make([]api.TaskDropCollectionMsg, 0, len(taskMsgs)) + for _, msg := range taskMsgs { + result = append(result, msg) + } + return result, nil + } + + if taskMsgs, ok := r.dropCollectionMsgs[taskID]; ok { + if msg, ok := taskMsgs[msgID]; ok { + return []api.TaskDropCollectionMsg{msg}, nil + } + } + return nil, errors.Errorf("taskID %s or msgID %s not found", taskID, msgID) +} + +func (r *ReplicateMeteImpl) UpdateTaskDropPartitionMsg(ctx context.Context, msg api.TaskDropPartitionMsg) (bool, error) { + r.metaLock.Lock() + defer r.metaLock.Unlock() + taskMsgs, ok := r.dropPartitionMsgs[msg.Base.TaskID] + if !ok { + taskMsgs = map[string]api.TaskDropPartitionMsg{ + msg.Base.MsgID: msg, + } + r.dropPartitionMsgs[msg.Base.TaskID] = taskMsgs + metaMsg, err := msg.ConvertToMetaMsg() + if err != nil { + return false, err + } + err = r.store.Put(ctx, GetMetaKey(msg.Base.TaskID, msg.Base.MsgID), metaMsg) + if err != nil { + return false, err + } + return msg.Base.IsReady(), nil + } + var taskMsg api.TaskDropPartitionMsg + if taskMsg, ok = taskMsgs[msg.Base.MsgID]; !ok { + taskMsgs[msg.Base.MsgID] = msg + metaMsg, err := msg.ConvertToMetaMsg() + if err != nil { + return false, err + } + err = r.store.Put(ctx, GetMetaKey(msg.Base.TaskID, msg.Base.MsgID), metaMsg) + if err != nil { + return false, err + } + return msg.Base.IsReady(), nil + } + taskMsg.Base.ReadyChannels = lo.Union[string](taskMsg.Base.ReadyChannels, msg.Base.ReadyChannels) + metaMsg, err := taskMsg.ConvertToMetaMsg() + if err != nil { + return false, err + } + err = r.store.Put(ctx, GetMetaKey(msg.Base.TaskID, msg.Base.MsgID), metaMsg) + if err != nil { + return false, err + } + return taskMsg.Base.IsReady(), nil +} + +func (r *ReplicateMeteImpl) GetTaskDropPartitionMsg(ctx context.Context, taskID string, msgID string) ([]api.TaskDropPartitionMsg, error) { + if taskID == "" { + return nil, errors.New("taskID is empty") + } + r.metaLock.RLock() + defer r.metaLock.RUnlock() + if msgID == "" { + taskMsgs, ok := r.dropPartitionMsgs[taskID] + if !ok { + return nil, errors.Errorf("taskID %s not found", taskID) + } + result := make([]api.TaskDropPartitionMsg, 0, len(taskMsgs)) + for _, msg := range taskMsgs { + result = append(result, msg) + } + return result, nil + } + if taskMsgs, ok := r.dropPartitionMsgs[taskID]; ok { + if msg, ok := taskMsgs[msgID]; ok { + return []api.TaskDropPartitionMsg{msg}, nil + } + } + return nil, errors.Errorf("taskID %s or msgID %s not found", taskID, msgID) +} + +func (r *ReplicateMeteImpl) RemoveTaskMsg(ctx context.Context, taskID string, msgID string) error { + key := GetMetaKey(taskID, msgID) + err := r.store.Remove(ctx, key) + if err != nil { + return err + } + r.metaLock.Lock() + defer r.metaLock.Unlock() + if taskMsgs, ok := r.dropCollectionMsgs[taskID]; ok { + delete(taskMsgs, msgID) + } + return nil +} + +func GetMetaKey(taskID, msgID string) string { + return KeyPrefix + "/" + taskID + "/" + msgID +} + +func GetKeyDetail(key string) (taskID, msgID string) { + details := strings.Split(key, "/") + if len(details) < 3 { + return "", "" + } + l := len(details) + return details[l-2], details[l-1] +} diff --git a/core/mocks/replicate_meta.go b/core/mocks/replicate_meta.go new file mode 100644 index 00000000..4e5a2ac2 --- /dev/null +++ b/core/mocks/replicate_meta.go @@ -0,0 +1,300 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + api "github.com/zilliztech/milvus-cdc/core/api" + + mock "github.com/stretchr/testify/mock" +) + +// ReplicateMeta is an autogenerated mock type for the ReplicateMeta type +type ReplicateMeta struct { + mock.Mock +} + +type ReplicateMeta_Expecter struct { + mock *mock.Mock +} + +func (_m *ReplicateMeta) EXPECT() *ReplicateMeta_Expecter { + return &ReplicateMeta_Expecter{mock: &_m.Mock} +} + +// GetTaskDropCollectionMsg provides a mock function with given fields: ctx, taskID, msgID +func (_m *ReplicateMeta) GetTaskDropCollectionMsg(ctx context.Context, taskID string, msgID string) ([]api.TaskDropCollectionMsg, error) { + ret := _m.Called(ctx, taskID, msgID) + + var r0 []api.TaskDropCollectionMsg + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]api.TaskDropCollectionMsg, error)); ok { + return rf(ctx, taskID, msgID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) []api.TaskDropCollectionMsg); ok { + r0 = rf(ctx, taskID, msgID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]api.TaskDropCollectionMsg) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, taskID, msgID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReplicateMeta_GetTaskDropCollectionMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTaskDropCollectionMsg' +type ReplicateMeta_GetTaskDropCollectionMsg_Call struct { + *mock.Call +} + +// GetTaskDropCollectionMsg is a helper method to define mock.On call +// - ctx context.Context +// - taskID string +// - msgID string +func (_e *ReplicateMeta_Expecter) GetTaskDropCollectionMsg(ctx interface{}, taskID interface{}, msgID interface{}) *ReplicateMeta_GetTaskDropCollectionMsg_Call { + return &ReplicateMeta_GetTaskDropCollectionMsg_Call{Call: _e.mock.On("GetTaskDropCollectionMsg", ctx, taskID, msgID)} +} + +func (_c *ReplicateMeta_GetTaskDropCollectionMsg_Call) Run(run func(ctx context.Context, taskID string, msgID string)) *ReplicateMeta_GetTaskDropCollectionMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *ReplicateMeta_GetTaskDropCollectionMsg_Call) Return(_a0 []api.TaskDropCollectionMsg, _a1 error) *ReplicateMeta_GetTaskDropCollectionMsg_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReplicateMeta_GetTaskDropCollectionMsg_Call) RunAndReturn(run func(context.Context, string, string) ([]api.TaskDropCollectionMsg, error)) *ReplicateMeta_GetTaskDropCollectionMsg_Call { + _c.Call.Return(run) + return _c +} + +// GetTaskDropPartitionMsg provides a mock function with given fields: ctx, taskID, msgID +func (_m *ReplicateMeta) GetTaskDropPartitionMsg(ctx context.Context, taskID string, msgID string) ([]api.TaskDropPartitionMsg, error) { + ret := _m.Called(ctx, taskID, msgID) + + var r0 []api.TaskDropPartitionMsg + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) ([]api.TaskDropPartitionMsg, error)); ok { + return rf(ctx, taskID, msgID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) []api.TaskDropPartitionMsg); ok { + r0 = rf(ctx, taskID, msgID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]api.TaskDropPartitionMsg) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, taskID, msgID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReplicateMeta_GetTaskDropPartitionMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTaskDropPartitionMsg' +type ReplicateMeta_GetTaskDropPartitionMsg_Call struct { + *mock.Call +} + +// GetTaskDropPartitionMsg is a helper method to define mock.On call +// - ctx context.Context +// - taskID string +// - msgID string +func (_e *ReplicateMeta_Expecter) GetTaskDropPartitionMsg(ctx interface{}, taskID interface{}, msgID interface{}) *ReplicateMeta_GetTaskDropPartitionMsg_Call { + return &ReplicateMeta_GetTaskDropPartitionMsg_Call{Call: _e.mock.On("GetTaskDropPartitionMsg", ctx, taskID, msgID)} +} + +func (_c *ReplicateMeta_GetTaskDropPartitionMsg_Call) Run(run func(ctx context.Context, taskID string, msgID string)) *ReplicateMeta_GetTaskDropPartitionMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *ReplicateMeta_GetTaskDropPartitionMsg_Call) Return(_a0 []api.TaskDropPartitionMsg, _a1 error) *ReplicateMeta_GetTaskDropPartitionMsg_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReplicateMeta_GetTaskDropPartitionMsg_Call) RunAndReturn(run func(context.Context, string, string) ([]api.TaskDropPartitionMsg, error)) *ReplicateMeta_GetTaskDropPartitionMsg_Call { + _c.Call.Return(run) + return _c +} + +// RemoveTaskMsg provides a mock function with given fields: ctx, taskID, msgID +func (_m *ReplicateMeta) RemoveTaskMsg(ctx context.Context, taskID string, msgID string) error { + ret := _m.Called(ctx, taskID, msgID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, taskID, msgID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ReplicateMeta_RemoveTaskMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveTaskMsg' +type ReplicateMeta_RemoveTaskMsg_Call struct { + *mock.Call +} + +// RemoveTaskMsg is a helper method to define mock.On call +// - ctx context.Context +// - taskID string +// - msgID string +func (_e *ReplicateMeta_Expecter) RemoveTaskMsg(ctx interface{}, taskID interface{}, msgID interface{}) *ReplicateMeta_RemoveTaskMsg_Call { + return &ReplicateMeta_RemoveTaskMsg_Call{Call: _e.mock.On("RemoveTaskMsg", ctx, taskID, msgID)} +} + +func (_c *ReplicateMeta_RemoveTaskMsg_Call) Run(run func(ctx context.Context, taskID string, msgID string)) *ReplicateMeta_RemoveTaskMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string)) + }) + return _c +} + +func (_c *ReplicateMeta_RemoveTaskMsg_Call) Return(_a0 error) *ReplicateMeta_RemoveTaskMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ReplicateMeta_RemoveTaskMsg_Call) RunAndReturn(run func(context.Context, string, string) error) *ReplicateMeta_RemoveTaskMsg_Call { + _c.Call.Return(run) + return _c +} + +// UpdateTaskDropCollectionMsg provides a mock function with given fields: ctx, msg +func (_m *ReplicateMeta) UpdateTaskDropCollectionMsg(ctx context.Context, msg api.TaskDropCollectionMsg) (bool, error) { + ret := _m.Called(ctx, msg) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, api.TaskDropCollectionMsg) (bool, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, api.TaskDropCollectionMsg) bool); ok { + r0 = rf(ctx, msg) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, api.TaskDropCollectionMsg) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReplicateMeta_UpdateTaskDropCollectionMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTaskDropCollectionMsg' +type ReplicateMeta_UpdateTaskDropCollectionMsg_Call struct { + *mock.Call +} + +// UpdateTaskDropCollectionMsg is a helper method to define mock.On call +// - ctx context.Context +// - msg api.TaskDropCollectionMsg +func (_e *ReplicateMeta_Expecter) UpdateTaskDropCollectionMsg(ctx interface{}, msg interface{}) *ReplicateMeta_UpdateTaskDropCollectionMsg_Call { + return &ReplicateMeta_UpdateTaskDropCollectionMsg_Call{Call: _e.mock.On("UpdateTaskDropCollectionMsg", ctx, msg)} +} + +func (_c *ReplicateMeta_UpdateTaskDropCollectionMsg_Call) Run(run func(ctx context.Context, msg api.TaskDropCollectionMsg)) *ReplicateMeta_UpdateTaskDropCollectionMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(api.TaskDropCollectionMsg)) + }) + return _c +} + +func (_c *ReplicateMeta_UpdateTaskDropCollectionMsg_Call) Return(_a0 bool, _a1 error) *ReplicateMeta_UpdateTaskDropCollectionMsg_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReplicateMeta_UpdateTaskDropCollectionMsg_Call) RunAndReturn(run func(context.Context, api.TaskDropCollectionMsg) (bool, error)) *ReplicateMeta_UpdateTaskDropCollectionMsg_Call { + _c.Call.Return(run) + return _c +} + +// UpdateTaskDropPartitionMsg provides a mock function with given fields: ctx, msg +func (_m *ReplicateMeta) UpdateTaskDropPartitionMsg(ctx context.Context, msg api.TaskDropPartitionMsg) (bool, error) { + ret := _m.Called(ctx, msg) + + var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, api.TaskDropPartitionMsg) (bool, error)); ok { + return rf(ctx, msg) + } + if rf, ok := ret.Get(0).(func(context.Context, api.TaskDropPartitionMsg) bool); ok { + r0 = rf(ctx, msg) + } else { + r0 = ret.Get(0).(bool) + } + + if rf, ok := ret.Get(1).(func(context.Context, api.TaskDropPartitionMsg) error); ok { + r1 = rf(ctx, msg) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReplicateMeta_UpdateTaskDropPartitionMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTaskDropPartitionMsg' +type ReplicateMeta_UpdateTaskDropPartitionMsg_Call struct { + *mock.Call +} + +// UpdateTaskDropPartitionMsg is a helper method to define mock.On call +// - ctx context.Context +// - msg api.TaskDropPartitionMsg +func (_e *ReplicateMeta_Expecter) UpdateTaskDropPartitionMsg(ctx interface{}, msg interface{}) *ReplicateMeta_UpdateTaskDropPartitionMsg_Call { + return &ReplicateMeta_UpdateTaskDropPartitionMsg_Call{Call: _e.mock.On("UpdateTaskDropPartitionMsg", ctx, msg)} +} + +func (_c *ReplicateMeta_UpdateTaskDropPartitionMsg_Call) Run(run func(ctx context.Context, msg api.TaskDropPartitionMsg)) *ReplicateMeta_UpdateTaskDropPartitionMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(api.TaskDropPartitionMsg)) + }) + return _c +} + +func (_c *ReplicateMeta_UpdateTaskDropPartitionMsg_Call) Return(_a0 bool, _a1 error) *ReplicateMeta_UpdateTaskDropPartitionMsg_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReplicateMeta_UpdateTaskDropPartitionMsg_Call) RunAndReturn(run func(context.Context, api.TaskDropPartitionMsg) (bool, error)) *ReplicateMeta_UpdateTaskDropPartitionMsg_Call { + _c.Call.Return(run) + return _c +} + +// NewReplicateMeta creates a new instance of ReplicateMeta. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewReplicateMeta(t interface { + mock.TestingT + Cleanup(func()) +}) *ReplicateMeta { + mock := &ReplicateMeta{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/mocks/replicate_store.go b/core/mocks/replicate_store.go new file mode 100644 index 00000000..2300e17e --- /dev/null +++ b/core/mocks/replicate_store.go @@ -0,0 +1,181 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package mocks + +import ( + context "context" + + api "github.com/zilliztech/milvus-cdc/core/api" + + mock "github.com/stretchr/testify/mock" +) + +// ReplicateStore is an autogenerated mock type for the ReplicateStore type +type ReplicateStore struct { + mock.Mock +} + +type ReplicateStore_Expecter struct { + mock *mock.Mock +} + +func (_m *ReplicateStore) EXPECT() *ReplicateStore_Expecter { + return &ReplicateStore_Expecter{mock: &_m.Mock} +} + +// Get provides a mock function with given fields: ctx, key, withPrefix +func (_m *ReplicateStore) Get(ctx context.Context, key string, withPrefix bool) ([]api.MetaMsg, error) { + ret := _m.Called(ctx, key, withPrefix) + + var r0 []api.MetaMsg + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) ([]api.MetaMsg, error)); ok { + return rf(ctx, key, withPrefix) + } + if rf, ok := ret.Get(0).(func(context.Context, string, bool) []api.MetaMsg); ok { + r0 = rf(ctx, key, withPrefix) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]api.MetaMsg) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, bool) error); ok { + r1 = rf(ctx, key, withPrefix) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReplicateStore_Get_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Get' +type ReplicateStore_Get_Call struct { + *mock.Call +} + +// Get is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - withPrefix bool +func (_e *ReplicateStore_Expecter) Get(ctx interface{}, key interface{}, withPrefix interface{}) *ReplicateStore_Get_Call { + return &ReplicateStore_Get_Call{Call: _e.mock.On("Get", ctx, key, withPrefix)} +} + +func (_c *ReplicateStore_Get_Call) Run(run func(ctx context.Context, key string, withPrefix bool)) *ReplicateStore_Get_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(bool)) + }) + return _c +} + +func (_c *ReplicateStore_Get_Call) Return(_a0 []api.MetaMsg, _a1 error) *ReplicateStore_Get_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *ReplicateStore_Get_Call) RunAndReturn(run func(context.Context, string, bool) ([]api.MetaMsg, error)) *ReplicateStore_Get_Call { + _c.Call.Return(run) + return _c +} + +// Put provides a mock function with given fields: ctx, key, value +func (_m *ReplicateStore) Put(ctx context.Context, key string, value api.MetaMsg) error { + ret := _m.Called(ctx, key, value) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, api.MetaMsg) error); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ReplicateStore_Put_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Put' +type ReplicateStore_Put_Call struct { + *mock.Call +} + +// Put is a helper method to define mock.On call +// - ctx context.Context +// - key string +// - value api.MetaMsg +func (_e *ReplicateStore_Expecter) Put(ctx interface{}, key interface{}, value interface{}) *ReplicateStore_Put_Call { + return &ReplicateStore_Put_Call{Call: _e.mock.On("Put", ctx, key, value)} +} + +func (_c *ReplicateStore_Put_Call) Run(run func(ctx context.Context, key string, value api.MetaMsg)) *ReplicateStore_Put_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(api.MetaMsg)) + }) + return _c +} + +func (_c *ReplicateStore_Put_Call) Return(_a0 error) *ReplicateStore_Put_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ReplicateStore_Put_Call) RunAndReturn(run func(context.Context, string, api.MetaMsg) error) *ReplicateStore_Put_Call { + _c.Call.Return(run) + return _c +} + +// Remove provides a mock function with given fields: ctx, key +func (_m *ReplicateStore) Remove(ctx context.Context, key string) error { + ret := _m.Called(ctx, key) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ReplicateStore_Remove_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Remove' +type ReplicateStore_Remove_Call struct { + *mock.Call +} + +// Remove is a helper method to define mock.On call +// - ctx context.Context +// - key string +func (_e *ReplicateStore_Expecter) Remove(ctx interface{}, key interface{}) *ReplicateStore_Remove_Call { + return &ReplicateStore_Remove_Call{Call: _e.mock.On("Remove", ctx, key)} +} + +func (_c *ReplicateStore_Remove_Call) Run(run func(ctx context.Context, key string)) *ReplicateStore_Remove_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *ReplicateStore_Remove_Call) Return(_a0 error) *ReplicateStore_Remove_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *ReplicateStore_Remove_Call) RunAndReturn(run func(context.Context, string) error) *ReplicateStore_Remove_Call { + _c.Call.Return(run) + return _c +} + +// NewReplicateStore creates a new instance of ReplicateStore. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewReplicateStore(t interface { + mock.TestingT + Cleanup(func()) +}) *ReplicateStore { + mock := &ReplicateStore{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/mocks/writer.go b/core/mocks/writer.go index 6a7986b5..830d066e 100644 --- a/core/mocks/writer.go +++ b/core/mocks/writer.go @@ -188,6 +188,49 @@ func (_c *Writer_HandleReplicateMessage_Call) RunAndReturn(run func(context.Cont return _c } +// RecoveryMetaMsg provides a mock function with given fields: ctx, taskID +func (_m *Writer) RecoveryMetaMsg(ctx context.Context, taskID string) error { + ret := _m.Called(ctx, taskID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, taskID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Writer_RecoveryMetaMsg_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RecoveryMetaMsg' +type Writer_RecoveryMetaMsg_Call struct { + *mock.Call +} + +// RecoveryMetaMsg is a helper method to define mock.On call +// - ctx context.Context +// - taskID string +func (_e *Writer_Expecter) RecoveryMetaMsg(ctx interface{}, taskID interface{}) *Writer_RecoveryMetaMsg_Call { + return &Writer_RecoveryMetaMsg_Call{Call: _e.mock.On("RecoveryMetaMsg", ctx, taskID)} +} + +func (_c *Writer_RecoveryMetaMsg_Call) Run(run func(ctx context.Context, taskID string)) *Writer_RecoveryMetaMsg_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Writer_RecoveryMetaMsg_Call) Return(_a0 error) *Writer_RecoveryMetaMsg_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Writer_RecoveryMetaMsg_Call) RunAndReturn(run func(context.Context, string) error) *Writer_RecoveryMetaMsg_Call { + _c.Call.Return(run) + return _c +} + // NewWriter creates a new instance of Writer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewWriter(t interface { diff --git a/core/model/reader.go b/core/model/reader.go index b124d42a..4e40a38a 100644 --- a/core/model/reader.go +++ b/core/model/reader.go @@ -33,12 +33,17 @@ type SourceCollectionInfo struct { ShardNum int } -// source collection info in the handler +// HandlerCollectionInfo source collection info in the handler type HandlerCollectionInfo struct { CollectionID int64 PChannel string } +type BarrierSignal struct { + Msg msgstream.TsMsg + VChannel string +} + type TargetCollectionInfo struct { DatabaseName string CollectionID int64 @@ -46,8 +51,8 @@ type TargetCollectionInfo struct { PartitionInfo map[string]int64 PChannel string VChannel string - BarrierChan *OnceWriteChan[uint64] - PartitionBarrierChan map[int64]*OnceWriteChan[uint64] // id is the source partition id + BarrierChan *OnceWriteChan[*BarrierSignal] + PartitionBarrierChan map[int64]*OnceWriteChan[*BarrierSignal] // id is the source partition id Dropped bool DroppedPartition map[int64]struct{} // id is the source partition id } diff --git a/core/reader/data_barrier.go b/core/reader/data_barrier.go index 8884b83b..60434b5b 100644 --- a/core/reader/data_barrier.go +++ b/core/reader/data_barrier.go @@ -18,17 +18,24 @@ package reader +import ( + "github.com/milvus-io/milvus/pkg/mq/msgstream" + + "github.com/zilliztech/milvus-cdc/core/model" +) + type Barrier struct { - Dest int - BarrierChan chan uint64 - CloseChan chan struct{} + Dest int + BarrierSignalChan chan *model.BarrierSignal + CloseChan chan struct{} + UpdateFunc func() } -func NewBarrier(count int, f func(msgTs uint64, b *Barrier)) *Barrier { +func NewBarrier(count int, f func(msgTs uint64, b *Barrier), u func(vchannel string, m msgstream.TsMsg)) *Barrier { barrier := &Barrier{ - Dest: count, - BarrierChan: make(chan uint64, count), - CloseChan: make(chan struct{}), + Dest: count, + BarrierSignalChan: make(chan *model.BarrierSignal, count), + CloseChan: make(chan struct{}), } go func() { @@ -37,7 +44,11 @@ func NewBarrier(count int, f func(msgTs uint64, b *Barrier)) *Barrier { for current < barrier.Dest { select { case <-barrier.CloseChan: - case msgTs = <-barrier.BarrierChan: + case signal := <-barrier.BarrierSignalChan: + if u != nil { + u(signal.VChannel, signal.Msg) + } + msgTs = signal.Msg.BeginTs() current++ } } diff --git a/core/reader/data_barrier_test.go b/core/reader/data_barrier_test.go index 80dce202..56d449bb 100644 --- a/core/reader/data_barrier_test.go +++ b/core/reader/data_barrier_test.go @@ -24,12 +24,15 @@ import ( "github.com/stretchr/testify/assert" + "github.com/milvus-io/milvus/pkg/mq/msgstream" + + "github.com/zilliztech/milvus-cdc/core/model" "github.com/zilliztech/milvus-cdc/core/util" ) func TestNewBarrier(t *testing.T) { t.Run("close", func(t *testing.T) { - b := NewBarrier(2, func(msgTs uint64, b *Barrier) {}) + b := NewBarrier(2, func(msgTs uint64, b *Barrier) {}, nil) close(b.CloseChan) }) @@ -39,9 +42,23 @@ func TestNewBarrier(t *testing.T) { b := NewBarrier(2, func(msgTs uint64, b *Barrier) { assert.EqualValues(t, 2, msgTs) isExecuted.Store(true) - }) - b.BarrierChan <- 2 - b.BarrierChan <- 2 + }, nil) + b.BarrierSignalChan <- &model.BarrierSignal{ + Msg: &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 2, + }, + }, + VChannel: "v1", + } + b.BarrierSignalChan <- &model.BarrierSignal{ + Msg: &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + BeginTimestamp: 2, + }, + }, + VChannel: "v2", + } assert.Eventually(t, isExecuted.Load, time.Second, time.Millisecond*100) }) } diff --git a/core/reader/replicate_channel_manager.go b/core/reader/replicate_channel_manager.go index 891ed801..ceea0633 100644 --- a/core/reader/replicate_channel_manager.go +++ b/core/reader/replicate_channel_manager.go @@ -62,6 +62,7 @@ type replicateChannelManager struct { streamCreator StreamCreator targetClient api.TargetAPI metaOp api.MetaOp + replicateMeta api.ReplicateMeta retryOptions []retry.Option startReadRetryOptions []retry.Option @@ -105,6 +106,7 @@ func NewReplicateChannelManager( client api.TargetAPI, readConfig config.ReaderConfig, metaOp api.MetaOp, + replicateMeta api.ReplicateMeta, msgPackCallback func(string, *msgstream.MsgPack), downstream string, ) (api.ChannelManager, error) { @@ -114,6 +116,7 @@ func NewReplicateChannelManager( streamCreator: NewDisptachClientStreamCreator(factory, dispatchClient), targetClient: client, metaOp: metaOp, + replicateMeta: replicateMeta, retryOptions: util.GetRetryOptions(readConfig.Retry), startReadRetryOptions: util.GetRetryOptions(config.RetrySettings{ RetryTimes: readConfig.Retry.RetryTimes, @@ -354,6 +357,7 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m }, ReplicateParam: api.ReplicateParam{Database: targetInfo.DatabaseName}, TaskID: taskID, + MsgID: api.GetDropCollectionMsgID(info.ID), }: r.droppedCollections.Store(info.ID, struct{}{}) for _, name := range info.PhysicalChannelNames { @@ -363,6 +367,26 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m r.collectionLock.Lock() delete(r.replicateCollections, info.ID) r.collectionLock.Unlock() + }, func(vchannel string, m msgstream.TsMsg) { + dropCollectionMsg, ok := m.(*msgstream.DropCollectionMsg) + if !ok { + log.Panic("the message is not drop collection message", zap.Any("msg", m)) + } + msgID := api.GetDropCollectionMsgID(dropCollectionMsg.CollectionID) + _, err := r.replicateMeta.UpdateTaskDropCollectionMsg(ctx, api.TaskDropCollectionMsg{ + Base: api.BaseTaskMsg{ + TaskID: taskID, + MsgID: msgID, + TargetChannels: info.VirtualChannelNames, + ReadyChannels: []string{vchannel}, + }, + DatabaseName: dropCollectionMsg.DbName, + CollectionName: dropCollectionMsg.CollectionName, + DropTS: dropCollectionMsg.EndTs(), + }) + if err != nil { + log.Panic("failed to update task drop collection msg", zap.Error(err)) + } }) r.replicateCollections[info.ID] = barrier.CloseChan r.collectionLock.Unlock() @@ -384,8 +408,8 @@ func (r *replicateChannelManager) StartReadCollection(ctx context.Context, db *m PartitionInfo: targetInfo.Partitions, PChannel: targetPChannel, VChannel: targetVChannel, - BarrierChan: model.NewOnceWriteChan(barrier.BarrierChan), - PartitionBarrierChan: make(map[int64]*model.OnceWriteChan[uint64]), + BarrierChan: model.NewOnceWriteChan(barrier.BarrierSignalChan), + PartitionBarrierChan: make(map[int64]*model.OnceWriteChan[*model.BarrierSignal]), Dropped: targetInfo.Dropped, DroppedPartition: make(map[int64]struct{}), }) @@ -543,12 +567,34 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode }, ReplicateParam: api.ReplicateParam{Database: dbInfo.Name}, TaskID: taskID, + MsgID: api.GetDropPartitionMsgID(collectionID, partitionInfo.PartitionID), }: r.droppedPartitions.Store(partitionInfo.PartitionID, struct{}{}) for _, handler := range handlers { handler.RemovePartitionInfo(collectionID, partitionInfo.PartitionName, partitionInfo.PartitionID) } } + }, func(vchannel string, m msgstream.TsMsg) { + dropPartitionMsg, ok := m.(*msgstream.DropPartitionMsg) + if !ok { + log.Panic("the message is not drop partition message", zap.Any("msg", m)) + } + msgID := api.GetDropPartitionMsgID(collectionID, partitionInfo.PartitionID) + _, err := r.replicateMeta.UpdateTaskDropPartitionMsg(ctx, api.TaskDropPartitionMsg{ + Base: api.BaseTaskMsg{ + TaskID: taskID, + MsgID: msgID, + TargetChannels: collectionInfo.VirtualChannelNames, + ReadyChannels: []string{vchannel}, + }, + DatabaseName: dropPartitionMsg.DbName, + CollectionName: dropPartitionMsg.CollectionName, + PartitionName: dropPartitionMsg.PartitionName, + DropTS: dropPartitionMsg.EndTs(), + }) + if err != nil { + log.Panic("failed to update task drop partition msg", zap.Error(err)) + } }) r.partitionLock.Lock() if _, ok := r.replicatePartitions[collectionID]; !ok { @@ -562,7 +608,7 @@ func (r *replicateChannelManager) AddPartition(ctx context.Context, dbInfo *mode r.replicatePartitions[collectionID][partitionInfo.PartitionID] = barrier.CloseChan r.partitionLock.Unlock() for _, handler := range handlers { - err = handler.AddPartitionInfo(taskID, collectionInfo, partitionInfo, barrier.BarrierChan) + err = handler.AddPartitionInfo(taskID, collectionInfo, partitionInfo, barrier.BarrierSignalChan) if err != nil { return err } @@ -1011,7 +1057,7 @@ func (r *replicateChannelHandler) RemoveCollection(collectionID int64) { log.Info("remove collection from handler", zap.Int64("collection_id", collectionID)) } -func (r *replicateChannelHandler) AddPartitionInfo(taskID string, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo, barrierChan chan<- uint64) error { +func (r *replicateChannelHandler) AddPartitionInfo(taskID string, collectionInfo *pb.CollectionInfo, partitionInfo *pb.PartitionInfo, barrierChan chan<- *model.BarrierSignal) error { collectionID := collectionInfo.ID partitionID := partitionInfo.PartitionID collectionName := collectionInfo.Schema.Name @@ -1501,7 +1547,10 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa realMsg = msg.(*msgstream.DropCollectionMsg) realMsg.CollectionID = info.CollectionID - info.BarrierChan.Write(msg.EndTs()) + info.BarrierChan.Write(&model.BarrierSignal{ + Msg: msg, + VChannel: info.VChannel, + }) needTsMsg = true r.RemoveCollection(collectionID) case *msgstream.DropPartitionMsg: @@ -1528,7 +1577,7 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa log.Warn("invalid drop partition message, empty partition name", zap.Any("msg", msg)) continue } - var partitionBarrierChan *model.OnceWriteChan[uint64] + var partitionBarrierChan *model.OnceWriteChan[*model.BarrierSignal] retryErr := retry.Do(r.replicateCtx, func() error { err = nil r.recordLock.RLock() @@ -1561,7 +1610,10 @@ func (r *replicateChannelHandler) handlePack(forward bool, pack *msgstream.MsgPa zap.Int64("partition_id", partitionID), zap.String("partition_name", realMsg.PartitionName)) continue } - partitionBarrierChan.Write(msg.EndTs()) + partitionBarrierChan.Write(&model.BarrierSignal{ + Msg: msg, + VChannel: info.VChannel, + }) r.RemovePartitionInfo(sourceCollectionID, realMsg.PartitionName, partitionID) } } diff --git a/core/reader/replicate_channel_manager_test.go b/core/reader/replicate_channel_manager_test.go index 04f81fe4..3b5cdeed 100644 --- a/core/reader/replicate_channel_manager_test.go +++ b/core/reader/replicate_channel_manager_test.go @@ -285,9 +285,9 @@ func TestStartReadCollectionForMilvus(t *testing.T) { }, PChannel: "ttest_read_channel", VChannel: "ttest_read_channel_v0", - BarrierChan: model.NewOnceWriteChan(make(chan<- uint64)), - PartitionBarrierChan: map[int64]*model.OnceWriteChan[uint64]{ - 1101: model.NewOnceWriteChan(make(chan<- uint64)), + BarrierChan: model.NewOnceWriteChan(make(chan<- *model.BarrierSignal)), + PartitionBarrierChan: map[int64]*model.OnceWriteChan[*model.BarrierSignal]{ + 1101: model.NewOnceWriteChan(make(chan<- *model.BarrierSignal)), }, }) assert.NoError(t, err) @@ -449,9 +449,9 @@ func TestStartReadCollectionForKafka(t *testing.T) { }, PChannel: "kafka_ttest_read_channel", VChannel: "kafka_ttest_read_channel_v0", - BarrierChan: model.NewOnceWriteChan(make(chan<- uint64)), - PartitionBarrierChan: map[int64]*model.OnceWriteChan[uint64]{ - 1101: model.NewOnceWriteChan(make(chan<- uint64)), + BarrierChan: model.NewOnceWriteChan(make(chan<- *model.BarrierSignal)), + PartitionBarrierChan: map[int64]*model.OnceWriteChan[*model.BarrierSignal]{ + 1101: model.NewOnceWriteChan(make(chan<- *model.BarrierSignal)), }, }) assert.NoError(t, err) @@ -702,8 +702,8 @@ func TestReplicateChannelHandler(t *testing.T) { CollectionID: 2, }, &model.TargetCollectionInfo{ CollectionName: "test2", - PartitionBarrierChan: map[int64]*model.OnceWriteChan[uint64]{ - 1001: model.NewOnceWriteChan(make(chan<- uint64)), + PartitionBarrierChan: map[int64]*model.OnceWriteChan[*model.BarrierSignal]{ + 1001: model.NewOnceWriteChan(make(chan<- *model.BarrierSignal)), }, DroppedPartition: make(map[int64]struct{}), }) @@ -716,7 +716,7 @@ func TestReplicateChannelHandler(t *testing.T) { }, &pb.PartitionInfo{ PartitionID: 2001, PartitionName: "p2", - }, make(chan<- uint64)) + }, make(chan<- *model.BarrierSignal)) assert.NoError(t, err) time.Sleep(1500 * time.Millisecond) handler.RemovePartitionInfo(2, "p2", 10002) @@ -763,8 +763,8 @@ func TestReplicateChannelHandler(t *testing.T) { stream.EXPECT().Seek(mock.Anything, mock.Anything, mock.Anything).Return(nil).Once().Twice() stream.EXPECT().Chan().Return(streamChan) - barrierChan := make(chan uint64, 1) - partitionBarrierChan := make(chan uint64, 1) + barrierChan := make(chan *model.BarrierSignal, 1) + partitionBarrierChan := make(chan *model.BarrierSignal, 1) apiEventChan := make(chan *api.ReplicateAPIEvent, 10) handler, err := newReplicateChannelHandler(context.Background(), &model.SourceCollectionInfo{ CollectionID: 1, @@ -780,7 +780,7 @@ func TestReplicateChannelHandler(t *testing.T) { PChannel: "test_q", VChannel: "test_q_v1", BarrierChan: model.NewOnceWriteChan(barrierChan), - PartitionBarrierChan: map[int64]*model.OnceWriteChan[uint64]{}, + PartitionBarrierChan: map[int64]*model.OnceWriteChan[*model.BarrierSignal]{}, DroppedPartition: make(map[int64]struct{}), }, targetClient, &api.DefaultMetaOp{}, apiEventChan, &model.HandlerOpts{ Factory: factory, @@ -818,13 +818,14 @@ func TestReplicateChannelHandler(t *testing.T) { go func() { defer close(done) { + log.Info("receive timetick msg") // timetick pack replicateMsg := <-targetMsgChan pack := replicateMsg.MsgPack // assert pack assert.NotNil(t, pack) assert.EqualValues(t, 1, pack.BeginTs) - assert.EqualValues(t, 2, pack.EndTs) + assert.EqualValues(t, 3, pack.EndTs) assert.Len(t, pack.StartPositions, 1) assert.Len(t, pack.EndPositions, 1) assert.Len(t, pack.Msgs, 1) @@ -832,6 +833,7 @@ func TestReplicateChannelHandler(t *testing.T) { assert.True(t, ok, pack.Msgs[0]) } { + log.Info("receive insert msg") // insert msg replicateMsg := <-targetMsgChan pack := replicateMsg.MsgPack @@ -843,6 +845,7 @@ func TestReplicateChannelHandler(t *testing.T) { } { + log.Info("receive delete msg") // delete msg replicateMsg := <-targetMsgChan pack := replicateMsg.MsgPack @@ -862,6 +865,7 @@ func TestReplicateChannelHandler(t *testing.T) { } { + log.Info("receive drop partition msg") // drop partition msg replicateMsg := <-targetMsgChan pack := replicateMsg.MsgPack @@ -869,17 +873,20 @@ func TestReplicateChannelHandler(t *testing.T) { dropMsg := pack.Msgs[0].(*msgstream.DropPartitionMsg) assert.EqualValues(t, 100, dropMsg.CollectionID) assert.EqualValues(t, 100021, dropMsg.PartitionID) - assert.EqualValues(t, 2, <-partitionBarrierChan) + signal := <-partitionBarrierChan + assert.EqualValues(t, 12, signal.Msg.EndTs()) } { + log.Info("receive drop collection msg") // drop collection msg replicateMsg := <-targetMsgChan pack := replicateMsg.MsgPack assert.Len(t, pack.Msgs, 2) dropMsg := pack.Msgs[0].(*msgstream.DropCollectionMsg) assert.EqualValues(t, 100, dropMsg.CollectionID) - assert.EqualValues(t, 2, <-barrierChan) + signal := <-barrierChan + assert.EqualValues(t, 14, signal.Msg.EndTs()) } }() @@ -887,7 +894,7 @@ func TestReplicateChannelHandler(t *testing.T) { log.Info("create collection msg / create partition msg / timetick msg") streamChan <- &msgstream.MsgPack{ BeginTs: 1, - EndTs: 2, + EndTs: 3, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -901,7 +908,7 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.CreateCollectionMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, + BeginTimestamp: 2, EndTimestamp: 2, HashValues: []uint32{0}, }, @@ -913,7 +920,7 @@ func TestReplicateChannelHandler(t *testing.T) { }, &msgstream.CreatePartitionMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, + BeginTimestamp: 2, EndTimestamp: 2, HashValues: []uint32{0}, }, @@ -925,7 +932,7 @@ func TestReplicateChannelHandler(t *testing.T) { }, &msgstream.TimeTickMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, + BeginTimestamp: 2, EndTimestamp: 2, HashValues: []uint32{0}, }, @@ -937,7 +944,7 @@ func TestReplicateChannelHandler(t *testing.T) { }, &msgstream.TimeTickMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, + BeginTimestamp: 2, EndTimestamp: 2, HashValues: []uint32{0}, }, @@ -951,8 +958,8 @@ func TestReplicateChannelHandler(t *testing.T) { } streamChan <- &msgstream.MsgPack{ - BeginTs: 1, - EndTs: 2, + BeginTs: 3, + EndTs: 5, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -966,8 +973,8 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.TimeTickMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, - EndTimestamp: 2, + BeginTimestamp: 4, + EndTimestamp: 4, HashValues: []uint32{0}, }, TimeTickMsg: &msgpb.TimeTickMsg{ @@ -980,8 +987,8 @@ func TestReplicateChannelHandler(t *testing.T) { } streamChan <- &msgstream.MsgPack{ - BeginTs: 1, - EndTs: 2, + BeginTs: 5, + EndTs: 7, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -995,8 +1002,8 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.TimeTickMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, - EndTimestamp: 2, + BeginTimestamp: 6, + EndTimestamp: 6, HashValues: []uint32{0}, }, TimeTickMsg: &msgpb.TimeTickMsg{ @@ -1011,8 +1018,8 @@ func TestReplicateChannelHandler(t *testing.T) { // insert msg log.Info("insert msg") streamChan <- &msgstream.MsgPack{ - BeginTs: 1, - EndTs: 2, + BeginTs: 7, + EndTs: 9, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -1026,8 +1033,8 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.InsertMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, - EndTimestamp: 2, + BeginTimestamp: 8, + EndTimestamp: 8, HashValues: []uint32{0}, MsgPosition: &msgstream.MsgPosition{ChannelName: "test_p"}, }, @@ -1047,8 +1054,8 @@ func TestReplicateChannelHandler(t *testing.T) { // delete msg log.Info("delete msg") streamChan <- &msgstream.MsgPack{ - BeginTs: 1, - EndTs: 2, + BeginTs: 9, + EndTs: 11, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -1062,8 +1069,8 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, - EndTimestamp: 1, + BeginTimestamp: 10, + EndTimestamp: 10, HashValues: []uint32{0}, MsgPosition: &msgstream.MsgPosition{ChannelName: "test_p"}, }, @@ -1077,8 +1084,8 @@ func TestReplicateChannelHandler(t *testing.T) { }, &msgstream.DeleteMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 2, - EndTimestamp: 2, + BeginTimestamp: 11, + EndTimestamp: 11, HashValues: []uint32{0}, MsgPosition: &msgstream.MsgPosition{ChannelName: "test_p"}, }, @@ -1098,8 +1105,8 @@ func TestReplicateChannelHandler(t *testing.T) { // drop partition msg log.Info("drop partition msg") streamChan <- &msgstream.MsgPack{ - BeginTs: 1, - EndTs: 2, + BeginTs: 11, + EndTs: 13, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -1113,8 +1120,8 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.DropPartitionMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, - EndTimestamp: 2, + BeginTimestamp: 12, + EndTimestamp: 12, HashValues: []uint32{0}, MsgPosition: &msgstream.MsgPosition{ChannelName: "test_p"}, }, @@ -1133,8 +1140,8 @@ func TestReplicateChannelHandler(t *testing.T) { // drop collection msg log.Info("drop collection msg") streamChan <- &msgstream.MsgPack{ - BeginTs: 1, - EndTs: 2, + BeginTs: 13, + EndTs: 15, StartPositions: []*msgstream.MsgPosition{ { ChannelName: "test_p", @@ -1148,8 +1155,8 @@ func TestReplicateChannelHandler(t *testing.T) { Msgs: []msgstream.TsMsg{ &msgstream.DropCollectionMsg{ BaseMsg: msgstream.BaseMsg{ - BeginTimestamp: 1, - EndTimestamp: 2, + BeginTimestamp: 14, + EndTimestamp: 14, HashValues: []uint32{0}, MsgPosition: &msgstream.MsgPosition{ChannelName: "test_p"}, }, diff --git a/core/writer/channel_writer.go b/core/writer/channel_writer.go index 0f506ad0..507d1b4c 100644 --- a/core/writer/channel_writer.go +++ b/core/writer/channel_writer.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-sdk-go/v2/entity" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/requestutil" @@ -36,6 +37,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/config" "github.com/zilliztech/milvus-cdc/core/log" + "github.com/zilliztech/milvus-cdc/core/pb" "github.com/zilliztech/milvus-cdc/core/util" ) @@ -49,6 +51,7 @@ type ( type ChannelWriter struct { dataHandler api.DataHandler messageManager api.MessageManager + replicateMeta api.ReplicateMeta opMessageFuncs map[commonpb.MsgType]opMessageFunc apiEventFuncs map[api.ReplicateAPIEventType]apiEventFunc @@ -64,13 +67,16 @@ type ChannelWriter struct { replicateID string } -func NewChannelWriter(dataHandler api.DataHandler, +func NewChannelWriter( + dataHandler api.DataHandler, + replicateMeta api.ReplicateMeta, writerConfig config.WriterConfig, droppedObjs map[string]map[string]uint64, downstream string, ) api.Writer { w := &ChannelWriter{ dataHandler: dataHandler, + replicateMeta: replicateMeta, messageManager: NewReplicateMessageManager(dataHandler, writerConfig.MessageBufferSize), retryOptions: util.GetRetryOptions(writerConfig.Retry), downstream: downstream, @@ -344,6 +350,76 @@ func (c *ChannelWriter) HandleOpMessagePack(ctx context.Context, msgPack *msgstr return endPosition.MsgID, nil } +func (c *ChannelWriter) RecoveryMetaMsg(ctx context.Context, taskID string) error { + dropCollectionMetaMsgs, err := c.replicateMeta.GetTaskDropCollectionMsg(ctx, taskID, "") + if err != nil { + log.Warn("fail to get task drop collection msg", zap.Error(err)) + return err + } + for _, metaMsg := range dropCollectionMetaMsgs { + if !metaMsg.Base.IsReady() { + continue + } + err = c.HandleReplicateAPIEvent(ctx, &api.ReplicateAPIEvent{ + EventType: api.ReplicateDropCollection, + CollectionInfo: &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: metaMsg.CollectionName, + }, + }, + ReplicateInfo: &commonpb.ReplicateInfo{ + IsReplicate: true, + MsgTimestamp: metaMsg.DropTS, + }, + ReplicateParam: api.ReplicateParam{ + Database: metaMsg.DatabaseName, + }, + TaskID: metaMsg.Base.TaskID, + MsgID: metaMsg.Base.MsgID, + }) + if err != nil { + log.Warn("fail to handle replicate api event", zap.Error(err)) + return err + } + } + + dropPartitionMetaMsgs, err := c.replicateMeta.GetTaskDropPartitionMsg(ctx, taskID, "") + if err != nil { + log.Warn("fail to get task drop partition msg", zap.Error(err)) + return err + } + for _, metaMsg := range dropPartitionMetaMsgs { + if !metaMsg.Base.IsReady() { + continue + } + err = c.HandleReplicateAPIEvent(ctx, &api.ReplicateAPIEvent{ + EventType: api.ReplicateDropPartition, + CollectionInfo: &pb.CollectionInfo{ + Schema: &schemapb.CollectionSchema{ + Name: metaMsg.CollectionName, + }, + }, + PartitionInfo: &pb.PartitionInfo{ + PartitionName: metaMsg.PartitionName, + }, + ReplicateInfo: &commonpb.ReplicateInfo{ + IsReplicate: true, + MsgTimestamp: metaMsg.DropTS, + }, + ReplicateParam: api.ReplicateParam{ + Database: metaMsg.DatabaseName, + }, + TaskID: metaMsg.Base.TaskID, + MsgID: metaMsg.Base.MsgID, + }) + if err != nil { + log.Warn("fail to handle replicate api event", zap.Error(err)) + return err + } + } + return nil +} + // WaitDatabaseReady wait for database ready, return value: skip the op or not, wait timeout or not func (c *ChannelWriter) WaitDatabaseReady(ctx context.Context, databaseName string, msgTs uint64, collectionName string) InfoState { if databaseName == "" || databaseName == util.DefaultDbName { @@ -537,6 +613,11 @@ func (c *ChannelWriter) dropCollection(ctx context.Context, apiEvent *api.Replic } _, dropKey := util.GetCollectionInfoKeys(collectionName, databaseName) c.collectionInfos.Store(dropKey, apiEvent.ReplicateInfo.MsgTimestamp) + err = c.replicateMeta.RemoveTaskMsg(ctx, apiEvent.TaskID, apiEvent.MsgID) + if err != nil { + log.Warn("fail to remove task msg", zap.Error(err)) + return err + } return nil } @@ -600,6 +681,11 @@ func (c *ChannelWriter) dropPartition(ctx context.Context, apiEvent *api.Replica } _, dropKey := util.GetPartitionInfoKeys(partitionName, collectionName, databaseName) c.partitionInfos.Store(dropKey, apiEvent.ReplicateInfo.MsgTimestamp) + err = c.replicateMeta.RemoveTaskMsg(ctx, apiEvent.TaskID, apiEvent.MsgID) + if err != nil { + log.Warn("fail to remove task msg", zap.Error(err)) + return err + } return nil } diff --git a/core/writer/channel_writer_test.go b/core/writer/channel_writer_test.go index 6af18032..c2b5153a 100644 --- a/core/writer/channel_writer_test.go +++ b/core/writer/channel_writer_test.go @@ -42,14 +42,21 @@ import ( func GetMockObjs(t *testing.T) (*mocks.DataHandler, api.Writer) { dataHandler := mocks.NewDataHandler(t) messageManager := mocks.NewMessageManager(t) - w := NewChannelWriter(dataHandler, config.WriterConfig{ - MessageBufferSize: 10, - Retry: config.RetrySettings{ - RetryTimes: 2, - InitBackOff: 1, - MaxBackOff: 1, + replicateMeta := mocks.NewReplicateMeta(t) + w := NewChannelWriter( + dataHandler, + replicateMeta, + config.WriterConfig{ + MessageBufferSize: 10, + Retry: config.RetrySettings{ + RetryTimes: 2, + InitBackOff: 1, + MaxBackOff: 1, + }, }, - }, map[string]map[string]uint64{}, "milvus") + map[string]map[string]uint64{}, + "milvus", + ) assert.NotNil(t, w) realWriter := w.(*ChannelWriter) realWriter.messageManager = messageManager diff --git a/server/api/meta_store.go b/server/api/meta_store.go index 8e18b8d1..041ce34f 100644 --- a/server/api/meta_store.go +++ b/server/api/meta_store.go @@ -21,6 +21,7 @@ package api import ( "context" + "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/server/model/meta" ) @@ -37,6 +38,7 @@ type MetaStore[M any] interface { type MetaStoreFactory interface { GetTaskInfoMetaStore(ctx context.Context) MetaStore[*meta.TaskInfo] GetTaskCollectionPositionMetaStore(ctx context.Context) MetaStore[*meta.TaskCollectionPosition] + GetReplicateStore(ctx context.Context) api.ReplicateStore // Txn return commit function and error Txn(ctx context.Context) (any, func(err error) error, error) } diff --git a/server/cdc_impl.go b/server/cdc_impl.go index e9369cd0..9c38bd51 100644 --- a/server/cdc_impl.go +++ b/server/cdc_impl.go @@ -42,6 +42,7 @@ import ( "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/config" "github.com/zilliztech/milvus-cdc/core/log" + meta2 "github.com/zilliztech/milvus-cdc/core/meta" coremodel "github.com/zilliztech/milvus-cdc/core/model" "github.com/zilliztech/milvus-cdc/core/pb" cdcreader "github.com/zilliztech/milvus-cdc/core/reader" @@ -207,6 +208,8 @@ func (e *MetaCDC) ReloadTask() { log.Warn("fail to start the task", zap.Any("task_info", taskInfo), zap.Error(err)) _ = e.pauseTaskWithReason(taskInfo.TaskID, "fail to start task, err: "+err.Error(), []meta.TaskState{}) } + // replicateEntity := e.replicateEntityMap.data[uKey] + // replicateEntity. } } @@ -852,6 +855,11 @@ func (e *MetaCDC) newReplicateEntity(info *meta.TaskInfo) (*ReplicateEntity, err // default value: 10 bufferSize := e.config.SourceConfig.ReadChanLen ttInterval := e.config.SourceConfig.TimeTickInterval + replicateMeta, err := meta2.NewReplicateMetaImpl(e.metaStoreFactory.GetReplicateStore(ctx)) + if err != nil { + taskLog.Warn("fail to new replicate meta", zap.Error(err)) + return nil, servererror.NewClientError("fail to new replicate meta") + } channelManager, err := cdcreader.NewReplicateChannelManager( msgTTDispatcherClient, streamFactory, @@ -863,7 +871,10 @@ func (e *MetaCDC) newReplicateEntity(info *meta.TaskInfo) (*ReplicateEntity, err SourceChannelNum: e.config.SourceConfig.ChannelNum, TargetChannelNum: info.MilvusConnectParam.ChannelNum, ReplicateID: uKey, - }, metaOp, func(s string, pack *msgstream.MsgPack) { + }, + metaOp, + replicateMeta, + func(s string, pack *msgstream.MsgPack) { replicateMetric(info.TaskID, s, pack, metrics.OPTypeRead) }, downstream) if err != nil { @@ -891,11 +902,17 @@ func (e *MetaCDC) newReplicateEntity(info *meta.TaskInfo) (*ReplicateEntity, err taskLog.Warn("fail to new the data handler", zap.Error(err)) return nil, servererror.NewClientError("fail to new the data handler, task_id: " + info.TaskID) } - writerObj := cdcwriter.NewChannelWriter(dataHandler, config.WriterConfig{ - MessageBufferSize: bufferSize, - Retry: e.config.Retry, - ReplicateID: e.config.ReplicateID, - }, metaOp.GetAllDroppedObj(), downstream) + writerObj := cdcwriter.NewChannelWriter( + dataHandler, + replicateMeta, + config.WriterConfig{ + MessageBufferSize: bufferSize, + Retry: e.config.Retry, + ReplicateID: e.config.ReplicateID, + }, + metaOp.GetAllDroppedObj(), + downstream, + ) e.replicateEntityMap.Lock() defer e.replicateEntityMap.Unlock() // TODO fubang should be fix diff --git a/server/mocks/meta_store_factory.go b/server/mocks/meta_store_factory.go index 9ecf6f3d..2667da79 100644 --- a/server/mocks/meta_store_factory.go +++ b/server/mocks/meta_store_factory.go @@ -5,11 +5,13 @@ package mocks import ( context "context" - api "github.com/zilliztech/milvus-cdc/server/api" + api "github.com/zilliztech/milvus-cdc/core/api" meta "github.com/zilliztech/milvus-cdc/server/model/meta" mock "github.com/stretchr/testify/mock" + + serverapi "github.com/zilliztech/milvus-cdc/server/api" ) // MetaStoreFactory is an autogenerated mock type for the MetaStoreFactory type @@ -25,16 +27,60 @@ func (_m *MetaStoreFactory) EXPECT() *MetaStoreFactory_Expecter { return &MetaStoreFactory_Expecter{mock: &_m.Mock} } +// GetReplicateStore provides a mock function with given fields: ctx +func (_m *MetaStoreFactory) GetReplicateStore(ctx context.Context) api.ReplicateStore { + ret := _m.Called(ctx) + + var r0 api.ReplicateStore + if rf, ok := ret.Get(0).(func(context.Context) api.ReplicateStore); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(api.ReplicateStore) + } + } + + return r0 +} + +// MetaStoreFactory_GetReplicateStore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReplicateStore' +type MetaStoreFactory_GetReplicateStore_Call struct { + *mock.Call +} + +// GetReplicateStore is a helper method to define mock.On call +// - ctx context.Context +func (_e *MetaStoreFactory_Expecter) GetReplicateStore(ctx interface{}) *MetaStoreFactory_GetReplicateStore_Call { + return &MetaStoreFactory_GetReplicateStore_Call{Call: _e.mock.On("GetReplicateStore", ctx)} +} + +func (_c *MetaStoreFactory_GetReplicateStore_Call) Run(run func(ctx context.Context)) *MetaStoreFactory_GetReplicateStore_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MetaStoreFactory_GetReplicateStore_Call) Return(_a0 api.ReplicateStore) *MetaStoreFactory_GetReplicateStore_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MetaStoreFactory_GetReplicateStore_Call) RunAndReturn(run func(context.Context) api.ReplicateStore) *MetaStoreFactory_GetReplicateStore_Call { + _c.Call.Return(run) + return _c +} + // GetTaskCollectionPositionMetaStore provides a mock function with given fields: ctx -func (_m *MetaStoreFactory) GetTaskCollectionPositionMetaStore(ctx context.Context) api.MetaStore[*meta.TaskCollectionPosition] { +func (_m *MetaStoreFactory) GetTaskCollectionPositionMetaStore(ctx context.Context) serverapi.MetaStore[*meta.TaskCollectionPosition] { ret := _m.Called(ctx) - var r0 api.MetaStore[*meta.TaskCollectionPosition] - if rf, ok := ret.Get(0).(func(context.Context) api.MetaStore[*meta.TaskCollectionPosition]); ok { + var r0 serverapi.MetaStore[*meta.TaskCollectionPosition] + if rf, ok := ret.Get(0).(func(context.Context) serverapi.MetaStore[*meta.TaskCollectionPosition]); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(api.MetaStore[*meta.TaskCollectionPosition]) + r0 = ret.Get(0).(serverapi.MetaStore[*meta.TaskCollectionPosition]) } } @@ -59,26 +105,26 @@ func (_c *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call) Run(run func return _c } -func (_c *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call) Return(_a0 api.MetaStore[*meta.TaskCollectionPosition]) *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call { +func (_c *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call) Return(_a0 serverapi.MetaStore[*meta.TaskCollectionPosition]) *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call { _c.Call.Return(_a0) return _c } -func (_c *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call) RunAndReturn(run func(context.Context) api.MetaStore[*meta.TaskCollectionPosition]) *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call { +func (_c *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call) RunAndReturn(run func(context.Context) serverapi.MetaStore[*meta.TaskCollectionPosition]) *MetaStoreFactory_GetTaskCollectionPositionMetaStore_Call { _c.Call.Return(run) return _c } // GetTaskInfoMetaStore provides a mock function with given fields: ctx -func (_m *MetaStoreFactory) GetTaskInfoMetaStore(ctx context.Context) api.MetaStore[*meta.TaskInfo] { +func (_m *MetaStoreFactory) GetTaskInfoMetaStore(ctx context.Context) serverapi.MetaStore[*meta.TaskInfo] { ret := _m.Called(ctx) - var r0 api.MetaStore[*meta.TaskInfo] - if rf, ok := ret.Get(0).(func(context.Context) api.MetaStore[*meta.TaskInfo]); ok { + var r0 serverapi.MetaStore[*meta.TaskInfo] + if rf, ok := ret.Get(0).(func(context.Context) serverapi.MetaStore[*meta.TaskInfo]); ok { r0 = rf(ctx) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(api.MetaStore[*meta.TaskInfo]) + r0 = ret.Get(0).(serverapi.MetaStore[*meta.TaskInfo]) } } @@ -103,12 +149,12 @@ func (_c *MetaStoreFactory_GetTaskInfoMetaStore_Call) Run(run func(ctx context.C return _c } -func (_c *MetaStoreFactory_GetTaskInfoMetaStore_Call) Return(_a0 api.MetaStore[*meta.TaskInfo]) *MetaStoreFactory_GetTaskInfoMetaStore_Call { +func (_c *MetaStoreFactory_GetTaskInfoMetaStore_Call) Return(_a0 serverapi.MetaStore[*meta.TaskInfo]) *MetaStoreFactory_GetTaskInfoMetaStore_Call { _c.Call.Return(_a0) return _c } -func (_c *MetaStoreFactory_GetTaskInfoMetaStore_Call) RunAndReturn(run func(context.Context) api.MetaStore[*meta.TaskInfo]) *MetaStoreFactory_GetTaskInfoMetaStore_Call { +func (_c *MetaStoreFactory_GetTaskInfoMetaStore_Call) RunAndReturn(run func(context.Context) serverapi.MetaStore[*meta.TaskInfo]) *MetaStoreFactory_GetTaskInfoMetaStore_Call { _c.Call.Return(run) return _c } diff --git a/server/store/etcd.go b/server/store/etcd.go index 4ebb3d3c..b5e0a489 100644 --- a/server/store/etcd.go +++ b/server/store/etcd.go @@ -27,8 +27,10 @@ import ( clientv3 "go.etcd.io/etcd/client/v3" "go.uber.org/zap" + api2 "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/config" "github.com/zilliztech/milvus-cdc/core/log" + meta2 "github.com/zilliztech/milvus-cdc/core/meta" "github.com/zilliztech/milvus-cdc/core/util" "github.com/zilliztech/milvus-cdc/server/api" "github.com/zilliztech/milvus-cdc/server/model/meta" @@ -42,12 +44,13 @@ var ( type EtcdMetaStore struct { log *zap.Logger etcdClient *clientv3.Client + replicateStore api2.ReplicateStore taskInfoStore *TaskInfoEtcdStore taskCollectionPositionStore *TaskCollectionPositionEtcdStore txnMap map[any][]clientv3.Op } -var _ api.MetaStoreFactory = &EtcdMetaStore{} +var _ api.MetaStoreFactory = (*EtcdMetaStore)(nil) func NewEtcdMetaStoreWithAddress(ctx context.Context, endpoints []string, rootPath string) (*EtcdMetaStore, error) { return NewEtcdMetaStore(ctx, config.EtcdServerConfig{ @@ -78,12 +81,18 @@ func NewEtcdMetaStore(ctx context.Context, etcdServerConfig config.EtcdServerCon log.Warn("fail to get task collection position store") return nil, err } + replicateStore, err := meta2.NewEtcdReplicateStore(etcdServerConfig.Address, etcdServerConfig.RootPath) + if err != nil { + log.Warn("fail to get replicate store", zap.Error(err)) + return nil, err + } return &EtcdMetaStore{ log: log, etcdClient: etcdClient, taskInfoStore: taskInfoStore, taskCollectionPositionStore: taskCollectionPositionStore, + replicateStore: replicateStore, txnMap: txnMap, }, nil } @@ -96,6 +105,10 @@ func (e *EtcdMetaStore) GetTaskCollectionPositionMetaStore(ctx context.Context) return e.taskCollectionPositionStore } +func (e *EtcdMetaStore) GetReplicateStore(ctx context.Context) api2.ReplicateStore { + return e.replicateStore +} + func (e *EtcdMetaStore) Txn(ctx context.Context) (any, func(err error) error, error) { txn := e.etcdClient.Txn(ctx) commitFunc := func(err error) error { diff --git a/server/store/mysql.go b/server/store/mysql.go index 434924ee..aa032e84 100644 --- a/server/store/mysql.go +++ b/server/store/mysql.go @@ -29,6 +29,7 @@ import ( "github.com/goccy/go-json" "go.uber.org/zap" + api2 "github.com/zilliztech/milvus-cdc/core/api" "github.com/zilliztech/milvus-cdc/core/log" "github.com/zilliztech/milvus-cdc/core/util" "github.com/zilliztech/milvus-cdc/server/api" @@ -40,6 +41,7 @@ type MySQLMetaStore struct { db *sql.DB taskInfoStore *TaskInfoMysqlStore taskCollectionPositionStore *TaskCollectionPositionMysqlStore + replicateStore api2.ReplicateStore txnMap map[any]func() *sql.Tx } @@ -81,11 +83,16 @@ func (s *MySQLMetaStore) init(ctx context.Context, dataSourceName string, rootPa return err } s.txnMap = txnMap + s.replicateStore, err = NewMySQLReplicateStore(ctx, dataSourceName, rootPath) + if err != nil { + s.log.Warn("fail to create replicate store", zap.Error(err)) + return err + } return nil } -var _ api.MetaStoreFactory = &MySQLMetaStore{} +var _ api.MetaStoreFactory = (*MySQLMetaStore)(nil) func (s *MySQLMetaStore) GetTaskInfoMetaStore(ctx context.Context) api.MetaStore[*meta.TaskInfo] { return s.taskInfoStore @@ -95,6 +102,10 @@ func (s *MySQLMetaStore) GetTaskCollectionPositionMetaStore(ctx context.Context) return s.taskCollectionPositionStore } +func (s *MySQLMetaStore) GetReplicateStore(ctx context.Context) api2.ReplicateStore { + return s.replicateStore +} + func (s *MySQLMetaStore) Txn(ctx context.Context) (any, func(err error) error, error) { txObj, err := s.db.BeginTx(ctx, nil) if err != nil { diff --git a/server/store/mysql_replicate_store.go b/server/store/mysql_replicate_store.go new file mode 100644 index 00000000..4edf0b08 --- /dev/null +++ b/server/store/mysql_replicate_store.go @@ -0,0 +1,149 @@ +/* + * Licensed to the LF AI & Data foundation under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * // + * http://www.apache.org/licenses/LICENSE-2.0 + * // + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package store + +import ( + "context" + "database/sql" + "fmt" + "path" + "time" + + "github.com/goccy/go-json" + "go.uber.org/zap" + + "github.com/zilliztech/milvus-cdc/core/api" + "github.com/zilliztech/milvus-cdc/core/log" + "github.com/zilliztech/milvus-cdc/core/util" +) + +type MySQLReplicateStore struct { + log *zap.Logger + db *sql.DB + rootPath string +} + +func NewMySQLReplicateStore(ctx context.Context, dataSourceName string, rootPath string) (*MySQLReplicateStore, error) { + s := &MySQLReplicateStore{} + s.rootPath = rootPath + s.log = log.With(zap.String("meta_store", "mysql")).Logger + db, err := sql.Open("mysql", dataSourceName) + if err != nil { + s.log.Warn("fail to open mysql", zap.Error(err)) + return nil, err + } + s.db = db + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 10*time.Second) + defer cancelFunc() + err = db.PingContext(timeoutCtx) + if err != nil { + s.log.Warn("fail to ping mysql", zap.Error(err)) + return nil, err + } + + _, err = db.ExecContext(timeoutCtx, ` + CREATE TABLE IF NOT EXISTS task_msg ( + task_msg_key VARCHAR(255) NOT NULL, + task_msg_value JSON NOT NULL, + PRIMARY KEY (task_msg_key), + INDEX idx_key (task_id), + INDEX idx_task_id (task_msg_value) + ) + `) + if err != nil { + s.log.Warn("fail to create table", zap.Error(err)) + return nil, err + } + + return s, nil +} + +func (s *MySQLReplicateStore) Get(ctx context.Context, key string, withPrefix bool) ([]api.MetaMsg, error) { + // key format is {task id}/{msg id} + taskMsgKey := path.Join(s.rootPath, key) + var sqlStr string + var sqlArgs []any + if withPrefix { + sqlStr = fmt.Sprintf("SELECT task_msg_value FROM task_msg WHERE task_msg_key LIKE '%s%%'", taskMsgKey) + } else { + sqlStr = "SELECT task_msg_value FROM task_msg WHERE task_msg_key = ?" + } + sqlArgs = append(sqlArgs, taskMsgKey) + + rows, err := s.db.QueryContext(ctx, sqlStr, sqlArgs...) + if err != nil { + s.log.Warn("fail to get task info", zap.Error(err)) + return nil, err + } + defer rows.Close() + + var result []api.MetaMsg + for rows.Next() { + var taskMsgValue string + err = rows.Scan(&taskMsgValue) + if err != nil { + s.log.Warn("fail to scan task info", zap.Error(err)) + return nil, err + } + var msg api.MetaMsg + if err := json.Unmarshal(util.ToBytes(taskMsgValue), &msg); err != nil { + return nil, err + } + result = append(result, msg) + } + + return result, nil +} + +func (s *MySQLReplicateStore) Put(ctx context.Context, key string, value api.MetaMsg) error { + // key format is {task id}/{msg id} + data, err := json.Marshal(value) + if err != nil { + return err + } + sqlStr := "INSERT INTO task_msg (task_msg_key, task_msg_value) VALUES (?, ?) ON DUPLICATE KEY UPDATE task_msg_value = ?" + taskMsgKey := path.Join(s.rootPath, key) + defer func() { + if err != nil { + s.log.Warn("fail to put task msg", zap.Error(err)) + } + }() + _, err = s.db.ExecContext(ctx, sqlStr, taskMsgKey, util.ToString(data), util.ToString(data)) + if err != nil { + return err + } + return nil +} + +func (s *MySQLReplicateStore) Remove(ctx context.Context, key string) error { + // key format is {task id}/{msg id} + sqlStr := "DELETE FROM task_msg WHERE task_msg_key = ?" + taskMsgKey := path.Join(s.rootPath, key) + var err error + defer func() { + if err != nil { + s.log.Warn("fail to delete task info", zap.Error(err)) + } + }() + _, err = s.db.ExecContext(ctx, sqlStr, taskMsgKey) + if err != nil { + return err + } + return nil +}