From 3f00aa91c11dab5c3f8636cd079f2808f07e5c26 Mon Sep 17 00:00:00 2001 From: Khushboo <68757952+desaikd@users.noreply.github.com> Date: Tue, 14 Jan 2025 11:26:36 -0800 Subject: [PATCH] Adds support to string coercions (#18) --- .../hive/formats/ion/IonDecoderFactory.java | 53 ++++++++++++++++--- .../trino/hive/formats/ion/TestIonFormat.java | 36 +++++++++++++ 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoderFactory.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoderFactory.java index 6ddfc58a64fd..255882bb43ef 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoderFactory.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/ion/IonDecoderFactory.java @@ -16,7 +16,9 @@ import com.amazon.ion.IonException; import com.amazon.ion.IonReader; import com.amazon.ion.IonType; +import com.amazon.ion.IonWriter; import com.amazon.ion.Timestamp; +import com.amazon.ion.system.IonTextWriterBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.slice.Slices; @@ -53,6 +55,7 @@ import io.trino.spi.type.VarcharType; import io.trino.spi.type.Varchars; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; @@ -127,8 +130,8 @@ private static BlockDecoder decoderForType(Type type) case BooleanType t -> wrapDecoder(boolDecoder, t, IonType.BOOL); case DateType t -> wrapDecoder(dateDecoder, t, IonType.TIMESTAMP); case TimestampType t -> wrapDecoder(timestampDecoder(t), t, IonType.TIMESTAMP); - case VarcharType t -> wrapDecoder(varcharDecoder(t), t, IonType.STRING, IonType.SYMBOL); - case CharType t -> wrapDecoder(charDecoder(t), t, IonType.STRING, IonType.SYMBOL); + case VarcharType t -> wrapDecoder(varcharDecoder(t), t, IonType.values()); + case CharType t -> wrapDecoder(charDecoder(t), t, IonType.values()); case VarbinaryType t -> wrapDecoder(binaryDecoder, t, IonType.BLOB, IonType.CLOB); case RowType t -> wrapDecoder(RowDecoder.forFields(t.getFields()), t, IonType.STRUCT); case ArrayType t -> wrapDecoder(new ArrayDecoder(decoderForType(t.getElementType())), t, IonType.LIST, IonType.SEXP); @@ -148,7 +151,7 @@ private static BlockDecoder decoderForType(Type type) */ private static BlockDecoder wrapDecoder(BlockDecoder decoder, Type trinoType, IonType... allowedTypes) { - Set allowedWithNull = new HashSet<>(Arrays.asList(allowedTypes)); + final Set allowedWithNull = new HashSet<>(Arrays.asList(allowedTypes)); allowedWithNull.add(IonType.NULL); return (reader, builder) -> { @@ -360,16 +363,52 @@ private static BlockDecoder decimalDecoder(DecimalType type) }; } + private static String getCoercedValue(IonReader ionReader) + { + IonTextWriterBuilder textWriterBuilder = IonTextWriterBuilder.standard(); + StringBuilder stringBuilder = new StringBuilder(); + IonWriter writer = textWriterBuilder.build(stringBuilder); + try { + writer.writeValue(ionReader); + } + catch (IOException e) { + throw new RuntimeException(e); + } + return stringBuilder.toString(); + } + private static BlockDecoder varcharDecoder(VarcharType type) { - return (ionReader, blockBuilder) -> - type.writeSlice(blockBuilder, Varchars.truncateToLength(Slices.utf8Slice(ionReader.stringValue()), type)); + return (ionReader, blockBuilder) -> { + IonType valueType = ionReader.getType(); + String value; + + if (valueType == IonType.SYMBOL || valueType == IonType.STRING) { + value = ionReader.stringValue(); + } + else { + // For any types other than IonType.SYMBOL and IonType.STRING, performs text coercion + value = getCoercedValue(ionReader); + } + type.writeSlice(blockBuilder, Varchars.truncateToLength(Slices.utf8Slice(value), type)); + }; } private static BlockDecoder charDecoder(CharType type) { - return (ionReader, blockBuilder) -> - type.writeSlice(blockBuilder, Chars.truncateToLengthAndTrimSpaces(Slices.utf8Slice(ionReader.stringValue()), type)); + return (ionReader, blockBuilder) -> { + IonType valueType = ionReader.getType(); + String value; + + if (valueType == IonType.SYMBOL || valueType == IonType.STRING) { + value = ionReader.stringValue(); + } + else { + // For any types other than IonType.SYMBOL and IonType.STRING, performs text coercion + value = getCoercedValue(ionReader); + } + type.writeSlice(blockBuilder, Chars.truncateToLengthAndTrimSpaces(Slices.utf8Slice(value), type)); + }; } private static final BlockDecoder byteDecoder = (ionReader, blockBuilder) -> diff --git a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/ion/TestIonFormat.java b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/ion/TestIonFormat.java index bb9b3539e566..ead606fe3e1b 100644 --- a/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/ion/TestIonFormat.java +++ b/lib/trino-hive-formats/src/test/java/io/trino/hive/formats/ion/TestIonFormat.java @@ -186,6 +186,42 @@ public void testCaseInsensitivityOfKeys() List.of(31, "baz")); } + @Test + public void testStringCoercions() + throws IOException + { + assertValues( + RowType.rowType( + field("foo", VARCHAR)), + "{ foo: true }", + List.of("true")); + assertValues( + RowType.rowType( + field("foo", VARCHAR)), + "{ foo: 31 }", + List.of("31")); + assertValues( + RowType.rowType( + field("foo", VARCHAR)), + "{ foo: 31.50 }", + List.of("31.50")); + assertValues( + RowType.rowType( + field("foo", VARCHAR)), + "{ foo: [1, 2, 3] }", + List.of("[1,2,3]")); + assertValues( + RowType.rowType( + field("foo", VARCHAR)), + "{ foo: \"bar\" }", + List.of("bar")); + assertValues( + RowType.rowType( + field("foo", VARCHAR)), + "{ foo: { nested_foo: 12 } }", + List.of("{nested_foo:12}")); + } + @Test public void testCaseInsensitivityOfDuplicateKeys() throws IOException