Skip to content

Commit

Permalink
NRL Operation must be Recoverable
Browse files Browse the repository at this point in the history
  • Loading branch information
zuevmaxim committed Sep 22, 2021
1 parent 7a1298a commit 2d68cbf
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ import org.objectweb.asm.commons.GeneratorAdapter
import org.objectweb.asm.commons.Method
import kotlin.reflect.jvm.javaMethod

internal open class CrashEnabledVisitor(cv: ClassVisitor, testClass: Class<*>, initial: Boolean = true) :
internal open class CrashEnabledVisitor(cv: ClassVisitor, initial: Boolean = true) :
ClassVisitor(ASM_API, cv) {
private val superClassNames = testClass.superClassNames()
var shouldTransform = initial
private set
var name: String? = null
Expand All @@ -48,12 +47,6 @@ internal open class CrashEnabledVisitor(cv: ClassVisitor, testClass: Class<*>, i
) {
super.visit(version, access, name, signature, superName, interfaces)
this.name = name
if (name in superClassNames || name !== null &&
name.startsWith("org.jetbrains.kotlinx.lincheck.") &&
!name.startsWith("org.jetbrains.kotlinx.lincheck.test.")
) {
shouldTransform = false
}
}

override fun visitAnnotation(descriptor: String?, visible: Boolean): AnnotationVisitor {
Expand All @@ -69,10 +62,7 @@ internal open class CrashEnabledVisitor(cv: ClassVisitor, testClass: Class<*>, i
}
}

internal class CrashTransformer(
cv: ClassVisitor,
testClass: Class<*>
) : CrashEnabledVisitor(cv, testClass) {
internal class CrashTransformer(cv: ClassVisitor) : CrashEnabledVisitor(cv) {
override fun visitMethod(
access: Int,
name: String?,
Expand Down Expand Up @@ -193,13 +183,3 @@ internal class CrashRethrowTransformer(cv: ClassVisitor) : ClassVisitor(ASM_API,
}
}
}

private fun Class<*>.superClassNames(): List<String> {
val result = mutableListOf<String>()
var clazz: Class<*>? = this
while (clazz !== null) {
result.add(Type.getInternalName(clazz))
clazz = clazz.superclass
}
return result
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

package org.jetbrains.kotlinx.lincheck.nvm

import org.jetbrains.kotlinx.lincheck.annotations.Operation
import org.jetbrains.kotlinx.lincheck.annotations.Recoverable
import org.jetbrains.kotlinx.lincheck.execution.ExecutionScenario
import org.jetbrains.kotlinx.lincheck.verifier.Verifier
import org.jetbrains.kotlinx.lincheck.verifier.linearizability.LinearizabilityVerifier
Expand Down Expand Up @@ -76,8 +78,8 @@ private object RecoverExecutionCallback : ExecutionCallback {
internal enum class StrategyRecoveryOptions {
STRESS, MANAGED;

fun createCrashTransformer(cv: ClassVisitor, clazz: Class<*>): ClassVisitor = when (this) {
STRESS -> CrashRethrowTransformer(CrashTransformer(cv, clazz))
fun createCrashTransformer(cv: ClassVisitor): ClassVisitor = when (this) {
STRESS -> CrashRethrowTransformer(CrashTransformer(cv))
MANAGED -> CrashRethrowTransformer(cv) // add crashes in ManagedStrategyTransformer
}
}
Expand Down Expand Up @@ -111,6 +113,7 @@ interface RecoverabilityModel {
fun defaultExpectedCrashes(): Int
fun createExecutionCallback(): ExecutionCallback
fun createProbabilityModel(): ProbabilityModel
fun checkTestClass(testClass: Class<*>) {}
val awaitSystemCrashBeforeThrow: Boolean
val verifierClass: Class<out Verifier>

Expand Down Expand Up @@ -151,10 +154,24 @@ private class NRLModel(
override fun createTransformer(cv: ClassVisitor, clazz: Class<*>): ClassVisitor {
var result: ClassVisitor = RecoverabilityTransformer(cv)
if (crashes) {
result = strategyRecoveryOptions.createCrashTransformer(result, clazz)
result = strategyRecoveryOptions.createCrashTransformer(result)
}
return result
}

override fun checkTestClass(testClass: Class<*>) {
var clazz: Class<*>? = testClass
while (clazz !== null) {
clazz.declaredMethods.forEach { method ->
val isOperation = method.isAnnotationPresent(Operation::class.java)
val isRecoverable = method.isAnnotationPresent(Recoverable::class.java)
require(!isOperation || isRecoverable) {
"Every operation must have a Recovery annotation, but ${method.name} operation in ${clazz!!.name} class is not Recoverable."
}
}
clazz = clazz.superclass
}
}
}

private open class DurableModel(val strategyRecoveryOptions: StrategyRecoveryOptions) : RecoverabilityModel {
Expand All @@ -171,7 +188,7 @@ private open class DurableModel(val strategyRecoveryOptions: StrategyRecoveryOpt
override val verifierClass: Class<out Verifier> get() = DurableLinearizabilityVerifier::class.java
override fun createTransformerWrapper(cv: ClassVisitor, clazz: Class<*>) = DurableRecoverAllGenerator(cv, clazz)
override fun createTransformer(cv: ClassVisitor, clazz: Class<*>): ClassVisitor =
strategyRecoveryOptions.createCrashTransformer(DurableOperationRecoverTransformer(cv, clazz), clazz)
strategyRecoveryOptions.createCrashTransformer(DurableOperationRecoverTransformer(cv, clazz))
}

private class DetectableExecutionModel(strategyRecoveryOptions: StrategyRecoveryOptions) :
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ internal class SwitchesAndCrashesModelCheckingStrategy(
override fun createRoot(): InterleavingTreeNode = ThreadChoosingNodeWithCrashes((0 until nThreads).toList())

override fun createTransformer(cv: ClassVisitor): ClassVisitor {
val visitor = CrashEnabledVisitor(cv, testClass, recoverModel.crashes)
val visitor = CrashEnabledVisitor(cv, recoverModel.crashes)
val recoverTransformer = recoverModel.createTransformer(visitor, testClass)
val managedTransformer = CrashesManagedStrategyTransformer(
recoverTransformer, tracePointConstructors, testCfg.guarantees, testCfg.eliminateLocalObjects,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ internal open class ParallelThreadsRunner(
override fun initialize() {
executionCallback.reset(scenario, recoverModel)
super.initialize()
recoverModel.checkTestClass(testClass)
testThreadExecutions = Array(scenario.threads) { t ->
TestThreadExecutionGenerator.create(this, t, scenario.parallelExecution[t], completions[t], scenario.hasSuspendableActors(), recoverModel.createActorCrashHandlerGenerator())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelChecki
import org.jetbrains.kotlinx.lincheck.strategy.stress.StressOptions
import org.jetbrains.kotlinx.lincheck.test.checkTraceHasNoLincheckEvents
import org.junit.Test
import java.lang.IllegalStateException
import java.lang.reflect.InvocationTargetException
import kotlin.reflect.KClass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ private const val THREADS_NUMBER = 3
internal interface Counter {
fun increment(threadId: Int)
fun get(threadId: Int): Int
fun incrementBefore(p: Int) {}
fun incrementRecover(p: Int) {}
}

/**
Expand All @@ -43,9 +45,13 @@ internal class CounterTest : AbstractNVMLincheckTest(Recover.NRL, THREADS_NUMBER
private val counter = NRLCounter(THREADS_NUMBER + 2)

@Operation
@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "incrementRecover")
override fun increment(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.increment(threadId)
override fun incrementBefore(p: Int) = counter.incrementBefore(p)
override fun incrementRecover(p: Int) = counter.incrementRecover(p)

@Operation
@Recoverable
override fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId)
}

Expand All @@ -65,22 +71,19 @@ internal open class NRLCounter(threadsCount: Int) : Counter {
protected val checkPointer = MutableList(threadsCount) { nonVolatile(0) }
protected val currentValue = MutableList(threadsCount) { nonVolatile(0) }

@Recoverable
override fun get(threadId: Int) = r.sumBy { it.read()!! }

@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "incrementRecover")
override fun increment(threadId: Int) = incrementImpl(threadId)

protected open fun incrementImpl(p: Int) {
r[p].write(1 + currentValue[p].value, p)
checkPointer[p].value = 1
}

protected open fun incrementRecover(p: Int) {
override fun incrementRecover(p: Int) {
if (checkPointer[p].value == 0) return incrementImpl(p)
}

protected open fun incrementBefore(p: Int) {
override fun incrementBefore(p: Int) {
currentValue[p].value = r[p].read()!!
checkPointer[p].value = 0
currentValue[p].flush()
Expand All @@ -92,10 +95,14 @@ internal abstract class CounterFailingTest :
AbstractNVMLincheckFailingTest(Recover.NRL, THREADS_NUMBER, SequentialCounter::class) {
protected abstract val counter: Counter

@Recoverable(beforeMethod = "incrementBefore", recoverMethod = "incrementRecover")
@Operation
fun increment(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.increment(threadId)
fun incrementBefore(p: Int) = counter.incrementBefore(p)
fun incrementRecover(p: Int) = counter.incrementRecover(p)

@Operation
@Recoverable
fun get(@Param(gen = ThreadIdGen::class) threadId: Int) = counter.get(threadId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,26 @@ private const val THREADS_NUMBER = 3
interface RWO<T> {
fun read(): T?
fun write(value: T, p: Int)
fun writeRecover(value: T, p: Int) {}
}

internal class ReadWriteObjectTest :
AbstractNVMLincheckTest(Recover.NRL, THREADS_NUMBER, SequentialReadWriteObject::class) {
private val rwo = NRLReadWriteObject<Pair<Int, Int>>(THREADS_NUMBER + 2)

@Recoverable
@Operation
fun read() = rwo.read()?.first

@Recoverable(recoverMethod = "writeRecover")
@Operation
fun write(
@Param(gen = ThreadIdGen::class) threadId: Int,
value: Int,
@Param(gen = OperationIdGen::class) operationId: Int
) = rwo.write(value to operationId, threadId)

fun writeRecover(threadId: Int, value: Int, operationId: Int) = rwo.writeRecover(value to operationId, threadId)
}

private val nullObject = Any()
Expand Down Expand Up @@ -85,10 +90,7 @@ internal open class NRLReadWriteObject<T>(threadsCount: Int, initial: T? = null)
// (state, value) for every thread
protected val state = MutableList(threadsCount) { nonVolatile(0 to null as T?) }

@Recoverable
override fun read(): T? = register.value

@Recoverable(recoverMethod = "writeRecover")
override fun write(value: T, p: Int) = writeImpl(value, p)

protected open fun writeImpl(value: T, p: Int) {
Expand All @@ -100,7 +102,7 @@ internal open class NRLReadWriteObject<T>(threadsCount: Int, initial: T? = null)
state[p].flush()
}

protected open fun writeRecover(value: T, p: Int) {
override fun writeRecover(value: T, p: Int) {
val (flag, current) = state[p].value
if (flag == 0 && current != value) return writeImpl(value, p)
else if (flag == 1 && current === register.value) return writeImpl(value, p)
Expand All @@ -114,14 +116,18 @@ internal abstract class ReadWriteObjectFailingTest :
protected abstract val rwo: RWO<Pair<Int, Int>>

@Operation
@Recoverable
fun read() = rwo.read()?.first

@Operation
@Recoverable(recoverMethod = "writeRecover")
fun write(
@Param(gen = ThreadIdGen::class) threadId: Int,
value: Int,
@Param(gen = OperationIdGen::class) operationId: Int
) = rwo.write(value to operationId, threadId)

fun writeRecover(threadId: Int, value: Int, operationId: Int) = rwo.writeRecover(value to operationId, threadId)
}

internal class SmallScenarioTest : ReadWriteObjectFailingTest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ private const val THREADS = 3
internal class RecoverableMutualExclusionWithPrimitivesTest : AbstractNVMLincheckTest(Recover.NRL, THREADS, SequentialCounter::class) {
private val counter = CounterWithLock(THREADS + 2, LockWithPrimitives(THREADS + 2))

@Recoverable(recoverMethod = "incRecover", beforeMethod = "incBefore")
@Operation
fun inc(@Param(gen = ThreadIdGen::class) threadId: Int): Int = counter.inc(threadId)
fun incRecover(threadId: Int) = counter.incRecover(threadId)
fun incBefore(threadId: Int) = counter.incBefore(threadId)

override fun testWithStressStrategy() {
println("${this::class.qualifiedName}:testWithStressStrategy test is ignored as no special atomic primitives available.")
Expand All @@ -61,7 +64,6 @@ internal class CounterWithLock(threads: Int, private val lock: DurableLock) {
private val cp = Array(threads) { nonVolatile(0) }
private val before = Array(threads) { nonVolatile(0) }

@Recoverable(recoverMethod = "incRecover", beforeMethod = "incBefore")
fun inc(threadId: Int) = incInternal(threadId)

private fun incInternal(threadId: Int): Int {
Expand Down Expand Up @@ -193,8 +195,11 @@ internal class LockWithPrimitives(threads: Int) : DurableLock {
internal class MutualExclusionFailingTest : AbstractNVMLincheckFailingTest(Recover.NRL, THREADS, SequentialCounter::class, false, DeadlockWithDumpFailure::class) {
private val counter = CounterWithLock(THREADS + 2, SimplestLockEver())

@Recoverable(recoverMethod = "incRecover", beforeMethod = "incBefore")
@Operation
fun inc(@Param(gen = ThreadIdGen::class) threadId: Int): Int = counter.inc(threadId)
fun incRecover(threadId: Int) = counter.incRecover(threadId)
fun incBefore(threadId: Int) = counter.incBefore(threadId)

override fun <O : Options<O, *>> O.customize() {
actorsBefore(0)
Expand Down
Loading

0 comments on commit 2d68cbf

Please sign in to comment.