Skip to content

Commit

Permalink
KTOR-7194 Deferred session fetching for public endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhham committed Jan 15, 2025
1 parent 01abaa9 commit 38f8ec5
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,63 @@ public final class io/ktor/server/jetty/jakarta/JettyApplicationEngineBase$Confi
public final fun setIdleTimeout-LRDsOJo (J)V
}

public final class io/ktor/server/jetty/jakarta/JettyApplicationRequest : io/ktor/server/engine/BaseApplicationRequest {
public fun <init> (Lio/ktor/server/application/PipelineCall;Lorg/eclipse/jetty/server/Request;)V
public fun getCookies ()Lio/ktor/server/request/RequestCookies;
public fun getLocal ()Lio/ktor/http/RequestConnectionPoint;
public fun getQueryParameters ()Lio/ktor/http/Parameters;
public fun getRawQueryParameters ()Lio/ktor/http/Parameters;
}

public final class io/ktor/server/jetty/jakarta/JettyApplicationResponse : io/ktor/server/servlet/jakarta/AsyncServletApplicationResponse {
public fun <init> (Lio/ktor/server/servlet/jakarta/AsyncServletApplicationCall;Ljakarta/servlet/http/HttpServletRequest;Ljakarta/servlet/http/HttpServletResponse;Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/CoroutineContext;Lorg/eclipse/jetty/server/Request;Lkotlin/coroutines/CoroutineContext;)V
public fun push (Lio/ktor/server/response/ResponsePushBuilder;)V
}

public final class io/ktor/server/jetty/jakarta/JettyConnectionPoint : io/ktor/http/RequestConnectionPoint {
public fun <init> (Lorg/eclipse/jetty/server/Request;)V
public fun getHost ()Ljava/lang/String;
public fun getLocalAddress ()Ljava/lang/String;
public fun getLocalHost ()Ljava/lang/String;
public fun getLocalPort ()I
public fun getMethod ()Lio/ktor/http/HttpMethod;
public fun getPort ()I
public fun getRemoteAddress ()Ljava/lang/String;
public fun getRemoteHost ()Ljava/lang/String;
public fun getRemotePort ()I
public fun getScheme ()Ljava/lang/String;
public fun getServerHost ()Ljava/lang/String;
public fun getServerPort ()I
public fun getUri ()Ljava/lang/String;
public fun getVersion ()Ljava/lang/String;
}

public final class io/ktor/server/jetty/jakarta/JettyHeaders : io/ktor/http/Headers {
public fun <init> (Lorg/eclipse/jetty/server/Request;)V
public fun contains (Ljava/lang/String;)Z
public fun contains (Ljava/lang/String;Ljava/lang/String;)Z
public fun entries ()Ljava/util/Set;
public fun forEach (Lkotlin/jvm/functions/Function2;)V
public fun get (Ljava/lang/String;)Ljava/lang/String;
public fun getAll (Ljava/lang/String;)Ljava/util/List;
public fun getCaseInsensitiveName ()Z
public fun isEmpty ()Z
public fun names ()Ljava/util/Set;
}

public final class io/ktor/server/jetty/jakarta/JettyRequestCookies : io/ktor/server/request/RequestCookies {
public fun <init> (Lio/ktor/server/jetty/jakarta/JettyApplicationRequest;Lorg/eclipse/jetty/server/Request;)V
}

public final class io/ktor/server/jetty/jakarta/JettyWebsocketConnection : org/eclipse/jetty/io/AbstractConnection, kotlinx/coroutines/CoroutineScope, org/eclipse/jetty/io/Connection$UpgradeTo {
public fun <init> (Lorg/eclipse/jetty/io/EndPoint;Lkotlin/coroutines/CoroutineContext;)V
public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext;
public final fun getInputChannel ()Lio/ktor/utils/io/ByteChannel;
public final fun getOutputChannel ()Lio/ktor/utils/io/ByteChannel;
public fun onFillable ()V
public fun onUpgradeTo (Ljava/nio/ByteBuffer;)V
}

