From 68e359dde9fc12ce29f85d9096af29b798cd33bb Mon Sep 17 00:00:00 2001 From: Kamil Kloch Date: Fri, 22 Dec 2023 14:15:24 +0100 Subject: [PATCH] Add `autoPing` for web sockets. --- .../blazecore/websocket/Http4sWSStage.scala | 21 ++++++++++++++++--- .../websocket/Http4sWSStageSpec.scala | 3 ++- .../blaze/server/BlazeServerBuilder.scala | 13 ++++++++++++ .../blaze/server/Http1ServerStage.scala | 4 ++++ .../blaze/server/ProtocolSelector.scala | 5 ++++- .../blaze/server/WebSocketSupport.scala | 5 +++++ .../blaze/server/Http1ServerStageSpec.scala | 1 + 7 files changed, 47 insertions(+), 5 deletions(-) diff --git a/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala index 8cf640b43..b0a7de76e 100644 --- a/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala +++ b/blaze-core/src/main/scala/org/http4s/blazecore/websocket/Http4sWSStage.scala @@ -41,8 +41,10 @@ import org.http4s.websocket.WebSocketSeparatePipe import java.net.ProtocolException import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.ExecutionContext +import scala.concurrent.duration.FiniteDuration import scala.util.Failure import scala.util.Success +import cats.effect.syntax.all._ private[http4s] class Http4sWSStage[F[_]]( ws: WebSocket[F], @@ -50,6 +52,7 @@ private[http4s] class Http4sWSStage[F[_]]( deadSignal: SignallingRef[F, Boolean], writeSemaphore: Semaphore[F], dispatcher: Dispatcher[F], + autoPing: Option[(FiniteDuration, WebSocketFrame.Ping)], )(implicit F: Async[F]) extends TailStage[WebSocketFrame] { @@ -164,8 +167,8 @@ private[http4s] class Http4sWSStage[F[_]]( receiveSend(inputstream) } - val wsStream = - receiveSent + val wsStream = { + val s = receiveSent .evalMap(snkFun) .drain .interruptWhen(deadSignal) @@ -176,6 +179,17 @@ private[http4s] class Http4sWSStage[F[_]]( .compile .drain + autoPing match { + case None => s + case Some((delay, f)) => + snkFun(f) + .delayBy(delay) + .foreverM + .background + .use((_: F[Outcome[F, Throwable, Nothing]]) => s) + } + } + val result = F.handleErrorWith(wsStream) { case EOF => F.delay(stageShutdown()) @@ -203,6 +217,7 @@ object Http4sWSStage { sentClose: AtomicBoolean, deadSignal: SignallingRef[F, Boolean], dispatcher: Dispatcher[F], + autoPing: Option[(FiniteDuration, WebSocketFrame.Ping)], )(implicit F: Async[F]): F[Http4sWSStage[F]] = - Semaphore[F](1L).map(t => new Http4sWSStage(ws, sentClose, deadSignal, t, dispatcher)) + Semaphore[F](1L).map(t => new Http4sWSStage(ws, sentClose, deadSignal, t, dispatcher, autoPing)) } diff --git a/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala index 833b0e972..e64b8e819 100644 --- a/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala +++ b/blaze-core/src/test/scala/org/http4s/blazecore/websocket/Http4sWSStageSpec.scala @@ -82,9 +82,10 @@ class Http4sWSStageSpec extends CatsEffectSuite with DispatcherIOFixture { _.evalMap(backendInQ.offer), IO(closeHook.set(true)), ) + autoPing = None deadSignal <- SignallingRef[IO, Boolean](false) wsHead <- WSTestHead() - http4sWSStage <- Http4sWSStage[IO](ws, closeHook, deadSignal, dispatcher) + http4sWSStage <- Http4sWSStage[IO](ws, closeHook, deadSignal, dispatcher, autoPing) head = LeafBuilder(http4sWSStage).base(wsHead) _ <- IO(head.sendInboundCommand(Command.Connected)) } yield new TestWebsocketStage(outQ, head, closeHook, backendInQ) diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala index 11f7d9831..e4421a503 100644 --- a/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/BlazeServerBuilder.scala @@ -45,6 +45,7 @@ import org.http4s.server.SSLKeyStoreSupport.StoreInfo import org.http4s.server._ import org.http4s.server.websocket.WebSocketBuilder2 import org.http4s.{BuildInfo => Http4sBuildInfo} +import org.http4s.websocket.WebSocketFrame import org.log4s.getLogger import org.typelevel.vault._ import scodec.bits.ByteVector @@ -87,6 +88,7 @@ import scala.concurrent.duration._ * such as Nil disables this * @param maxConnections: The maximum number of client connections that may be active at any time. * @param maxWebSocketBufferSize: The maximum Websocket buffer length. 'None' means unbounded. + * @param webSocketAutoPing If `Some`, send the given Websocket `Ping` frame at the given interval. If `None`, do not automatically send pings */ class BlazeServerBuilder[F[_]] private ( socketAddress: InetSocketAddress, @@ -107,6 +109,7 @@ class BlazeServerBuilder[F[_]] private ( maxConnections: Int, val channelOptions: ChannelOptions, maxWebSocketBufferSize: Option[Int], + webSocketAutoPing: Option[(FiniteDuration, WebSocketFrame.Ping)], )(implicit protected val F: Async[F]) extends ServerBuilder[F] with BlazeBackendBuilder[Server] { @@ -133,6 +136,7 @@ class BlazeServerBuilder[F[_]] private ( maxConnections: Int = maxConnections, channelOptions: ChannelOptions = channelOptions, maxWebSocketBufferSize: Option[Int] = maxWebSocketBufferSize, + webSocketAutoPing: Option[(FiniteDuration, WebSocketFrame.Ping)] = webSocketAutoPing, ): Self = new BlazeServerBuilder( socketAddress, @@ -153,6 +157,7 @@ class BlazeServerBuilder[F[_]] private ( maxConnections, channelOptions, maxWebSocketBufferSize, + webSocketAutoPing, ) /** Configure HTTP parser length limits @@ -275,6 +280,11 @@ class BlazeServerBuilder[F[_]] private ( def withMaxWebSocketBufferSize(maxWebSocketBufferSize: Option[Int]): BlazeServerBuilder[F] = copy(maxWebSocketBufferSize = maxWebSocketBufferSize) + def withWebSocketAutoPing( + webSocketAutoPing: Option[(FiniteDuration, WebSocketFrame.Ping)] + ): BlazeServerBuilder[F] = + copy(webSocketAutoPing = webSocketAutoPing) + private def pipelineFactory( scheduler: TickWheelExecutor, engineConfig: Option[(SSLContext, SSLEngine => Unit)], @@ -335,6 +345,7 @@ class BlazeServerBuilder[F[_]] private ( scheduler = scheduler, dispatcher = dispatcher, maxWebSocketBufferSize = maxWebSocketBufferSize, + webSocketAutoPing = webSocketAutoPing, ) } @@ -358,6 +369,7 @@ class BlazeServerBuilder[F[_]] private ( dispatcher = dispatcher, webSocketKey = builder.webSocketKey, maxWebSocketBufferSize = maxWebSocketBufferSize, + webSocketAutoPing = webSocketAutoPing, ) } @@ -489,6 +501,7 @@ object BlazeServerBuilder { maxConnections = defaults.MaxConnections, channelOptions = ChannelOptions(Vector.empty), maxWebSocketBufferSize = None, + webSocketAutoPing = Some((42.seconds, WebSocketFrame.Ping())), ) private def defaultApp[F[_]: Applicative]: HttpApp[F] = diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala index 700fe8a0d..5fa43251c 100644 --- a/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/Http1ServerStage.scala @@ -41,6 +41,7 @@ import org.http4s.headers.`Transfer-Encoding` import org.http4s.server.ServiceErrorHandler import org.http4s.util.StringWriter import org.http4s.websocket.WebSocketContext +import org.http4s.websocket.WebSocketFrame import org.typelevel.vault._ import java.nio.ByteBuffer @@ -71,6 +72,7 @@ private[http4s] object Http1ServerStage { scheduler: TickWheelExecutor, dispatcher: Dispatcher[F], maxWebSocketBufferSize: Option[Int], + webSocketAutoPing: Option[(FiniteDuration, WebSocketFrame.Ping)], )(implicit F: Async[F]): Http1ServerStage[F] = new Http1ServerStage( routes, @@ -87,6 +89,8 @@ private[http4s] object Http1ServerStage { ) with WebSocketSupport[F] { val webSocketKey = wsKey override protected def maxBufferSize: Option[Int] = maxWebSocketBufferSize + override protected def autoPing: Option[(FiniteDuration, WebSocketFrame.Ping)] = + webSocketAutoPing } } diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala index 5ebe21884..da8e8138f 100644 --- a/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/ProtocolSelector.scala @@ -29,12 +29,13 @@ import org.http4s.blaze.pipeline.TailStage import org.http4s.blaze.util.TickWheelExecutor import org.http4s.server.ServiceErrorHandler import org.http4s.websocket.WebSocketContext +import org.http4s.websocket.WebSocketFrame import org.typelevel.vault._ import java.nio.ByteBuffer import javax.net.ssl.SSLEngine import scala.concurrent.ExecutionContext -import scala.concurrent.duration.Duration +import scala.concurrent.duration.{Duration, FiniteDuration} /** Facilitates the use of ALPN when using blaze http2 support */ private[http4s] object ProtocolSelector { @@ -53,6 +54,7 @@ private[http4s] object ProtocolSelector { dispatcher: Dispatcher[F], webSocketKey: Key[WebSocketContext[F]], maxWebSocketBufferSize: Option[Int], + webSocketAutoPing: Option[(FiniteDuration, WebSocketFrame.Ping)], )(implicit F: Async[F]): ALPNServerSelector = { def http2Stage(): TailStage[ByteBuffer] = { val newNode = { (streamId: Int) => @@ -100,6 +102,7 @@ private[http4s] object ProtocolSelector { scheduler, dispatcher, maxWebSocketBufferSize, + webSocketAutoPing, ) def preference(protos: Set[String]): String = diff --git a/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala b/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala index 19043b99c..48b024d87 100644 --- a/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala +++ b/blaze-server/src/main/scala/org/http4s/blaze/server/WebSocketSupport.scala @@ -27,12 +27,14 @@ import org.http4s.blazecore.websocket.Http4sWSStage import org.http4s.blazecore.websocket.WebSocketHandshake import org.http4s.headers._ import org.http4s.websocket.WebSocketContext +import org.http4s.websocket.WebSocketFrame import org.typelevel.vault.Key import java.nio.ByteBuffer import java.nio.charset.StandardCharsets._ import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration import scala.util.Failure import scala.util.Success @@ -45,6 +47,8 @@ private[http4s] trait WebSocketSupport[F[_]] extends Http1ServerStage[F] { protected def maxBufferSize: Option[Int] + protected def autoPing: Option[(FiniteDuration, WebSocketFrame.Ping)] + override protected def renderResponse( req: Request[F], resp: Response[F], @@ -111,6 +115,7 @@ private[http4s] trait WebSocketSupport[F[_]] extends Http1ServerStage[F] { deadSignal, writeSemaphore, dispatcher, + autoPing, ) ) // TODO: there is a constructor .prepend(new WSFrameAggregator) diff --git a/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala b/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala index de0a2500e..23cb6dfc8 100644 --- a/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala +++ b/blaze-server/src/test/scala/org/http4s/blaze/server/Http1ServerStageSpec.scala @@ -108,6 +108,7 @@ class Http1ServerStageSpec extends CatsEffectSuite { tw, dispatcher(), None, + None, ) pipeline.LeafBuilder(httpStage).base(head)