diff --git a/stream/src/main/java/org/enginehub/linbus/stream/impl/LinNbtReader.java b/stream/src/main/java/org/enginehub/linbus/stream/impl/LinNbtReader.java index cfcf9b2..622f81b 100644 --- a/stream/src/main/java/org/enginehub/linbus/stream/impl/LinNbtReader.java +++ b/stream/src/main/java/org/enginehub/linbus/stream/impl/LinNbtReader.java @@ -30,7 +30,10 @@ import java.io.DataInputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.CharBuffer; import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; import java.nio.charset.StandardCharsets; import java.util.ArrayDeque; import java.util.Deque; @@ -65,7 +68,7 @@ public class LinNbtReader implements LinStream { */ private static final int THREE_BYTE_SURROGATE_CONTINUATION = 0b1010_0000; - private static StringEncoding getGuaranteedStringEncoding(byte[] bytes) { + private static StringEncoding getGuaranteedStringEncoding(ByteBuffer bytes) { // The differences between the modified UTF-8 format and the standard UTF-8 format are the following: // The null byte '\u0000' is encoded in 2-byte format rather than 1-byte, so that the encoded strings never have embedded nulls. // Only the 1-byte, 2-byte, and 3-byte formats are used. @@ -75,7 +78,8 @@ private static StringEncoding getGuaranteedStringEncoding(byte[] bytes) { // So we can't use those as a definitive indicator of modified UTF-8 or not. boolean sawTwoByteNullStart = false; boolean sawThreeByteSurrogateStart = false; - for (byte b : bytes) { + for (int i = 0; i < bytes.remaining(); i++) { + byte b = bytes.get(i); if (b == TWO_BYTE_NULL_START) { sawTwoByteNullStart = true; } else if (sawTwoByteNullStart) { @@ -165,12 +169,75 @@ private enum StringEncoding { UNKNOWN, } + private static final class NormalUtf8Decoder { + private final CharsetDecoder decoder = StandardCharsets.UTF_8.newDecoder(); + // Default to some small allocation that is likely to cover most strings. + private ByteBuffer sourceBuffer = ByteBuffer.allocate(128); + private CharBuffer decodeBuffer = CharBuffer.allocate(128); + + void fill(DataInput input, int length) throws IOException { + ensureSourceBufferCapacity(length); + input.readFully(sourceBuffer.array(), 0, length); + sourceBuffer.limit(length); + } + + private void ensureSourceBufferCapacity(int requiredCapacity) { + if (sourceBuffer.capacity() < requiredCapacity) { + sourceBuffer = ByteBuffer.allocate(requiredCapacity); + } else { + sourceBuffer.clear(); + } + } + + private void ensureCharBufferCapacity(int requiredCapacity) { + if (decodeBuffer.capacity() < requiredCapacity) { + decodeBuffer = CharBuffer.allocate(requiredCapacity); + } else { + decodeBuffer.clear(); + } + } + + public String decode() throws CharacterCodingException { + int n = (int) (sourceBuffer.remaining() * decoder.averageCharsPerByte()); + ensureCharBufferCapacity(n); + + if ((n == 0) && (sourceBuffer.remaining() == 0)) + return ""; + decoder.reset(); + for (; ; ) { + CoderResult cr = sourceBuffer.hasRemaining() + ? decoder.decode(sourceBuffer, decodeBuffer, true) + : CoderResult.UNDERFLOW; + if (cr.isUnderflow()) { + cr = decoder.flush(decodeBuffer); + } + + if (cr.isUnderflow()) { + break; + } + if (cr.isOverflow()) { + // Ensure progress; n might be 0! + n += n / 2 + 1; + CharBuffer o = CharBuffer.allocate(n); + decodeBuffer.flip(); + o.put(decodeBuffer); + decodeBuffer = o; + continue; + } + cr.throwException(); + } + decodeBuffer.flip(); + return decodeBuffer.toString(); + } + } + private final DataInput input; /** * The state stack. We're currently on the one that's LAST. */ private final Deque stateStack; private StringEncoding stringEncoding; + private @Nullable NormalUtf8Decoder decoder; /** * Creates a new reader. @@ -295,42 +362,47 @@ private LinToken handleReadValue(LinTagId id) throws IOException { }; } + private NormalUtf8Decoder getNormalUtf8Decoder() { + NormalUtf8Decoder decoder = this.decoder; + if (decoder == null) { + decoder = new NormalUtf8Decoder(); + this.decoder = decoder; + } + return decoder; + } + private String readUtf() throws IOException { return switch (stringEncoding) { case MODIFIED_UTF_8 -> input.readUTF(); case NORMAL_UTF_8 -> { int length = input.readUnsignedShort(); - byte[] bytes = new byte[length]; - input.readFully(bytes); - yield decodeNormalUtf8(bytes); + NormalUtf8Decoder decoder = getNormalUtf8Decoder(); + decoder.fill(input, length); + yield decoder.decode(); } case UNKNOWN -> { int length = input.readUnsignedShort(); - byte[] bytes = new byte[length]; - input.readFully(bytes); - StringEncoding knownEncoding = getGuaranteedStringEncoding(bytes); + NormalUtf8Decoder decoder = getNormalUtf8Decoder(); + decoder.fill(input, length); + StringEncoding knownEncoding = getGuaranteedStringEncoding(decoder.sourceBuffer); yield switch (knownEncoding) { case MODIFIED_UTF_8 -> { stringEncoding = knownEncoding; - byte[] withLength = new byte[bytes.length + 2]; + byte[] withLength = new byte[length + 2]; withLength[0] = (byte) (length >> 8); withLength[1] = (byte) length; - System.arraycopy(bytes, 0, withLength, 2, bytes.length); + System.arraycopy(decoder.sourceBuffer.array(), 0, withLength, 2, length); yield new DataInputStream(new ByteArrayInputStream(withLength)).readUTF(); } case NORMAL_UTF_8 -> { stringEncoding = knownEncoding; - yield decodeNormalUtf8(bytes); + yield decoder.decode(); } // These are valid UTF-8 bytes that fit either encoding. Just read them as normal UTF-8, // but don't change the encoding. - case UNKNOWN -> decodeNormalUtf8(bytes); + case UNKNOWN -> decoder.decode(); }; } }; } - - private static String decodeNormalUtf8(byte[] bytes) throws CharacterCodingException { - return StandardCharsets.UTF_8.newDecoder().decode(ByteBuffer.wrap(bytes)).toString(); - } }