Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket frame validator #1825

Merged
merged 6 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions otoroshi/app/events/analytics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,45 @@ trait HealthCheckDataStore {
def push(event: JsValue)(implicit ec: ExecutionContext, env: Env): Future[Long]
}

case class WebsocketEvent(
`@type`: String = "WebsocketEvent",
`@id`: String,
`@timestamp`: DateTime,
reqId: String,
protocol: String,
to: Location,
target: Location,
frame: String,
frameSize: Int,
`@serviceId`: String,
`@service`: String,
statusCode: Option[Int] = None,
reason: Option[String] = None
) extends AnalyticEvent {
override def fromOrigin: Option[String] = None
override def fromUserAgent: Option[String] = None
def toJson(implicit _env: Env): JsValue = WebsocketEvent.writes(this, _env)
}

object WebsocketEvent {
def writes(o: WebsocketEvent, env: Env): JsValue =
Json.obj(
"@type" -> o.`@type`,
"@id" -> o.`@id`,
"@timestamp" -> o.`@timestamp`,
"reqId" -> o.reqId,
"protocol" -> o.protocol,
"to" -> Location.format.writes(o.to),
"target" -> Location.format.writes(o.target),
"@serviceId" -> o.`@serviceId`,
"@service" -> o.`@service`,
"statusCode" -> o.statusCode,
"reason" -> o.reason,
"frame" -> o.frame,
"frameSize" -> o.frameSize
)
}

sealed trait Filterable
case class ServiceDescriptorFilterable(service: ServiceDescriptor) extends Filterable
case class ApiKeyFilterable(apiKey: ApiKey) extends Filterable
Expand Down
42 changes: 41 additions & 1 deletion otoroshi/app/gateway/errors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import otoroshi.el.TargetExpressionLanguage
import otoroshi.env.Env
import otoroshi.events._
import otoroshi.gateway.Errors.{errorTemplate, messages}
import otoroshi.models.{ErrorTemplate, RemainingQuotas, ServiceDescriptor}
import otoroshi.models.{ErrorTemplate, RemainingQuotas, ServiceDescriptor, Target}
import otoroshi.next.models.NgRoute
import otoroshi.next.plugins.api.{NgPluginHttpResponse, NgTransformerErrorContext}
import otoroshi.next.proxy.NgExecutionReport
Expand Down Expand Up @@ -722,4 +722,44 @@ object Errors {
}
}
}

def craftWebsocketResponseResultSync(
frame: String,
frameSize: Int,
statusCode: Option[Int] = None,
reason: Option[String] = None,
req: RequestHeader,
sendEvent: Boolean = true,
route: NgRoute,
target: Target
)(implicit ec: ExecutionContext, env: Env): Result = {
val errorId = env.snowflakeGenerator.nextIdStr()

val finalRes = customResultSync(route.id, req, Status(statusCode.getOrElse(400)), "failed", None, emptyBody = true, errorId, modern = false)
if (sendEvent) {
WebsocketEvent(
`@id` = errorId,
reqId = env.snowflakeGenerator.nextIdStr(),
`@timestamp` = DateTime.now(),
protocol = req.version,
to = Location(
scheme = req.theProtocol,
host = req.theHost,
uri = req.relativeUri
),
target = Location(
scheme = target.scheme,
host = target.host,
uri = req.relativeUri
),
frame = frame,
frameSize = frameSize,
statusCode = statusCode,
reason = reason,
`@serviceId` = route.id,
`@service` = route.name,
).toAnalytics()(env)
}
finalRes
}
}
335 changes: 242 additions & 93 deletions otoroshi/app/gateway/websockets.scala

Large diffs are not rendered by default.

