Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KTOR-7194 Deferred session fetching for public endpoints #4609

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ import io.ktor.server.response.*
import io.ktor.server.routing.*
import io.ktor.server.sessions.*
import io.ktor.server.testing.*
import kotlinx.serialization.*
import kotlin.test.*
import kotlinx.serialization.Serializable
import kotlin.test.Test
import kotlin.test.assertEquals

class SessionAuthTest {
@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.tests.auth

import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.get
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.http.HttpStatusCode
import io.ktor.server.application.install
import io.ktor.server.auth.Authentication
import io.ktor.server.auth.authenticate
import io.ktor.server.auth.session
import io.ktor.server.response.respondText
import io.ktor.server.routing.get
import io.ktor.server.routing.post
import io.ktor.server.routing.routing
import io.ktor.server.sessions.SessionStorage
import io.ktor.server.sessions.Sessions
import io.ktor.server.sessions.cookie
import io.ktor.server.sessions.defaultSessionSerializer
import io.ktor.server.sessions.serialization.KotlinxSessionSerializer
import io.ktor.server.sessions.sessions
import io.ktor.server.sessions.set
import io.ktor.server.testing.testApplication
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

class SessionAuthJvmTest {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this test be moved to the jvmAndPosix source-set?


@Test
fun sessionIgnoredForNonPublicEndpoints() = testApplication {
val brokenStorage = object : SessionStorage {
override suspend fun write(id: String, value: String) = Unit
override suspend fun invalidate(id: String) = error("invalidate called")
override suspend fun read(id: String): String = error("read called")
}
application {
install(Sessions) {
cookie<MySession>("S", storage = brokenStorage) {
serializer = KotlinxSessionSerializer(Json.Default)
}
deferred = true
}
install(Authentication.Companion) {
session<MySession> {
validate { it }
}
}
routing {
authenticate {
get("/authenticated") {
call.respondText("Secret info")
}
}
post("/session") {
call.sessions.set(MySession(1))
call.respondText("OK")
}
get("/public") {
call.respondText("Public info")
}
}
}
val withCookie: HttpRequestBuilder.() -> Unit = {
header("Cookie", "S=${defaultSessionSerializer<MySession>().serialize(MySession(1))}")
}

assertEquals(HttpStatusCode.Companion.OK, client.post("/session").status)
assertEquals(HttpStatusCode.Companion.OK, client.get("/public", withCookie).status)
assertFailsWith<IllegalStateException> {
client.get("/authenticated", withCookie).status
}
Comment on lines +76 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a successful call after the failed one? Maybe move one of the calls made above

}

@Serializable
data class MySession(val id: Int)
}
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,10 @@ public final class io/ktor/server/sessions/SessionsBuilderKt {

public final class io/ktor/server/sessions/SessionsConfig {
public fun <init> ()V
public final fun getDeferred ()Z
public final fun getProviders ()Ljava/util/List;
public final fun register (Lio/ktor/server/sessions/SessionProvider;)V
public final fun setDeferred (Z)V
}

public final class io/ktor/server/sessions/SessionsKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ final class io.ktor.server.sessions/SessionsConfig { // io.ktor.server.sessions/
final val providers // io.ktor.server.sessions/SessionsConfig.providers|{}providers[0]
final fun <get-providers>(): kotlin.collections/List<io.ktor.server.sessions/SessionProvider<*>> // io.ktor.server.sessions/SessionsConfig.providers.<get-providers>|<get-providers>(){}[0]

final var deferred // io.ktor.server.sessions/SessionsConfig.deferred|{}deferred[0]
final fun <get-deferred>(): kotlin/Boolean // io.ktor.server.sessions/SessionsConfig.deferred.<get-deferred>|<get-deferred>(){}[0]
final fun <set-deferred>(kotlin/Boolean) // io.ktor.server.sessions/SessionsConfig.deferred.<set-deferred>|<set-deferred>(kotlin.Boolean){}[0]

final fun register(io.ktor.server.sessions/SessionProvider<*>) // io.ktor.server.sessions/SessionsConfig.register|register(io.ktor.server.sessions.SessionProvider<*>){}[0]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ public interface CurrentSession {
public fun findName(type: KClass<*>): String
}

/**
* Extends [CurrentSession] with a call to include session data in the server response.
*/
internal interface StatefulSession : CurrentSession {

/**
* Iterates over session data items and writes them to the application call.
* The session cannot be modified after this is called.
* This is called after the session data is sent to the response.
*/
suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit = {})
}

/**
* Sets a session instance with the type [T].
* @throws IllegalStateException if no session provider is registered for the type [T]
Expand Down Expand Up @@ -99,11 +112,15 @@ public inline fun <reified T : Any> CurrentSession.getOrSet(name: String = findN

internal data class SessionData(
val providerData: Map<String, SessionProviderData<*>>
) : CurrentSession {
) : StatefulSession {

private var committed = false

internal fun commit() {
override suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit) {
providerData.values.forEach { data ->
onEach(data.provider.name)
data.sendSessionData(call)
}
committed = true
}

Expand Down Expand Up @@ -175,7 +192,7 @@ internal data class SessionProviderData<S : Any>(
val provider: SessionProvider<S>
)

internal val SessionDataKey = AttributeKey<SessionData>("SessionKey")
internal val SessionDataKey = AttributeKey<StatefulSession>("SessionKey")

private fun ApplicationCall.reportMissingSession(): Nothing {
application.plugin(Sessions) // ensure the plugin is installed
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall

/**
* Creates a lazy loading session from the given providers.
*/
internal expect fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package io.ktor.server.sessions

import io.ktor.server.application.*
import io.ktor.server.request.*
import io.ktor.server.response.*
import io.ktor.util.*
import io.ktor.util.logging.*

Expand All @@ -27,24 +26,25 @@ internal val LOGGER = KtorSimpleLogger("io.ktor.server.sessions.Sessions")
*/
public val Sessions: RouteScopedPlugin<SessionsConfig> = createRouteScopedPlugin("Sessions", ::SessionsConfig) {
val providers = pluginConfig.providers.toList()
val sessionSupplier: suspend (ApplicationCall, List<SessionProvider<*>>) -> StatefulSession =
if (pluginConfig.deferred) {
::createDeferredSession
} else {
::createSession
}

application.attributes.put(SessionProvidersKey, providers)

onCall { call ->
// For each call, call each provider and retrieve session data if needed.
// Capture data in the attribute's value
val providerData = providers.associateBy({ it.name }) {
it.receiveSessionData(call)
}

if (providerData.isEmpty()) {
LOGGER.trace("No sessions found for ${call.request.uri}")
if (providers.isEmpty()) {
LOGGER.trace { "No sessions found for ${call.request.uri}" }
} else {
val sessions = providerData.keys.joinToString()
LOGGER.trace("Sessions found for ${call.request.uri}: $sessions")
LOGGER.trace {
val sessions = providers.joinToString { it.name }
"Sessions found for ${call.request.uri}: $sessions"
}
}
val sessionData = SessionData(providerData)
call.attributes.put(SessionDataKey, sessionData)
call.attributes.put(SessionDataKey, sessionSupplier(call, providers))
}

// When response is being sent, call each provider to update/remove session data
Expand All @@ -58,11 +58,18 @@ public val Sessions: RouteScopedPlugin<SessionsConfig> = createRouteScopedPlugin
*/
val sessionData = call.attributes.getOrNull(SessionDataKey) ?: return@on

sessionData.providerData.values.forEach { data ->
LOGGER.trace("Sending session data for ${call.request.uri}: ${data.provider.name}")
data.sendSessionData(call)
sessionData.sendSessionData(call) { provider ->
LOGGER.trace { "Sending session data for ${call.request.uri}: $provider" }
}
}
}

sessionData.commit()
private suspend fun createSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession {
// For each call, call each provider and retrieve session data if needed.
// Capture data in the attribute's value
val providerData = providers.associateBy({ it.name }) {
it.receiveSessionData(call)
}

return SessionData(providerData)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ public class SessionsConfig {
*/
public val providers: List<SessionProvider<*>> get() = registered.toList()

/**
* When set to true, sessions will be lazily retrieved from storage.
*
* Note: this is only available for JVM in Ktor 3.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For JVM and Native?

*/
public var deferred: Boolean = false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would mark it with OptIn and look for a better explaining name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a system property (like io.ktor.server.sessions.lazycreate) for this instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, this way we don't need to change the API 👍


/**
* Registers a session [provider].
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall

internal actual fun createDeferredSession(call: ApplicationCall, providers: List<SessionProvider<*>>): StatefulSession =
TODO("Deferred session retrieval is currently only available for JVM")
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
package io.ktor.server.sessions

import io.ktor.http.*
import io.ktor.server.sessions.serialization.*
import io.ktor.util.*
import kotlinx.serialization.*
import kotlinx.serialization.json.*
import java.lang.reflect.*
import java.math.*
import java.lang.reflect.ParameterizedType
import java.lang.reflect.Type
import java.math.BigDecimal
import java.math.BigInteger
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.reflect.*
import kotlin.reflect.full.*
import kotlin.reflect.jvm.*
import kotlin.reflect.full.memberProperties
import kotlin.reflect.full.superclasses
import kotlin.reflect.jvm.javaType
import kotlin.reflect.jvm.jvmErasure

private const val TYPE_TOKEN_PARAMETER_NAME: String = "\$type"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2014-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package io.ktor.server.sessions

import io.ktor.server.application.ApplicationCall
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.runBlocking
import kotlin.coroutines.CoroutineContext
import kotlin.reflect.KClass

/**
* An implementation of [StatefulSession] that lazily references session providers to
* avoid unnecessary calls to session storage.
* All access to the deferred providers is done through blocking calls.
*/
internal class BlockingDeferredSessionData(
val callContext: CoroutineContext,
val providerData: Map<String, Deferred<SessionProviderData<*>>>,
) : StatefulSession {

private var committed = false

@OptIn(ExperimentalCoroutinesApi::class)
override suspend fun sendSessionData(call: ApplicationCall, onEach: (String) -> Unit) {
for (deferredProvider in providerData.values) {
// skip non-completed providers because they were not modified
if (!deferredProvider.isCompleted) continue
val data = deferredProvider.getCompleted()
onEach(data.provider.name)
data.sendSessionData(call)
}
committed = true
}

override fun findName(type: KClass<*>): String {
val entry = providerData.values.map {
it.awaitBlocking()
}.firstOrNull {
it.provider.type == type
} ?: throw IllegalArgumentException("Session data for type `$type` was not registered")

return entry.provider.name
}

override fun set(name: String, value: Any?) {
if (committed) {
throw TooLateSessionSetException()
}
val providerData = checkNotNull(providerData[name]) { "Session data for `$name` was not registered" }
setTyped(providerData.awaitBlocking(), value)
}

@Suppress("UNCHECKED_CAST")
private fun <S : Any> setTyped(data: SessionProviderData<S>, value: Any?) {
if (value != null) {
data.provider.tracker.validate(value as S)
}
data.newValue = value as S
}

override fun get(name: String): Any? {
val providerDataDeferred =
providerData[name] ?: throw IllegalStateException("Session data for `$name` was not registered")
val providerData = providerDataDeferred.awaitBlocking()
return providerData.newValue ?: providerData.oldValue
}

override fun clear(name: String) {
val providerDataDeferred =
providerData[name] ?: throw IllegalStateException("Session data for `$name` was not registered")
val providerData = providerDataDeferred.awaitBlocking()
providerData.oldValue = null
providerData.newValue = null
}

private fun Deferred<SessionProviderData<*>>.awaitBlocking() =
runBlocking(callContext) { await() }
}
Loading
Loading