diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 272a8aa128141..23adc772b7f34 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -357,8 +357,7 @@ public static UTF8String toUpperCase(final UTF8String target) { private static UTF8String toUpperCaseSlow(final UTF8String target) { // Note: In order to achieve the desired behavior, we use the ICU UCharacter class to // convert the string to uppercase, which only accepts a Java strings as input. - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - return UTF8String.fromString(UCharacter.toUpperCase(target.toString())); + return UTF8String.fromString(UCharacter.toUpperCase(target.toValidString())); } /** @@ -377,8 +376,7 @@ private static UTF8String toUpperCaseSlow(final UTF8String target, final int col // convert the string to uppercase, which only accepts a Java strings as input. ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - return UTF8String.fromString(UCharacter.toUpperCase(locale, target.toString())); + return UTF8String.fromString(UCharacter.toUpperCase(locale, target.toValidString())); } /** @@ -395,8 +393,7 @@ public static UTF8String toLowerCase(final UTF8String target) { private static UTF8String toLowerCaseSlow(final UTF8String target) { // Note: In order to achieve the desired behavior, we use the ICU UCharacter class to // convert the string to lowercase, which only accepts a Java strings as input. - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - return UTF8String.fromString(UCharacter.toLowerCase(target.toString())); + return UTF8String.fromString(UCharacter.toLowerCase(target.toValidString())); } /** @@ -415,8 +412,7 @@ private static UTF8String toLowerCaseSlow(final UTF8String target, final int col // convert the string to lowercase, which only accepts a Java strings as input. ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - return UTF8String.fromString(UCharacter.toLowerCase(locale, target.toString())); + return UTF8String.fromString(UCharacter.toLowerCase(locale, target.toValidString())); } /** @@ -459,8 +455,7 @@ public static UTF8String lowerCaseCodePoints(final UTF8String target) { } private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) { - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - String targetString = target.toString(); + String targetString = target.toValidString(); StringBuilder sb = new StringBuilder(); for (int i = 0; i < targetString.length(); ++i) { lowercaseCodePoint(targetString.codePointAt(i), sb); @@ -474,8 +469,7 @@ private static UTF8String lowerCaseCodePointsSlow(final UTF8String target) { public static UTF8String toTitleCase(final UTF8String target) { // Note: In order to achieve the desired behavior, we use the ICU UCharacter class to // convert the string to titlecase, which only accepts a Java strings as input. - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - return UTF8String.fromString(UCharacter.toTitleCase(target.toString(), + return UTF8String.fromString(UCharacter.toTitleCase(target.toValidString(), BreakIterator.getWordInstance())); } @@ -485,8 +479,7 @@ public static UTF8String toTitleCase(final UTF8String target) { public static UTF8String toTitleCase(final UTF8String target, final int collationId) { ULocale locale = CollationFactory.fetchCollation(collationId) .collator.getLocale(ULocale.ACTUAL_LOCALE); - // TODO(SPARK-48715): All UTF8String -> String conversions should use `makeValid` - return UTF8String.fromString(UCharacter.toTitleCase(locale, target.toString(), + return UTF8String.fromString(UCharacter.toTitleCase(locale, target.toValidString(), BreakIterator.getWordInstance(locale))); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index a2372d28a6c41..e6bddb12da56b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -21,7 +21,6 @@ import java.io.*; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.Arrays; import java.util.function.Function; import java.util.Iterator; @@ -62,6 +61,28 @@ public final class UTF8String implements Comparable, Externalizable, private int numBytes; private volatile int numChars = -1; + /** + * The validity of the UTF8Strings can be cached to avoid repeated validation checks, because + * that operation requires full string scan. Valid strings have no illegal UTF-8 byte sequences. + */ + private enum UTF8StringValidity { + UNKNOWN, IS_VALID, NOT_VALID + } + + /** + * Internal flag to indicate whether the UTF-8 string is valid or not. Initially, the validity + * is UNKNOWN, and will be set to either IS_VALID or NOT_VALID after the first validation check. + */ + private volatile UTF8StringValidity isValid = UTF8StringValidity.UNKNOWN; + + /** + * In case the current UTF-8 string is not valid, the number of bytes of the validated version + * of the current string (after possible replacement) will be stored in this field. This value + * will be equal to `numBytes` if the current string is valid. However, note that this doesn't + * GUARANTEE that the string is valid - only the `isValid` field can provide that information. + */ + private volatile int numBytesValid = -1; + public Object getBaseObject() { return base; } public long getBaseOffset() { return offset; } @@ -121,14 +142,6 @@ public static UTF8String fromBytes(byte[] bytes) { } } - private static UTF8String fromBytes(ArrayList bytes) { - byte[] byteArray = new byte[bytes.size()]; - for (int i = 0; i < bytes.size(); i++) { - byteArray[i] = bytes.get(i); - } - return fromBytes(byteArray); - } - /** * Creates an UTF8String from byte array, which should be encoded in UTF-8. * @@ -313,11 +326,19 @@ private static boolean isValidSecondByte(byte b, byte firstByte) { }; } + /** + * The Unicode replacement character (U+FFFD) is used to replace invalid code points. + */ private static final byte[] UNICODE_REPLACEMENT_CHARACTER = new byte[] { (byte) 0xEF, (byte) 0xBF, (byte) 0xBD }; - private static void appendReplacementCharacter(ArrayList bytes) { - for (byte b : UTF8String.UNICODE_REPLACEMENT_CHARACTER) bytes.add(b); + /** + * Private helper method to insert the Unicode replacement character (U+FFFD) to a byte array. + */ + private static void insertReplacementCharacter(byte[] bytes, int byteIndex) { + for (byte b : UTF8String.UNICODE_REPLACEMENT_CHARACTER) { + bytes[byteIndex++] = b; + } } /** @@ -329,8 +350,19 @@ private static void appendReplacementCharacter(ArrayList bytes) { * @return A new UTF8String that is a valid UTF8 string. */ public UTF8String makeValid() { - ArrayList bytes = new ArrayList<>(); - int byteIndex = 0; + if (isValid()) return this; + return UTF8String.fromBytes(makeValidBytes()); + } + + /** + * Private helper method to create a valid UTF-8 byte sequence from the current UTF8String. + * In order to use this method, the number of bytes of the validated version of the current + * string (after possible replacement) must be evaluated first by calling `getIsValid`. + */ + private byte[] makeValidBytes() { + assert(numBytesValid > 0); + byte[] bytes = new byte[numBytesValid]; + int byteIndex = 0, byteIndexValid = 0; while (byteIndex < numBytes) { // Read the first byte. byte firstByte = getByte(byteIndex); @@ -338,21 +370,28 @@ public UTF8String makeValid() { int codePointLen = Math.min(expectedLen, numBytes - byteIndex); // 0B UTF-8 sequence (invalid first byte). if (codePointLen == 0) { - appendReplacementCharacter(bytes); + insertReplacementCharacter(bytes, byteIndexValid); + byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length; ++byteIndex; continue; } // 1B UTF-8 sequence (ASCII or truncated). if (codePointLen == 1) { - if (firstByte >= 0) bytes.add(firstByte); - else appendReplacementCharacter(bytes); + if (firstByte >= 0) { + bytes[byteIndexValid++] = firstByte; + } + else { + insertReplacementCharacter(bytes, byteIndexValid); + byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length; + } ++byteIndex; continue; } // Read the second byte. byte secondByte = getByte(byteIndex + 1); if (!isValidSecondByte(secondByte, firstByte)) { - appendReplacementCharacter(bytes); + insertReplacementCharacter(bytes, byteIndexValid); + byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length; ++byteIndex; continue; } @@ -366,17 +405,18 @@ public UTF8String makeValid() { } // Invalid UTF-8 sequence (not enough continuation bytes). if (continuationBytes < expectedLen) { - appendReplacementCharacter(bytes); + insertReplacementCharacter(bytes, byteIndexValid); + byteIndexValid += UNICODE_REPLACEMENT_CHARACTER.length; byteIndex += continuationBytes; continue; } // Valid UTF-8 sequence. for (int i = 0; i < codePointLen; ++i) { - bytes.add(getByte(byteIndex + i)); + bytes[byteIndexValid++] = getByte(byteIndex + i); } byteIndex += codePointLen; } - return UTF8String.fromBytes(bytes); + return bytes; } /** @@ -385,37 +425,89 @@ public UTF8String makeValid() { * @return If string represents a valid UTF8 string. */ public boolean isValid() { - int byteIndex = 0; + if (isValid == UTF8StringValidity.UNKNOWN) { + isValid = getIsValid(); + } + return isValid == UTF8StringValidity.IS_VALID; + } + + /** + * Private helper method to calculate whether the current UTF-8 string is valid. Checking + * all code points is a linear time operation, as we need to scan the entire UTF-8 string. + * Hence, this method should generally only be called only once during UTF8String lifetime. + * Unlike `getNumBytesValid`, this method performs early exit as soon as an invalid byte + * sequence is found, and returns a boolean indicating the validity of the current string. + */ + private UTF8StringValidity getIsValid() { + boolean isValid = true; + int byteIndex = 0, byteCount = 0; while (byteIndex < numBytes) { // Read the first byte. byte firstByte = getByte(byteIndex); int expectedLen = bytesOfCodePointInUTF8[firstByte & 0xFF]; int codePointLen = Math.min(expectedLen, numBytes - byteIndex); // 0B UTF-8 sequence (invalid first byte). - if (codePointLen == 0) return false; + if (codePointLen == 0) { + byteCount += UNICODE_REPLACEMENT_CHARACTER.length; + isValid = false; + ++byteIndex; + continue; + } // 1B UTF-8 sequence (ASCII or truncated). if (codePointLen == 1) { if (firstByte >= 0) { - ++byteIndex; - continue; + ++byteCount; + } + else { + byteCount += UNICODE_REPLACEMENT_CHARACTER.length; + isValid = false; } - else return false; + ++byteIndex; + continue; } // Read the second byte. byte secondByte = getByte(byteIndex + 1); - if (!isValidSecondByte(secondByte, firstByte)) return false; + if (!isValidSecondByte(secondByte, firstByte)) { + byteCount += UNICODE_REPLACEMENT_CHARACTER.length; + isValid = false; + ++byteIndex; + continue; + } // Read remaining continuation bytes. int continuationBytes = 2; for (; continuationBytes < codePointLen; ++continuationBytes) { byte nextByte = getByte(byteIndex + continuationBytes); - if (!isValidContinuationByte(nextByte)) return false; + if (!isValidContinuationByte(nextByte)) { + break; + } } // Invalid UTF-8 sequence (not enough continuation bytes). - if (continuationBytes < expectedLen) return false; + if (continuationBytes < expectedLen) { + byteCount += UNICODE_REPLACEMENT_CHARACTER.length; + isValid = false; + byteIndex += continuationBytes; + continue; + } // Valid UTF-8 sequence. + for (int i = 0; i < codePointLen; ++i) { + ++byteCount; + } byteIndex += codePointLen; } - return true; + setNumBytesValid(byteCount); + return isValid ? UTF8StringValidity.IS_VALID : UTF8StringValidity.NOT_VALID; + } + + /** + * The method sets the total number of bytes of the validated version of the current string + * (after possible replacement), which will be equal to `numBytes` if the UTF8String is valid. + * This method should generally only be called once, from the `getIsValid` method. + */ + private void setNumBytesValid(int byteCount) { + if (byteCount < 0) { + throw new IllegalStateException("Error in UTF-8 byte count"); + } + numBytesValid = byteCount; } /** @@ -445,7 +537,7 @@ public Iterator codePointIterator() { } public Iterator codePointIterator(CodePointIteratorType iteratorMode) { - if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) { + if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID) { return makeValid().codePointIterator(); } return new CodePointIterator(); @@ -487,7 +579,7 @@ public Iterator reverseCodePointIterator() { } public Iterator reverseCodePointIterator(CodePointIteratorType iteratorMode) { - if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID && !isValid()) { + if (iteratorMode == CodePointIteratorType.CODE_POINT_ITERATOR_MAKE_VALID) { return makeValid().reverseCodePointIterator(); } return new ReverseCodePointIterator(); @@ -1734,11 +1826,29 @@ public byte toByteExact() { throw new NumberFormatException("invalid input syntax for type numeric: '" + this + "'"); } + /** + * Returns a string representation of this UTF8String object. The string representation consists + * of the string's characters encoded in UTF-8 and the result of this method is always a valid + * UTF-8 string. However, if the current UTF8String contains illegal UTF-8 byte sequences, the + * method will replace the illegal byte sequences with the Unicode replacement character U+FFFD, + * according to Java specification. Using this method with invalid UTF8Strings is NOT RECOMMENDED. + */ @Override public String toString() { return new String(getBytes(), StandardCharsets.UTF_8); } + /** + * Returns a string representation of this UTF8String object, but uses our custom implementation + * for invalid UTF-8 byte sequence replacement, as per the specification defined in the Unicode + * standard 15, Section 3.9, Paragraph D86, Table 3-7. Hence, the result of this method is + * always a valid UTF-8 string. This is the recommended method to use with invalid UTF8Strings. + */ + public String toValidString() { + if (isValid()) return toString(); + return new String(makeValidBytes(), StandardCharsets.UTF_8); + } + @Override public UTF8String clone() { return fromBytes(getBytes()); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index d084ef098248f..9438484344d62 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -187,7 +187,7 @@ public void testLowerCaseCodePoints() { UTF8String.fromString("\uFFFD\uFFFD"), false); assertLowerCaseCodePoints(UTF8String.fromBytes(new byte[] {(byte) 0xED, (byte) 0xA0, (byte) 0x80, (byte) 0xED, (byte) 0xB0, (byte) 0x80}), - UTF8String.fromString("\uFFFD\uFFFD"), true); + UTF8String.fromString("\uFFFD\uFFFD\uFFFD\uFFFD\uFFFD\uFFFD"), true); // != Java toLowerCase } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index d690da53c7c66..2428d40fe8016 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -43,7 +43,9 @@ private static void checkBasic(String str, int len) { assertEquals(len, s2.numChars()); assertEquals(str, s1.toString()); + assertEquals(str, s1.toValidString()); assertEquals(str, s2.toString()); + assertEquals(str, s2.toValidString()); assertEquals(s1, s2); assertEquals(s1.hashCode(), s2.hashCode()); @@ -886,7 +888,10 @@ private void testMakeValid(String input, String expected) { for (String hex : expected.split(" ")) exp.write(Integer.parseInt(hex.substring(2), 16)); ByteArrayOutputStream inp = new ByteArrayOutputStream(); for (String hex : input.split(" ")) inp.write(Integer.parseInt(hex.substring(2), 16)); - assertEquals(fromBytes(exp.toByteArray()), fromBytes(inp.toByteArray()).makeValid()); + UTF8String expUTF8String = fromBytes(exp.toByteArray()); + UTF8String inpUTF8String = fromBytes(inp.toByteArray()); + assertEquals(expUTF8String, inpUTF8String.makeValid()); + assertEquals(inpUTF8String.toValidString(), inpUTF8String.toString()); } @Test public void makeValid() {