Skip to content

Commit

Permalink
[SPARK-48715][SQL] Integrate UTF8String validation into collation-awa…
Browse files Browse the repository at this point in the history
…re string function implementations

### What changes were proposed in this pull request?
Use our own invalid UTF-8 byte sequence replacement logic in UTF8String, before all `.toString()` method calls.

### Why are the changes needed?
Avoid relying on Java to perform invalid UTF-8 byte sequence replacement, and ensure consistent results.

### Does this PR introduce _any_ user-facing change?
Yes, collation aware string function implementations will now rely on our own invalid UTF-8 string replacement implementation, instead of Java's.

### How was this patch tested?
Existing tests, with some changes in `UTF8StringSuite` and `CollationSupportSuite`.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47131 from uros-db/make-valid.

Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
uros-db authored and cloud-fan committed Jul 4, 2024
1 parent f438674 commit bf25f0a
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,7 @@ public static UTF8String toUpperCase(final UTF8String target) {
private static UTF8String toUpperCaseSlow(final UTF8String target) {
// Note: In order to achieve the desired behavior, we use the ICU UCharacter class to
// convert the string to uppercase, which only accepts a Java strings as input.
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
return UTF8String.fromString(UCharacter.toUpperCase(target.toString()));
return UTF8String.fromString(UCharacter.toUpperCase(target.toValidString()));
}

/**
Expand All @@ -377,8 +376,7 @@ private static UTF8String toUpperCaseSlow(final UTF8String target, final int col
// convert the string to uppercase, which only accepts a Java strings as input.
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
return UTF8String.fromString(UCharacter.toUpperCase(locale, target.toString()));
return UTF8String.fromString(UCharacter.toUpperCase(locale, target.toValidString()));
}

/**
Expand All @@ -395,8 +393,7 @@ public static UTF8String toLowerCase(final UTF8String target) {
private static UTF8String toLowerCaseSlow(final UTF8String target) {
// Note: In order to achieve the desired behavior, we use the ICU UCharacter class to
// convert the string to lowercase, which only accepts a Java strings as input.
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
return UTF8String.fromString(UCharacter.toLowerCase(target.toString()));
return UTF8String.fromString(UCharacter.toLowerCase(target.toValidString()));
}

/**
Expand All @@ -415,8 +412,7 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col
// convert the string to lowercase, which only accepts a Java strings as input.
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
return UTF8String.fromString(UCharacter.toLowerCase(locale, target.toString()));
return UTF8String.fromString(UCharacter.toLowerCase(locale, target.toValidString()));
}

/**
Expand Down Expand Up @@ -459,8 +455,7 @@ public static UTF8String lowerCaseCodePoints(final UTF8String target) {
}

private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
String targetString = target.toString();
String targetString = target.toValidString();
StringBuilder sb = new StringBuilder();
for (int i = 0; i < targetString.length(); ++i) {
lowercaseCodePoint(targetString.codePointAt(i), sb);
Expand All @@ -474,8 +469,7 @@ private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) {
public static UTF8String toTitleCase(final UTF8String target) {
// Note: In order to achieve the desired behavior, we use the ICU UCharacter class to
// convert the string to titlecase, which only accepts a Java strings as input.
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
return UTF8String.fromString(UCharacter.toTitleCase(target.toString(),
return UTF8String.fromString(UCharacter.toTitleCase(target.toValidString(),
BreakIterator.getWordInstance()));
}

Expand All @@ -485,8 +479,7 @@ public static UTF8String toTitleCase(final UTF8String target) {
public static UTF8String toTitleCase(final UTF8String target, final int collationId) {
ULocale locale = CollationFactory.fetchCollation(collationId)
.collator.getLocale(ULocale.ACTUAL_LOCALE);
// TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid`
return UTF8String.fromString(UCharacter.toTitleCase(locale, target.toString(),
return UTF8String.fromString(UCharacter.toTitleCase(locale, target.toValidString(),
BreakIterator.getWordInstance(locale)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.function.Function;
import java.util.Iterator;
Expand Down Expand Up @@ -62,6 +61,28 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
private int numBytes;
private volatile int numChars = -1;

/**
* The validity of the UTF8Strings can be cached to avoid repeated validation checks, because
* that operation requires full string scan. Valid strings have no illegal UTF-8 byte sequences.
*/
private enum UTF8StringValidity {
UNKNOWN, IS_VALID, NOT_VALID
}

/**
* Internal flag to indicate whether the UTF-8 string is valid or not. Initially, the validity
* is UNKNOWN, and will be set to either IS_VALID or NOT_VALID after the first validation check.
*/
private volatile UTF8StringValidity isValid = UTF8StringValidity.UNKNOWN;

/**
* In case the current UTF-8 string is not valid, the number of bytes of the validated version
* of the current string (after possible replacement) will be stored in this field. This value
* will be equal to `numBytes` if the current string is valid. However, note that this doesn't
* GUARANTEE that the string is valid - only the `isValid` field can provide that information.
*/
private volatile int numBytesValid = -1;

public Object getBaseObject() { return base; }
public long getBaseOffset() { return offset; }

Expand Down Expand Up @@ -121,14 +142,6 @@ public static UTF8String fromBytes(byte[] bytes) {
}
}

