Skip to content

Commit

Permalink
websocket plugin to filter content
Browse files Browse the repository at this point in the history
  • Loading branch information
Zwiterrion committed Jan 5, 2024
1 parent df87c97 commit 2b4ea20
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 36 deletions.
152 changes: 117 additions & 35 deletions otoroshi/app/gateway/websockets.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package otoroshi.gateway

import java.net.{InetAddress, InetSocketAddress}
import java.util.concurrent.atomic.AtomicReference
import akka.NotUsed
import akka.{Done, NotUsed}
import akka.actor.{Actor, ActorRef, PoisonPill, Props}
import akka.http.scaladsl.ClientTransport
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.model.ws.{InvalidUpgradeResponse, ValidUpgrade, WebSocketRequest}
import akka.http.scaladsl.model.ws.{BinaryMessage, InvalidUpgradeResponse, TextMessage, ValidUpgrade, WebSocketRequest}
import akka.http.scaladsl.settings.ClientConnectionSettings
import akka.http.scaladsl.util.FastFuture
import akka.stream.scaladsl.{Flow, Keep, Sink, Source, SourceQueueWithComplete, Tcp}
Expand All @@ -18,18 +18,16 @@ import otoroshi.events._
import otoroshi.models._
import org.joda.time.DateTime
import otoroshi.el.TargetExpressionLanguage
import otoroshi.next.models.{NgPluginInstance, NgRoute}
import otoroshi.next.plugins.api.{NgAccess, NgPluginWrapper, NgTransformerRequestContext, NgWebsocketPlugin, NgWebsocketPluginContext}
import otoroshi.next.proxy.NgProxyEngineError
import otoroshi.next.proxy.NgProxyEngineError.NgResultProxyEngineError
import otoroshi.next.utils.FEither
import otoroshi.script.Implicits._
import otoroshi.script.TransformerRequestContext
import otoroshi.utils.UrlSanitizer
import play.api.Logger
import play.api.http.websocket.{
CloseMessage,
PingMessage,
PongMessage,
BinaryMessage => PlayWSBinaryMessage,
Message => PlayWSMessage,
TextMessage => PlayWSTextMessage
}
import play.api.http.websocket.{CloseMessage, PingMessage, PongMessage, BinaryMessage => PlayWSBinaryMessage, Message => PlayWSMessage, TextMessage => PlayWSTextMessage}
import play.api.libs.json.{JsValue, Json}
import play.api.libs.streams.ActorFlow
import play.api.libs.ws.DefaultWSCookie
Expand Down Expand Up @@ -588,6 +586,7 @@ class WebSocketHandler()(implicit env: Env) {
out,
httpRequest.headers.toSeq, //.filterNot(_._1 == "Cookie"),
descriptor,
None, // TODO - check if we can pass the current route
httpRequest.target.getOrElse(_target),
env
)
Expand Down Expand Up @@ -621,10 +620,11 @@ object WebSocketProxyActor {
out: ActorRef,
headers: Seq[(String, String)],
descriptor: ServiceDescriptor,
route: Option[NgRoute],
target: Target,
env: Env
) =
Props(new WebSocketProxyActor(url, out, headers, descriptor, target, env))
Props(new WebSocketProxyActor(url, out, headers, descriptor, route, target, env))

def wsCall(url: String, headers: Seq[(String, String)], descriptor: ServiceDescriptor, target: Target)(implicit
env: Env,
Expand Down Expand Up @@ -743,6 +743,7 @@ class WebSocketProxyActor(
out: ActorRef,
headers: Seq[(String, String)],
descriptor: ServiceDescriptor,
route: Option[NgRoute],
target: Target,
env: Env
) extends Actor {
Expand All @@ -751,6 +752,7 @@ class WebSocketProxyActor(

implicit val ec = env.otoroshiExecutionContext
implicit val mat = env.otoroshiMaterializer
implicit val e = env

lazy val source = Source.queue[akka.http.scaladsl.model.ws.Message](50000, OverflowStrategy.dropTail)
lazy val logger = Logger("otoroshi-websocket-handler-actor")
Expand All @@ -775,7 +777,7 @@ class WebSocketProxyActor(
case Failure(e) => List.empty
}
case (key, value) if key.toLowerCase == "host" =>
Seq(akka.http.scaladsl.model.headers.Host(value))
Seq(akka.http.scaladsl.model.headers.Host(value.split(":").head))
case (key, value) if key.toLowerCase == "user-agent" =>
Seq(akka.http.scaladsl.model.headers.`User-Agent`(value))
case (key, value) =>
Expand Down Expand Up @@ -864,28 +866,108 @@ class WebSocketProxyActor(
// out ! PoisonPill
}

def receive = {
case msg: PlayWSBinaryMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] binary message from client: ${msg.data.utf8String}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data)))
}
case msg: PlayWSTextMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] text message from client: ${msg.data}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.TextMessage(msg.data)))
}
case msg: PingMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Ping message from client: ${msg.data}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data)))
}
case msg: PongMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Pong message from client: ${msg.data}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data)))
}
case CloseMessage(status, reason) => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] close message from client: $status : $reason")
Option(queueRef.get()).foreach(_.complete())
// out ! PoisonPill
}
case e => logger.error(s"[WEBSOCKET] Bad message type: $e")
def receive: Receive = {
case data: play.api.http.websocket.Message => new WebsocketEngine()
.handle(route.get, data)
.map(_ => data match {
case msg: PlayWSBinaryMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] binary message from client: ${msg.data.utf8String}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data)))
}
case msg: PlayWSTextMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] text message from client: ${msg.data}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.TextMessage(msg.data)))
}
case msg: PingMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Ping message from client: ${msg.data}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data)))
}
case msg: PongMessage => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Pong message from client: ${msg.data}")
Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data)))
}
case CloseMessage(status, reason) => {
if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] close message from client: $status : $reason")
Option(queueRef.get()).foreach(_.complete())
// out ! PoisonPill
}
case e => logger.error(s"[WEBSOCKET] Bad message type: $e")
})
}
}

