Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Decoding Tests, Minor Fix to MapDecoder #20

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.IntFunction;

Expand Down Expand Up @@ -166,7 +167,7 @@ private static BlockDecoder decoderForType(Type type)
case VarbinaryType t -> wrapDecoder(binaryDecoder, t, IonType.BLOB, IonType.CLOB);
case RowType t -> wrapDecoder(RowDecoder.forFields(t.getFields()), t, IonType.STRUCT);
case ArrayType t -> wrapDecoder(new ArrayDecoder(decoderForType(t.getElementType())), t, IonType.LIST, IonType.SEXP);
case MapType t -> wrapDecoder(new MapDecoder(t, decoderForType(t.getValueType())), t, IonType.STRUCT);
case MapType t -> wrapDecoder(new MapDecoder(t), t, IonType.STRUCT);
default -> throw new IllegalArgumentException(String.format("Unsupported type: %s", type));
};
}
Expand Down Expand Up @@ -256,24 +257,32 @@ private void decode(IonReader ionReader, IntFunction<BlockBuilder> blockSelector
private static class MapDecoder
implements BlockDecoder
{
private final BiConsumer<String, BlockBuilder> keyConsumer;
private final BlockDecoder valueDecoder;
private final Type keyType;
private final Type valueType;
private final DistinctMapKeys distinctMapKeys;
private BlockBuilder keyBlockBuilder;
private BlockBuilder valueBlockBuilder;

public MapDecoder(MapType mapType, BlockDecoder valueDecoder)
MapDecoder(MapType mapType)
{
this.keyType = mapType.getKeyType();
if (!(keyType instanceof VarcharType _ || keyType instanceof CharType _)) {
throw new UnsupportedOperationException("Unsupported map key type: " + keyType);
}
this.valueType = mapType.getValueType();
this.valueDecoder = valueDecoder;
Type keyType = mapType.getKeyType();
Type valueType = mapType.getValueType();
this.valueDecoder = decoderForType(valueType);
this.distinctMapKeys = new DistinctMapKeys(mapType, true);
this.keyBlockBuilder = mapType.getKeyType().createBlockBuilder(null, 128);
this.valueBlockBuilder = mapType.getValueType().createBlockBuilder(null, 128);
this.keyBlockBuilder = keyType.createBlockBuilder(null, 128);
this.valueBlockBuilder = valueType.createBlockBuilder(null, 128);

this.keyConsumer = switch (keyType) {
case VarcharType t -> {
yield (String fieldName, BlockBuilder blockBuilder) ->
t.writeSlice(blockBuilder, Varchars.truncateToLength(Slices.utf8Slice(fieldName), t));
}
case CharType t -> {
yield (String fieldName, BlockBuilder blockBuilder) ->
t.writeSlice(blockBuilder, Chars.truncateToLengthAndTrimSpaces(Slices.utf8Slice(fieldName), t));
}
default -> throw new UnsupportedOperationException("Unsupported map key type: " + keyType);
};
}

@Override
Expand All @@ -282,13 +291,13 @@ public void decode(IonReader ionReader, BlockBuilder builder)
ionReader.stepIn();
// buffer the keys and values
while (ionReader.next() != null) {
VarcharType.VARCHAR.writeSlice(keyBlockBuilder, Slices.utf8Slice(ionReader.getFieldName()));
keyConsumer.accept(ionReader.getFieldName(), keyBlockBuilder);
valueDecoder.decode(ionReader, valueBlockBuilder);
}
ValueBlock keys = keyBlockBuilder.buildValueBlock();
ValueBlock values = valueBlockBuilder.buildValueBlock();
keyBlockBuilder = keyType.createBlockBuilder(null, keys.getPositionCount());
valueBlockBuilder = valueType.createBlockBuilder(null, values.getPositionCount());
keyBlockBuilder = keyBlockBuilder.newBlockBuilderLike(null);
valueBlockBuilder = valueBlockBuilder.newBlockBuilderLike(null);

// copy the distinct key entries to the output
boolean[] distinctKeys = distinctMapKeys.selectDistinctKeys(keys);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import io.trino.spi.type.CharType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.RowType;
Expand All @@ -37,13 +38,15 @@
import io.trino.spi.type.SqlTimestamp;
import io.trino.spi.type.SqlVarbinary;
import io.trino.spi.type.TimestampType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.spi.type.VarcharType;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
Expand All @@ -60,6 +63,8 @@
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RowType.field;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -95,14 +100,27 @@ public void testSuperBasicStruct()
public void testMap()
throws IOException
{
MapType mapType = new MapType(VARCHAR, INTEGER, TYPE_OPERATORS);
MapType mapType = new MapType(VarcharType.createVarcharType(3), INTEGER, TYPE_OPERATORS);
assertValues(
RowType.rowType(field("foo", mapType)),
"{ foo: { a: 1, a: 2, b: 5 } }",
List.of(ImmutableMap.builder()
.put("a", 2)
.put("b", 5)
.buildOrThrow()));
"{ foo: { bar: 1, bar: 2, baz: 5, quxx: 8 } } { foo: { bar: 17, baz: 31, qux: 53 } }",
List.of(Map.of("bar", 2, "baz", 5, "qux", 8)),
List.of(Map.of("bar", 17, "baz", 31, "qux", 53)));

mapType = new MapType(CharType.createCharType(3), INTEGER, TYPE_OPERATORS);
assertValues(
RowType.rowType(field("foo", mapType)),
"{ foo: { bar: 1, bar: 2, baz: 5, quxx: 8 } }",
List.of(Map.of("bar", 2, "baz", 5, "qux", 8)));
}

@Test
public void testUnsupportedMapKeys()
throws IOException
{
MapType mapType = new MapType(INTEGER, INTEGER, TYPE_OPERATORS);
Assertions.assertThrows(UnsupportedOperationException.class, () ->
assertValues(RowType.rowType(field("bad_map", mapType)), "", List.of()));
}

@Test
Expand Down Expand Up @@ -257,16 +275,12 @@ public void testCaseSensitiveExtraction()
public void testStructWithNullAndMissingValues()
throws IOException
{
final List<Object> listWithNulls = new ArrayList<>();
listWithNulls.add(null);
listWithNulls.add(null);

assertValues(
RowType.rowType(
field("foo", INTEGER),
field("bar", VARCHAR)),
"{ bar: null.symbol }",
listWithNulls);
Arrays.asList(null, null));
}