private static UTF8String fromBytes(ArrayList<Byte> bytes) {
byte[] byteArray = new byte[bytes.size()];
for (int i = 0; i < bytes.size(); i++) {
byteArray[i] = bytes.get(i);
}
return fromBytes(byteArray);
}

/**
* Creates an UTF8String from byte array, which should be encoded in UTF-8.
*
Expand Down Expand Up @@ -313,11 +326,19 @@ private static boolean isValidSecondByte(byte b, byte firstByte) {
};
}

/**
* The Unicode replacement character (U+FFFD) is used to replace invalid code points.
*/
private static final byte[] UNICODE_REPLACEMENT_CHARACTER =
new byte[] { (byte) 0xEF, (byte) 0xBF, (byte) 0xBD };

private static void appendReplacementCharacter(ArrayList<Byte> bytes) {
for (byte b : UTF8String.UNICODE_REPLACEMENT_CHARACTER) bytes.add(b);
/**
* Private helper method to insert the Unicode replacement character (U+FFFD) to a byte array.
*/
private static void insertReplacementCharacter(byte[] bytes, int byteIndex) {
for (byte b : UTF8String.UNICODE_REPLACEMENT_CHARACTER) {
bytes[byteIndex++] = b;
}
}

/**
Expand All @@ -329,30 +350,48 @@ private static void appendReplacementCharacter(ArrayList<Byte> bytes) {
* @return A new UTF8String that is a valid UTF8 string.
*/
public UTF8String makeValid() {
ArrayList<Byte> bytes = new ArrayList<>();
int byteIndex = 0;
if (isValid()) return this;
return UTF8String.fromBytes(makeValidBytes());
}

/**
* Private helper method to create a valid UTF-8 byte sequence from the current UTF8String.
* In order to use this method, the number of bytes of the validated version of the current
* string (after possible replacement) must be evaluated first by calling `getIsValid`.
*/
private byte[] makeValidBytes() {
assert(numBytesValid > 0);
byte[] bytes = new byte[numBytesValid];
int byteIndex = 0, byteIndexValid = 0;
while (byteIndex < numBytes) {
// Read the first byte.
byte firstByte = getByte(byteIndex);
int expectedLen = bytesOfCodePointInUTF8[firstByte & 0xFF];
int codePointLen = Math.min(expectedLen, numBytes - byteIndex);
// 0B UTF-8 sequence (invalid first byte).
if (codePointLen == 0) {
appendReplacementCharacter(bytes);
insertReplacementCharacter(bytes, byteIndexValid);
byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length;
++byteIndex;
continue;
}
// 1B UTF-8 sequence (ASCII or truncated).
if (codePointLen == 1) {
if (firstByte >= 0) bytes.add(firstByte);
else appendReplacementCharacter(bytes);
if (firstByte >= 0) {
bytes[byteIndexValid++] = firstByte;
}
else {
insertReplacementCharacter(bytes, byteIndexValid);
byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length;
}
++byteIndex;
continue;
}
// Read the second byte.
byte secondByte = getByte(byteIndex + 1);
if (!isValidSecondByte(secondByte, firstByte)) {
appendReplacementCharacter(bytes);
insertReplacementCharacter(bytes, byteIndexValid);
byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length;
++byteIndex;
continue;
}
Expand All @@ -366,17 +405,18 @@ public UTF8String makeValid() {
}
// Invalid UTF-8 sequence (not enough continuation bytes).
if (continuationBytes < expectedLen) {
appendReplacementCharacter(bytes);
insertReplacementCharacter(bytes, byteIndexValid);
byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length;
byteIndex += continuationBytes;
continue;
}
// Valid UTF-8 sequence.
for (int i = 0; i < codePointLen; ++i) {
bytes.add(getByte(byteIndex + i));
bytes[byteIndexValid++] = getByte(byteIndex + i);
}
byteIndex += codePointLen;
}
return UTF8String.fromBytes(bytes);
return bytes;
}

/**
Expand All @@ -385,37 +425,89 @@ public UTF8String makeValid() {
* @return If string represents a valid UTF8 string.
*/
public boolean isValid() {
int byteIndex = 0;
if (isValid == UTF8StringValidity.UNKNOWN) {
isValid = getIsValid();
}
return isValid == UTF8StringValidity.IS_VALID;
}

/**
* Private helper method to calculate whether the current UTF-8 string is valid. Checking
* all code points is a linear time operation, as we need to scan the entire UTF-8 string.
* Hence, this method should generally only be called only once during UTF8String lifetime.
* Unlike `getNumBytesValid`, this method performs early exit as soon as an invalid byte
* sequence is found, and returns a boolean indicating the validity of the current string.
*/
private UTF8StringValidity getIsValid() {
boolean isValid = true;
int byteIndex = 0, byteCount = 0;
while (byteIndex < numBytes) {
// Read the first byte.
byte firstByte = getByte(byteIndex);
int expectedLen = bytesOfCodePointInUTF8[firstByte & 0xFF];
int codePointLen = Math.min(expectedLen, numBytes - byteIndex);
// 0B UTF-8 sequence (invalid first byte).
if (codePointLen == 0) return false;
if (codePointLen == 0) {
byteCount += UNICODE_REPLACEMENT_CHARACTER.length;
isValid = false;
++byteIndex;
continue;
}
// 1B UTF-8 sequence (ASCII or truncated).
if (codePointLen == 1) {
if (firstByte >= 0) {
++byteIndex;
continue;
++byteCount;
}
else {
byteCount += UNICODE_REPLACEMENT_CHARACTER.length;
isValid = false;
}
else return false;
++byteIndex;
continue;
}
// Read the second byte.
byte secondByte = getByte(byteIndex + 1);
if (!isValidSecondByte(secondByte, firstByte)) return false;
if (!isValidSecondByte(secondByte, firstByte)) {
byteCount += UNICODE_REPLACEMENT_CHARACTER.length;
isValid = false;
++byteIndex;
continue;
}
// Read remaining continuation bytes.
int continuationBytes = 2;
for (; continuationBytes < codePointLen; ++continuationBytes) {
byte nextByte = getByte(byteIndex + continuationBytes);
if (!isValidContinuationByte(nextByte)) return false;
if (!isValidContinuationByte(nextByte)) {
break;
}
}
// Invalid UTF-8 sequence (not enough continuation bytes).
if (continuationBytes < expectedLen) return false;
if (continuationBytes < expectedLen) {
byteCount += UNICODE_REPLACEMENT_CHARACTER.length;
isValid = false;
byteIndex += continuationBytes;
continue;
}
// Valid UTF-8 sequence.
for (int i = 0; i < codePointLen; ++i) {
++byteCount;
}
byteIndex += codePointLen;
}
return true;
setNumBytesValid(byteCount);
return isValid ? UTF8StringValidity.IS_VALID : UTF8StringValidity.NOT_VALID;
}

/**
* The method sets the total number of bytes of the validated version of the current string
* (after possible replacement), which will be equal to `numBytes` if the UTF8String is valid.
* This method should generally only be called once, from the `getIsValid` method.
*/
private void setNumBytesValid(int byteCount) {
if (byteCount < 0) {
throw new IllegalStateException("Error in UTF-8 byte count");
}
numBytesValid = byteCount;
}

/**
Expand Down Expand Up @@ -445,7 +537,7 @@ public Iterator<Integer> codePointIterator() {
}

public Iterator<Integer> codePointIterator(CodePointIteratorType iteratorMode) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID) {
return makeValid().codePointIterator();
}
return new CodePointIterator();
Expand Down Expand Up @@ -487,7 +579,7 @@ public Iterator<Integer> reverseCodePointIterator() {
}

public Iterator<Integer> reverseCodePointIterator(CodePointIteratorType iteratorMode) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) {
if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID) {
return makeValid().reverseCodePointIterator();
}
return new ReverseCodePointIterator();
Expand Down Expand Up @@ -1734,11 +1826,29 @@ public byte toByteExact() {
throw new NumberFormatException("invalid input syntax for type numeric: '" + this + "'");
}

/**
* Returns a string representation of this UTF8String object. The string representation consists
* of the string's characters encoded in UTF-8 and the result of this method is always a valid
* UTF-8 string. However, if the current UTF8String contains illegal UTF-8 byte sequences, the
* method will replace the illegal byte sequences with the Unicode replacement character U+FFFD,
* according to Java specification. Using this method with invalid UTF8Strings is NOT RECOMMENDED.
*/
@Override
public String toString() {
return new String(getBytes(), StandardCharsets.UTF_8);
}

/**
* Returns a string representation of this UTF8String object, but uses our custom implementation
* for invalid UTF-8 byte sequence replacement, as per the specification defined in the Unicode
* standard 15, Section 3.9, Paragraph D86, Table 3-7. Hence, the result of this method is
* always a valid UTF-8 string. This is the recommended method to use with invalid UTF8Strings.
*/
public String toValidString() {
if (isValid()) return toString();
return new String(makeValidBytes(), StandardCharsets.UTF_8);
}

@Override
public UTF8String clone() {
return fromBytes(getBytes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void testLowerCaseCodePoints() {
UTF8String.fromString("\uFFFD\uFFFD"), false);
assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[]
{(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}),
UTF8String.fromString("\uFFFD\uFFFD"), true);
UTF8String.fromString("\uFFFD\uFFFD\uFFFD\uFFFD\uFFFD\uFFFD"), true); // != Java toLowerCase
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ private static void checkBasic(String str, int len) {
assertEquals(len, s2.numChars());

assertEquals(str, s1.toString());
assertEquals(str, s1.toValidString());
assertEquals(str, s2.toString());
assertEquals(str, s2.toValidString());
assertEquals(s1, s2);

assertEquals(s1.hashCode(), s2.hashCode());
Expand Down Expand Up @@ -886,7 +888,10 @@ private void testMakeValid(String input, String expected) {
for (String hex : expected.split(" ")) exp.write(Integer.parseInt(hex.substring(2), 16));
ByteArrayOutputStream inp = new ByteArrayOutputStream();
for (String hex : input.split(" ")) inp.write(Integer.parseInt(hex.substring(2), 16));
assertEquals(fromBytes(exp.toByteArray()), fromBytes(inp.toByteArray()).makeValid());
UTF8String expUTF8String = fromBytes(exp.toByteArray());
UTF8String inpUTF8String = fromBytes(inp.toByteArray());
assertEquals(expUTF8String, inpUTF8String.makeValid());
assertEquals(inpUTF8String.toValidString(), inpUTF8String.toString());
}
@Test
public void makeValid() {
Expand Down

0 comments on commit bf25f0a

Please sign in to comment.