From c30e84ef0d584210f4fa9571db60dce5e6fff212 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 18 Jun 2026 13:23:55 +0800 Subject: [PATCH] [SPARK-57520][SQL] Fix UTF8String.codePointFrom and copyUTF8String reading past the end of a truncated trailing UTF-8 sequence `codePointFrom` read the declared number of continuation bytes (from `numBytesForFirstByte`) without checking they exist, and `copyUTF8String` copied `end - start + 1` bytes without clamping to what remains. When a string ends in a truncated multi-byte sequence (a leader byte whose width exceeds the remaining bytes), both read past the end of the backing memory. `trimLeft`/`trimRight` build their search character through `copyUTF8String`, so they over-read too. `codePointFrom` now reads continuation bytes through a helper that returns 0 past the end, and `copyUTF8String` clamps the copy length to the bytes that remain. Once `copyUTF8String` stops over-reading, `trimRight` needs a matching accounting fix: it decremented `trimEnd` by the declared character width, which overshoots for a truncated trailing character, so it now uses the actual (clamped) byte count, as `trimLeft` already does. Well-formed UTF-8 is unaffected. Follow-up of SPARK-57507. --- .../apache/spark/unsafe/types/UTF8String.java | 32 ++++++++--- .../spark/unsafe/types/UTF8StringSuite.java | 55 +++++++++++++++++++ 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index ab57e308ee4d7..13375aa00462a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -723,7 +723,8 @@ public int getChar(int charIndex) { /** * Returns the code point starting from the byte at position `byteIndex`. - * If byte index is invalid, throws exception. + * If byte index is invalid, throws exception. If the sequence is truncated (the leader byte + * declares more bytes than remain), the missing continuation bytes are treated as 0. */ public int codePointFrom(int byteIndex) { Objects.checkIndex(byteIndex, numBytes); @@ -733,18 +734,28 @@ public int codePointFrom(int byteIndex) { case 1 -> b & 0x7F; case 2 -> - ((b & 0x1F) << 6) | (getByte(byteIndex + 1) & 0x3F); + ((b & 0x1F) << 6) | continuationByte(byteIndex + 1); case 3 -> - ((b & 0x0F) << 12) | ((getByte(byteIndex + 1) & 0x3F) << 6) | - (getByte(byteIndex + 2) & 0x3F); + ((b & 0x0F) << 12) | (continuationByte(byteIndex + 1) << 6) | + continuationByte(byteIndex + 2); case 4 -> - ((b & 0x07) << 18) | ((getByte(byteIndex + 1) & 0x3F) << 12) | - ((getByte(byteIndex + 2) & 0x3F) << 6) | (getByte(byteIndex + 3) & 0x3F); + ((b & 0x07) << 18) | (continuationByte(byteIndex + 1) << 12) | + (continuationByte(byteIndex + 2) << 6) | continuationByte(byteIndex + 3); default -> throw new IllegalStateException("Error in UTF-8 code point"); }; } + /** + * Returns the low 6 bits of the UTF-8 continuation byte at `byteIndex`, or 0 when `byteIndex` + * is past the end of the string. The bounds check stops a truncated trailing multi-byte + * sequence (a leader byte whose declared width exceeds the bytes that remain) from reading + * past the end of the backing memory. + */ + private int continuationByte(int byteIndex) { + return byteIndex < numBytes ? getByte(byteIndex) & 0x3F : 0; + } + public boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; @@ -941,7 +952,10 @@ public int findInSet(UTF8String match) { * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. */ public UTF8String copyUTF8String(int start, int end) { - int len = end - start + 1; + // Clamp to the bytes that actually remain so an out-of-range `end` (for example, derived + // from a truncated trailing multi-byte sequence) can't copy past the end of the backing + // memory. + int len = Math.min(end - start + 1, numBytes - start); byte[] newBytes = new byte[len]; copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); @@ -1134,7 +1148,9 @@ public UTF8String trimRight(UTF8String trimString) { stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); if (trimString.find(searchChar, 0) >= 0) { - trimEnd -= stringCharLen[numChars - 1]; + // Advance by the bytes the character actually occupies. A truncated trailing leader is + // shorter than the width its leader byte declares, so use the (clamped) search char. + trimEnd -= searchChar.numBytes; } else { break; } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 26b96155377e8..9b368e33a391a 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -1204,6 +1204,61 @@ public void testCodePointFrom() { assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(-1)); assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length())); assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length() + 1)); + + // Truncated trailing multi-byte sequence: the leader declares more bytes than remain. + // codePointFrom should decode only the bytes present (missing continuation bytes count as + // 0) and not read past the end. Each backing array has extra trailing bytes, so an + // over-read regression would show up in the result. + // 2-byte leader 0xCE with no continuation byte present. + assertEquals((0xCE & 0x1F) << 6, + fromBytes(new byte[] {(byte) 0xCE, 0x42}, 0, 1).codePointFrom(0)); + // 3-byte leader 0xE4 with no continuation bytes present. + assertEquals((0xE4 & 0x0F) << 12, + fromBytes(new byte[] {(byte) 0xE4, 0x42, 0x43}, 0, 1).codePointFrom(0)); + // 3-byte leader 0xE4 0xB8 with the final continuation byte missing. + assertEquals(((0xE4 & 0x0F) << 12) | ((0xB8 & 0x3F) << 6), + fromBytes(new byte[] {(byte) 0xE4, (byte) 0xB8, 0x42}, 0, 2).codePointFrom(0)); + // 4-byte leader 0xF1 with no continuation bytes present. + assertEquals((0xF1 & 0x07) << 18, + fromBytes(new byte[] {(byte) 0xF1, 0x42, 0x43, 0x44}, 0, 1).codePointFrom(0)); + // 4-byte leader 0xF1 with two continuation bytes present and only the last one missing, + // so just the final read crosses the end. + assertEquals(((0xF1 & 0x07) << 18) | ((0x9F & 0x3F) << 12) | ((0x8F & 0x3F) << 6), + fromBytes(new byte[] {(byte) 0xF1, (byte) 0x9F, (byte) 0x8F, 0x42}, 0, 3).codePointFrom(0)); + } + + @Test + public void copyUTF8StringClampsToRemainingBytes() { + // Here `end` runs one byte past the string, as it would for a truncated trailing sequence. + // copyUTF8String should clamp to the available bytes; the extra backing byte would show up + // in the result if it over-read. + byte[] backing = new byte[] {0x41, 0x42, 0x43}; + UTF8String s = fromBytes(backing, 0, 2); // views "AB" + // `end` (2) is one past the last valid byte index (1); only the two real bytes are copied. + assertEquals(fromString("AB"), s.copyUTF8String(0, 2)); + // Same with a non-zero start, so the clamp uses `numBytes - start`, not `numBytes`. + assertEquals(fromString("B"), s.copyUTF8String(1, 2)); + // In-bounds copies are unaffected. + assertEquals(fromString("AB"), s.copyUTF8String(0, 1)); + assertEquals(fromString("B"), s.copyUTF8String(1, 1)); + } + + @Test + public void trimTruncatedTrailingSequence() { + // trimLeft/trimRight build the search character through copyUTF8String, so an over-read would + // make it longer than the bytes that remain. The backing arrays carry an extra trailing byte + // to make any over-read deterministic. + // A lone truncated 2-byte leader (0xC2): the clamped search char is just the leader, which + // matches the 1-byte trim set and is trimmed away. + UTF8String lone = fromBytes(new byte[] {(byte) 0xC2, 0x42}, 0, 1); + UTF8String trim2 = fromBytes(new byte[] {(byte) 0xC2}); + assertEquals(EMPTY_UTF8, lone.trimLeft(trim2)); + assertEquals(EMPTY_UTF8, lone.trimRight(trim2)); + // 'A' followed by a truncated 3-byte leader (0xE4). Trimming the leader from the right must + // keep 'A': the trailing character occupies only one real byte, so only that byte is removed. + UTF8String prefixed = fromBytes(new byte[] {0x41, (byte) 0xE4, 0x42}, 0, 2); + UTF8String trim3 = fromBytes(new byte[] {(byte) 0xE4}); + assertEquals(fromString("A"), prefixed.trimRight(trim3)); } @Test