class WebsocketEngine {

def handle(route: NgRoute, data: play.api.http.websocket.Message)
(implicit env: Env, ec: ExecutionContext): FEither[NgProxyEngineError, Done] = {
val (requestValidators, responseValidators) = route.plugins.slots
.filter(_.enabled)
// .filter(_.matches(request))
.map(inst => (inst, inst.getPlugin[NgWebsocketPlugin]))
.collect { case (inst, Some(plugin)) =>
NgPluginWrapper.NgSimplePluginWrapper(inst, plugin)
}
.partition(_.plugin.onRequestFlow)

val _ctx = NgWebsocketPluginContext(
snowflake = "snowflake",
request = null,
// rawRequest = rawRequest,
// otoroshiRequest = otoroshiRequest,
// apikey = attrs.get(otoroshi.plugins.Keys.ApiKeyKey),
// user = attrs.get(otoroshi.plugins.Keys.UserKey),
route = route,
config = Json.obj(),
// globalConfig = globalConfig.plugins.config,
attrs = null,
report = null,
// sequence = sequence,
// markPluginItem = markPluginItem
)

val promise = Promise[Either[NgProxyEngineError, Done]]()

def next(plugins: Seq[NgPluginWrapper[NgWebsocketPlugin]]): Unit = {
plugins.headOption match {
case None => promise.trySuccess(Right(Done))
case Some(wrapper) => {
val pluginConfig: JsValue = wrapper.plugin.defaultConfig
.map(dc => dc ++ wrapper.instance.config.raw)
.getOrElse(wrapper.instance.config.raw)
val ctx = _ctx.copy(
config = pluginConfig,
// apikey = _ctx.apikey.orElse(attrs.get(otoroshi.plugins.Keys.ApiKeyKey)),
// user = _ctx.user.orElse(attrs.get(otoroshi.plugins.Keys.UserKey)),
// idx = wrapper.instance.instanceId
)

wrapper.plugin.access(ctx, data match {
case PlayWSTextMessage(data) => data
case PlayWSBinaryMessage(data) => data.utf8String
case CloseMessage(statusCode, reason) => ""
case PingMessage(data) => data.utf8String
case PongMessage(data) => data.utf8String
}).andThen {
case Failure(_) =>
promise.trySuccess(
Left(
NgResultProxyEngineError(Results.InternalServerError)
)
)
case Success(NgAccess.NgDenied(result)) =>
promise.trySuccess(Left(NgResultProxyEngineError(result)))
case Success(NgAccess.NgAllowed) if plugins.size == 1 =>
promise.trySuccess(Right(Done))
case Success(NgAccess.NgAllowed) =>
next(plugins.tail)
}
}
}
}

next(requestValidators)
FEither.apply(promise.future)
}


}
33 changes: 32 additions & 1 deletion otoroshi/app/next/plugins/api.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ object NgPluginCategory {
case object Wasm extends NgPluginCategory { def name: String = "Wasm" }
case object Classic extends NgPluginCategory { def name: String = "Classic" }
case object ServiceDiscovery extends NgPluginCategory { def name: String = "ServiceDiscovery" }
case object Websocket extends NgPluginCategory { def name: String = "Websocket" }

val all = Seq(
Classic,
Expand All @@ -233,7 +234,8 @@ object NgPluginCategory {
TrafficControl,
Transformations,
Tunnel,
Wasm
Wasm,
Websocket
)
}

Expand Down Expand Up @@ -1230,3 +1232,32 @@ class NgMergedAccessValidator(plugins: Seq[NgPluginWrapper.NgSimplePluginWrapper
next(plugins, plugins.size)
}
}

case class NgWebsocketPluginContext(
snowflake: String,
config: JsValue,
idx: Int = 0,
request: RequestHeader,
route: NgRoute,
attrs: TypedMap,
report: NgExecutionReport,
) extends NgCachedConfigContext {
def json: JsValue = Json.obj(
"config" -> config
)

def wasmJson(implicit env: Env, ec: ExecutionContext): JsObject = {
(json.asObject ++ Json.obj(
"route" -> route.json
))
}
}

trait NgWebsocketPlugin extends NgNamedPlugin {
def onRequestFlow: Boolean = true
def onResponseFlow: Boolean = true

def accessSync(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env, ec: ExecutionContext): NgAccess = NgAccess.NgAllowed

def access(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = accessSync(ctx, message).vfuture
}
94 changes: 94 additions & 0 deletions otoroshi/app/next/plugins/websocket.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package otoroshi.next.plugins

import otoroshi.env.Env
import otoroshi.gateway.Errors
import otoroshi.next.plugins.api._
import otoroshi.utils.JsonPathValidator
import otoroshi.utils.syntax.implicits._
import play.api.http.websocket._
import play.api.libs.json._
import play.api.mvc.{Result, Results}

import scala.concurrent.{ExecutionContext, Future}
import scala.util._

case class FrameFormatValidatorConfig(
validator: Option[JsonPathValidator] = None
) extends NgPluginConfig {
def json: JsValue = FrameFormatValidatorConfig.format.writes(this)
}

object FrameFormatValidatorConfig {
val default = FrameFormatValidatorConfig()
val format = new Format[FrameFormatValidatorConfig] {
override def writes(o: FrameFormatValidatorConfig): JsValue = Json.obj(
"validator" -> o.validator.map(_.json)
)
override def reads(json: JsValue): JsResult[FrameFormatValidatorConfig] = {
Try {
FrameFormatValidatorConfig(
validator = (json \ "validator").asOpt[JsValue]
.flatMap(v => JsonPathValidator.format.reads(v).asOpt)
)
} match {
case Failure(e) => JsError(e.getMessage)
case Success(s) => JsSuccess(s)
}
}
}
}

class FrameFormatValidator extends NgWebsocketPlugin {

override def multiInstance: Boolean = true
override def defaultConfigObject: Option[NgPluginConfig] = Some(FrameFormatValidatorConfig.default)
override def core: Boolean = false
override def name: String = "Websocket frame format validator"
override def description: Option[String] = "Validate the format of each frames".some
override def visibility: NgPluginVisibility = NgPluginVisibility.NgUserLand
override def categories: Seq[NgPluginCategory] = Seq(NgPluginCategory.Websocket)
override def steps: Seq[NgStep] = Seq(NgStep.ValidateAccess)

override def onResponseFlow: Boolean = true
override def onRequestFlow: Boolean = true

private def validate(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env): Boolean = {
val config = ctx.cachedConfig(internalName)(FrameFormatValidatorConfig.format).getOrElse(FrameFormatValidatorConfig())
// val token: JsValue = ctx.attrs
// .get(otoroshi.next.plugins.Keys.JwtInjectionKey)
// .flatMap(_.decodedToken)
// .map { token =>
// Json.obj(
// "header" -> token.getHeader.fromBase64.parseJson,
// "payload" -> token.getPayload.fromBase64.parseJson
// )
// }
// .getOrElse(JsNull)
val json = ctx.json.asObject ++ Json.obj(
"route" -> ctx.route.json,
"message" -> message
// "token" -> token
)
config.validator.forall(validator => validator.validate(json))
}

override def access(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = {
if (validate(ctx, message)) {
NgAccess.NgAllowed.vfuture
} else {
Errors
.craftResponseResult(
"forbidden",
Results.Forbidden,
ctx.request,
None,
None,
duration = 0L,// ctx.report.getDurationNow(),
overhead = 0L, //ctx.report.getOverheadInNow(),
attrs = ctx.attrs,
maybeRoute = ctx.route.some
)
.map(r => NgAccess.NgDenied(r))
}
}
}
1 change: 1 addition & 0 deletions otoroshi/app/next/proxy/engine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2699,6 +2699,7 @@ class ProxyEngine() extends RequestHandler {
out,
request.headers.toSeq,
route.serviceDescriptor,
route.some,
finalTarget,
env
)
Expand Down
Loading

0 comments on commit 2b4ea20

Please sign in to comment.