public final class io/ktor/server/jetty/jakarta/internal/JettyUpgradeImpl : io/ktor/server/servlet/jakarta/ServletUpgrade {
public static final field INSTANCE Lio/ktor/server/jetty/jakarta/internal/JettyUpgradeImpl;
public fun performUpgrade (Lio/ktor/http/content/OutgoingContent$ProtocolUpgrade;Ljakarta/servlet/http/HttpServletRequest;Ljakarta/servlet/http/HttpServletResponse;Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.testing.*
import kotlinx.serialization.*
import kotlin.test.*
import kotlinx.serialization.Serializable
import kotlin.test.Test
import kotlin.test.assertEquals

class SessionAuthTest {
@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.tests.auth

import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.get
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.install
import io.ktor.server.auth.Authentication
import io.ktor.server.auth.authenticate
import io.ktor.server.auth.session
import io.ktor.server.response.respondText
import io.ktor.server.routing.get
import io.ktor.server.routing.post
import io.ktor.server.routing.routing
import io.ktor.server.sessions.SessionStorage
import io.ktor.server.sessions.Sessions
import io.ktor.server.sessions.cookie
import io.ktor.server.sessions.defaultSessionSerializer
import io.ktor.server.sessions.serialization.KotlinxSessionSerializer
import io.ktor.server.sessions.sessions
import io.ktor.server.sessions.set
import io.ktor.server.testing.testApplication
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

class SessionAuthJvmTest {

@Test
fun sessionIgnoredForNonPublicEndpoints() = testApplication {
val brokenStorage = object : SessionStorage {
override suspend fun write(id: String, value: String) = Unit
override suspend fun invalidate(id: String) = error("invalidate called")
override suspend fun read(id: String): String = error("read called")
}
application {
install(Sessions) {
cookie<MySession>("S", storage = brokenStorage) {
serializer = KotlinxSessionSerializer(Json.Default)
}
deferred = true
}
install(Authentication.Companion) {
session<MySession> {
validate { it }
}
}
routing {
authenticate {
get("/authenticated") {
call.respondText("Secret info")
}
}
post("/session") {
call.sessions.set(MySession(1))
call.respondText("OK")
}
get("/public") {
call.respondText("Public info")
}
}
}
val withCookie: HttpRequestBuilder.() -> Unit = {
header("Cookie", "S=${defaultSessionSerializer<MySession>().serialize(MySession(1))}")
}

assertEquals(HttpStatusCode.Companion.OK, client.post("/session").status)
assertEquals(HttpStatusCode.Companion.OK, client.get("/public", withCookie).status)
assertFailsWith<IllegalStateException> {
client.get("/authenticated", withCookie).status
}
}

@Serializable
data class MySession(val id: Int)

}
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ public final class io/ktor/server/sessions/SessionsBuilderKt {

public final class io/ktor/server/sessions/SessionsConfig {
public fun <init> ()V
public final fun getDeferred ()Z
public final fun getProviders ()Ljava/util/List;
public final fun register (Lio/ktor/server/sessions/SessionProvider;)V
public final fun setDeferred (Z)V
}

public final class io/ktor/server/sessions/SessionsKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ final class io.ktor.server.sessions/SessionsConfig { // io.ktor.server.sessions/
final val providers // io.ktor.server.sessions/SessionsConfig.providers|{}providers[0]
final fun <get-providers>(): kotlin.collections/List<io.ktor.server.sessions/SessionProvider<*>> // io.ktor.server.sessions/SessionsConfig.providers.<get-providers>|<get-providers>(){}[0]

final var deferred // io.ktor.server.sessions/SessionsConfig.deferred|{}deferred[0]
final fun <get-deferred>(): kotlin/Boolean // io.ktor.server.sessions/SessionsConfig.deferred.<get-deferred>|<get-deferred>(){}[0]
final fun <set-deferred>(kotlin/Boolean) // io.ktor.server.sessions/SessionsConfig.deferred.<set-deferred>|<set-deferred>(kotlin.Boolean){}[0]

final fun register(io.ktor.server.sessions/SessionProvider<*>) // io.ktor.server.sessions/SessionsConfig.register|register(io.ktor.server.sessions.SessionProvider<*>){}[0]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ public interface CurrentSession {
public fun findName(type: KClass<*>): String
}

/**
* Extends [CurrentSession] with a call to include session data in the server response.
*/
internal interface StatefulSession : CurrentSession {

/**
* Iterates over session data items and writes them to the application call.
* The session cannot be modified after this is called.
* This is called after the session data is sent to the response.
*/
suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit = {})
}

/**
* Sets a session instance with the type [T].
* @throws IllegalStateException if no session provider is registered for the type [T]
Expand Down Expand Up @@ -99,11 +112,15 @@ public inline fun <reified T : Any> CurrentSession.getOrSet(name: String = findN

internal data class SessionData(
val providerData: Map<String, SessionProviderData<*>>
) : CurrentSession {
) : StatefulSession {

private var committed = false

internal fun commit() {
override suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit) {
providerData.values.forEach { data ->
onEach(data.provider.name)
data.sendSessionData(call)
}
committed = true
}

Expand Down Expand Up @@ -175,7 +192,7 @@ internal data class SessionProviderData<S : Any>(
val provider: SessionProvider<S>
)

internal val SessionDataKey = AttributeKey<SessionData>("SessionKey")
internal val SessionDataKey = AttributeKey<StatefulSession>("SessionKey")

private fun ApplicationCall.reportMissingSession(): Nothing {
application.plugin(Sessions) // ensure the plugin is installed
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall

/**
* Creates a lazy loading session from the given providers.
*/
internal expect fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package io.ktor.server.sessions

import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.*
import io.ktor.util.logging.*

Expand All @@ -27,24 +26,23 @@ internal val LOGGER = KtorSimpleLogger("io.ktor.server.sessions.Sessions")
*/
public val Sessions: RouteScopedPlugin<SessionsConfig> = createRouteScopedPlugin("Sessions", ::SessionsConfig) {
val providers = pluginConfig.providers.toList()
val sessionSupplier: suspend (ApplicationCall, List<SessionProvider<*>>) -> StatefulSession =
if (pluginConfig.deferred) {
::createDeferredSession
} else {
::createSession
}

application.attributes.put(SessionProvidersKey, providers)

onCall { call ->
// For each call, call each provider and retrieve session data if needed.
// Capture data in the attribute's value
val providerData = providers.associateBy({ it.name }) {
it.receiveSessionData(call)
}

if (providerData.isEmpty()) {
if (providers.isEmpty()) {
LOGGER.trace("No sessions found for ${call.request.uri}")
} else {
val sessions = providerData.keys.joinToString()
val sessions = providers.joinToString { it.name }
LOGGER.trace("Sessions found for ${call.request.uri}: $sessions")
}
val sessionData = SessionData(providerData)
call.attributes.put(SessionDataKey, sessionData)
call.attributes.put(SessionDataKey, sessionSupplier(call, providers))
}

// When response is being sent, call each provider to update/remove session data
Expand All @@ -58,11 +56,18 @@ public val Sessions: RouteScopedPlugin<SessionsConfig> = createRouteScopedPlugin
*/
val sessionData = call.attributes.getOrNull(SessionDataKey) ?: return@on

sessionData.providerData.values.forEach { data ->
LOGGER.trace("Sending session data for ${call.request.uri}: ${data.provider.name}")
data.sendSessionData(call)
sessionData.sendSessionData(call) { provider ->
LOGGER.trace("Sending session data for ${call.request.uri}: $provider")
}
}
}

sessionData.commit()
private suspend fun createSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession {
// For each call, call each provider and retrieve session data if needed.
// Capture data in the attribute's value
val providerData = providers.associateBy({ it.name }) {
it.receiveSessionData(call)
}

return SessionData(providerData)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ public class SessionsConfig {
*/
public val providers: List<SessionProvider<*>> get() = registered.toList()

/**
* When set to true, sessions will be lazily retrieved from storage.
*
* Note: this is only available for JVM in Ktor 3.0
*/
public var deferred: Boolean = false

/**
* Registers a session [provider].
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall

internal actual fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession =
TODO("Deferred session retrieval is currently only available for JVM")
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
package io.ktor.server.sessions

import io.ktor.http.*
import io.ktor.server.sessions.serialization.*
import io.ktor.util.*
import kotlinx.serialization.*
import kotlinx.serialization.json.*
import java.lang.reflect.*
import java.math.*
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.math.BigDecimal
import java.math.BigInteger
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.*
import kotlin.reflect.full.*
import kotlin.reflect.jvm.*
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.superclasses
import kotlin.reflect.jvm.javaType
import kotlin.reflect.jvm.jvmErasure

private const val TYPE_TOKEN_PARAMETER_NAME: String = "\$type"

Expand Down
Loading

0 comments on commit 38f8ec5

Please sign in to comment.