Skip to content

Commit

Permalink
Fix miscellaneous issues with WebSocket closing (#77)
Browse files Browse the repository at this point in the history
* small close fixes

* Fixes

* Skip test dependent on local client

Co-authored-by: Siemen Sikkema <[email protected]>
  • Loading branch information
tanner0101 and siemensikkema authored Oct 27, 2020
1 parent b073601 commit 2b06a70
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 97 deletions.
37 changes: 26 additions & 11 deletions Sources/WebSocketKit/WebSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ public final class WebSocket {
return channel.eventLoop
}

public private(set) var isClosed: Bool
public var isClosed: Bool {
!self.channel.isActive
}
public private(set) var closeCode: WebSocketErrorCode?

public var onClose: EventLoopFuture<Void> {
return self.channel.closeFuture
self.channel.closeFuture
}

private let channel: Channel
Expand All @@ -29,7 +31,8 @@ public final class WebSocket {
private var onPingCallback: (WebSocket) -> ()
private var frameSequence: WebSocketFrameSequence?
private let type: PeerType
private var waitingForPong = false
private var waitingForPong: Bool
private var waitingForClose: Bool
private var scheduledTimeoutTask: Scheduled<Void>?

init(channel: Channel, type: PeerType) {
Expand All @@ -39,7 +42,9 @@ public final class WebSocket {
self.onBinaryCallback = { _, _ in }
self.onPongCallback = { _ in }
self.onPingCallback = { _ in }
self.isClosed = false
self.waitingForPong = false
self.waitingForClose = false
self.scheduledTimeoutTask = nil
}

public func onText(_ callback: @escaping (WebSocket, String) -> ()) {
Expand Down Expand Up @@ -133,7 +138,11 @@ public final class WebSocket {
promise?.succeed(())
return
}
self.isClosed = true
guard !self.waitingForClose else {
promise?.succeed(())
return
}
self.waitingForClose = true
self.closeCode = code

let codeAsInt = UInt16(webSocketErrorCode: code)
Expand Down Expand Up @@ -168,17 +177,23 @@ public final class WebSocket {
func handle(incoming frame: WebSocketFrame) {
switch frame.opcode {
case .connectionClose:
if self.isClosed {
if self.waitingForClose {
// peer confirmed close, time to close channel
self.channel.close(mode: .all, promise: nil)
self.channel.close(mode: .output, promise: nil)
} else {
// peer asking for close, confirm and close channel
// peer asking for close, confirm and close output side channel
let promise = self.eventLoop.makePromise(of: Void.self)
var data = frame.data
self.close(code: data.readWebSocketErrorCode() ?? .unknown(1005),
promise: promise)
let maskingKey = frame.maskKey
if let maskingKey = maskingKey {
data.webSocketUnmask(maskingKey)
}
self.close(
code: data.readWebSocketErrorCode() ?? .unknown(1005),
promise: promise
)
promise.futureResult.whenComplete { _ in
self.channel.close(mode: .all, promise: nil)
self.channel.close(mode: .output, promise: nil)
}
}
case .ping:
Expand Down
222 changes: 136 additions & 86 deletions Tests/WebSocketKitTests/WebSocketKitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ import NIOWebSocket
final class WebSocketKitTests: XCTestCase {
func testWebSocketEcho() throws {
let promise = elg.next().makePromise(of: String.self)
let closePromise = elg.next().makePromise(of: Void.self)
WebSocket.connect(to: "ws://echo.websocket.org", on: elg) { ws in
ws.send("hello")
ws.onText { ws, string in
promise.succeed(string)
ws.close(promise: nil)
ws.close(promise: closePromise)
}
}.cascadeFailure(to: promise)
try XCTAssertEqual(promise.futureResult.wait(), "hello")
XCTAssertNoThrow(try closePromise.futureResult.wait())
}

func testWebSocketWithTLSEcho() throws {
Expand All @@ -33,34 +35,69 @@ final class WebSocketKitTests: XCTestCase {
XCTAssertThrowsError(try WebSocket.connect(host: "asdf", on: elg) { _ in }.wait())
}

func testImmediateSend() throws {
func testServerClose() throws {
let port = Int.random(in: 8000..<9000)

let promise = self.elg.next().makePromise(of: String.self)
let sendPromise = self.elg.next().makePromise(of: Void.self)
let serverClose = self.elg.next().makePromise(of: Void.self)
let clientClose = self.elg.next().makePromise(of: Void.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
if text == "close" {
ws.close(promise: serverClose)
}
}
}.bind(host: "localhost", port: port).wait()

let server = try ServerBootstrap(group: self.elg).childChannelInitializer { channel in
let webSocket = NIOWebSocketServerUpgrader(
shouldUpgrade: { channel, req in
return channel.eventLoop.makeSucceededFuture([:])
},
upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel) { ws in
ws.send("hello")
ws.onText { ws, string in
promise.succeed(string)
ws.close(promise: nil)
}
}
WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
ws.send("close", promise: sendPromise)
ws.onClose.cascade(to: clientClose)
}.cascadeFailure(to: sendPromise)

XCTAssertNoThrow(try sendPromise.futureResult.wait())
XCTAssertNoThrow(try serverClose.futureResult.wait())
XCTAssertNoThrow(try clientClose.futureResult.wait())
try server.close(mode: .all).wait()
}

func testClientClose() throws {
let port = Int.random(in: 8000..<9000)

let sendPromise = self.elg.next().makePromise(of: Void.self)
let serverClose = self.elg.next().makePromise(of: Void.self)
let clientClose = self.elg.next().makePromise(of: Void.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onText { ws, text in
ws.send(text)
}
ws.onClose.cascade(to: serverClose)
}.bind(host: "localhost", port: port).wait()

WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
ws.send("close", promise: sendPromise)
ws.onText { ws, text in
if text == "close" {
ws.close(promise: clientClose)
}
)
return channel.pipeline.configureHTTPServerPipeline(
withServerUpgrade: (
upgraders: [webSocket],
completionHandler: { ctx in
// complete
}
)
)
}
}.cascadeFailure(to: sendPromise)

XCTAssertNoThrow(try sendPromise.futureResult.wait())
XCTAssertNoThrow(try serverClose.futureResult.wait())
XCTAssertNoThrow(try clientClose.futureResult.wait())
try server.close(mode: .all).wait()
}

func testImmediateSend() throws {
let port = Int.random(in: 8000..<9000)

let promise = self.elg.next().makePromise(of: String.self)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.send("hello")
ws.onText { ws, string in
promise.succeed(string)
ws.close(promise: nil)
}
}.bind(host: "localhost", port: port).wait()

WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
Expand All @@ -78,28 +115,10 @@ final class WebSocketKitTests: XCTestCase {
let port = Int.random(in: 8000..<9000)

let pongPromise = self.elg.next().makePromise(of: String.self)

let server = try ServerBootstrap(group: self.elg).childChannelInitializer { channel in
let webSocket = NIOWebSocketServerUpgrader(
shouldUpgrade: { channel, req in
return channel.eventLoop.makeSucceededFuture([:])
},
upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel) { ws in
ws.onPing { ws in
ws.close(promise: nil)
}
}
}
)
return channel.pipeline.configureHTTPServerPipeline(
withServerUpgrade: (
upgraders: [webSocket],
completionHandler: { ctx in
// complete
}
)
)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.onPing { ws in
ws.close(promise: nil)
}
}.bind(host: "localhost", port: port).wait()

WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
Expand All @@ -119,25 +138,8 @@ final class WebSocketKitTests: XCTestCase {

let promise = self.elg.next().makePromise(of: WebSocketErrorCode.self)

_ = try ServerBootstrap(group: self.elg).childChannelInitializer { channel in
let webSocket = NIOWebSocketServerUpgrader(
shouldUpgrade: { channel, req in
return channel.eventLoop.makeSucceededFuture([:])
},
upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel) { ws in
ws.close(code: .normalClosure, promise: nil)
}
}
)
return channel.pipeline.configureHTTPServerPipeline(
withServerUpgrade: (
upgraders: [webSocket],
completionHandler: { ctx in
// complete
}
)
)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.close(code: .normalClosure, promise: nil)
}.bind(host: "localhost", port: port).wait()

WebSocket.connect(to: "ws://localhost:\(port)", on: self.elg) { ws in
Expand All @@ -151,33 +153,17 @@ final class WebSocketKitTests: XCTestCase {
}.cascadeFailure(to: promise)

try XCTAssertEqual(promise.futureResult.wait(), WebSocketErrorCode.normalClosure)
try server.close(mode: .all).wait()
}

func testHeadersAreSent() throws {
let port = Int.random(in: 8000..<9000)

let promise = self.elg.next().makePromise(of: String.self)

let server = try ServerBootstrap(group: self.elg).childChannelInitializer { channel in
let webSocket = NIOWebSocketServerUpgrader(
shouldUpgrade: { channel, req in
return channel.eventLoop.makeSucceededFuture([:])
},
upgradePipelineHandler: { channel, req in
promise.succeed(req.headers["Auth"].first!)
return WebSocket.server(on: channel) { ws in
ws.close(promise: nil)
}
}
)
return channel.pipeline.configureHTTPServerPipeline(
withServerUpgrade: (
upgraders: [webSocket],
completionHandler: { ctx in
// complete
}
)
)
let server = try ServerBootstrap.webSocket(on: self.elg) { req, ws in
promise.succeed(req.headers.first(name: "Auth")!)
ws.close(promise: nil)
}.bind(host: "localhost", port: port).wait()

WebSocket.connect(
Expand All @@ -191,6 +177,42 @@ final class WebSocketKitTests: XCTestCase {
try server.close(mode: .all).wait()
}

func testLocally() throws {
// swap to test websocket server against local client
try XCTSkipIf(true)

let port = Int(1337)
let shutdownPromise = self.elg.next().makePromise(of: Void.self)

let server = try! ServerBootstrap.webSocket(on: self.elg) { req, ws in
ws.send("welcome!")

ws.onClose.whenComplete {
print("ws.onClose done: \($0)")
}

ws.onText { ws, text in
switch text {
case "shutdown":
shutdownPromise.succeed(())
case "close":
ws.close().whenComplete {
print("ws.close() done \($0)")
}
default:
ws.send(text.reversed())
}
}
}.bind(host: "localhost", port: port).wait()
print("Serving at ws://localhost:\(port)")

print("Waiting for server shutdown...")
try shutdownPromise.futureResult.wait()

print("Waiting for server close...")
try server.close(mode: .all).wait()
}

var elg: EventLoopGroup!
override func setUp() {
// needs to be at least two to avoid client / server on same EL timing issues
Expand All @@ -200,3 +222,31 @@ final class WebSocketKitTests: XCTestCase {
try! self.elg.syncShutdownGracefully()
}
}

extension ServerBootstrap {
static func webSocket(
on eventLoopGroup: EventLoopGroup,
onUpgrade: @escaping (HTTPRequestHead, WebSocket) -> ()
) -> ServerBootstrap {
ServerBootstrap(group: eventLoopGroup).childChannelInitializer { channel in
let webSocket = NIOWebSocketServerUpgrader(
shouldUpgrade: { channel, req in
return channel.eventLoop.makeSucceededFuture([:])
},
upgradePipelineHandler: { channel, req in
return WebSocket.server(on: channel) { ws in
onUpgrade(req, ws)
}
}
)
return channel.pipeline.configureHTTPServerPipeline(
withServerUpgrade: (
upgraders: [webSocket],
completionHandler: { ctx in
// complete
}
)
)
}
}
}

0 comments on commit 2b06a70

Please sign in to comment.