diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/BinaryPartialEndpoint.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/BinaryPartialEndpoint.java index 6ea938d43f..552ce29297 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/BinaryPartialEndpoint.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/BinaryPartialEndpoint.java @@ -35,52 +35,56 @@ public final class BinaryPartialEndpoint extends Endpoint { @Override public void onOpen(final Session session, EndpointConfig config) { - session.addMessageHandler(new MessageHandler.Partial() { + session.addMessageHandler(new BinaryPartialMessageHandler(session)); - private ByteArrayOutputStream buffer; + } - @Override - public void onMessage(byte[] bytes, boolean last) { - if (last) { - if (buffer == null) { - onRequest(bytes); - } else { - try { - buffer(bytes); - byte[] tmp = buffer.toByteArray(); - onRequest(tmp); - } finally { - buffer = null; - } - } - } else { - buffer(bytes); - } - } + private static class BinaryPartialMessageHandler implements MessageHandler.Partial { - private void onRequest(final byte[] bytes) { - // Just return the received bytes for the test - DefaultServer.getWorker().execute(new Runnable() { - @Override - public void run() { - try { - session.getBasicRemote().sendBinary( - ByteBuffer.wrap(bytes)); - } catch (IOException e) { - throw new IllegalStateException(e); - } - } - }); - } + private final Session session; + private ByteArrayOutputStream buffer; - private void buffer(byte[] data) { + BinaryPartialMessageHandler(Session session) { + this.session = session; + } + + @Override + public void onMessage(byte[] bytes, boolean last) { + if (last) { if (buffer == null) { - buffer = new ByteArrayOutputStream(8096); + onRequest(bytes); + } else { + try { + buffer(bytes); + byte[] tmp = buffer.toByteArray(); + onRequest(tmp); + } finally { + buffer = null; + } } - buffer.write(data, 0, data.length); + } else { + buffer(bytes); } + } - }); + private void onRequest(final byte[] bytes) { + // Just return the received bytes for the test + DefaultServer.getWorker().execute(() -> { + try { + session.getBasicRemote().sendBinary( + ByteBuffer.wrap(bytes)); + } catch (IOException e) { + throw new IllegalStateException(e); + } + }); + } + + private void buffer(byte[] data) { + if (buffer == null) { + buffer = new ByteArrayOutputStream(8096); + } + buffer.write(data, 0, data.length); + } } } diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/ProgramaticErrorEndpoint.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/ProgramaticErrorEndpoint.java index 96a21c5c88..6c14efe545 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/ProgramaticErrorEndpoint.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/ProgramaticErrorEndpoint.java @@ -48,17 +48,13 @@ public static String getMessage() { @Override public void onOpen(Session session, EndpointConfig config) { - session.addMessageHandler(new MessageHandler.Whole() { + session.addMessageHandler((MessageHandler.Whole) message -> { - @Override - public void onMessage(String message) { - - QUEUE.add(message); - if (message.equals("app-error")) { - throw new RuntimeException("an error"); - } else if (message.equals("io-error")) { - throw new RuntimeException(new IOException()); - } + QUEUE.add(message); + if (message.equals("app-error")) { + throw new RuntimeException("an error"); + } else if (message.equals("io-error")) { + throw new RuntimeException(new IOException()); } }); } diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/TestMessagesReceivedInOrder.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/TestMessagesReceivedInOrder.java index 2cdcdb4054..f360aef439 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/TestMessagesReceivedInOrder.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/TestMessagesReceivedInOrder.java @@ -109,45 +109,7 @@ public void testMessagesReceivedInOrder() throws Exception { final CountDownLatch done = new CountDownLatch(1); final AtomicReference error = new AtomicReference<>(); ContainerProvider.getWebSocketContainer() - .connectToServer(new Endpoint() { - @Override - public void onOpen(final Session session, EndpointConfig endpointConfig) { - - try { - RemoteEndpoint.Basic rem = session.getBasicRemote(); - List messages = new ArrayList<>(); - for (int i = 0; i < MESSAGES; i++) { - byte[] data = new byte[2048]; - (new Random()).nextBytes(data); - String crc = md5(data); - rem.sendBinary(ByteBuffer.wrap(data)); - messages.add(crc); - } - - List received = EchoSocket.receivedEchos.getIoFuture().get(); - StringBuilder sb = new StringBuilder(); - boolean fail = false; - for (int i = 0; i < messages.size(); i++) { - if (received.size() <= i) { - fail = true; - sb.append(i + ": should be " + messages.get(i) + " but is empty."); - } else { - if (!messages.get(i).equals(received.get(i))) { - fail = true; - sb.append(i + ": should be " + messages.get(i) + " but is " + received.get(i) + " (but found at " + received.indexOf(messages.get(i)) + ")."); - } - } - } - if(fail) { - error.set(sb.toString()); - } - done.countDown(); - - } catch (Throwable t) { - t.printStackTrace(); - } - } - }, clientEndpointConfig, new URI(DefaultServer.getDefaultServerURL() + "/webSocket") + .connectToServer(new MessageOrderValidatorEndpoint(error, done), clientEndpointConfig, new URI(DefaultServer.getDefaultServerURL() + "/webSocket") ); assertTrue(done.await(30, TimeUnit.SECONDS)); if(error.get() != null) { @@ -186,4 +148,52 @@ private static String md5(byte[] buffer) { throw new InternalError("MD5 not supported on this platform"); } } + + private static class MessageOrderValidatorEndpoint extends Endpoint { + private final AtomicReference error; + private final CountDownLatch done; + + MessageOrderValidatorEndpoint(AtomicReference error, CountDownLatch done) { + this.error = error; + this.done = done; + } + + @Override + public void onOpen(final Session session, EndpointConfig endpointConfig) { + + try { + RemoteEndpoint.Basic rem = session.getBasicRemote(); + List messages = new ArrayList<>(); + for (int i = 0; i < MESSAGES; i++) { + byte[] data = new byte[2048]; + (new Random()).nextBytes(data); + String crc = md5(data); + rem.sendBinary(ByteBuffer.wrap(data)); + messages.add(crc); + } + + List received = EchoSocket.receivedEchos.getIoFuture().get(); + StringBuilder sb = new StringBuilder(); + boolean fail = false; + for (int i = 0; i < messages.size(); i++) { + if (received.size() <= i) { + fail = true; + sb.append(i + ": should be " + messages.get(i) + " but is empty."); + } else { + if (!messages.get(i).equals(received.get(i))) { + fail = true; + sb.append(i + ": should be " + messages.get(i) + " but is " + received.get(i) + " (but found at " + received.indexOf(messages.get(i)) + ")."); + } + } + } + if(fail) { + error.set(sb.toString()); + } + done.countDown(); + + } catch (Throwable t) { + t.printStackTrace(); + } + } + } }