Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 76 additions & 22 deletions sync_diff_inspector/splitter/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ import (

// RandomIterator is used to random iterate a table
type RandomIterator struct {
table *common.TableDiff
chunkSize int64
chunks []*chunk.Range
nextChunk uint
table *common.TableDiff
chunkSize int64
chunkBase *chunk.Range
splitColumns []string
randomValues [][]string
chunkCount int
nextChunk uint

dbConn *sql.DB
}
Expand Down Expand Up @@ -128,7 +131,7 @@ NEXTINDEX:
return &RandomIterator{
table: table,
chunkSize: 0,
chunks: nil,
chunkBase: nil,
nextChunk: 0,
dbConn: dbConn,
}, nil
Expand Down Expand Up @@ -174,36 +177,75 @@ NEXTINDEX:
bucketChunkCnt = chunkCnt
}

chunks, err := splitRangeByRandom(ctx, dbConn, chunkRange, chunkCnt, table.Schema, table.Table, fields, table.Range, table.Collation)
randomValues, err := getRandomSplitValues(ctx, dbConn, chunkRange, chunkCnt, table.Schema, table.Table, fields, table.Range, table.Collation)
if err != nil {
return nil, errors.Trace(err)
}
chunk.InitChunks(chunks, chunk.Random, 0, 0, beginIndex, table.Collation, table.Range, bucketChunkCnt)
chunkCount := len(randomValues) + 1
chunkRange.Index = &chunk.CID{
BucketIndexLeft: 0,
BucketIndexRight: 0,
ChunkIndex: beginIndex,
ChunkCnt: bucketChunkCnt,
}
chunkRange.Type = chunk.Random

failpoint.Inject("ignore-last-n-chunk-in-bucket", func(v failpoint.Value) {
log.Info("failpoint ignore-last-n-chunk-in-bucket injected (random splitter)", zap.Int("n", v.(int)))
if len(chunks) <= 1+v.(int) {
if chunkCount <= 1+v.(int) {
failpoint.Return(nil, nil)
}
chunks = chunks[:(len(chunks) - v.(int))]
chunkCount -= v.(int)
})
Comment on lines +184 to 199
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ChunkCnt in the chunk's CID is not being set to the actual number of chunks. It's using bucketChunkCnt, which is an initial estimate and may be incorrect if getRandomSplitValues returns fewer values than requested, or if the failpoint ignore-last-n-chunk-in-bucket is triggered. This can lead to incorrect behavior when resuming from a checkpoint.

The CID should be initialized after the final chunkCount is determined (i.e., after the failpoint injection). Also, when it's not a resumed run (startRange == nil), bucketChunkCnt should be updated to the actual chunkCount.

 	chunkCount := len(randomValues) + 1
 	failpoint.Inject("ignore-last-n-chunk-in-bucket", func(v failpoint.Value) {
 		log.Info("failpoint ignore-last-n-chunk-in-bucket injected (random splitter)", zap.Int("n", v.(int)))
 		if chunkCount <= 1+v.(int) {
 			failpoint.Return(nil, nil)
 		}
 		chunkCount -= v.(int)
 		if len(randomValues) >= chunkCount {
 			randomValues = randomValues[:chunkCount-1]
 		}
 	})

 	if startRange == nil {
 		bucketChunkCnt = chunkCount
 	}
 	chunkRange.Index = &chunk.CID{
 		BucketIndexLeft:  0,
 		BucketIndexRight: 0,
 		ChunkIndex:       beginIndex,
 		ChunkCnt:         bucketChunkCnt,
 	}
 	chunkRange.Type = chunk.Random


progress.StartTable(progressID, len(chunks), true)
progress.StartTable(progressID, chunkCount, true)
splitColumns := make([]string, 0, len(fields))
for _, field := range fields {
splitColumns = append(splitColumns, field.Name.O)
}
return &RandomIterator{
table: table,
chunkSize: chunkSize,
chunks: chunks,
nextChunk: 0,
dbConn: dbConn,
table: table,
chunkSize: chunkSize,
chunkBase: chunkRange,
splitColumns: splitColumns,
randomValues: randomValues,
chunkCount: chunkCount,
nextChunk: 0,
dbConn: dbConn,
}, nil
}