17 changes: 3 additions & 14 deletions otoroshi/app/next/controllers/plugins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,7 @@ package otoroshi.next.controllers
import akka.http.scaladsl.util.FastFuture
import otoroshi.actions.ApiAction
import otoroshi.env.Env
import otoroshi.next.plugins.api.{
NgAccessValidator,
NgBackendCall,
NgNamedPlugin,
NgPluginCategory,
NgPluginVisibility,
NgPreRouting,
NgRequestSink,
NgRequestSinkContext,
NgRequestTransformer,
NgRouteMatcher,
NgStep,
NgTunnelHandler
}
import otoroshi.next.plugins.api.{NgAccessValidator, NgBackendCall, NgNamedPlugin, NgPluginCategory, NgPluginVisibility, NgPreRouting, NgRequestSink, NgRequestSinkContext, NgRequestTransformer, NgRouteMatcher, NgStep, NgTunnelHandler, NgWebsocketPlugin}
import otoroshi.utils.syntax.implicits.BetterSyntax
import play.api.libs.json._
import play.api.mvc.{AbstractController, ControllerComponents}
Expand Down Expand Up @@ -67,10 +54,12 @@ class NgPluginsController(
case _: NgRequestSink => true
case _: NgRouteMatcher => true
case _: NgTunnelHandler => true
case p: NgWebsocketPlugin => p.onRequestFlow
case _ => false
}
val onResponse = plugin match {
case a: NgRequestTransformer => a.transformsResponse || a.transformsError
case p: NgWebsocketPlugin => p.onResponseFlow
case _: NgPreRouting => false
case _: NgAccessValidator => false
case _: NgRequestSink => false
Expand Down
130 changes: 124 additions & 6 deletions otoroshi/app/next/plugins/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Source}
import akka.util.ByteString
import com.github.blemale.scaffeine.{Cache, Scaffeine}
import otoroshi.auth.{AuthModule, BasicAuthModule, BasicAuthModuleConfig, SessionCookieValues}
import otoroshi.env.Env
import otoroshi.models.{ApiKey, BackOfficeUser, GlobalConfig, PrivateAppsUser, ServiceDescriptor}
import otoroshi.gateway.Errors
import otoroshi.models.{ApiKey, PrivateAppsUser, Target}
import otoroshi.next.models.{NgMatchedRoute, NgPluginInstance, NgRoute, NgTarget}
import otoroshi.next.plugins.RejectStrategy
import otoroshi.next.proxy.{NgExecutionReport, NgProxyEngineError, NgReportPluginSequence, NgReportPluginSequenceItem}
import otoroshi.next.utils.JsonHelpers
import otoroshi.script.{InternalEventListener, NamedPlugin, PluginType, StartableAndStoppable}
import otoroshi.security.IdGenerator
import otoroshi.utils.TypedMap
import otoroshi.utils.http.WSCookieWithSameSite
import otoroshi.utils.syntax.implicits._
import play.api.http.HttpEntity
import play.api.http.websocket.Message
import play.api.http.websocket.{CloseMessage, Message, PingMessage, PongMessage, BinaryMessage => PlayWSBinaryMessage, TextMessage => PlayWSTextMessage}
import play.api.libs.json._
import play.api.libs.ws.{WSCookie, WSResponse}
import play.api.mvc.{AnyContent, Cookie, Request, RequestHeader, Result, Results}
import play.api.mvc.{Cookie, RequestHeader, Result, Results}

