diff --git a/otoroshi/app/gateway/websockets.scala b/otoroshi/app/gateway/websockets.scala index a21dc2e4c..3542e18c7 100644 --- a/otoroshi/app/gateway/websockets.scala +++ b/otoroshi/app/gateway/websockets.scala @@ -2,12 +2,12 @@ package otoroshi.gateway import java.net.{InetAddress, InetSocketAddress} import java.util.concurrent.atomic.AtomicReference -import akka.NotUsed +import akka.{Done, NotUsed} import akka.actor.{Actor, ActorRef, PoisonPill, Props} import akka.http.scaladsl.ClientTransport import akka.http.scaladsl.model.Uri import akka.http.scaladsl.model.headers.RawHeader -import akka.http.scaladsl.model.ws.{InvalidUpgradeResponse, ValidUpgrade, WebSocketRequest} +import akka.http.scaladsl.model.ws.{BinaryMessage, InvalidUpgradeResponse, TextMessage, ValidUpgrade, WebSocketRequest} import akka.http.scaladsl.settings.ClientConnectionSettings import akka.http.scaladsl.util.FastFuture import akka.stream.scaladsl.{Flow, Keep, Sink, Source, SourceQueueWithComplete, Tcp} @@ -18,18 +18,16 @@ import otoroshi.events._ import otoroshi.models._ import org.joda.time.DateTime import otoroshi.el.TargetExpressionLanguage +import otoroshi.next.models.{NgPluginInstance, NgRoute} +import otoroshi.next.plugins.api.{NgAccess, NgPluginWrapper, NgTransformerRequestContext, NgWebsocketPlugin, NgWebsocketPluginContext} +import otoroshi.next.proxy.NgProxyEngineError +import otoroshi.next.proxy.NgProxyEngineError.NgResultProxyEngineError +import otoroshi.next.utils.FEither import otoroshi.script.Implicits._ import otoroshi.script.TransformerRequestContext import otoroshi.utils.UrlSanitizer import play.api.Logger -import play.api.http.websocket.{ - CloseMessage, - PingMessage, - PongMessage, - BinaryMessage => PlayWSBinaryMessage, - Message => PlayWSMessage, - TextMessage => PlayWSTextMessage -} +import play.api.http.websocket.{CloseMessage, PingMessage, PongMessage, BinaryMessage => PlayWSBinaryMessage, Message => PlayWSMessage, TextMessage => PlayWSTextMessage} import play.api.libs.json.{JsValue, Json} import play.api.libs.streams.ActorFlow import play.api.libs.ws.DefaultWSCookie @@ -588,6 +586,7 @@ class WebSocketHandler()(implicit env: Env) { out, httpRequest.headers.toSeq, //.filterNot(_._1 == "Cookie"), descriptor, + None, // TODO - check if we can pass the current route httpRequest.target.getOrElse(_target), env ) @@ -621,10 +620,11 @@ object WebSocketProxyActor { out: ActorRef, headers: Seq[(String, String)], descriptor: ServiceDescriptor, + route: Option[NgRoute], target: Target, env: Env ) = - Props(new WebSocketProxyActor(url, out, headers, descriptor, target, env)) + Props(new WebSocketProxyActor(url, out, headers, descriptor, route, target, env)) def wsCall(url: String, headers: Seq[(String, String)], descriptor: ServiceDescriptor, target: Target)(implicit env: Env, @@ -743,6 +743,7 @@ class WebSocketProxyActor( out: ActorRef, headers: Seq[(String, String)], descriptor: ServiceDescriptor, + route: Option[NgRoute], target: Target, env: Env ) extends Actor { @@ -751,6 +752,7 @@ class WebSocketProxyActor( implicit val ec = env.otoroshiExecutionContext implicit val mat = env.otoroshiMaterializer + implicit val e = env lazy val source = Source.queue[akka.http.scaladsl.model.ws.Message](50000, OverflowStrategy.dropTail) lazy val logger = Logger("otoroshi-websocket-handler-actor") @@ -775,7 +777,7 @@ class WebSocketProxyActor( case Failure(e) => List.empty } case (key, value) if key.toLowerCase == "host" => - Seq(akka.http.scaladsl.model.headers.Host(value)) + Seq(akka.http.scaladsl.model.headers.Host(value.split(":").head)) case (key, value) if key.toLowerCase == "user-agent" => Seq(akka.http.scaladsl.model.headers.`User-Agent`(value)) case (key, value) => @@ -864,28 +866,108 @@ class WebSocketProxyActor( // out ! PoisonPill } - def receive = { - case msg: PlayWSBinaryMessage => { - if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] binary message from client: ${msg.data.utf8String}") - Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data))) - } - case msg: PlayWSTextMessage => { - if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] text message from client: ${msg.data}") - Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.TextMessage(msg.data))) - } - case msg: PingMessage => { - if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Ping message from client: ${msg.data}") - Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data))) - } - case msg: PongMessage => { - if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Pong message from client: ${msg.data}") - Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data))) - } - case CloseMessage(status, reason) => { - if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] close message from client: $status : $reason") - Option(queueRef.get()).foreach(_.complete()) - // out ! PoisonPill - } - case e => logger.error(s"[WEBSOCKET] Bad message type: $e") + def receive: Receive = { + case data: play.api.http.websocket.Message => new WebsocketEngine() + .handle(route.get, data) + .map(_ => data match { + case msg: PlayWSBinaryMessage => { + if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] binary message from client: ${msg.data.utf8String}") + Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data))) + } + case msg: PlayWSTextMessage => { + if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] text message from client: ${msg.data}") + Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.TextMessage(msg.data))) + } + case msg: PingMessage => { + if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Ping message from client: ${msg.data}") + Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data))) + } + case msg: PongMessage => { + if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] Pong message from client: ${msg.data}") + Option(queueRef.get()).foreach(_.offer(akka.http.scaladsl.model.ws.BinaryMessage(msg.data))) + } + case CloseMessage(status, reason) => { + if (logger.isDebugEnabled) logger.debug(s"[WEBSOCKET] close message from client: $status : $reason") + Option(queueRef.get()).foreach(_.complete()) + // out ! PoisonPill + } + case e => logger.error(s"[WEBSOCKET] Bad message type: $e") + }) } } + +class WebsocketEngine { + + def handle(route: NgRoute, data: play.api.http.websocket.Message) + (implicit env: Env, ec: ExecutionContext): FEither[NgProxyEngineError, Done] = { + val (requestValidators, responseValidators) = route.plugins.slots + .filter(_.enabled) + // .filter(_.matches(request)) + .map(inst => (inst, inst.getPlugin[NgWebsocketPlugin])) + .collect { case (inst, Some(plugin)) => + NgPluginWrapper.NgSimplePluginWrapper(inst, plugin) + } + .partition(_.plugin.onRequestFlow) + + val _ctx = NgWebsocketPluginContext( + snowflake = "snowflake", + request = null, + // rawRequest = rawRequest, + // otoroshiRequest = otoroshiRequest, + // apikey = attrs.get(otoroshi.plugins.Keys.ApiKeyKey), + // user = attrs.get(otoroshi.plugins.Keys.UserKey), + route = route, + config = Json.obj(), + // globalConfig = globalConfig.plugins.config, + attrs = null, + report = null, + // sequence = sequence, + // markPluginItem = markPluginItem + ) + + val promise = Promise[Either[NgProxyEngineError, Done]]() + + def next(plugins: Seq[NgPluginWrapper[NgWebsocketPlugin]]): Unit = { + plugins.headOption match { + case None => promise.trySuccess(Right(Done)) + case Some(wrapper) => { + val pluginConfig: JsValue = wrapper.plugin.defaultConfig + .map(dc => dc ++ wrapper.instance.config.raw) + .getOrElse(wrapper.instance.config.raw) + val ctx = _ctx.copy( + config = pluginConfig, + // apikey = _ctx.apikey.orElse(attrs.get(otoroshi.plugins.Keys.ApiKeyKey)), + // user = _ctx.user.orElse(attrs.get(otoroshi.plugins.Keys.UserKey)), + // idx = wrapper.instance.instanceId + ) + + wrapper.plugin.access(ctx, data match { + case PlayWSTextMessage(data) => data + case PlayWSBinaryMessage(data) => data.utf8String + case CloseMessage(statusCode, reason) => "" + case PingMessage(data) => data.utf8String + case PongMessage(data) => data.utf8String + }).andThen { + case Failure(_) => + promise.trySuccess( + Left( + NgResultProxyEngineError(Results.InternalServerError) + ) + ) + case Success(NgAccess.NgDenied(result)) => + promise.trySuccess(Left(NgResultProxyEngineError(result))) + case Success(NgAccess.NgAllowed) if plugins.size == 1 => + promise.trySuccess(Right(Done)) + case Success(NgAccess.NgAllowed) => + next(plugins.tail) + } + } + } + } + + next(requestValidators) + FEither.apply(promise.future) + } + + +} \ No newline at end of file diff --git a/otoroshi/app/next/plugins/api.scala b/otoroshi/app/next/plugins/api.scala index 6679184e9..a08257986 100644 --- a/otoroshi/app/next/plugins/api.scala +++ b/otoroshi/app/next/plugins/api.scala @@ -216,6 +216,7 @@ object NgPluginCategory { case object Wasm extends NgPluginCategory { def name: String = "Wasm" } case object Classic extends NgPluginCategory { def name: String = "Classic" } case object ServiceDiscovery extends NgPluginCategory { def name: String = "ServiceDiscovery" } + case object Websocket extends NgPluginCategory { def name: String = "Websocket" } val all = Seq( Classic, @@ -233,7 +234,8 @@ object NgPluginCategory { TrafficControl, Transformations, Tunnel, - Wasm + Wasm, + Websocket ) } @@ -1230,3 +1232,32 @@ class NgMergedAccessValidator(plugins: Seq[NgPluginWrapper.NgSimplePluginWrapper next(plugins, plugins.size) } } + +case class NgWebsocketPluginContext( + snowflake: String, + config: JsValue, + idx: Int = 0, + request: RequestHeader, + route: NgRoute, + attrs: TypedMap, + report: NgExecutionReport, + ) extends NgCachedConfigContext { + def json: JsValue = Json.obj( + "config" -> config + ) + + def wasmJson(implicit env: Env, ec: ExecutionContext): JsObject = { + (json.asObject ++ Json.obj( + "route" -> route.json + )) + } +} + +trait NgWebsocketPlugin extends NgNamedPlugin { + def onRequestFlow: Boolean = true + def onResponseFlow: Boolean = true + + def accessSync(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env, ec: ExecutionContext): NgAccess = NgAccess.NgAllowed + + def access(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = accessSync(ctx, message).vfuture +} \ No newline at end of file diff --git a/otoroshi/app/next/plugins/websocket.scala b/otoroshi/app/next/plugins/websocket.scala new file mode 100644 index 000000000..4ed01ddf5 --- /dev/null +++ b/otoroshi/app/next/plugins/websocket.scala @@ -0,0 +1,94 @@ +package otoroshi.next.plugins + +import otoroshi.env.Env +import otoroshi.gateway.Errors +import otoroshi.next.plugins.api._ +import otoroshi.utils.JsonPathValidator +import otoroshi.utils.syntax.implicits._ +import play.api.http.websocket._ +import play.api.libs.json._ +import play.api.mvc.{Result, Results} + +import scala.concurrent.{ExecutionContext, Future} +import scala.util._ + +case class FrameFormatValidatorConfig( + validator: Option[JsonPathValidator] = None +) extends NgPluginConfig { + def json: JsValue = FrameFormatValidatorConfig.format.writes(this) +} + +object FrameFormatValidatorConfig { + val default = FrameFormatValidatorConfig() + val format = new Format[FrameFormatValidatorConfig] { + override def writes(o: FrameFormatValidatorConfig): JsValue = Json.obj( + "validator" -> o.validator.map(_.json) + ) + override def reads(json: JsValue): JsResult[FrameFormatValidatorConfig] = { + Try { + FrameFormatValidatorConfig( + validator = (json \ "validator").asOpt[JsValue] + .flatMap(v => JsonPathValidator.format.reads(v).asOpt) + ) + } match { + case Failure(e) => JsError(e.getMessage) + case Success(s) => JsSuccess(s) + } + } + } +} + +class FrameFormatValidator extends NgWebsocketPlugin { + + override def multiInstance: Boolean = true + override def defaultConfigObject: Option[NgPluginConfig] = Some(FrameFormatValidatorConfig.default) + override def core: Boolean = false + override def name: String = "Websocket frame format validator" + override def description: Option[String] = "Validate the format of each frames".some + override def visibility: NgPluginVisibility = NgPluginVisibility.NgUserLand + override def categories: Seq[NgPluginCategory] = Seq(NgPluginCategory.Websocket) + override def steps: Seq[NgStep] = Seq(NgStep.ValidateAccess) + + override def onResponseFlow: Boolean = true + override def onRequestFlow: Boolean = true + + private def validate(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env): Boolean = { + val config = ctx.cachedConfig(internalName)(FrameFormatValidatorConfig.format).getOrElse(FrameFormatValidatorConfig()) +// val token: JsValue = ctx.attrs +// .get(otoroshi.next.plugins.Keys.JwtInjectionKey) +// .flatMap(_.decodedToken) +// .map { token => +// Json.obj( +// "header" -> token.getHeader.fromBase64.parseJson, +// "payload" -> token.getPayload.fromBase64.parseJson +// ) +// } +// .getOrElse(JsNull) + val json = ctx.json.asObject ++ Json.obj( + "route" -> ctx.route.json, + "message" -> message +// "token" -> token + ) + config.validator.forall(validator => validator.validate(json)) + } + + override def access(ctx: NgWebsocketPluginContext, message: String)(implicit env: Env, ec: ExecutionContext): Future[NgAccess] = { + if (validate(ctx, message)) { + NgAccess.NgAllowed.vfuture + } else { + Errors + .craftResponseResult( + "forbidden", + Results.Forbidden, + ctx.request, + None, + None, + duration = 0L,// ctx.report.getDurationNow(), + overhead = 0L, //ctx.report.getOverheadInNow(), + attrs = ctx.attrs, + maybeRoute = ctx.route.some + ) + .map(r => NgAccess.NgDenied(r)) + } + } +} diff --git a/otoroshi/app/next/proxy/engine.scala b/otoroshi/app/next/proxy/engine.scala index b48937d92..206741e59 100644 --- a/otoroshi/app/next/proxy/engine.scala +++ b/otoroshi/app/next/proxy/engine.scala @@ -2699,6 +2699,7 @@ class ProxyEngine() extends RequestHandler { out, request.headers.toSeq, route.serviceDescriptor, + route.some, finalTarget, env ) diff --git a/otoroshi/javascript/src/forms/ng_plugins/FrameFormatValidator.js b/otoroshi/javascript/src/forms/ng_plugins/FrameFormatValidator.js new file mode 100644 index 000000000..af3127180 --- /dev/null +++ b/otoroshi/javascript/src/forms/ng_plugins/FrameFormatValidator.js @@ -0,0 +1,30 @@ +export default { + id: 'cp:otoroshi.next.plugins.FrameFormatValidator', + config_schema: { + validator: { + label: 'validator', + type: 'object', + format: 'form', + schema: { + path: { + label: 'path', + type: 'string', + props: { + subTitle: 'Example: $.apikey.metadata.foo', + }, + }, + value: { + type: 'code', + help: 'Example: Contains(bar)', + props: { + label: 'Value', + type: 'json', + editorOnly: true, + }, + }, + }, + flow: ['path', 'value'], + }, + }, + config_flow: ['validator'], +}; diff --git a/otoroshi/javascript/src/forms/ng_plugins/index.js b/otoroshi/javascript/src/forms/ng_plugins/index.js index 9a131ea6b..e903e54ff 100644 --- a/otoroshi/javascript/src/forms/ng_plugins/index.js +++ b/otoroshi/javascript/src/forms/ng_plugins/index.js @@ -19,6 +19,7 @@ import EurekaTarget from './EurekaTarget'; import ExternalEurekaTarget from './ExternalEurekaTarget'; import ForceHttpsTraffic from './ForceHttpsTraffic'; import ForwardedHeader from './ForwardedHeader'; +import FrameFormatValidator from './FrameFormatValidator'; import GlobalMaintenanceMode from './GlobalMaintenanceMode'; import GlobalPerIpAddressThrottling from './GlobalPerIpAddressThrottling'; import GlobalThrottling from './GlobalThrottling'; @@ -163,6 +164,7 @@ const pluginsArray = [ ExternalEurekaTarget, ForceHttpsTraffic, ForwardedHeader, + FrameFormatValidator, GlobalMaintenanceMode, GlobalPerIpAddressThrottling, GlobalThrottling,