Skip to content

Commit

Permalink
apply reject strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
Zwiterrion committed Feb 1, 2024
1 parent 01c527d commit f26bd7d
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 62 deletions.
50 changes: 33 additions & 17 deletions otoroshi/app/gateway/websockets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import otoroshi.events._
import otoroshi.models._
import otoroshi.next.models.NgRoute
import otoroshi.next.plugins.RejectStrategy
import otoroshi.next.plugins.api.{NgAccess, NgPluginWrapper, NgWebsocketPluginContext, NgWebsocketValidatorPlugin, WebsocketMessage}
import otoroshi.next.plugins.api.{NgAccess, NgPluginWrapper, NgWebsocketPluginContext, NgWebsocketResponse, NgWebsocketValidatorPlugin, WebsocketMessage}
import otoroshi.next.proxy.NgProxyEngineError
import otoroshi.next.proxy.NgProxyEngineError.NgResultProxyEngineError
import otoroshi.next.utils.FEither
Expand Down Expand Up @@ -797,7 +797,14 @@ class WebSocketProxyActor(
.fromSinkAndSourceMat(
Sink.foreach[akka.http.scaladsl.model.ws.Message] {
data => new WebsocketEngine()
.handleResponse(route.get, rawRequest, data)(() => out ! PoisonPill)
.handleResponse(route.get, rawRequest, data)((message: NgWebsocketResponse) => {
Option(queueRef.get()).foreach(_.complete())

// message match {
// case NgWebsocketResponse(_, Some(status), Some(reason)) => out ! CloseMessage(status, reason)
// }

})
.map(_ => data match {
case akka.http.scaladsl.model.ws.TextMessage.Strict(text) =>
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] text message from target")
Expand Down Expand Up @@ -875,7 +882,14 @@ class WebSocketProxyActor(

def receive: Receive = {
case data: play.api.http.websocket.Message => new WebsocketEngine()
.handleRequest(route.get, rawRequest, data)(() => out ! PoisonPill)
.handleRequest(route.get, rawRequest, data)((message: NgWebsocketResponse) => {
Option(queueRef.get()).foreach(_.complete())

message match {
case NgWebsocketResponse(_, Some(status), Some(reason)) => out ! CloseMessage(status, reason)
}

})
.map(_ => data match {
case msg: PlayWSBinaryMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] binary message from client: ${msg.data.utf8String}")
Expand Down Expand Up @@ -919,7 +933,7 @@ class WebsocketEngine {
private def handle[A](validators: Seq[NgPluginWrapper.NgSimplePluginWrapper[NgWebsocketValidatorPlugin]],
route: NgRoute,
rawRequest: RequestHeader,
data: WebsocketMessage[A])(closeConnection: () => Unit)
data: WebsocketMessage[A])(closeConnection: NgWebsocketResponse => Unit)
(implicit env: Env, ec: ExecutionContext): FEither[NgProxyEngineError, Done] = {
val promise = Promise[Either[NgProxyEngineError, Done]]()

Expand All @@ -939,17 +953,19 @@ class WebsocketEngine {
wrapper.plugin.access(ctx, data).andThen {
case Failure(_) =>
promise.trySuccess(Left(NgResultProxyEngineError(Results.InternalServerError)))
case Success(NgAccess.NgDenied(result)) =>
println("DENIED", wrapper.plugin.rejectStrategy(ctx), wrapper.plugin.name)
wrapper.plugin.rejectStrategy(ctx) match {
case RejectStrategy.Drop => // TODO - do additional things ???
case RejectStrategy.Close => closeConnection()
}
promise.trySuccess(Left(NgResultProxyEngineError(result)))
case Success(NgAccess.NgAllowed) if plugins.size == 1 =>
promise.trySuccess(Right(Done))
case Success(NgAccess.NgAllowed) =>
next(plugins.tail)
case Success(value) => value match {
case response @ NgWebsocketResponse(NgAccess.NgDenied(result), status, reason) =>
println("DENIED", wrapper.plugin.rejectStrategy(ctx), wrapper.plugin.name, status, reason)
wrapper.plugin.rejectStrategy(ctx) match {
case RejectStrategy.Close => closeConnection(response)
case _ => // TODO - logging ??
}
promise.trySuccess(Left(NgResultProxyEngineError(result)))
case NgWebsocketResponse(NgAccess.NgAllowed, _, _) if plugins.size == 1 =>
promise.trySuccess(Right(Done))
case NgWebsocketResponse(NgAccess.NgAllowed, _, _) =>
next(plugins.tail)
}
}
}
}
Expand All @@ -960,7 +976,7 @@ class WebsocketEngine {

def handleRequest(route: NgRoute,
rawRequest: RequestHeader,
data: play.api.http.websocket.Message)(closeConnection: () => Unit)
data: play.api.http.websocket.Message)(closeConnection: NgWebsocketResponse => Unit)
(implicit env: Env, ec: ExecutionContext): FEither[NgProxyEngineError, Done] = {
val requestValidators: Seq[NgPluginWrapper.NgSimplePluginWrapper[NgWebsocketValidatorPlugin]] = getValidators(route)(_.plugin.onRequestFlow)

Expand All @@ -969,7 +985,7 @@ class WebsocketEngine {

def handleResponse(route: NgRoute,
rawRequest: RequestHeader,
data: akka.http.scaladsl.model.ws.Message)(closeConnection: () => Unit)
data: akka.http.scaladsl.model.ws.Message)(closeConnection: NgWebsocketResponse => Unit)
(implicit env: Env, ec: ExecutionContext): FEither[NgProxyEngineError, Done] = {
val responseValidators = getValidators(route)(_.plugin.onResponseFlow)

Expand Down
14 changes: 12 additions & 2 deletions otoroshi/app/next/plugins/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1288,15 +1288,25 @@ object WebsocketMessage {
}
}

case class NgWebsocketResponse(
result: NgAccess = NgAccess.NgAllowed,
statusCode: Option[Int] = None,
reason: Option[String] = None)

object NgWebsocketResponse {
def default: Future[NgWebsocketResponse] = NgWebsocketResponse().future
def denied(result: Result, statusCode: Int, reason: String) = NgWebsocketResponse(NgAccess.NgDenied(result), statusCode.some, reason.some)
def fdenied(result: Result, statusCode: Int, reason: String) = denied(result, statusCode, reason).future
}

trait NgWebsocketPlugin extends NgNamedPlugin {
def onRequestFlow: Boolean = true
def onResponseFlow: Boolean = true

def accessSync[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A]): NgAccess = NgAccess.NgAllowed
def accessSync[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A]): NgWebsocketResponse = NgWebsocketResponse()

def access[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])
(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = accessSync(ctx, message).vfuture
(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = accessSync(ctx, message).vfuture
}

trait NgWebsocketValidatorPlugin extends NgWebsocketPlugin {
Expand Down
89 changes: 47 additions & 42 deletions otoroshi/app/next/plugins/websocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import otoroshi.gateway.Errors
import otoroshi.next.plugins.api._
import otoroshi.utils.JsonPathValidator
import otoroshi.utils.syntax.implicits._
import play.api.http.websocket.CloseCodes
import play.api.libs.json._
import play.api.mvc.Results

import java.nio.charset.StandardCharsets
import scala.concurrent.{ExecutionContext, Future}
import scala.util._

Expand Down Expand Up @@ -139,11 +141,11 @@ class WebsocketContentValidatorIn extends NgWebsocketValidatorPlugin {
}

override def access[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])
(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = {
(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = {
validate(ctx, message)
.map {
case true => NgAccess.NgAllowed
case false => NgAccess.NgDenied(Errors
case true => NgWebsocketResponse()
case false => NgWebsocketResponse.denied(Errors
.craftResponseResultSync(
"forbidden",
Results.Forbidden,
Expand All @@ -152,7 +154,7 @@ class WebsocketContentValidatorIn extends NgWebsocketValidatorPlugin {
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
))
), CloseCodes.PolicyViolated, "failed to validate message")
}
}

Expand All @@ -177,46 +179,48 @@ class WebsocketTypeValidator extends NgWebsocketValidatorPlugin {
override def onResponseFlow: Boolean = false
override def onRequestFlow: Boolean = true

override def access[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = {
override def access[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = {
implicit val m: Materializer = env.otoroshiMaterializer

val config = ctx.cachedConfig(internalName)(WebsocketTypeValidatorConfig.format).getOrElse(WebsocketTypeValidatorConfig())

val accepted: Future[Boolean] = config.allowedFormat match {
case FrameFormat.All => true.future
case FrameFormat.Binary => message.isBinary.future
case FrameFormat.Text => message.isText.future
(config.allowedFormat match {
case FrameFormat.Binary if !message.isBinary => NgWebsocketResponse.fdenied(getResultError(ctx), CloseCodes.Unacceptable, "expected binary content")
case FrameFormat.Text if !message.isText => NgWebsocketResponse.fdenied(getResultError(ctx), CloseCodes.Unacceptable, "expected text content")
case FrameFormat.Text if message.isText => message.str()
.map(str => {
if (!StandardCharsets.UTF_8.newEncoder().canEncode(str)) {
NgWebsocketResponse.denied(getResultError(ctx), CloseCodes.InconsistentData, "non-UTF-8 data within content")
} else {
NgWebsocketResponse()
}
})
case FrameFormat.Json if message.isText => message.str()
.map(bs => Try(Json.parse(bs)))
.map {
case Success(_) => true
case _ => false
}
case _ => false.future
}

accepted.map {
case true => NgAccess.NgAllowed
case false =>
val result = Errors
.craftResponseResultSync(
"forbidden",
Results.Forbidden,
ctx.request,
None,
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)

NgAccess.NgDenied(result)
}
.map(bs => (Try(Json.parse(bs)), bs))
.map(res => {
res._1 match {
case Success(_) if !StandardCharsets.UTF_8.newEncoder().canEncode(res._2) => NgWebsocketResponse.denied(getResultError(ctx), CloseCodes.InconsistentData, "non-UTF-8 data within content")
case Failure(_) => NgWebsocketResponse.denied(getResultError(ctx), CloseCodes.Unacceptable, "expected json content")
case _ => NgWebsocketResponse()
}
})
case _ => NgWebsocketResponse.default
})
}

private def getResultError(ctx: NgWebsocketPluginContext)(implicit env: Env, ec: ExecutionContext) = Errors
.craftResponseResultSync(
"forbidden",
Results.Forbidden,
ctx.request,
None,
None,
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)

override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = {
val config = ctx.cachedConfig(internalName)(WebsocketTypeValidatorConfig.format).getOrElse(WebsocketTypeValidatorConfig())
println("HERHE")
println(config)
config.rejectStrategy
}
}
Expand Down Expand Up @@ -268,26 +272,27 @@ class WebsocketJsonFormatValidator extends NgWebsocketValidatorPlugin {
override def onResponseFlow: Boolean = false
override def onRequestFlow: Boolean = true

override def access[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = {
override def access[A](ctx: NgWebsocketPluginContext, message: WebsocketMessage[A])(implicit env: Env, ec: ExecutionContext): Future[NgWebsocketResponse] = {
implicit val m: Materializer = env.otoroshiMaterializer

val config = ctx.cachedConfig(internalName)(WebsocketJsonFormatValidatorConfig.format).getOrElse(WebsocketJsonFormatValidatorConfig())
println(config)

message.str()
.map(data => {
val userSchema = config.schema.get
val userSchema = config.schema.getOrElse("")

val jsonSchemaFactory = JsonSchemaFactory.getInstance(VersionFlag.valueOf(config.specification))
val jsonSchemaFactory = JsonSchemaFactory.getInstance(VersionFlag.fromId(config.specification).get())

val schemaConfig = new SchemaValidatorsConfig()
schemaConfig.setPathType(PathType.JSON_POINTER)
schemaConfig.setFormatAssertionsEnabled(true)

val schema = jsonSchemaFactory.getSchema(userSchema, schemaConfig)

schema.validate(data, InputFormat.JSON).isEmpty
})
.map {
case true => NgAccess.NgAllowed
case true => NgWebsocketResponse()
case false =>
val result = Errors
.craftResponseResultSync(
Expand All @@ -300,12 +305,12 @@ class WebsocketJsonFormatValidator extends NgWebsocketValidatorPlugin {
maybeRoute = ctx.route.some
)

NgAccess.NgDenied(result)
NgWebsocketResponse.denied(result, CloseCodes.PolicyViolated, "failed to validate message")
}
}

override def rejectStrategy(ctx: NgWebsocketPluginContext): RejectStrategy = {
val config = ctx.cachedConfig(internalName)(FrameFormatValidatorConfig.format).getOrElse(FrameFormatValidatorConfig())
val config = ctx.cachedConfig(internalName)(WebsocketJsonFormatValidatorConfig.format).getOrElse(WebsocketJsonFormatValidatorConfig())
config.rejectStrategy
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,28 @@ export default {
],
},
},
schema: {
label: 'schema',
type: 'code',
props: {
editorOnly: true,
label: 'Validation schema'
},
},
specification: {
label: 'JSON specification used',
type: 'select',
props: {
defaultValue: 'https://json-schema.org/draft/2020-12/schema',
options: [
{ value: "http://json-schema.org/draft-04/schema#", label: 'V4' },
{ value: "http://json-schema.org/draft-06/schema#", label: 'V6' },
{ value: "http://json-schema.org/draft-07/schema#", label: 'V7' },
{ value: "https://json-schema.org/draft/2019-09/schema", label: 'V201909' },
{ value: "https://json-schema.org/draft/2020-12/schema", label: 'V202012' }
],
},
}
},
config_flow: ['reject_strategy'],
config_flow: ['reject_strategy', 'schema', 'specification'],
};

0 comments on commit f26bd7d

Please sign in to comment.