From d656aab949fcefb10322b822d4e28fde80f8b14a Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Mon, 18 Apr 2022 18:40:23 +0800 Subject: [PATCH 1/7] support lob --- drainer/translator/mysql.go | 47 ++++++--- drainer/translator/oracle.go | 4 + go.mod | 5 +- go.sum | 10 +- pkg/loader/executor.go | 48 ++++----- pkg/loader/executor_test.go | 8 +- pkg/loader/model.go | 182 ++++++++++------------------------- pkg/loader/model_test.go | 144 +++++++++++---------------- pkg/loader/util.go | 17 +++- 9 files changed, 191 insertions(+), 274 deletions(-) diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index 4fa7ceff5..19993bcc1 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -31,6 +31,8 @@ import ( const implicitColID = -1 +var destDBType = "tidb" + func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (names []string, args []interface{}, err error) { columns := writableColumns(table) @@ -150,10 +152,11 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi } dml := &loader.DML{ - Tp: loader.InsertDMLType, - Database: schema, - Table: table, - Values: make(map[string]interface{}), + Tp: loader.InsertDMLType, + Database: schema, + Table: table, + Values: make(map[string]interface{}), + DestDBType: destDBType, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -166,11 +169,12 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi } dml := &loader.DML{ - Tp: loader.UpdateDMLType, - Database: schema, - Table: table, - Values: make(map[string]interface{}), - OldValues: make(map[string]interface{}), + Tp: loader.UpdateDMLType, + Database: schema, + Table: table, + Values: make(map[string]interface{}), + OldValues: make(map[string]interface{}), + DestDBType: destDBType, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -185,10 +189,11 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi } dml := &loader.DML{ - Tp: loader.DeleteDMLType, - Database: schema, - Table: table, - Values: make(map[string]interface{}), + Tp: loader.DeleteDMLType, + Database: schema, + Table: table, + Values: make(map[string]interface{}), + DestDBType: destDBType, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -251,7 +256,13 @@ func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { } switch ft.Tp { - case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeTimestamp, mysql.TypeDuration, mysql.TypeNewDecimal, mysql.TypeJSON: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeNewDate, mysql.TypeTimestamp, mysql.TypeNewDecimal, mysql.TypeJSON: + data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) + case mysql.TypeDuration: + //only for oracle db + if destDBType == "oracle" { + return data, errors.New("unsupport column type[time]") + } data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) case mysql.TypeEnum: data = types.NewDatum(data.GetMysqlEnum().Value) @@ -264,6 +275,14 @@ func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { return types.Datum{}, err } data = types.NewUintDatum(val) + case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: + //only for oracle db + if destDBType == "oracle" { + stype := types.TypeToStr(ft.Tp, ft.Charset) + if stype == "blob" || stype == "tinyblob" || stype == "mediumblob" || stype == "longblob" { + data = types.NewBytesDatum(data.GetBytes()) + } + } } return data, nil diff --git a/drainer/translator/oracle.go b/drainer/translator/oracle.go index 7d6625557..2dc7ba0de 100644 --- a/drainer/translator/oracle.go +++ b/drainer/translator/oracle.go @@ -14,6 +14,7 @@ import ( // TiBinlogToOracleTxn translate the format to loader.Txn func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string, tiBinlog *tipb.Binlog, pv *tipb.PrewriteValue, shouldSkip bool, tableRouter *router.Table) (txn *loader.Txn, err error) { + destDBType = "oracle" txn = new(loader.Txn) if tiBinlog.DdlJobId > 0 { @@ -77,6 +78,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Table: downStreamTable, Values: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], + DestDBType: destDBType, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -95,6 +97,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Values: make(map[string]interface{}), OldValues: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], + DestDBType: destDBType, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -114,6 +117,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Table: downStreamTable, Values: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], + DestDBType: destDBType, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { diff --git a/go.mod b/go.mod index e360927c0..23a4a558d 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/Shopify/sarama v1.30.0 github.com/dustin/go-humanize v1.0.0 github.com/go-sql-driver/mysql v1.6.0 - github.com/godror/godror v0.29.0 + github.com/godror/godror v0.33.0 github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.2 @@ -71,7 +71,8 @@ require ( github.com/eapache/queue v1.1.0 // indirect github.com/fatih/color v1.13.0 // indirect github.com/form3tech-oss/jwt-go v3.2.5+incompatible // indirect - github.com/go-logfmt/logfmt v0.5.0 // indirect + github.com/go-logfmt/logfmt v0.5.1 // indirect + github.com/go-logr/logr v1.2.3 // indirect github.com/go-ole/go-ole v1.2.6 // indirect github.com/godror/knownpb v0.1.0 // indirect github.com/golang/glog v1.0.0 // indirect diff --git a/go.sum b/go.sum index 540b47ad5..86a38bcd1 100644 --- a/go.sum +++ b/go.sum @@ -76,7 +76,6 @@ github.com/Shopify/sarama v1.30.0 h1:TOZL6r37xJBDEMLx4yjB77jxbZYXPaDow08TSK6vIL0 github.com/Shopify/sarama v1.30.0/go.mod h1:zujlQQx1kzHsh4jfV1USnptCQrHAEZ2Hk8fTKCulPVs= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae h1:ePgznFqEG1v3AjMklnK8H7BSc++FDSo7xfK9K7Af+0Y= github.com/Shopify/toxiproxy/v2 v2.1.6-0.20210914104332-15ea381dcdae/go.mod h1:/cvHQkZ1fst0EmZnA5dFtiQdWCNCFYzb+uE2vqVgvx0= -github.com/UNO-SOFT/knownpb v0.0.2/go.mod h1:p80FhK7Efqtw1I44+KdbwHKT2Fg2KluTHKtkGN8YXfE= github.com/VividCortex/ewma v1.1.1 h1:MnEK4VOv6n0RSY4vtRe3h11qjxL3+t0B8yOL8iMXdcM= github.com/VividCortex/ewma v1.1.1/go.mod h1:2Tkkvm3sRDVXaiyucHiACn4cqf7DpdyLvmxzcbUokwA= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -218,8 +217,11 @@ github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2 github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= -github.com/go-logfmt/logfmt v0.5.0 h1:TrB8swr/68K7m9CcGut2g3UOihhbcbiMAYiuTXdEih4= github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-logfmt/logfmt v0.5.1 h1:otpy5pqBCBZ1ng9RQ0dPu4PN7ba75Y/aA+UpowDyNVA= +github.com/go-logfmt/logfmt v0.5.1/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs= +github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= +github.com/go-logr/logr v1.2.3/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= @@ -228,8 +230,8 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/godror/godror v0.29.0 h1:J5PiWMy7glh4cZnExYk5ryAYx0c972YQUavh/ml+wlM= -github.com/godror/godror v0.29.0/go.mod h1:dwNYusI/Ug2JlbJuVvQQMhzlxVEJeq+MwaXwTYlDyC8= +github.com/godror/godror v0.33.0 h1:ZK1W7GohHVDPoLp/37U9QCSHARnYB4vVxNJya+CyWQ4= +github.com/godror/godror v0.33.0/go.mod h1:qHYnDISFm/h0vM+HDwg0LpyoLvxRKFRSwvhYF7ufjZ8= github.com/godror/knownpb v0.1.0 h1:dJPK8s/I3PQzGGaGcUStL2zIaaICNzKKAK8BzP1uLio= github.com/godror/knownpb v0.1.0/go.mod h1:4nRFbQo1dDuwKnblRXDxrfCFYeT4hjg3GjMqef58eRE= github.com/gogo/protobuf v0.0.0-20171007142547-342cbe0a0415/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go index 226e12d92..571f5231e 100644 --- a/pkg/loader/executor.go +++ b/pkg/loader/executor.go @@ -109,15 +109,7 @@ type tx struct { // wrap of sql.Tx.Exec() func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { start := time.Now() - var ( - res gosql.Result - err error - ) - if len(args) == 0 { - res, err = tx.Tx.Exec(query) - } else { - res, err = tx.Tx.Exec(query, args...) - } + res, err := tx.Tx.Exec(query, args...) if tx.queryHistogramVec != nil { tx.queryHistogramVec.WithLabelValues("exec").Observe(time.Since(start).Seconds()) } @@ -126,11 +118,7 @@ func (tx *tx) exec(query string, args ...interface{}) (gosql.Result, error) { } func (tx *tx) autoRollbackExec(query string, args ...interface{}) (res gosql.Result, err error) { - if len(args) == 0 { - res, err = tx.exec(query) - } else { - res, err = tx.exec(query, args...) - } + res, err = tx.exec(query, args...) if err != nil { log.Error("Exec fail, will rollback", zap.String("query", query), zap.Reflect("args", args), zap.Error(err)) if rbErr := tx.Rollback(); rbErr != nil { @@ -225,10 +213,10 @@ func (e *executor) bulkReplace(inserts []*DML) error { var builder strings.Builder - cols := "(" + buildColumnList(info.columns) + ")" + cols := "(" + buildColumnList(info.columns, e.destDBType) + ")" builder.WriteString("REPLACE INTO " + inserts[0].TableName() + cols + " VALUES ") - holder := fmt.Sprintf("(%s)", holderString(len(info.columns))) + holder := fmt.Sprintf("(%s)", holderString(len(info.columns), e.destDBType)) for i := 0; i < len(inserts); i++ { if i > 0 { builder.WriteByte(',') @@ -265,8 +253,8 @@ func (e *executor) oracleBulkOperation(dmls []*DML) error { return errors.Trace(err) } for _, dml := range dmls { - sql := dml.oracleSQL() - _, err = tx.autoRollbackExec(sql) + sql, args := dml.sql() + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } @@ -463,43 +451,43 @@ func (e *executor) singleOracleExec(dmls []*DML, safeMode bool) error { for _, dml := range dmls { if safeMode && dml.Tp == UpdateDMLType { //delete old row - sql := dml.oracleDeleteSQL() + sql, args := dml.deleteSQL() log.Debug("safeMode and UpdateDMLType", zap.String("delete old", sql)) - _, err := tx.autoRollbackExec(sql) + _, err := tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } //delete new row - sql = dml.oracleDeleteNewValueSQL() + sql, args = dml.oracleDeleteNewValueSQL() log.Debug("safeMode and UpdateDMLType", zap.String("delete new old", sql)) - _, err = tx.autoRollbackExec(sql) + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } //insert new row - sql = dml.oracleInsertSQL() + sql, args = dml.insertSQL() log.Debug("safeMode and UpdateDMLType", zap.String("insert new old", sql)) - _, err = tx.autoRollbackExec(sql) + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } } else if safeMode && dml.Tp == InsertDMLType { - sql := dml.oracleDeleteSQL() + sql, args := dml.deleteSQL() log.Debug("safeMode and InsertDMLType", zap.String("delete sql", sql)) - _, err := tx.autoRollbackExec(sql) + _, err := tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } - sql = dml.oracleInsertSQL() + sql, args = dml.insertSQL() log.Debug("safeMode and InsertDMLType", zap.String("insert sql", sql)) - _, err = tx.autoRollbackExec(sql) + _, err = tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } } else { - sql := dml.oracleSQL() + sql, args := dml.sql() log.Debug("normal sql with no safeMode", zap.String("sql", sql)) - _, err := tx.autoRollbackExec(sql) + _, err := tx.autoRollbackExec(sql, args...) if err != nil { return errors.Trace(err) } diff --git a/pkg/loader/executor_test.go b/pkg/loader/executor_test.go index 677884d73..c99f89b2a 100644 --- a/pkg/loader/executor_test.go +++ b/pkg/loader/executor_test.go @@ -17,11 +17,12 @@ import ( "context" "database/sql" "fmt" + "regexp" + "sync/atomic" + "github.com/pingcap/tidb/parser/model" tmysql "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/types" - "regexp" - "sync/atomic" sqlmock "github.com/DATA-DOG/go-sqlmock" "github.com/go-sql-driver/mysql" @@ -247,6 +248,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, + DestDBType: "oracle", } delSQL := "DELETE FROM unicorn.users.*" insertSQL := "INSERT INTO unicorn.users.*" @@ -319,6 +321,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, + DestDBType: "oracle", } delSQL := "DELETE FROM unicorn.users.*" insertSQL := "INSERT INTO unicorn.users.*" @@ -378,6 +381,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, + DestDBType: "oracle", } delSQL := "DELETE FROM unicorn.users.*" diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 06aa86546..d45255a02 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -19,10 +19,8 @@ import ( "strconv" "strings" - "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/log" + "github.com/pingcap/tidb/parser/model" "go.uber.org/zap" ) @@ -50,6 +48,8 @@ type DML struct { info *tableInfo UpColumnsInfoMap map[string]*model.ColumnInfo + + DestDBType string } // DDL holds the ddl info @@ -167,94 +167,75 @@ func (dml *DML) oldPrimaryKeyValues() []interface{} { // TableName returns the fully qualified name of the DML's table func (dml *DML) TableName() string { + if dml.DestDBType == "oracle" { + return fmt.Sprintf("%s.%s", dml.Database, dml.Table) + } return quoteSchema(dml.Database, dml.Table) } -// OracleTableName returns the fully qualified name of the DML's table in oracle db -func (dml *DML) OracleTableName() string { - return fmt.Sprintf("%s.%s", dml.Database, dml.Table) -} - func (dml *DML) updateSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "UPDATE %s SET ", dml.TableName()) - + colName := "" + oracleHolderPos := 1 for _, name := range dml.columnNames() { if len(args) > 0 { builder.WriteByte(',') } arg := dml.Values[name] - fmt.Fprintf(builder, "%s = ?", quoteName(name)) - args = append(args, arg) - } - - builder.WriteString(" WHERE ") - - whereArgs := dml.buildWhere(builder) - args = append(args, whereArgs...) - - builder.WriteString(" LIMIT 1") - sql = builder.String() - return -} - -func (dml *DML) oracleUpdateSQL() (sql string) { - builder := new(strings.Builder) - - fmt.Fprintf(builder, "UPDATE %s SET ", dml.OracleTableName()) - - for i, name := range dml.columnNames() { - if i > 0 { - builder.WriteByte(',') + if dml.DestDBType == "oracle" { + colName = escapeName(name) + } else { + colName = quoteName(name) } - value := dml.Values[name] - if value == nil { - fmt.Fprintf(builder, "%s = NULL", escapeName(name)) + if dml.DestDBType == "oracle" { + fmt.Fprintf(builder, "%s = :%d", colName, oracleHolderPos) + oracleHolderPos++ } else { - fmt.Fprintf(builder, "%s = %s", escapeName(name), genOracleValue(dml.UpColumnsInfoMap[name], value)) + fmt.Fprintf(builder, "%s = ?", colName) } + args = append(args, arg) } builder.WriteString(" WHERE ") - dml.buildOracleWhere(builder) - builder.WriteString(" AND rownum <=1") - + whereArgs := dml.buildWhere(builder, oracleHolderPos) + args = append(args, whereArgs...) + if dml.DestDBType == "oracle" { + builder.WriteString(" AND rownum <=1") + } else { + builder.WriteString(" LIMIT 1") + } sql = builder.String() return } -func (dml *DML) buildWhere(builder *strings.Builder) (args []interface{}) { +func (dml *DML) buildWhere(builder *strings.Builder, oracleHolderPos int) (args []interface{}) { wnames, wargs := dml.whereSlice() - for i := 0; i < len(wnames); i++ { + for i, pOracleHolderPos := 0, oracleHolderPos; i < len(wnames); i++ { if i > 0 { builder.WriteString(" AND ") } if wargs[i] == nil { - builder.WriteString(quoteName(wnames[i]) + " IS NULL") + if dml.DestDBType == "oracle" { + builder.WriteString(escapeName(wnames[i]) + " IS NULL") + } else { + builder.WriteString(quoteName(wnames[i]) + " IS NULL") + } } else { - builder.WriteString(quoteName(wnames[i]) + " = ?") + if dml.DestDBType == "oracle" { + builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos)) + pOracleHolderPos++ + } else { + builder.WriteString(quoteName(wnames[i]) + " = ?") + } args = append(args, wargs[i]) } } return } -func (dml *DML) buildOracleWhere(builder *strings.Builder) { - colNames, colValues := dml.whereSlice() - for i := 0; i < len(colNames); i++ { - if i > 0 { - builder.WriteString(" AND ") - } - if colValues[i] == nil { - builder.WriteString(escapeName(colNames[i]) + " IS NULL") - } else { - builder.WriteString(fmt.Sprintf("%s = %s", escapeName(colNames[i]), genOracleValue(dml.UpColumnsInfoMap[colNames[i]], colValues[i]))) - } - } -} - func (dml *DML) whereValues(names []string) (values []interface{}) { valueMap := dml.Values if dml.Tp == UpdateDMLType { @@ -293,26 +274,21 @@ func (dml *DML) deleteSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) - args = dml.buildWhere(builder) - builder.WriteString(" LIMIT 1") - - sql = builder.String() - return -} + args = dml.buildWhere(builder, 1) -func (dml *DML) oracleDeleteSQL() (sql string) { - builder := new(strings.Builder) + if dml.DestDBType == "oracle" { + builder.WriteString(" AND rownum <=1") + } else { + builder.WriteString(" LIMIT 1") + } - fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.OracleTableName()) - dml.buildOracleWhere(builder) - builder.WriteString(" AND rownum <=1") sql = builder.String() return } -func (dml *DML) oracleDeleteNewValueSQL() (sql string) { +func (dml *DML) oracleDeleteNewValueSQL() (sql string, args []interface{}) { builder := new(strings.Builder) - fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.OracleTableName()) + fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) valueMap := dml.Values colNames := make([]string, 0) @@ -343,14 +319,16 @@ func (dml *DML) oracleDeleteNewValueSQL() (sql string) { } } - for i := 0; i < len(colNames); i++ { + for i, oracleHolderPos := 0, 1; i < len(colNames); i++ { if i > 0 { builder.WriteString(" AND ") } if colValues[i] == nil { builder.WriteString(escapeName(colNames[i]) + " IS NULL") } else { - builder.WriteString(fmt.Sprintf("%s = %s", colNames[i], genOracleValue(dml.UpColumnsInfoMap[colNames[i]], colValues[i]))) + builder.WriteString(fmt.Sprintf("%s = :%d", colNames[i], oracleHolderPos)) + oracleHolderPos++ + args = append(args, colValues[i]) } } builder.WriteString(" AND rownum <=1") @@ -371,7 +349,7 @@ func (dml *DML) columnNames() []string { func (dml *DML) replaceSQL() (sql string, args []interface{}) { names := dml.columnNames() - sql = fmt.Sprintf("REPLACE INTO %s(%s) VALUES(%s)", dml.TableName(), buildColumnList(names), holderString(len(names))) + sql = fmt.Sprintf("REPLACE INTO %s(%s) VALUES(%s)", dml.TableName(), buildColumnList(names, dml.DestDBType), holderString(len(names), dml.DestDBType)) for _, name := range names { v := dml.Values[name] args = append(args, v) @@ -385,23 +363,6 @@ func (dml *DML) insertSQL() (sql string, args []interface{}) { return } -func (dml *DML) oracleInsertSQL() (sql string) { - builder := new(strings.Builder) - columns, values := dml.buildOracleInsertColAndValue() - fmt.Fprintf(builder, "INSERT INTO %s (%s) VALUES (%s)", dml.OracleTableName(), columns, values) - sql = builder.String() - return -} - -func (dml *DML) buildOracleInsertColAndValue() (string, string) { - names := dml.columnNames() - values := make([]string, 0, len(dml.Values)) - for _, name := range names { - values = append(values, genOracleValue(dml.UpColumnsInfoMap[name], dml.Values[name])) - } - return strings.Join(names, ", "), strings.Join(values, ", ") -} - func (dml *DML) sql() (sql string, args []interface{}) { switch dml.Tp { case InsertDMLType: @@ -417,21 +378,6 @@ func (dml *DML) sql() (sql string, args []interface{}) { return } -func (dml *DML) oracleSQL() (sql string) { - switch dml.Tp { - case InsertDMLType: - return dml.oracleInsertSQL() - case UpdateDMLType: - return dml.oracleUpdateSQL() - case DeleteDMLType: - return dml.oracleDeleteSQL() - } - - log.Debug("get sql for dml", zap.Reflect("dml", dml), zap.String("sql", sql)) - - return -} - func formatKey(values []interface{}) string { builder := new(strings.Builder) for i, v := range values { @@ -498,31 +444,3 @@ func getKeys(dml *DML) (keys []string) { return } - -func genOracleValue(column *model.ColumnInfo, value interface{}) string { - if value == nil { - return "NULL" - } - switch column.Tp { - case mysql.TypeDate: - return fmt.Sprintf("TO_DATE('%v', 'yyyy-mm-dd')", value) - case mysql.TypeDatetime: - if column.Decimal == 0 { - return fmt.Sprintf("TO_DATE('%v', 'yyyy-mm-dd hh24:mi:ss')", value) - } - return fmt.Sprintf("TO_TIMESTAMP('%v', 'yyyy-mm-dd hh24:mi:ss.ff%d')", value, column.Decimal) - case mysql.TypeTimestamp: - return fmt.Sprintf("TO_TIMESTAMP('%s', 'yyyy-mm-dd hh24:mi:ss.ff%d')", value, column.Decimal) - case mysql.TypeDuration: - return fmt.Sprintf("TO_DATE('%s', 'hh24:mi:ss')", value) - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeInt24, - mysql.TypeYear, mysql.TypeFloat, mysql.TypeDouble, mysql.TypeNewDecimal: - return fmt.Sprintf("%v", value) - default: - return fmt.Sprintf("'%s'", processOracleQuoteStringValue(fmt.Sprintf("%v", value))) - } -} - -func processOracleQuoteStringValue(data string) string { - return strings.ReplaceAll(data, "'", "''") -} diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go index 3747d3b57..214ac44ec 100644 --- a/pkg/loader/model_test.go +++ b/pkg/loader/model_test.go @@ -47,6 +47,7 @@ func getDML(key bool, tp DMLType) *DML { dml.Database = "test" dml.Table = "test" dml.Tp = tp + dml.DestDBType = "tidb" return dml } @@ -76,7 +77,7 @@ func (d *dmlSuite) testWhere(c *check.C, tp DMLType) { c.Assert(args, check.DeepEquals, []interface{}{1}) builder := new(strings.Builder) - args = dml.buildWhere(builder) + args = dml.buildWhere(builder, 0) c.Assert(args, check.DeepEquals, []interface{}{1}) c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) @@ -94,14 +95,14 @@ func (d *dmlSuite) testWhere(c *check.C, tp DMLType) { c.Assert(args, check.DeepEquals, []interface{}{1, 1}) builder.Reset() - args = dml.buildWhere(builder) + args = dml.buildWhere(builder, 0) c.Assert(args, check.DeepEquals, []interface{}{1, 1}) c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) // set a1 to NULL value values["a1"] = nil builder.Reset() - args = dml.buildWhere(builder) + args = dml.buildWhere(builder, 0) c.Assert(args, check.DeepEquals, []interface{}{1}) c.Assert(strings.Count(builder.String(), "?"), check.Equals, len(args)) } @@ -188,6 +189,7 @@ func (s *SQLSuite) TestInsertSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, + DestDBType: "tidb", } sql, args := dml.sql() c.Assert(sql, check.Equals, "INSERT INTO `test`.`hello`(`age`,`name`) VALUES(?,?)") @@ -208,6 +210,7 @@ func (s *SQLSuite) TestDeleteSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, + DestDBType: "tidb", } sql, args := dml.sql() c.Assert( @@ -232,6 +235,7 @@ func (s *SQLSuite) TestUpdateSQL(c *check.C) { info: &tableInfo{ columns: []string{"name"}, }, + DestDBType: "tidb", } sql, args := dml.sql() c.Assert( @@ -289,11 +293,17 @@ func (s *SQLSuite) TestOracleUpdateSQL(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "UPDATE db.tbl SET ID = 123,NAME = 'pc' WHERE ID = 123 AND NAME = 'pingcap' AND rownum <=1") + "UPDATE db.tbl SET ID = :1,NAME = :2 WHERE ID = :3 AND NAME = :4 AND rownum <=1") + c.Assert(args, check.HasLen, 4) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") + c.Assert(args[2], check.Equals, 123) + c.Assert(args[3], check.Equals, "pingcap") } func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) { @@ -328,11 +338,16 @@ func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "UPDATE db.tbl SET ID = 123,NAME = 'pc' WHERE ID = 123 AND rownum <=1") + "UPDATE db.tbl SET ID = :1,NAME = :2 WHERE ID = :3 AND rownum <=1") + c.Assert(args, check.HasLen, 3) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") + c.Assert(args[2], check.Equals, 123) } func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) { @@ -353,11 +368,15 @@ func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE ID = 123 AND NAME = 'pc' AND rownum <=1") + "DELETE FROM db.tbl WHERE ID = :1 AND NAME = :2 AND rownum <=1") + c.Assert(args, check.HasLen, 2) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") } func (s *SQLSuite) TestOracleInsertSQL(c *check.C) { @@ -381,78 +400,16 @@ func (s *SQLSuite) TestOracleInsertSQL(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleSQL() + sql, args := dml.sql() c.Assert( sql, check.Equals, - "INSERT INTO db.tbl (C2, ID, NAME) VALUES (NULL, 123, 'pc')") -} - -func (s *SQLSuite) TestGenOracleValue(c *check.C) { - columnInfo := model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDate}, - } - colVaue := "2021-09-13" - val := genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_DATE('2021-09-13', 'yyyy-mm-dd')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDatetime, Decimal: 0}, - } - colVaue = "2021-09-13 10:10:23" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_DATE('2021-09-13 10:10:23', 'yyyy-mm-dd hh24:mi:ss')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDatetime, Decimal: 6}, - } - colVaue = "2021-09-13 10:10:23.123456" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_TIMESTAMP('2021-09-13 10:10:23.123456', 'yyyy-mm-dd hh24:mi:ss.ff6')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeTimestamp, Decimal: 5}, - } - colVaue = "2021-09-13 10:10:23.12345" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, - "TO_TIMESTAMP('2021-09-13 10:10:23.12345', 'yyyy-mm-dd hh24:mi:ss.ff5')") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeYear}, - } - colVaue = "2021" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, "2021") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeVarchar}, - } - colVaue = "2021" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, "'2021'") - - columnInfo = model.ColumnInfo{ - FieldType: types.FieldType{Tp: mysql.TypeDuration}, - } - colVaue = "23:11:59" - val = genOracleValue(&columnInfo, colVaue) - c.Assert( - val, check.Equals, "TO_DATE('23:11:59', 'hh24:mi:ss')") - - var colVaue2 interface{} - val = genOracleValue(&columnInfo, colVaue2) - c.Assert( - val, check.Equals, "NULL") + "INSERT INTO db.tbl(C2,ID,NAME) VALUES(:1,:2,:3)") + c.Assert(args, check.HasLen, 3) + c.Assert(args[0], check.Equals, nil) + c.Assert(args[1], check.Equals, 123) + c.Assert(args[2], check.Equals, "pc") } func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { @@ -482,12 +439,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleDeleteNewValueSQL() + sql, args := dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE ID = 123 AND rownum <=1") + "DELETE FROM db.tbl WHERE ID = :1 AND rownum <=1") + c.Assert(args, check.HasLen, 1) + c.Assert(args[0], check.Equals, 123) // column in UK have nil value, so fall back to all columns dml = DML{ @@ -516,11 +476,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql = dml.oracleDeleteNewValueSQL() + sql, args = dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = 123 AND NAME = 'pc' AND rownum <=1") + "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = :1 AND NAME = :2 AND rownum <=1") + c.Assert(args, check.HasLen, 2) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "pc") } func (s *SQLSuite) TestOracleDeleteNewValueSQLWithMultiUK(c *check.C) { @@ -557,12 +521,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithMultiUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleDeleteNewValueSQL() + sql, args := dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE ID2 = '456' AND rownum <=1") + "DELETE FROM db.tbl WHERE ID2 = :1 AND rownum <=1") + c.Assert(args, check.HasLen, 1) + c.Assert(args[0], check.Equals, "456") } func (s *SQLSuite) TestOracleDeleteNewValueSQLWithNoUK(c *check.C) { @@ -589,10 +556,15 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithNoUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, + DestDBType: "oracle", } - sql := dml.oracleDeleteNewValueSQL() + sql, args := dml.oracleDeleteNewValueSQL() c.Assert( sql, check.Equals, - "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = 123 AND ID2 = '456' AND NAME = 'pc' AND rownum <=1") + "DELETE FROM db.tbl WHERE C2 IS NULL AND ID = :1 AND ID2 = :2 AND NAME = :3 AND rownum <=1") + c.Assert(args, check.HasLen, 3) + c.Assert(args[0], check.Equals, 123) + c.Assert(args[1], check.Equals, "456") + c.Assert(args[2], check.Equals, "pc") } diff --git a/pkg/loader/util.go b/pkg/loader/util.go index fc54c5224..5e9036460 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -230,6 +230,7 @@ func CreateOracleDB(user string, password string, host string, port int, service Timezone: loc, }, } + oraDSN.OnInitStmts = []string{"ALTER SESSION SET NLS_DATE_FORMAT='YYYY-MM-DD HH24.MI.SS' NLS_TIMESTAMP_FORMAT='YYYY-MM-DD HH24.MI.SS.FF' NLS_TIMESTAMP_TZ_FORMAT='YYYY-MM-DD HH24.MI.SS.FF TZR' NLS_TIME_FORMAT='HH24.MI.SS.FF' NLS_TIME_TZ_FORMAT='HH24.MI.SS.FF TZR'"} sqlDB := gosql.OpenDB(godror.NewConnector(oraDSN)) err = sqlDB.Ping() if err != nil { @@ -250,13 +251,17 @@ func escapeName(name string) string { return strings.Replace(name, "`", "``", -1) } -func holderString(n int) string { +func holderString(n int, destDBType string) string { builder := new(strings.Builder) for i := 0; i < n; i++ { if i > 0 { builder.WriteString(",") } - builder.WriteString("?") + if destDBType == "oracle" { + builder.WriteString(":" + strconv.Itoa(i+1)) + } else { + builder.WriteString("?") + } } return builder.String() } @@ -277,13 +282,17 @@ func splitDMLs(dmls []*DML, size int) (res [][]*DML) { return } -func buildColumnList(names []string) string { +func buildColumnList(names []string, destDBType string) string { var b strings.Builder for i, name := range names { if i > 0 { b.WriteString(",") } - b.WriteString(quoteName(name)) + if destDBType == "oracle" { + b.WriteString(escapeName(name)) + } else { + b.WriteString(quoteName(name)) + } } From 6481ff77c26c251258cbad3fcdb13b6dc26310af Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Mon, 25 Apr 2022 18:33:45 +0800 Subject: [PATCH 2/7] fix comment --- drainer/sync/mysql.go | 9 ++++++--- drainer/translator/mysql.go | 21 +++++++++++++-------- drainer/translator/oracle.go | 2 +- pkg/loader/executor.go | 10 +++++----- pkg/loader/executor_test.go | 24 ++++++++++++------------ pkg/loader/load.go | 14 +++++++------- pkg/loader/load_test.go | 6 +++--- pkg/loader/model.go | 29 ++++++++++++++--------------- pkg/loader/model_test.go | 24 ++++++++++++------------ pkg/loader/util.go | 8 ++++---- 10 files changed, 77 insertions(+), 70 deletions(-) diff --git a/drainer/sync/mysql.go b/drainer/sync/mysql.go index 6186a390f..ce4369096 100644 --- a/drainer/sync/mysql.go +++ b/drainer/sync/mysql.go @@ -57,10 +57,13 @@ func CreateLoader( enableDispatch bool, enableCausility bool, ) (ld loader.Loader, err error) { - + destDBTypeInt := loader.MysqlDB + if destDBType == "oracle" { + destDBTypeInt = loader.OracleDB + } var opts []loader.Option - opts = append(opts, loader.DestinationDBType(destDBType), loader.WorkerCount(worker), loader.BatchSize(batchSize), - loader.SaveAppliedTS(destDBType == "tidb" || destDBType == "oracle"), loader.SetloopBackSyncInfo(info)) + opts = append(opts, loader.DestinationDBType(destDBTypeInt), loader.WorkerCount(worker), loader.BatchSize(batchSize), + loader.SaveAppliedTS(destDBTypeInt == loader.MysqlDB || destDBTypeInt == loader.OracleDB), loader.SetloopBackSyncInfo(info)) if queryHistogramVec != nil { opts = append(opts, loader.Metrics(&loader.MetricsGroup{ QueryHistogramVec: queryHistogramVec, diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index 19993bcc1..c079da965 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -31,7 +31,7 @@ import ( const implicitColID = -1 -var destDBType = "tidb" +var destDBType = loader.MysqlDB func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (names []string, args []interface{}, err error) { columns := writableColumns(table) @@ -260,8 +260,8 @@ func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) case mysql.TypeDuration: //only for oracle db - if destDBType == "oracle" { - return data, errors.New("unsupport column type[time]") + if destDBType == loader.OracleDB { + return data, errors.New("unsupported column type[time]") } data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) case mysql.TypeEnum: @@ -277,13 +277,18 @@ func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { data = types.NewUintDatum(val) case mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: //only for oracle db - if destDBType == "oracle" { - stype := types.TypeToStr(ft.Tp, ft.Charset) - if stype == "blob" || stype == "tinyblob" || stype == "mediumblob" || stype == "longblob" { - data = types.NewBytesDatum(data.GetBytes()) - } + if destDBType == loader.OracleDB && isBlob(ft) { + data = types.NewBytesDatum(data.GetBytes()) } } return data, nil } + +func isBlob(ft types.FieldType) bool { + stype := types.TypeToStr(ft.Tp, ft.Charset) + if stype == "blob" || stype == "tinyblob" || stype == "mediumblob" || stype == "longblob" { + return true + } + return false +} diff --git a/drainer/translator/oracle.go b/drainer/translator/oracle.go index 2dc7ba0de..5cf2e2a20 100644 --- a/drainer/translator/oracle.go +++ b/drainer/translator/oracle.go @@ -14,7 +14,7 @@ import ( // TiBinlogToOracleTxn translate the format to loader.Txn func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string, tiBinlog *tipb.Binlog, pv *tipb.PrewriteValue, shouldSkip bool, tableRouter *router.Table) (txn *loader.Txn, err error) { - destDBType = "oracle" + destDBType = loader.OracleDB txn = new(loader.Txn) if tiBinlog.DdlJobId > 0 { diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go index 571f5231e..3e660e984 100644 --- a/pkg/loader/executor.go +++ b/pkg/loader/executor.go @@ -43,7 +43,7 @@ var ( type executor struct { db *gosql.DB - destDBType string + destDBType int batchSize int workerCount int info *loopbacksync.LoopBackSync @@ -70,7 +70,7 @@ func (e *executor) withRefreshTableInfo(fn func(schema string, table string) (in return e } -func (e *executor) withDestDBType(destDBType string) *executor { +func (e *executor) withDestDBType(destDBType int) *executor { e.destDBType = destDBType return e } @@ -284,7 +284,7 @@ func (e *executor) execTableBatch(ctx context.Context, dmls []*DML) error { if allDeletes, ok := types[DeleteDMLType]; ok { bulkDelete := e.bulkDelete - if e.destDBType == "oracle" { + if e.destDBType == OracleDB { bulkDelete = e.oracleBulkOperation } if err := e.splitExecDML(ctx, allDeletes, bulkDelete); err != nil { @@ -294,7 +294,7 @@ func (e *executor) execTableBatch(ctx context.Context, dmls []*DML) error { if allInserts, ok := types[InsertDMLType]; ok { bulkInsert := e.bulkReplace - if e.destDBType == "oracle" { + if e.destDBType == OracleDB { bulkInsert = e.oracleBulkOperation } if err := e.splitExecDML(ctx, allInserts, bulkInsert); err != nil { @@ -304,7 +304,7 @@ func (e *executor) execTableBatch(ctx context.Context, dmls []*DML) error { if allUpdates, ok := types[UpdateDMLType]; ok { bulkUpdate := e.bulkReplace - if e.destDBType == "oracle" { + if e.destDBType == OracleDB { bulkUpdate = e.oracleBulkOperation } if err := e.splitExecDML(ctx, allUpdates, bulkUpdate); err != nil { diff --git a/pkg/loader/executor_test.go b/pkg/loader/executor_test.go index c99f89b2a..5d7d1c99d 100644 --- a/pkg/loader/executor_test.go +++ b/pkg/loader/executor_test.go @@ -248,7 +248,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } delSQL := "DELETE FROM unicorn.users.*" insertSQL := "INSERT INTO unicorn.users.*" @@ -257,7 +257,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e := newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err := e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -268,7 +268,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -281,7 +281,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(insertSQL).WillReturnError(errors.New("insert")) e = newExecutor(s.db) err = e.singleOracleExec([]*DML{&dml}, true) - e.destDBType = "oracle" + e.destDBType = OracleDB c.Assert(err, ErrorMatches, "insert") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) s.resetMock(c) @@ -293,7 +293,7 @@ func (s *singleExecSuite) TestOracleSafeUpdate(c *C) { s.dbMock.ExpectExec(insertSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectCommit() e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, IsNil) c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -321,7 +321,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } delSQL := "DELETE FROM unicorn.users.*" insertSQL := "INSERT INTO unicorn.users.*" @@ -331,7 +331,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e := newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err := e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -342,7 +342,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectExec(insertSQL).WillReturnError(errors.New("insert")) e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "insert") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -354,7 +354,7 @@ func (s *singleExecSuite) TestOracleSafeInsert(c *C) { s.dbMock.ExpectExec(insertSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectCommit() e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, IsNil) c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -381,7 +381,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { "age": { FieldType: types.FieldType{Tp: tmysql.TypeInt24}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } delSQL := "DELETE FROM unicorn.users.*" @@ -390,7 +390,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnError(errors.New("del")) e := newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err := e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, ErrorMatches, "del") c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) @@ -401,7 +401,7 @@ func (s *singleExecSuite) TestOracleSafeDelete(c *C) { s.dbMock.ExpectExec(delSQL).WillReturnResult(sqlmock.NewResult(1, 1)) s.dbMock.ExpectCommit() e = newExecutor(s.db) - e.destDBType = "oracle" + e.destDBType = OracleDB err = e.singleOracleExec([]*DML{&dml}, true) c.Assert(err, IsNil) c.Assert(s.dbMock.ExpectationsWereMet(), IsNil) diff --git a/pkg/loader/load.go b/pkg/loader/load.go index 3f10a693a..bb5f1a9f0 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -68,7 +68,7 @@ type loaderImpl struct { // like column name, pk & uk db *gosql.DB //downStream db type, mysql,tidb,oracle - destDBType string + destDBType int // only set for test getTableInfoFromDB func(db *gosql.DB, schema string, table string) (info *tableInfo, err error) opts options @@ -130,7 +130,7 @@ type options struct { enableDispatch bool enableCausality bool merge bool - destDBType string + destDBType int } var defaultLoaderOptions = options{ @@ -143,7 +143,7 @@ var defaultLoaderOptions = options{ enableDispatch: true, enableCausality: true, merge: false, - destDBType: "tidb", + destDBType: MysqlDB, } // A Option sets options such batch size, worker count etc. @@ -195,7 +195,7 @@ func Merge(v bool) Option { } //DestinationDBType set destDBType option. -func DestinationDBType(t string) Option { +func DestinationDBType(t int) Option { return func(o *options) { o.destDBType = t } @@ -259,7 +259,7 @@ func NewLoader(db *gosql.DB, opt ...Option) (Loader, error) { ctx: ctx, cancel: cancel, } - if opts.destDBType == "oracle" { + if opts.destDBType == OracleDB { s.getTableInfoFromDB = getOracleTableInfo fGetAppliedTS = getOracleAppliedTS } @@ -395,7 +395,7 @@ func (s *loaderImpl) execDDL(ddl *DDL) error { if ddl.ShouldSkip { return nil } - if s.destDBType == "oracle" { + if s.destDBType == OracleDB { return s.processOracleDDL(ddl) } return s.processMysqlDDL(ddl) @@ -750,7 +750,7 @@ func filterGeneratedCols(dml *DML) { func (s *loaderImpl) getExecutor() *executor { e := newExecutor(s.db).withBatchSize(s.batchSize).withDestDBType(s.destDBType) - if s.destDBType == "oracle" { + if s.destDBType == OracleDB { e.fTryRefreshTableErr = tryRefreshTableOracleErr e.fSingleExec = e.singleOracleExec } diff --git a/pkg/loader/load_test.go b/pkg/loader/load_test.go index e05f51588..10e75eca2 100644 --- a/pkg/loader/load_test.go +++ b/pkg/loader/load_test.go @@ -350,7 +350,7 @@ func (s *execDDLSuite) TestShouldExecInTransaction(c *check.C) { mock.ExpectExec("CREATE TABLE `t` \\(`id` INT\\)").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: "mysql"} + loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: MysqlDB} ddl := DDL{SQL: "CREATE TABLE `t` (`id` INT)"} err = loader.execDDL(&ddl) @@ -365,7 +365,7 @@ func (s *execDDLSuite) TestOracleTruncateDDL(c *check.C) { mock.ExpectExec("BEGIN test.do_truncate\\('test.t1',''\\);END;").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: "oracle"} + loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: OracleDB} ddl := DDL{SQL: "truncate table t1", Database: "test", Table: "t1"} err = loader.execDDL(&ddl) @@ -389,7 +389,7 @@ func (s *execDDLSuite) TestShouldUseDatabase(c *check.C) { mock.ExpectExec("CREATE TABLE `t` \\(`id` INT\\)").WillReturnResult(sqlmock.NewResult(0, 0)) mock.ExpectCommit() - loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: "mysql"} + loader := &loaderImpl{db: db, ctx: context.Background(), destDBType: MysqlDB} ddl := DDL{SQL: "CREATE TABLE `t` (`id` INT)", Database: "test_db"} err = loader.execDDL(&ddl) diff --git a/pkg/loader/model.go b/pkg/loader/model.go index d45255a02..553b99622 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -35,6 +35,11 @@ const ( DeleteDMLType DMLType = 3 ) +const ( + MysqlDB = iota + OracleDB +) + // DML holds the dml info type DML struct { Database string @@ -49,7 +54,7 @@ type DML struct { UpColumnsInfoMap map[string]*model.ColumnInfo - DestDBType string + DestDBType int } // DDL holds the ddl info @@ -167,7 +172,7 @@ func (dml *DML) oldPrimaryKeyValues() []interface{} { // TableName returns the fully qualified name of the DML's table func (dml *DML) TableName() string { - if dml.DestDBType == "oracle" { + if dml.DestDBType == OracleDB { return fmt.Sprintf("%s.%s", dml.Database, dml.Table) } return quoteSchema(dml.Database, dml.Table) @@ -177,23 +182,17 @@ func (dml *DML) updateSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "UPDATE %s SET ", dml.TableName()) - colName := "" oracleHolderPos := 1 for _, name := range dml.columnNames() { if len(args) > 0 { builder.WriteByte(',') } arg := dml.Values[name] - if dml.DestDBType == "oracle" { - colName = escapeName(name) - } else { - colName = quoteName(name) - } - if dml.DestDBType == "oracle" { - fmt.Fprintf(builder, "%s = :%d", colName, oracleHolderPos) + if dml.DestDBType == OracleDB { + fmt.Fprintf(builder, "%s = :%d", escapeName(name), oracleHolderPos) oracleHolderPos++ } else { - fmt.Fprintf(builder, "%s = ?", colName) + fmt.Fprintf(builder, "%s = ?", quoteName(name)) } args = append(args, arg) } @@ -202,7 +201,7 @@ func (dml *DML) updateSQL() (sql string, args []interface{}) { whereArgs := dml.buildWhere(builder, oracleHolderPos) args = append(args, whereArgs...) - if dml.DestDBType == "oracle" { + if dml.DestDBType == OracleDB { builder.WriteString(" AND rownum <=1") } else { builder.WriteString(" LIMIT 1") @@ -218,13 +217,13 @@ func (dml *DML) buildWhere(builder *strings.Builder, oracleHolderPos int) (args builder.WriteString(" AND ") } if wargs[i] == nil { - if dml.DestDBType == "oracle" { + if dml.DestDBType == OracleDB { builder.WriteString(escapeName(wnames[i]) + " IS NULL") } else { builder.WriteString(quoteName(wnames[i]) + " IS NULL") } } else { - if dml.DestDBType == "oracle" { + if dml.DestDBType == OracleDB { builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos)) pOracleHolderPos++ } else { @@ -276,7 +275,7 @@ func (dml *DML) deleteSQL() (sql string, args []interface{}) { fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) args = dml.buildWhere(builder, 1) - if dml.DestDBType == "oracle" { + if dml.DestDBType == OracleDB { builder.WriteString(" AND rownum <=1") } else { builder.WriteString(" LIMIT 1") diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go index 214ac44ec..3ee01aa6c 100644 --- a/pkg/loader/model_test.go +++ b/pkg/loader/model_test.go @@ -47,7 +47,7 @@ func getDML(key bool, tp DMLType) *DML { dml.Database = "test" dml.Table = "test" dml.Tp = tp - dml.DestDBType = "tidb" + dml.DestDBType = MysqlDB return dml } @@ -189,7 +189,7 @@ func (s *SQLSuite) TestInsertSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, - DestDBType: "tidb", + DestDBType: MysqlDB, } sql, args := dml.sql() c.Assert(sql, check.Equals, "INSERT INTO `test`.`hello`(`age`,`name`) VALUES(?,?)") @@ -210,7 +210,7 @@ func (s *SQLSuite) TestDeleteSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, - DestDBType: "tidb", + DestDBType: MysqlDB, } sql, args := dml.sql() c.Assert( @@ -235,7 +235,7 @@ func (s *SQLSuite) TestUpdateSQL(c *check.C) { info: &tableInfo{ columns: []string{"name"}, }, - DestDBType: "tidb", + DestDBType: MysqlDB, } sql, args := dml.sql() c.Assert( @@ -293,7 +293,7 @@ func (s *SQLSuite) TestOracleUpdateSQL(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.sql() c.Assert( @@ -338,7 +338,7 @@ func (s *SQLSuite) TestOracleUpdateSQLPrimaryKey(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.sql() c.Assert( @@ -368,7 +368,7 @@ func (s *SQLSuite) TestOracleDeleteSQL(c *check.C) { "NAME": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.sql() c.Assert( @@ -400,7 +400,7 @@ func (s *SQLSuite) TestOracleInsertSQL(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.sql() c.Assert( @@ -439,7 +439,7 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.oracleDeleteNewValueSQL() @@ -476,7 +476,7 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithOneUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args = dml.oracleDeleteNewValueSQL() c.Assert( @@ -521,7 +521,7 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithMultiUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.oracleDeleteNewValueSQL() @@ -556,7 +556,7 @@ func (s *SQLSuite) TestOracleDeleteNewValueSQLWithNoUK(c *check.C) { "C2": { FieldType: types.FieldType{Tp: mysql.TypeVarString}}, }, - DestDBType: "oracle", + DestDBType: OracleDB, } sql, args := dml.oracleDeleteNewValueSQL() diff --git a/pkg/loader/util.go b/pkg/loader/util.go index 5e9036460..02479fd71 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -251,13 +251,13 @@ func escapeName(name string) string { return strings.Replace(name, "`", "``", -1) } -func holderString(n int, destDBType string) string { +func holderString(n int, destDBType int) string { builder := new(strings.Builder) for i := 0; i < n; i++ { if i > 0 { builder.WriteString(",") } - if destDBType == "oracle" { + if destDBType == OracleDB { builder.WriteString(":" + strconv.Itoa(i+1)) } else { builder.WriteString("?") @@ -282,13 +282,13 @@ func splitDMLs(dmls []*DML, size int) (res [][]*DML) { return } -func buildColumnList(names []string, destDBType string) string { +func buildColumnList(names []string, destDBType int) string { var b strings.Builder for i, name := range names { if i > 0 { b.WriteString(",") } - if destDBType == "oracle" { + if destDBType == OracleDB { b.WriteString(escapeName(name)) } else { b.WriteString(quoteName(name)) From cf892d9425c36e2a363a92e9390961bc7b108a48 Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Tue, 26 Apr 2022 14:27:53 +0800 Subject: [PATCH 3/7] add comment for MysqlDB/OracleDB --- pkg/loader/model.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 553b99622..25693f63d 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -35,6 +35,7 @@ const ( DeleteDMLType DMLType = 3 ) +// Destination database type can be Mysql/Tidb or Oracle const ( MysqlDB = iota OracleDB From 34bb7258c720b6b3714671614fd2fb17e6d8d67c Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Thu, 28 Apr 2022 18:26:33 +0800 Subject: [PATCH 4/7] fix comment 2 --- drainer/sync/mysql.go | 4 +++- drainer/translator/mysql.go | 34 ++++++++++++++++------------------ drainer/translator/oracle.go | 13 ++++++------- drainer/translator/pb.go | 10 ++++++---- pkg/loader/model.go | 3 ++- 5 files changed, 33 insertions(+), 31 deletions(-) diff --git a/drainer/sync/mysql.go b/drainer/sync/mysql.go index ce4369096..d36ab5c95 100644 --- a/drainer/sync/mysql.go +++ b/drainer/sync/mysql.go @@ -57,9 +57,11 @@ func CreateLoader( enableDispatch bool, enableCausility bool, ) (ld loader.Loader, err error) { - destDBTypeInt := loader.MysqlDB + destDBTypeInt := loader.DBTypeUnknown if destDBType == "oracle" { destDBTypeInt = loader.OracleDB + } else if destDBType == "tidb" || destDBType == "mysql" { + destDBTypeInt = loader.MysqlDB } var opts []loader.Option opts = append(opts, loader.DestinationDBType(destDBTypeInt), loader.WorkerCount(worker), loader.BatchSize(batchSize), diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index c079da965..e71bed7d5 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -31,9 +31,7 @@ import ( const implicitColID = -1 -var destDBType = loader.MysqlDB - -func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (names []string, args []interface{}, err error) { +func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte, destDBType int) (names []string, args []interface{}, err error) { columns := writableColumns(table) columnValues, err := insertRowToDatums(table, row) @@ -48,7 +46,7 @@ func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (nam val = getDefaultOrZeroValue(ptable, col) } - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, destDBType) if err != nil { return nil, nil, errors.Trace(err) } @@ -60,7 +58,7 @@ func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte) (nam return names, args, nil } -func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canAppendDefaultValue bool) (names []string, values []interface{}, oldValues []interface{}, err error) { +func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canAppendDefaultValue bool, destDBType int) (names []string, values []interface{}, oldValues []interface{}, err error) { columns := writableColumns(table) updtDecoder := newUpdateDecoder(ptable, table, canAppendDefaultValue) @@ -71,12 +69,12 @@ func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canA return nil, nil, nil, errors.Annotatef(err, "table `%s`.`%s`", schema, table.Name) } - _, oldValues, err = generateColumnAndValue(columns, oldColumnValues) + _, oldValues, err = generateColumnAndValue(columns, oldColumnValues, destDBType) if err != nil { return nil, nil, nil, errors.Trace(err) } - updateColumns, values, err = generateColumnAndValue(columns, newColumnValues) + updateColumns, values, err = generateColumnAndValue(columns, newColumnValues, destDBType) if err != nil { return nil, nil, nil, errors.Trace(err) } @@ -86,7 +84,7 @@ func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canA return } -func genDBDelete(schema string, table *model.TableInfo, row []byte) (names []string, values []interface{}, err error) { +func genDBDelete(schema string, table *model.TableInfo, row []byte, destDBType int) (names []string, values []interface{}, err error) { columns := table.Columns colsTypeMap := util.ToColumnTypeMap(columns) @@ -95,7 +93,7 @@ func genDBDelete(schema string, table *model.TableInfo, row []byte) (names []str return nil, nil, errors.Trace(err) } - columns, values, err = generateColumnAndValue(columns, columnValues) + columns, values, err = generateColumnAndValue(columns, columnValues, destDBType) if err != nil { return nil, nil, errors.Trace(err) } @@ -146,7 +144,7 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi switch mutType { case tipb.MutationType_Insert: - names, args, err := genDBInsert(schema, pinfo, info, row) + names, args, err := genDBInsert(schema, pinfo, info, row, loader.MysqlDB) if err != nil { return nil, errors.Annotate(err, "gen insert fail") } @@ -156,14 +154,14 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi Database: schema, Table: table, Values: make(map[string]interface{}), - DestDBType: destDBType, + DestDBType: loader.MysqlDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { dml.Values[name] = args[i] } case tipb.MutationType_Update: - names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue) + names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue, loader.MysqlDB) if err != nil { return nil, errors.Annotate(err, "gen update fail") } @@ -174,7 +172,7 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi Table: table, Values: make(map[string]interface{}), OldValues: make(map[string]interface{}), - DestDBType: destDBType, + DestDBType: loader.MysqlDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -183,7 +181,7 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi } case tipb.MutationType_DeleteRow: - names, args, err := genDBDelete(schema, info, row) + names, args, err := genDBDelete(schema, info, row, loader.MysqlDB) if err != nil { return nil, errors.Annotate(err, "gen delete fail") } @@ -193,7 +191,7 @@ func TiBinlogToTxn(infoGetter TableInfoGetter, schema string, table string, tiBi Database: schema, Table: table, Values: make(map[string]interface{}), - DestDBType: destDBType, + DestDBType: loader.MysqlDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -230,7 +228,7 @@ func genColumnNameList(columns []*model.ColumnInfo) (names []string) { return } -func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64]types.Datum) ([]*model.ColumnInfo, []interface{}, error) { +func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64]types.Datum, destDBType int) ([]*model.ColumnInfo, []interface{}, error) { var newColumn []*model.ColumnInfo var newColumnsValues []interface{} @@ -238,7 +236,7 @@ func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64] val, ok := columnValues[col.ID] if ok { newColumn = append(newColumn, col) - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, destDBType) if err != nil { return nil, nil, errors.Trace(err) } @@ -250,7 +248,7 @@ func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64] return newColumn, newColumnsValues, nil } -func formatData(data types.Datum, ft types.FieldType) (types.Datum, error) { +func formatData(data types.Datum, ft types.FieldType, destDBType int) (types.Datum, error) { if data.GetValue() == nil { return data, nil } diff --git a/drainer/translator/oracle.go b/drainer/translator/oracle.go index 5cf2e2a20..cd5a239d0 100644 --- a/drainer/translator/oracle.go +++ b/drainer/translator/oracle.go @@ -14,7 +14,6 @@ import ( // TiBinlogToOracleTxn translate the format to loader.Txn func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string, tiBinlog *tipb.Binlog, pv *tipb.PrewriteValue, shouldSkip bool, tableRouter *router.Table) (txn *loader.Txn, err error) { - destDBType = loader.OracleDB txn = new(loader.Txn) if tiBinlog.DdlJobId > 0 { @@ -67,7 +66,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string switch mutType { case tipb.MutationType_Insert: - names, args, err := genDBInsert(schema, pinfo, info, row) + names, args, err := genDBInsert(schema, pinfo, info, row, loader.OracleDB) if err != nil { return nil, errors.Annotate(err, "gen insert fail") } @@ -78,14 +77,14 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Table: downStreamTable, Values: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], - DestDBType: destDBType, + DestDBType: loader.OracleDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { dml.Values[strings.ToUpper(name)] = args[i] } case tipb.MutationType_Update: - names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue) + names, args, oldArgs, err := genDBUpdate(schema, pinfo, info, row, canAppendDefaultValue, loader.OracleDB) if err != nil { return nil, errors.Annotate(err, "gen update fail") } @@ -97,7 +96,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Values: make(map[string]interface{}), OldValues: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], - DestDBType: destDBType, + DestDBType: loader.OracleDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { @@ -106,7 +105,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string } case tipb.MutationType_DeleteRow: - names, args, err := genDBDelete(schema, info, row) + names, args, err := genDBDelete(schema, info, row, loader.OracleDB) if err != nil { return nil, errors.Annotate(err, "gen delete fail") } @@ -117,7 +116,7 @@ func TiBinlogToOracleTxn(infoGetter TableInfoGetter, schema string, table string Table: downStreamTable, Values: make(map[string]interface{}), UpColumnsInfoMap: tableIDColumnsMap[mut.GetTableId()], - DestDBType: destDBType, + DestDBType: loader.OracleDB, } txn.DMLs = append(txn.DMLs, dml) for i, name := range names { diff --git a/drainer/translator/pb.go b/drainer/translator/pb.go index de0d19f15..043b3a53a 100644 --- a/drainer/translator/pb.go +++ b/drainer/translator/pb.go @@ -19,6 +19,8 @@ import ( "strings" "time" + "github.com/pingcap/tidb-binlog/pkg/loader" + //nolint "github.com/golang/protobuf/proto" "github.com/pingcap/errors" @@ -137,7 +139,7 @@ func genInsert(schema string, ptable, table *model.TableInfo, row []byte) (event val = getDefaultOrZeroValue(ptable, col) } - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } @@ -173,11 +175,11 @@ func genUpdate(schema string, ptable, table *model.TableInfo, row []byte, canApp for _, col := range columns { val, ok := newColumnValues[col.ID] if ok { - oldValue, err := formatData(oldColumnValues[col.ID], col.FieldType) + oldValue, err := formatData(oldColumnValues[col.ID], col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } - newValue, err := formatData(val, col.FieldType) + newValue, err := formatData(val, col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } @@ -217,7 +219,7 @@ func genDelete(schema string, table *model.TableInfo, row []byte) (event *pb.Eve for _, col := range columns { val, ok := columnValues[col.ID] if ok { - value, err := formatData(val, col.FieldType) + value, err := formatData(val, col.FieldType, loader.DBTypeUnknown) if err != nil { return nil, errors.Trace(err) } diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 25693f63d..5248ca95f 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -37,7 +37,8 @@ const ( // Destination database type can be Mysql/Tidb or Oracle const ( - MysqlDB = iota + DBTypeUnknown = iota + MysqlDB OracleDB ) From 205bf01bcba3120b580de19f413c4056ef3f601e Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Thu, 28 Apr 2022 19:36:08 +0800 Subject: [PATCH 5/7] fix comment 3 --- drainer/translator/mysql.go | 10 +++++----- pkg/loader/executor.go | 4 ++-- pkg/loader/load.go | 6 +++--- pkg/loader/model.go | 7 +++++-- pkg/loader/util.go | 4 ++-- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index e71bed7d5..aa2402380 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -31,7 +31,7 @@ import ( const implicitColID = -1 -func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte, destDBType int) (names []string, args []interface{}, err error) { +func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte, destDBType loader.DBType) (names []string, args []interface{}, err error) { columns := writableColumns(table) columnValues, err := insertRowToDatums(table, row) @@ -58,7 +58,7 @@ func genDBInsert(schema string, ptable, table *model.TableInfo, row []byte, dest return names, args, nil } -func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canAppendDefaultValue bool, destDBType int) (names []string, values []interface{}, oldValues []interface{}, err error) { +func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canAppendDefaultValue bool, destDBType loader.DBType) (names []string, values []interface{}, oldValues []interface{}, err error) { columns := writableColumns(table) updtDecoder := newUpdateDecoder(ptable, table, canAppendDefaultValue) @@ -84,7 +84,7 @@ func genDBUpdate(schema string, ptable, table *model.TableInfo, row []byte, canA return } -func genDBDelete(schema string, table *model.TableInfo, row []byte, destDBType int) (names []string, values []interface{}, err error) { +func genDBDelete(schema string, table *model.TableInfo, row []byte, destDBType loader.DBType) (names []string, values []interface{}, err error) { columns := table.Columns colsTypeMap := util.ToColumnTypeMap(columns) @@ -228,7 +228,7 @@ func genColumnNameList(columns []*model.ColumnInfo) (names []string) { return } -func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64]types.Datum, destDBType int) ([]*model.ColumnInfo, []interface{}, error) { +func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64]types.Datum, destDBType loader.DBType) ([]*model.ColumnInfo, []interface{}, error) { var newColumn []*model.ColumnInfo var newColumnsValues []interface{} @@ -248,7 +248,7 @@ func generateColumnAndValue(columns []*model.ColumnInfo, columnValues map[int64] return newColumn, newColumnsValues, nil } -func formatData(data types.Datum, ft types.FieldType, destDBType int) (types.Datum, error) { +func formatData(data types.Datum, ft types.FieldType, destDBType loader.DBType) (types.Datum, error) { if data.GetValue() == nil { return data, nil } diff --git a/pkg/loader/executor.go b/pkg/loader/executor.go index 3e660e984..1656a7977 100644 --- a/pkg/loader/executor.go +++ b/pkg/loader/executor.go @@ -43,7 +43,7 @@ var ( type executor struct { db *gosql.DB - destDBType int + destDBType DBType batchSize int workerCount int info *loopbacksync.LoopBackSync @@ -70,7 +70,7 @@ func (e *executor) withRefreshTableInfo(fn func(schema string, table string) (in return e } -func (e *executor) withDestDBType(destDBType int) *executor { +func (e *executor) withDestDBType(destDBType DBType) *executor { e.destDBType = destDBType return e } diff --git a/pkg/loader/load.go b/pkg/loader/load.go index bb5f1a9f0..6417c5629 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -68,7 +68,7 @@ type loaderImpl struct { // like column name, pk & uk db *gosql.DB //downStream db type, mysql,tidb,oracle - destDBType int + destDBType DBType // only set for test getTableInfoFromDB func(db *gosql.DB, schema string, table string) (info *tableInfo, err error) opts options @@ -130,7 +130,7 @@ type options struct { enableDispatch bool enableCausality bool merge bool - destDBType int + destDBType DBType } var defaultLoaderOptions = options{ @@ -195,7 +195,7 @@ func Merge(v bool) Option { } //DestinationDBType set destDBType option. -func DestinationDBType(t int) Option { +func DestinationDBType(t DBType) Option { return func(o *options) { o.destDBType = t } diff --git a/pkg/loader/model.go b/pkg/loader/model.go index 5248ca95f..e1d09a351 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -36,8 +36,11 @@ const ( ) // Destination database type can be Mysql/Tidb or Oracle +type DBType int + +// DBType types const ( - DBTypeUnknown = iota + DBTypeUnknown DBType = iota MysqlDB OracleDB ) @@ -56,7 +59,7 @@ type DML struct { UpColumnsInfoMap map[string]*model.ColumnInfo - DestDBType int + DestDBType DBType } // DDL holds the ddl info diff --git a/pkg/loader/util.go b/pkg/loader/util.go index 02479fd71..c920f1268 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -251,7 +251,7 @@ func escapeName(name string) string { return strings.Replace(name, "`", "``", -1) } -func holderString(n int, destDBType int) string { +func holderString(n int, destDBType DBType) string { builder := new(strings.Builder) for i := 0; i < n; i++ { if i > 0 { @@ -282,7 +282,7 @@ func splitDMLs(dmls []*DML, size int) (res [][]*DML) { return } -func buildColumnList(names []string, destDBType int) string { +func buildColumnList(names []string, destDBType DBType) string { var b strings.Builder for i, name := range names { if i > 0 { From 61479ca23bd0210b1e825e3c0d8971d1f776b8fa Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Thu, 28 Apr 2022 19:44:51 +0800 Subject: [PATCH 6/7] fix comment on DBType --- pkg/loader/model.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/loader/model.go b/pkg/loader/model.go index e1d09a351..f9a96389c 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -35,7 +35,7 @@ const ( DeleteDMLType DMLType = 3 ) -// Destination database type can be Mysql/Tidb or Oracle +// DBType can be Mysql/Tidb or Oracle type DBType int // DBType types From 4e59a23c5abd0a5d9573f711a0bd686bcc9f8d5e Mon Sep 17 00:00:00 2001 From: cartersz <cartersunsz@163.com> Date: Thu, 5 May 2022 16:14:12 +0800 Subject: [PATCH 7/7] fix comment 4 --- drainer/sync/mysql.go | 11 +--- drainer/translator/mysql.go | 5 +- drainer/translator/pb.go | 3 +- pkg/loader/load.go | 12 +++- pkg/loader/model.go | 117 ++++++++++++++++++++++++++---------- pkg/loader/model_test.go | 8 +-- pkg/loader/util.go | 22 +++++-- 7 files changed, 124 insertions(+), 54 deletions(-) diff --git a/drainer/sync/mysql.go b/drainer/sync/mysql.go index d36ab5c95..6186a390f 100644 --- a/drainer/sync/mysql.go +++ b/drainer/sync/mysql.go @@ -57,15 +57,10 @@ func CreateLoader( enableDispatch bool, enableCausility bool, ) (ld loader.Loader, err error) { - destDBTypeInt := loader.DBTypeUnknown - if destDBType == "oracle" { - destDBTypeInt = loader.OracleDB - } else if destDBType == "tidb" || destDBType == "mysql" { - destDBTypeInt = loader.MysqlDB - } + var opts []loader.Option - opts = append(opts, loader.DestinationDBType(destDBTypeInt), loader.WorkerCount(worker), loader.BatchSize(batchSize), - loader.SaveAppliedTS(destDBTypeInt == loader.MysqlDB || destDBTypeInt == loader.OracleDB), loader.SetloopBackSyncInfo(info)) + opts = append(opts, loader.DestinationDBType(destDBType), loader.WorkerCount(worker), loader.BatchSize(batchSize), + loader.SaveAppliedTS(destDBType == "tidb" || destDBType == "oracle"), loader.SetloopBackSyncInfo(info)) if queryHistogramVec != nil { opts = append(opts, loader.Metrics(&loader.MetricsGroup{ QueryHistogramVec: queryHistogramVec, diff --git a/drainer/translator/mysql.go b/drainer/translator/mysql.go index aa2402380..387fe48f7 100644 --- a/drainer/translator/mysql.go +++ b/drainer/translator/mysql.go @@ -259,7 +259,7 @@ func formatData(data types.Datum, ft types.FieldType, destDBType loader.DBType) case mysql.TypeDuration: //only for oracle db if destDBType == loader.OracleDB { - return data, errors.New("unsupported column type[time]") + return types.Datum{}, errors.New("unsupported column type[time]") } data = types.NewDatum(fmt.Sprintf("%v", data.GetValue())) case mysql.TypeEnum: @@ -285,7 +285,8 @@ func formatData(data types.Datum, ft types.FieldType, destDBType loader.DBType) func isBlob(ft types.FieldType) bool { stype := types.TypeToStr(ft.Tp, ft.Charset) - if stype == "blob" || stype == "tinyblob" || stype == "mediumblob" || stype == "longblob" { + switch stype { + case "blob", "tinyblob", "mediumblob", "longblob": return true } return false diff --git a/drainer/translator/pb.go b/drainer/translator/pb.go index 043b3a53a..9e78040c3 100644 --- a/drainer/translator/pb.go +++ b/drainer/translator/pb.go @@ -19,8 +19,6 @@ import ( "strings" "time" - "github.com/pingcap/tidb-binlog/pkg/loader" - //nolint "github.com/golang/protobuf/proto" "github.com/pingcap/errors" @@ -31,6 +29,7 @@ import ( "github.com/pingcap/tidb/util/codec" tipb "github.com/pingcap/tipb/go-binlog" + "github.com/pingcap/tidb-binlog/pkg/loader" "github.com/pingcap/tidb-binlog/pkg/util" pb "github.com/pingcap/tidb-binlog/proto/binlog" ) diff --git a/pkg/loader/load.go b/pkg/loader/load.go index 6417c5629..6ea416ce0 100644 --- a/pkg/loader/load.go +++ b/pkg/loader/load.go @@ -195,9 +195,17 @@ func Merge(v bool) Option { } //DestinationDBType set destDBType option. -func DestinationDBType(t DBType) Option { +func DestinationDBType(t string) Option { + destDBType := DBTypeUnknown + if t == "oracle" { + destDBType = OracleDB + } else if t == "tidb" { + destDBType = TiDB + } else if t == "mysql" { + destDBType = MysqlDB + } return func(o *options) { - o.destDBType = t + o.destDBType = destDBType } } diff --git a/pkg/loader/model.go b/pkg/loader/model.go index f9a96389c..474124366 100644 --- a/pkg/loader/model.go +++ b/pkg/loader/model.go @@ -42,6 +42,7 @@ type DBType int const ( DBTypeUnknown DBType = iota MysqlDB + TiDB OracleDB ) @@ -184,56 +185,93 @@ func (dml *DML) TableName() string { } func (dml *DML) updateSQL() (sql string, args []interface{}) { + if dml.DestDBType == OracleDB { + return dml.updateOracleSQL() + } + return dml.updateTiDBSQL() +} + +func (dml *DML) updateTiDBSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "UPDATE %s SET ", dml.TableName()) - oracleHolderPos := 1 for _, name := range dml.columnNames() { if len(args) > 0 { builder.WriteByte(',') } arg := dml.Values[name] - if dml.DestDBType == OracleDB { - fmt.Fprintf(builder, "%s = :%d", escapeName(name), oracleHolderPos) - oracleHolderPos++ - } else { - fmt.Fprintf(builder, "%s = ?", quoteName(name)) - } + fmt.Fprintf(builder, "%s = ?", quoteName(name)) args = append(args, arg) } builder.WriteString(" WHERE ") - whereArgs := dml.buildWhere(builder, oracleHolderPos) + whereArgs := dml.buildTiDBWhere(builder) args = append(args, whereArgs...) - if dml.DestDBType == OracleDB { - builder.WriteString(" AND rownum <=1") - } else { - builder.WriteString(" LIMIT 1") + builder.WriteString(" LIMIT 1") + sql = builder.String() + return +} + +func (dml *DML) updateOracleSQL() (sql string, args []interface{}) { + builder := new(strings.Builder) + + fmt.Fprintf(builder, "UPDATE %s SET ", dml.TableName()) + oracleHolderPos := 1 + for _, name := range dml.columnNames() { + if len(args) > 0 { + builder.WriteByte(',') + } + arg := dml.Values[name] + fmt.Fprintf(builder, "%s = :%d", escapeName(name), oracleHolderPos) + oracleHolderPos++ + args = append(args, arg) } + + builder.WriteString(" WHERE ") + + whereArgs := dml.buildOracleWhere(builder, oracleHolderPos) + args = append(args, whereArgs...) + builder.WriteString(" AND rownum <=1") sql = builder.String() return } func (dml *DML) buildWhere(builder *strings.Builder, oracleHolderPos int) (args []interface{}) { + if dml.DestDBType == OracleDB { + dml.buildOracleWhere(builder, oracleHolderPos) + } + return dml.buildTiDBWhere(builder) +} + +func (dml *DML) buildTiDBWhere(builder *strings.Builder) (args []interface{}) { wnames, wargs := dml.whereSlice() - for i, pOracleHolderPos := 0, oracleHolderPos; i < len(wnames); i++ { + for i := 0; i < len(wnames); i++ { if i > 0 { builder.WriteString(" AND ") } if wargs[i] == nil { - if dml.DestDBType == OracleDB { - builder.WriteString(escapeName(wnames[i]) + " IS NULL") - } else { - builder.WriteString(quoteName(wnames[i]) + " IS NULL") - } + builder.WriteString(quoteName(wnames[i]) + " IS NULL") } else { - if dml.DestDBType == OracleDB { - builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos)) - pOracleHolderPos++ - } else { - builder.WriteString(quoteName(wnames[i]) + " = ?") - } + builder.WriteString(quoteName(wnames[i]) + " = ?") + args = append(args, wargs[i]) + } + } + return +} + +func (dml *DML) buildOracleWhere(builder *strings.Builder, oracleHolderPos int) (args []interface{}) { + wnames, wargs := dml.whereSlice() + pOracleHolderPos := oracleHolderPos + for i := 0; i < len(wnames); i++ { + if i > 0 { + builder.WriteString(" AND ") + } + if wargs[i] == nil { + builder.WriteString(escapeName(wnames[i]) + " IS NULL") + } else { + builder.WriteString(fmt.Sprintf("%s = :%d", escapeName(wnames[i]), pOracleHolderPos)) + pOracleHolderPos++ args = append(args, wargs[i]) } } @@ -275,16 +313,31 @@ func (dml *DML) whereSlice() (colNames []string, args []interface{}) { } func (dml *DML) deleteSQL() (sql string, args []interface{}) { + if dml.DestDBType == OracleDB { + return dml.deleteOracleSQL() + } + return dml.deleteTiDBSQL() +} + +func (dml *DML) deleteTiDBSQL() (sql string, args []interface{}) { builder := new(strings.Builder) fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) - args = dml.buildWhere(builder, 1) + args = dml.buildTiDBWhere(builder) - if dml.DestDBType == OracleDB { - builder.WriteString(" AND rownum <=1") - } else { - builder.WriteString(" LIMIT 1") - } + builder.WriteString(" LIMIT 1") + + sql = builder.String() + return +} + +func (dml *DML) deleteOracleSQL() (sql string, args []interface{}) { + builder := new(strings.Builder) + + fmt.Fprintf(builder, "DELETE FROM %s WHERE ", dml.TableName()) + args = dml.buildOracleWhere(builder, 1) + + builder.WriteString(" AND rownum <=1") sql = builder.String() return @@ -322,8 +375,8 @@ func (dml *DML) oracleDeleteNewValueSQL() (sql string, args []interface{}) { colValues = append(colValues, valueMap[col]) } } - - for i, oracleHolderPos := 0, 1; i < len(colNames); i++ { + oracleHolderPos := 1 + for i := 0; i < len(colNames); i++ { if i > 0 { builder.WriteString(" AND ") } diff --git a/pkg/loader/model_test.go b/pkg/loader/model_test.go index 3ee01aa6c..100b3d613 100644 --- a/pkg/loader/model_test.go +++ b/pkg/loader/model_test.go @@ -47,7 +47,7 @@ func getDML(key bool, tp DMLType) *DML { dml.Database = "test" dml.Table = "test" dml.Tp = tp - dml.DestDBType = MysqlDB + dml.DestDBType = TiDB return dml } @@ -189,7 +189,7 @@ func (s *SQLSuite) TestInsertSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, - DestDBType: MysqlDB, + DestDBType: TiDB, } sql, args := dml.sql() c.Assert(sql, check.Equals, "INSERT INTO `test`.`hello`(`age`,`name`) VALUES(?,?)") @@ -210,7 +210,7 @@ func (s *SQLSuite) TestDeleteSQL(c *check.C) { info: &tableInfo{ columns: []string{"name", "age"}, }, - DestDBType: MysqlDB, + DestDBType: TiDB, } sql, args := dml.sql() c.Assert( @@ -235,7 +235,7 @@ func (s *SQLSuite) TestUpdateSQL(c *check.C) { info: &tableInfo{ columns: []string{"name"}, }, - DestDBType: MysqlDB, + DestDBType: TiDB, } sql, args := dml.sql() c.Assert( diff --git a/pkg/loader/util.go b/pkg/loader/util.go index c920f1268..b2b1c5a4f 100644 --- a/pkg/loader/util.go +++ b/pkg/loader/util.go @@ -252,16 +252,30 @@ func escapeName(name string) string { } func holderString(n int, destDBType DBType) string { + if destDBType == OracleDB { + return holderStringOracle(n) + } + return holderStringTiDB(n) +} + +func holderStringTiDB(n int) string { builder := new(strings.Builder) for i := 0; i < n; i++ { if i > 0 { builder.WriteString(",") } - if destDBType == OracleDB { - builder.WriteString(":" + strconv.Itoa(i+1)) - } else { - builder.WriteString("?") + builder.WriteString("?") + } + return builder.String() +} + +func holderStringOracle(n int) string { + builder := new(strings.Builder) + for i := 0; i < n; i++ { + if i > 0 { + builder.WriteString(",") } + builder.WriteString(":" + strconv.Itoa(i+1)) } return builder.String() }