func (s *RandomIterator) buildChunk(chunkIndex int) *chunk.Range {
c := s.chunkBase.Clone()
if len(s.randomValues) > 0 {
for i, column := range s.splitColumns {
switch {
case chunkIndex == 0:
c.Update(column, "", s.randomValues[0][i], false, true)
case chunkIndex == len(s.randomValues):
c.Update(column, s.randomValues[chunkIndex-1][i], "", true, false)
default:
c.Update(column, s.randomValues[chunkIndex-1][i], s.randomValues[chunkIndex][i], true, true)
}
}
}

conditions, args := c.ToString(s.table.Collation)
c.Where = fmt.Sprintf("((%s) AND (%s))", conditions, s.table.Range)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The s.table.Range parameter is directly concatenated into the SQL WHERE clause. Since this value originates from user configuration, it poses a risk of SQL injection if the configuration is provided by an untrusted source. While this tool is primarily a CLI utility, it is recommended to use parameterized queries or a SQL parser to validate the range expression to prevent arbitrary SQL execution or boundary bypass.

c.Args = args
c.Index.ChunkIndex += chunkIndex
return c
}

// Next get the next chunk
func (s *RandomIterator) Next() (*chunk.Range, error) {
if uint(len(s.chunks)) <= s.nextChunk {
if s.chunkBase == nil {
return nil, nil
}
c := s.chunks[s.nextChunk]
if uint(s.chunkCount) <= s.nextChunk {
return nil, nil
}
c := s.buildChunk(int(s.nextChunk))
s.nextChunk = s.nextChunk + 1
failpoint.Inject("print-chunk-info", func() {
lowerBounds := make([]string, len(c.Bounds))
Expand Down Expand Up @@ -273,14 +315,10 @@ func splitRangeByRandom(ctx context.Context, db *sql.DB, chunk *chunk.Range, cou
return chunks, nil
}

chunkLimits, args := chunk.ToString(collation)
limitRange := fmt.Sprintf("(%s) AND (%s)", chunkLimits, limits)

randomValues, err := utils.GetRandomValues(ctx, db, schema, table, columns, count-1, limitRange, args, collation)
randomValues, err := getRandomSplitValues(ctx, db, chunk, count, schema, table, columns, limits, collation)
if err != nil {
return nil, errors.Trace(err)
}
log.Debug("get split values by random", zap.Stringer("chunk", chunk), zap.Int("random values num", len(randomValues)))
for i := 0; i <= len(randomValues); i++ {
newChunk := chunk.Copy()

Expand All @@ -302,3 +340,19 @@ func splitRangeByRandom(ctx context.Context, db *sql.DB, chunk *chunk.Range, cou
log.Debug("split range by random", zap.Stringer("origin chunk", chunk), zap.Int("split num", len(chunks)))
return chunks, nil
}

func getRandomSplitValues(ctx context.Context, db *sql.DB, chunk *chunk.Range, count int, schema string, table string, columns []*model.ColumnInfo, limits, collation string) ([][]string, error) {
if count <= 1 {
return nil, nil
}

chunkLimits, args := chunk.ToString(collation)
limitRange := fmt.Sprintf("(%s) AND (%s)", chunkLimits, limits)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The limits parameter (which contains the table range) is directly concatenated into the limitRange string, which is then used to build a SQL query in utils.GetRandomValues. This is a potential SQL injection point. It is recommended to use parameterized queries where possible to ensure that user-provided range filters cannot manipulate the intended query structure.


randomValues, err := utils.GetRandomValues(ctx, db, schema, table, columns, count-1, limitRange, args, collation)
if err != nil {
return nil, errors.Trace(err)
}
log.Debug("get split values by random", zap.Stringer("chunk", chunk), zap.Int("random values num", len(randomValues)))
return randomValues, nil
}