diff --git a/sync_diff_inspector/diff.go b/sync_diff_inspector/diff.go index a5b2a685e..28aa4e07a 100644 --- a/sync_diff_inspector/diff.go +++ b/sync_diff_inspector/diff.go @@ -50,6 +50,15 @@ const ( checkpointFile = "sync_diff_checkpoints.pb" ) +func getSplitThreshold() int64 { + failpoint.Inject("binsearchSplitThreshold", func(val failpoint.Value) { + if threshold, ok := val.(int); ok { + failpoint.Return(int64(threshold)) + } + }) + return int64(splitter.SplitThreshold) +} + // ChunkDML SQL struct for each chunk type ChunkDML struct { node *checkpoints.Node @@ -455,7 +464,7 @@ func (df *Diff) consume(ctx context.Context, rangeInfo *splitter.RangeInfo) bool state = checkpoints.FailedState // if the chunk's checksum differ, try to do binary check info := rangeInfo - if upCount > splitter.SplitThreshold { + if upCount > getSplitThreshold() { log.Debug("count greater than threshold, start do bingenerate", zap.Any("chunk id", rangeInfo.ChunkRange.Index), zap.Int64("upstream chunk size", upCount)) info, err = df.BinGenerate(ctx, df.workSource, rangeInfo, upCount) if err != nil { @@ -479,7 +488,7 @@ func (df *Diff) consume(ctx context.Context, rangeInfo *splitter.RangeInfo) bool } func (df *Diff) BinGenerate(ctx context.Context, targetSource source.Source, tableRange *splitter.RangeInfo, count int64) (*splitter.RangeInfo, error) { - if count <= splitter.SplitThreshold { + if count <= getSplitThreshold() { return tableRange, nil } tableDiff := targetSource.GetTables()[tableRange.GetTableIndex()] @@ -521,7 +530,7 @@ func (df *Diff) BinGenerate(ctx context.Context, targetSource source.Source, tab } func (df *Diff) binSearch(ctx context.Context, targetSource source.Source, tableRange *splitter.RangeInfo, count int64, tableDiff *common.TableDiff, indexColumns []*model.ColumnInfo) (*splitter.RangeInfo, error) { - if count <= splitter.SplitThreshold { + if count <= getSplitThreshold() { return tableRange, nil } var ( @@ -533,7 +542,8 @@ func (df *Diff) binSearch(ctx context.Context, targetSource source.Source, table chunkLimits, args := tableRange.ChunkRange.ToString(tableDiff.Collation) limitRange := fmt.Sprintf("(%s) AND (%s)", chunkLimits, tableDiff.Range) - midValues, err := utils.GetApproximateMidBySize(ctx, targetSource.GetDB(), tableDiff.Schema, tableDiff.Table, indexColumns, limitRange, args, count) + sourceSchema, sourceTable := targetSource.GetSourceTable(tableRange) + midValues, err := utils.GetApproximateMidBySize(ctx, targetSource.GetDB(), sourceSchema, sourceTable, indexColumns, limitRange, args, count) if err != nil { return nil, errors.Trace(err) } diff --git a/sync_diff_inspector/source/mysql_shard.go b/sync_diff_inspector/source/mysql_shard.go index 4ae410d83..4996c8fcd 100644 --- a/sync_diff_inspector/source/mysql_shard.go +++ b/sync_diff_inspector/source/mysql_shard.go @@ -156,6 +156,15 @@ func (s *MySQLSources) GetTables() []*common.TableDiff { return s.tableDiffs } +// GetSourceTable returns the physical source table mapped from tableDiff. +func (s *MySQLSources) GetSourceTable(tableRange *splitter.RangeInfo) (schema string, table string) { + tableDiff := s.GetTables()[tableRange.GetTableIndex()] + if matchedSources, ok := s.sourceTablesMap[utils.UniqueID(tableDiff.Schema, tableDiff.Table)]; ok && len(matchedSources) > 0 { + return matchedSources[0].OriginSchema, matchedSources[0].OriginTable + } + return tableDiff.Schema, tableDiff.Table +} + func (s *MySQLSources) GenerateFixSQL(t DMLType, upstreamData, downstreamData map[string]*dbutil.ColumnData, tableIndex int) string { switch t { case Insert: diff --git a/sync_diff_inspector/source/source.go b/sync_diff_inspector/source/source.go index 9fd7106cc..4ddfbcc2a 100644 --- a/sync_diff_inspector/source/source.go +++ b/sync_diff_inspector/source/source.go @@ -97,6 +97,10 @@ type Source interface { // GetTables represents the tableDiffs. GetTables() []*common.TableDiff + // GetSourceTable returns the physical source table name mapped from tableDiff. + // If no route is configured, it returns the same schema/table as tableDiff. + GetSourceTable(*splitter.RangeInfo) (schema string, table string) + // GetSourceStructInfo get the source table info from a given target table GetSourceStructInfo(context.Context, int) ([]*model.TableInfo, error) diff --git a/sync_diff_inspector/source/source_test.go b/sync_diff_inspector/source/source_test.go index 5dc614652..72d53bbb7 100644 --- a/sync_diff_inspector/source/source_test.go +++ b/sync_diff_inspector/source/source_test.go @@ -953,3 +953,77 @@ func TestCheckTableMatched(t *testing.T) { require.Equal(t, 1, tables[1].TableLack) require.Equal(t, -1, tables[2].TableLack) } + +func TestTiDBSourceGetSourceTable(t *testing.T) { + tableDiff := &common.TableDiff{ + Schema: "target_schema", + Table: "target_table", + } + sourceTableMap := map[string]*common.TableSource{ + utils.UniqueID("target_schema", "target_table"): { + OriginSchema: "source_schema", + OriginTable: "source_table", + }, + } + tidb := &TiDBSource{ + tableDiffs: []*common.TableDiff{tableDiff}, + sourceTableMap: sourceTableMap, + } + tableRange := &splitter.RangeInfo{ + ChunkRange: chunk.NewChunkRange(nil), + } + tableRange.ChunkRange.Index.TableIndex = 0 + + schema, table := tidb.GetSourceTable(tableRange) + require.Equal(t, "source_schema", schema) + require.Equal(t, "source_table", table) + + tidb.sourceTableMap = map[string]*common.TableSource{ + utils.UniqueID("other_schema", "other_table"): { + OriginSchema: "other_schema", + OriginTable: "other_table", + }, + } + schema, table = tidb.GetSourceTable(tableRange) + require.Equal(t, "target_schema", schema) + require.Equal(t, "target_table", table) +} + +func TestMySQLSourcesGetSourceTable(t *testing.T) { + tableDiff := &common.TableDiff{ + Schema: "target_schema", + Table: "target_table", + } + mysql := &MySQLSources{ + tableDiffs: []*common.TableDiff{tableDiff}, + sourceTablesMap: map[string][]*common.TableShardSource{ + utils.UniqueID("target_schema", "target_table"): { + { + TableSource: common.TableSource{ + OriginSchema: "source_schema_1", + OriginTable: "source_table_1", + }, + }, + { + TableSource: common.TableSource{ + OriginSchema: "source_schema_2", + OriginTable: "source_table_2", + }, + }, + }, + }, + } + tableRange := &splitter.RangeInfo{ + ChunkRange: chunk.NewChunkRange(nil), + } + tableRange.ChunkRange.Index.TableIndex = 0 + + schema, table := mysql.GetSourceTable(tableRange) + require.Equal(t, "source_schema_1", schema) + require.Equal(t, "source_table_1", table) + + mysql.sourceTablesMap = map[string][]*common.TableShardSource{} + schema, table = mysql.GetSourceTable(tableRange) + require.Equal(t, "target_schema", schema) + require.Equal(t, "target_table", table) +} diff --git a/sync_diff_inspector/source/tidb.go b/sync_diff_inspector/source/tidb.go index 6a091862e..81c76f3b0 100644 --- a/sync_diff_inspector/source/tidb.go +++ b/sync_diff_inspector/source/tidb.go @@ -126,7 +126,7 @@ func (s *TiDBSource) GetCountAndMd5(ctx context.Context, tableRange *splitter.Ra table := s.tableDiffs[tableRange.GetTableIndex()] chunk := tableRange.GetChunk() - matchSource := getMatchSource(s.sourceTableMap, table) + sourceSchema, sourceTable := s.GetSourceTable(tableRange) indexHint := "" if s.sqlHint == "auto" && len(chunk.IndexColumnNames) > 0 { // If sqlHint is set to "auto" and there are index column names in the chunk, @@ -140,7 +140,7 @@ func (s *TiDBSource) GetCountAndMd5(ctx context.Context, tableRange *splitter.Ra for _, index := range dbutil.FindAllIndex(tableInfos[0]) { if utils.IsIndexMatchingColumns(index, chunk.IndexColumnNames) { indexHint = fmt.Sprintf("/*+ USE_INDEX(%s, %s) */", - dbutil.TableName(matchSource.OriginSchema, matchSource.OriginTable), + dbutil.TableName(sourceSchema, sourceTable), dbutil.ColumnName(index.Name.O), ) break @@ -150,7 +150,7 @@ func (s *TiDBSource) GetCountAndMd5(ctx context.Context, tableRange *splitter.Ra } count, checksum, err := utils.GetCountAndMd5Checksum( - ctx, s.dbConn, matchSource.OriginSchema, matchSource.OriginTable, table.Info, + ctx, s.dbConn, sourceSchema, sourceTable, table.Info, chunk.Where, indexHint, chunk.Args) cost := time.Since(beginTime) @@ -176,6 +176,16 @@ func (s *TiDBSource) GetTables() []*common.TableDiff { return s.tableDiffs } +// GetSourceTable returns the physical source table mapped from tableDiff. +func (s *TiDBSource) GetSourceTable(tableRange *splitter.RangeInfo) (schema string, table string) { + tableDiff := s.GetTables()[tableRange.GetTableIndex()] + matchSource := getMatchSource(s.sourceTableMap, tableDiff) + if matchSource == nil { + return tableDiff.Schema, tableDiff.Table + } + return matchSource.OriginSchema, matchSource.OriginTable +} + func (s *TiDBSource) GetSourceStructInfo(ctx context.Context, tableIndex int) ([]*model.TableInfo, error) { var err error tableInfos := make([]*model.TableInfo, 1) @@ -207,8 +217,8 @@ func (s *TiDBSource) GetRowsIterator(ctx context.Context, tableRange *splitter.R chunk := tableRange.GetChunk() table := s.tableDiffs[tableRange.GetTableIndex()] - matchedSource := getMatchSource(s.sourceTableMap, table) - rowsQuery, _ := utils.GetTableRowsQueryFormat(matchedSource.OriginSchema, matchedSource.OriginTable, table.Info, table.Collation) + sourceSchema, sourceTable := s.GetSourceTable(tableRange) + rowsQuery, _ := utils.GetTableRowsQueryFormat(sourceSchema, sourceTable, table.Info, table.Collation) query := fmt.Sprintf(rowsQuery, chunk.Where) log.Debug("select data", zap.String("sql", query), zap.Reflect("args", chunk.Args)) diff --git a/tests/sync_diff_inspector/router/config_base.toml b/tests/sync_diff_inspector/router/config_base.toml new file mode 100644 index 000000000..569d2a208 --- /dev/null +++ b/tests/sync_diff_inspector/router/config_base.toml @@ -0,0 +1,39 @@ +# Diff Configuration. + +######################### Global config ######################### + +# how many goroutines are created to check data +check-thread-count = 4 + +# set false if just want compare data by checksum, will skip select data when checksum is not equal. +# set true if want compare all different rows, will slow down the total compare time. +export-fix-sql = true + +# ignore check table's data +check-struct-only = false + +######################### Databases config ######################### +[data-sources] +[data-sources.tidb_up] + host = "127.0.0.1" + port = 4001 + user = "root" + password = "" + route-rules = ["rule1"] + +[data-sources.mysql_down] + host = "127.0.0.1"#MYSQL_HOST + port = 3306#MYSQL_PORT + user = "root" + password = "" + +[routes.rule1] +schema-pattern = "route_up_test" +target-schema = "route_down_test" + +######################### Task config ######################### +[task] + output-dir = "/tmp/tidb_tools_test/sync_diff_inspector/output" + source-instances = ["tidb_up"] + target-instance = "mysql_down" + target-check-tables = ["route_down_test.t_route"] diff --git a/tests/sync_diff_inspector/router/run.sh b/tests/sync_diff_inspector/router/run.sh new file mode 100755 index 000000000..be0d3bbb5 --- /dev/null +++ b/tests/sync_diff_inspector/router/run.sh @@ -0,0 +1,37 @@ +#!/bin/sh + +set -ex + +cd "$(dirname "$0")" + +OUT_DIR=/tmp/tidb_tools_test/sync_diff_inspector/output +rm -rf $OUT_DIR +mkdir -p $OUT_DIR + +# prepare upstream TiDB table and downstream MySQL table with route mapping: +# route_up_test.t_route -> route_down_test.t_route +mysql -uroot -h 127.0.0.1 -P 4001 -e "drop database if exists route_up_test;" +mysql -uroot -h 127.0.0.1 -P 4001 -e "drop database if exists route_down_test;" +mysql -uroot -h 127.0.0.1 -P 4001 -e "create database route_up_test;" +mysql -uroot -h 127.0.0.1 -P 4001 -e "create table route_up_test.t_route(id int primary key, val varchar(20));" +mysql -uroot -h 127.0.0.1 -P 4001 -e "insert into route_up_test.t_route values (1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'), (5, 'e'), (6, 'f'), (7, 'g'), (8, 'h'), (9, 'i'), (10, 'j');" + +mysql -uroot -h ${MYSQL_HOST} -P ${MYSQL_PORT} -e "drop database if exists route_down_test;" +mysql -uroot -h ${MYSQL_HOST} -P ${MYSQL_PORT} -e "create database route_down_test;" +mysql -uroot -h ${MYSQL_HOST} -P ${MYSQL_PORT} -e "create table route_down_test.t_route(id int primary key, val varchar(20));" +mysql -uroot -h ${MYSQL_HOST} -P ${MYSQL_PORT} -e "insert into route_down_test.t_route values (1, 'a'), (2, 'b'), (3, 'c'), (4, 'x'), (5, 'e'), (6, 'f'), (7, 'g'), (8, 'h'), (9, 'i'), (10, 'j');" + +sed "s/\"127.0.0.1\"#MYSQL_HOST/\"${MYSQL_HOST}\"/g" ./config_base.toml | sed "s/3306#MYSQL_PORT/${MYSQL_PORT}/g" >./config.toml + +export GO_FAILPOINTS="main/binsearchSplitThreshold=return(4)" +sync_diff_inspector --config=./config.toml -L debug >$OUT_DIR/router.output || true +export GO_FAILPOINTS="" + +check_contains "check failed!!!" $OUT_DIR/sync_diff.log +check_contains "A total of 1 tables have been compared, 0 tables finished, 1 tables failed, 0 tables skipped." $OUT_DIR/router.output +check_contains "+1/-1" $OUT_DIR/summary.txt + +check_contains "get mid by size" $OUT_DIR/sync_diff.log +grep "get mid by size" $OUT_DIR/sync_diff.log >$OUT_DIR/router_mid.log +check_contains "FROM \`route_up_test\`.\`t_route\`" $OUT_DIR/router_mid.log +check_not_contains "FROM \`route_down_test\`.\`t_route\`" $OUT_DIR/router_mid.log