Skip to content
This repository was archived by the owner on May 16, 2022. It is now read-only.

Use ref for mutable auth0 token #2

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
val catsVersion = "1.4.0"
val circeVersion = "0.10.0"
val fs2Version = "1.0.0-RC1"
val fs2Version = "1.0.0"
val http4sVersion = "0.19.0-M2"

val jwtCirceVersion = "0.18.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package com.ovoenergy.http4s.client.middleware.auth0
import java.net.ConnectException

import cats.data.{EitherT, Kleisli}
import cats.effect.IO
import cats.effect.Sync
import cats.effect.concurrent.Ref
import cats.implicits._
import com.ovoenergy.http4s.client.middleware.auth0.Client.AuthZeroToken
import com.ovoenergy.http4s.client.middleware.auth0.Client.Error._
import org.http4s.circe._
import org.http4s.client.{Client => Http4sClient, DisposableResponse}
import org.http4s.client.{DisposableResponse, Client => Http4sClient}
import com.ovoenergy.http4s.client.middleware.auth0.TokenResponse._
import org.http4s._

import scala.util.Try

/**
* HTTP4s Client middleware that transparently provides Auth0 authentication
*
Expand All @@ -21,66 +22,71 @@ import scala.util.Try
* the token was not present or had become invalid.
* @todo Clean up case-logic for retry into something neater
*/
class Client(val config: Config, val client: Http4sClient[IO]) {
private implicit val authZeroErrorBodyEntityEncoder: EntityEncoder[IO, ErrorBody] = jsonEncoderOf
class Client[F[_]: Sync] private (val config: Config, val client: Http4sClient[F], currentToken: Ref[F, Option[AuthZeroToken]]) {
private implicit val authZeroErrorBodyEntityEncoder: EntityEncoder[F, ErrorBody] = jsonEncoderOf

import Client._

def open(req: Request[IO]): IO[DisposableResponse[IO]] = {
retryRequest(req, currentToken, 1).flatMap({
case Right((response, token)) =>
IO {
currentToken = Some(token)
response
}
case Left(err) =>
IO(currentToken = None)
.map(_ => errorResponse(err))
})
@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
def open(req: Request[F]): F[DisposableResponse[F]] = {
val newResponse: F[(DisposableResponse[F], Option[AuthZeroToken])] = for {
authToken <- currentToken.get
retry <- retryRequest(req, authToken, 1)
newResponseAndToken = getResultAndToken(retry)
} yield newResponseAndToken
newResponse.flatMap(r => currentToken.update(_ => r._2))
newResponse.map(r => r._1)
}


private def getResultAndToken(retryResult: Result[(DisposableResponse[F], AuthZeroToken)])
: (DisposableResponse[F], Option[AuthZeroToken]) = retryResult match {
case Right((response, token)) =>
(response, Some[AuthZeroToken](token))
case Left(err) =>
(errorResponse(err), Option.empty[AuthZeroToken])
}

@SuppressWarnings(Array("org.wartremover.warts.Recursion"))
private def retryRequest(req: Request[IO], maybeToken: Option[AuthZeroToken], retries: Int): IO[Result[ResponseAndToken]] = {
val result: IO[Result[ResponseAndToken]] = (for {
private def retryRequest(req: Request[F], maybeToken: Option[AuthZeroToken], retries: Int): F[Result[ResponseAndToken[F]]] = {
val result: F[Result[ResponseAndToken[F]]] = (for {
token <- EitherT(eitherToken(maybeToken))
result <- EitherT(performRequest(req, token))
} yield result).value

result.flatMap({
case Left(_) if retries > 0 => retryRequest(req, None, retries - 1)
case Left(err) => IO.pure(err.asLeft[ResponseAndToken])
case result@Right(_) => IO.pure(result)
case Left(err) => Sync[F].pure(err.asLeft[ResponseAndToken[F]])
case result@Right(_) => Sync[F].pure(result)
})
}

private def performRequest(req: Request[IO], token: AuthZeroToken): IO[Result[ResponseAndToken]] = {
private def performRequest(req: Request[F], token: AuthZeroToken): F[Result[ResponseAndToken[F]]] = {
client.open(enhanceRequest(req, token)).flatMap(disposableResponse => {
disposableResponse.response.status match {
case Status.Unauthorized => requestNotAuthorized(disposableResponse)
case Status.NotFound => requestNotAuthorized(disposableResponse)
case _ => IO.pure((disposableResponse, token).asRight[Error])
case _ => Sync[F].pure((disposableResponse, token).asRight[Error])
}
})
}

@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements"))
private def requestNotAuthorized(disposableResponse: DisposableResponse[IO]): IO[Result[ResponseAndToken]] = {
IO {
val _ = Try(disposableResponse.dispose.unsafeRunSync()) // TODO: log if this throws?
NotAuthorized().asLeft[ResponseAndToken]
}
}
private def requestNotAuthorized(disposableResponse: DisposableResponse[F]): F[Result[ResponseAndToken[F]]] =
disposableResponse.dispose.attempt.map{ _ => NotAuthorized().asLeft[ResponseAndToken[F]]} // swallow any exception and return NotAuthorised

private def eitherToken(maybeToken: Option[AuthZeroToken]): IO[Result[AuthZeroToken]] =
maybeToken.map(token => IO.pure(token.asRight[Error])).getOrElse(generateToken())
private def eitherToken(maybeToken: Option[AuthZeroToken]): F[Result[AuthZeroToken]] =
maybeToken.map(token => Sync[F].pure(token.asRight[Error])).getOrElse(generateToken())

private def generateToken(): IO[Result[AuthZeroToken]] = {
private def generateToken(): F[Result[AuthZeroToken]] = {
val request = TokenRequest(config.audience, config.id, config.secret)
implicit val tokenRequestEncoder: EntityEncoder[F, TokenRequest] = jsonEncoderOf
implicit val customEntityDecoder: EntityDecoder[F, TokenResponse] = jsonOf[F, TokenResponse]

val uri: Uri = config.uri / "oauth" / "token"

client
.expect[TokenResponse](Request[IO](method = Method.POST, uri = uri).withEntity(request))
.expect[TokenResponse](Request[F](method = Method.POST, uri = uri).withEntity(request))
.map(_.accessToken.asRight[Error])
.handleError {
case e: ConnectException => AuthZeroUnavailable(e).asLeft[AuthZeroToken]
Expand All @@ -89,39 +95,37 @@ class Client(val config: Config, val client: Http4sClient[IO]) {
}
}

private def enhanceRequest(req: Request[IO], token: AuthZeroToken): Request[IO] = req.putHeaders(Header("Authorization", s"Bearer $token"))
private def enhanceRequest(req: Request[F], token: AuthZeroToken): Request[F] = req.putHeaders(Header("Authorization", s"Bearer $token"))

private def errorResponse(err: Error): DisposableResponse[IO] = {
private def errorResponse(err: Error): DisposableResponse[F] = {
val status = err match {
case NotAuthorized() => Status.Unauthorized
case AuthZeroUnavailable(_) => Status.RequestTimeout
}

val entityResponse = Response[IO](status = status).withEntity(ErrorBody(err.msg))
val entityResponse = Response[F](status = status).withEntity(ErrorBody(err.msg))

DisposableResponse(entityResponse, nullOpDispose)
}

private val nullOpDispose = IO.pure(())

private var currentToken: Option[AuthZeroToken] = None
private val nullOpDispose = Sync[F].pure(())
}

object Client {

type ResponseAndToken = (DisposableResponse[IO], AuthZeroToken)
type ResponseAndToken[F[_]] = (DisposableResponse[F], AuthZeroToken)

type AuthZeroToken = String

@SuppressWarnings(Array("org.wartremover.warts.Nothing"))
def apply(config: Config)(client: Http4sClient[IO]): Http4sClient[IO] = {
val authClient = new Client(config, client)
def apply[F[_]: Sync](config: Config)(client: Http4sClient[F], clientToken: ClientToken[F]): F[Http4sClient[F]] = {
val authClient = new Client(config, client, clientToken.token)

def authenticatedOpen(req: Request[IO]): IO[DisposableResponse[IO]] = {
def authenticatedOpen(req: Request[F]): F[DisposableResponse[F]] = {
authClient.open(req)
}

client.copy(open = Kleisli(authenticatedOpen))
Sync[F].pure(client.copy(open = Kleisli(authenticatedOpen)))
}

sealed trait Error extends Product with Serializable {
Expand All @@ -141,5 +145,11 @@ object Client {
}

}
}

final class ClientToken[F[_]](val token: Ref[F, Option[AuthZeroToken]])
object ClientToken {
def apply[F[_]: Sync]: F[ClientToken[F]] = for {
t <- Ref.of[F, Option[AuthZeroToken]](None)
} yield new ClientToken[F](t)
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package com.ovoenergy.http4s.client.middleware.auth0

import cats.effect.IO
import io.circe.generic.extras.Configuration
import io.circe.generic.extras.semiauto._
import io.circe._
import org.http4s.EntityEncoder
import org.http4s.circe._

final case class ErrorBody(message: String)

@SuppressWarnings(Array("org.wartremover.warts.PublicInference"))
object ErrorBody {
implicit val customConfig: Configuration = Configuration.default.withSnakeCaseMemberNames.withDefaults
implicit val customEncoder: Encoder[ErrorBody] = deriveEncoder[ErrorBody]

implicit val customEntityEncoder: EntityEncoder[IO, ErrorBody] = jsonEncoderOf[IO, ErrorBody]
}

Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.ovoenergy.http4s.client.middleware.auth0

import cats.effect.IO
import io.circe.generic.extras.Configuration
import io.circe.generic.extras.semiauto._
import io.circe._
import org.http4s.EntityEncoder
import org.http4s.circe._

final case class TokenRequest(audience: String,
clientId: String,
Expand All @@ -17,5 +14,4 @@ object TokenRequest {
val DEFAULT_GRANT_TYPE = "client_credentials"
implicit val customConfig: Configuration = Configuration.default.withSnakeCaseMemberNames.withDefaults
implicit val customEncoder: Encoder[TokenRequest] = deriveEncoder[TokenRequest]
implicit val customEntityEncoder: EntityEncoder[IO, TokenRequest] = jsonEncoderOf
}
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
package com.ovoenergy.http4s.client.middleware.auth0

import cats.effect.IO
import io.circe.generic.extras.Configuration
import io.circe.generic.extras.semiauto._
import io.circe._
import org.http4s.EntityDecoder
import org.http4s.circe.jsonOf

final case class TokenResponse(accessToken: String)

@SuppressWarnings(Array("org.wartremover.warts.PublicInference"))
object TokenResponse {
implicit val customConfig: Configuration = Configuration.default.withSnakeCaseMemberNames.withDefaults
implicit val customDecoder: Decoder[TokenResponse] = deriveDecoder[TokenResponse]
implicit val customEntityDecoder: EntityDecoder[IO, TokenResponse] = jsonOf[IO, TokenResponse]
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.ovoenergy.http4s.server.middleware.auth0

import cats.effect.IO
import cats.syntax.either._
import com.ovoenergy.http4s.server.middleware.auth0.Authenticator.Error._
import org.http4s._
Expand All @@ -13,7 +12,7 @@ import scala.util.Try
*
* @param config Configuration
*/
class Authenticator(val config: Config) {
class Authenticator[F[_]](val config: Config) {

import Authenticator._

Expand All @@ -22,7 +21,7 @@ class Authenticator(val config: Config) {
* @param request The HTTP request to authenticate
* @return Either the answer to whether the request was authentic or possibly an error message
*/
def authenticate(request: Request[IO]): Result[AuthenticatedStatus] = {
def authenticate(request: Request[F]): Result[AuthenticatedStatus] = {
request.headers.get(Authorization) match {
case None => AuthorizationHeaderNotFound().asLeft[AuthenticatedStatus]
case Some(authorization) => validate(authorization)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.ovoenergy.http4s.server.middleware.auth0

import cats.Applicative
import cats.data.{Kleisli, OptionT}
import cats.effect.IO
import org.http4s._

/**
Expand All @@ -11,18 +11,18 @@ object Service {
import Authenticator._

@SuppressWarnings(Array("org.wartremover.warts.Nothing","org.wartremover.warts.Any"))
def apply(service: HttpRoutes[IO], config: Config): HttpRoutes[IO] = {
val authenticator: Authenticator = new Authenticator(config)
def apply[F[_]: Applicative](service: HttpRoutes[F], config: Config): HttpRoutes[F] = {
val authenticator: Authenticator[F] = new Authenticator(config)

Kleisli { req =>
authenticator.authenticate(req) match {
case Right(Authenticated) =>
service.run(req)
case Right(NotAuthenticated) =>
OptionT.pure(Response[IO](status = config.unAuthorizedStatus))
OptionT.pure[F](Response[F](status = config.unAuthorizedStatus))
case Left(_) =>
// TODO: logging would make debugging auth errors much easier
OptionT.pure(Response[IO](status = config.unAuthorizedStatus))
OptionT.pure[F](Response[F](status = config.unAuthorizedStatus))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks

class ClientSpec extends
WordSpec with Matchers with EitherValues with GeneratorDrivenPropertyChecks with BeforeAndAfterAll with BeforeAndAfterEach {
import ClientSpec._

implicit val cs: ContextShift[IO] = IO.contextShift(global)
"Client" when {

Expand All @@ -31,7 +33,7 @@ class ClientSpec extends
.withHeader(authorizationHeader, equalTo(bearerToken(token)))
.willReturn(aResponse().withBody(resourceBody)))

testWithClient(config) { client =>
testWithClient(config, clientToken) { client =>
val request: Request[IO] = Request(method = Method.GET, uri = resourceUri)

val result = client.expect[String](request).attempt.unsafeRunSync()
Expand All @@ -51,7 +53,7 @@ class ClientSpec extends
.withHeader(authorizationHeader, equalTo(bearerToken(token)))
.willReturn(aResponse().withBody(resourceBody)))

testWithClient(config) { client =>
testWithClient(config, clientToken) { client =>
val request: Request[IO] = Request(method = Method.GET, uri = resourceUri)

val firstResult = client.expect[String](request).attempt.unsafeRunSync()
Expand Down Expand Up @@ -93,7 +95,7 @@ class ClientSpec extends
.withHeader(authorizationHeader, equalTo(bearerToken(token)))
.willReturn(aResponse().withBody(resourceBody)))

testWithClient(config) { client =>
testWithClient(config, clientToken) { client =>
val request: Request[IO] = Request(method = Method.GET, uri = resourceUri)

val result = client.expect[String](request).attempt.unsafeRunSync()
Expand Down Expand Up @@ -130,7 +132,7 @@ class ClientSpec extends
.withHeader(authorizationHeader, equalTo(bearerToken(token)))
.willReturn(aResponse().withBody(resourceBody)))

testWithClient(config) { client =>
testWithClient(config, clientToken) { client =>
val request: Request[IO] = Request(method = Method.GET, uri = resourceUri)

val result = client.expect[String](request).attempt.unsafeRunSync()
Expand All @@ -148,7 +150,7 @@ class ClientSpec extends
.withHeader(authorizationHeader, equalTo(bearerToken(token)))
.willReturn(aResponse().withBody(resourceBody)))

testWithClient(config) { client =>
testWithClient(config, clientToken) { client =>
val request: Request[IO] = Request(method = Method.GET, uri = resourceUri)

val result = client.expect[String](request).attempt.unsafeRunSync()
Expand All @@ -171,7 +173,7 @@ class ClientSpec extends
.withHeader(authorizationHeader, equalTo(bearerToken(token)))
.willReturn(aResponse().withBody(resourceBody)))

testWithClient(config) { client =>
testWithClient(config, clientToken) { client =>
val request: Request[IO] = Request(method = Method.GET, uri = resourceUri)

val result = client.expect[String](request).attempt.unsafeRunSync()
Expand Down Expand Up @@ -199,11 +201,10 @@ class ClientSpec extends
private val resourceBody = "Hello World"

private def defaultConfig() = Config(baseUri, "audience", "client-identity", "client-secret")
private def testWithClient(config: Config = defaultConfig())(test: Http4sClient[IO] => Unit): Unit = {
private def testWithClient(config: Config, clientToken: ClientToken[IO])(test: Http4sClient[IO] => Unit): Unit = {
val testResult = for {
httpClient <- BlazeClientBuilder[IO](global).stream
client = Client(config)(httpClient)
_ = test(client)
_ = Client[IO](config)(httpClient, clientToken).map(test)
} yield ()
testResult.compile.drain.unsafeRunSync()
}
Expand Down Expand Up @@ -246,3 +247,7 @@ class ClientSpec extends
override def afterAll(): Unit = wireMockServer.stop()
}

object ClientSpec {
private val clientToken: ClientToken[IO] = ClientToken[IO].unsafeRunSync
}