import java.security.cert.X509Certificate
import scala.concurrent.duration.DurationInt
Expand Down Expand Up @@ -216,6 +216,7 @@ object NgPluginCategory {
case object Wasm extends NgPluginCategory { def name: String = "Wasm" }
case object Classic extends NgPluginCategory { def name: String = "Classic" }
case object ServiceDiscovery extends NgPluginCategory { def name: String = "ServiceDiscovery" }
case object Websocket extends NgPluginCategory { def name: String = "Websocket" }

val all = Seq(
Classic,
Expand All @@ -233,7 +234,8 @@ object NgPluginCategory {
TrafficControl,
Transformations,
Tunnel,
Wasm
Wasm,
Websocket
)
}

Expand Down Expand Up @@ -1230,3 +1232,119 @@ class NgMergedAccessValidator(plugins: Seq[NgPluginWrapper.NgSimplePluginWrapper
next(plugins, plugins.size)
}
}

case class NgWebsocketPluginContext(
config: JsValue,
idx: Int = 0,
request: RequestHeader,
route: NgRoute,
attrs: TypedMap,
target: Target
) extends NgCachedConfigContext {
def json: JsValue = Json.obj(
"config" -> config
)
}

sealed trait WebsocketMessage[A] {
def data: A
def str()(implicit m: Materializer, ec: ExecutionContext): Future[String]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a bytes() method ?

def size()(implicit m: Materializer, ec: ExecutionContext): Future[Int]
def isBinary: Boolean
def isText: Boolean = !isBinary
}

object WebsocketMessage {
case class AkkaMessage(override val data: akka.http.scaladsl.model.ws.Message) extends WebsocketMessage[akka.http.scaladsl.model.ws.Message] {
override def str()(implicit m: Materializer, ec: ExecutionContext): Future[String] = data match {
case akka.http.scaladsl.model.ws.TextMessage.Strict(text) => text.future
case akka.http.scaladsl.model.ws.TextMessage.Streamed(source) =>
source.runFold("")((concat, str) => concat + str)
case akka.http.scaladsl.model.ws.BinaryMessage.Strict(data) => data.utf8String.future
case akka.http.scaladsl.model.ws.BinaryMessage.Streamed(source) =>
source
.runFold(ByteString.empty)((concat, str) => concat ++ str).map(_.utf8String)
case _ => "".future
}

override def size()(implicit m: Materializer, ec: ExecutionContext): Future[Int] = data match {
case akka.http.scaladsl.model.ws.TextMessage.Strict(text) => text.length.future
case akka.http.scaladsl.model.ws.TextMessage.Streamed(source) =>
source.runFold("")((concat, str) => concat + str).map(_.length)
case akka.http.scaladsl.model.ws.BinaryMessage.Strict(data) => data.size.future
case akka.http.scaladsl.model.ws.BinaryMessage.Streamed(source) =>
source
.runFold(ByteString.empty)((concat, str) => concat ++ str).map(_.size)
case _ => 0.future
}

override def isBinary: Boolean = !data.isText
}
case class PlayMessage(override val data: play.api.http.websocket.Message) extends WebsocketMessage[play.api.http.websocket.Message] {
override def str()(implicit m: Materializer, ec: ExecutionContext): Future[String] = (data match {
case PlayWSTextMessage(data) => data
case PlayWSBinaryMessage(data) => data.utf8String
case CloseMessage(_, _) => ""
case PingMessage(data) => data.utf8String
case PongMessage(data) => data.utf8String
}).future

override def size()(implicit m: Materializer, ec: ExecutionContext): Future[Int] = (data match {
case PlayWSTextMessage(data) => data.length
case PlayWSBinaryMessage(data) => data.size
case CloseMessage(_, _) => 0
case PingMessage(data) => data.size
case PongMessage(data) => data.size
}).future

override def isBinary: Boolean = data.isInstanceOf[play.api.http.websocket.BinaryMessage]
}
}

case class NgWebsocketResponse(
result: NgAccess = NgAccess.NgAllowed,
statusCode: Option[Int] = None,
reason: Option[String] = None)

object NgWebsocketResponse {
def default: Future[NgWebsocketResponse] = NgWebsocketResponse().future
def error[A](ctx: NgWebsocketPluginContext,
message: WebsocketMessage[A],
statusCode: Int,
reason: String)(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = {
implicit val m: Materializer = env.otoroshiMaterializer
(for {
frame <- message.str
size <- message.size()
} yield (frame, size))
.collect {
case (frame, frameSize) =>
NgWebsocketResponse.denied(Errors
.craftWebsocketResponseResultSync(
frame = frame,
frameSize = frameSize,
statusCode = statusCode.some,
reason = reason.some,
req = ctx.request,
route = ctx.route,
target = ctx.target
), statusCode, reason)
}
}
private def denied(result: Result, statusCode: Int, reason: String) = NgWebsocketResponse(NgAccess.NgDenied(result), statusCode.some, reason.some)
}

trait NgWebsocketPlugin extends NgNamedPlugin {
def onRequestFlow: Boolean = false
def onResponseFlow: Boolean = false

def canAccess[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe move this one in the validator plugin

(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = NgWebsocketResponse.default

def canAccessResponse[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])
(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = NgWebsocketResponse.default
}

trait NgWebsocketValidatorPlugin extends NgWebsocketPlugin {
def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a transformer plugin

Loading
Loading