@Test
Expand Down Expand Up @@ -303,10 +317,27 @@ public void testNestedStruct()
field("name", RowType.rowType(
field("first", VARCHAR),
field("last", VARCHAR)))),
"{ name: { first: Woody, last: Guthrie } }",
"{ name: { first: Woody, last: Guthrie, superfluous: ignored } }",
List.of(List.of("Woody", "Guthrie")));
}

@Test
public void testNestedStructWithDuplicateAndMissingKeys()
throws IOException
{
assertValues(
RowType.rowType(
field("name", RowType.rowType(
field("first", VARCHAR),
field("last", VARCHAR)))),
"""
{ name: { last: Godfrey, last: Guthrie } }
{ name: { first: Joan, last: Baez } }
""",
List.of(Arrays.asList(null, "Guthrie")),
List.of(List.of("Joan", "Baez")));
}

@Test
public void testStructInList()
throws IOException
Expand All @@ -324,14 +355,55 @@ public void testStructInList()
}

@Test
public void testIonIntTooLargeForLong()
public void testIntsOfVariousSizes()
throws IOException
{
Assertions.assertThrows(TrinoException.class, () -> {
assertValues(RowType.rowType(field("my_bigint", BIGINT)),
"{ my_bigint: 18446744073709551786 }",
List.of());
});
List<String> ions = List.of(
"{ ion_int: 0x7f }", // < one byte
"{ ion_int: 0x7fff }", // < two bytes
"{ ion_int: 0x7fffffff }", // < four bytes
"{ ion_int: 0x7fffffffffffffff }", // < eight bytes
"{ ion_int: 0x7fffffffffffffff1 }" // > eight bytes
);

List<Type> intTypes = List.of(TINYINT, SMALLINT, INTEGER, BIGINT);
List<Object> expected = List.of((byte) 0x7f, (short) 0x7fff, 0x7fffffff, 0x7fffffffffffffffL);
for (int i = 0; i < intTypes.size(); i++) {
RowType rowType = RowType.rowType(field("ion_int", intTypes.get(i)));
assertValues(
rowType,
ions.get(i),
List.of(expected.get(i)));

int nextIon = i + 1;
Assertions.assertThrows(TrinoException.class, () -> {
assertValues(rowType,
ions.get(nextIon),
List.of());
});
}
}

@Test
public void testFloat()
throws IOException
{
RowType rowType = RowType.rowType(field("my_double", DoubleType.DOUBLE));
assertValues(
rowType,
"{ my_double: 4444e-4 }",
List.of(.4444));
}

@Test
public void testBytes()
throws IOException
{
RowType rowType = RowType.rowType(field("blobby", VARBINARY));
assertValues(
rowType,
"{ blobby: {{ YmxvYmJ5IG1jYmxvYmZhY2U= }} }",
List.of(new SqlVarbinary("blobby mcblobface".getBytes(StandardCharsets.UTF_8))));
}

@Test
Expand Down
Loading