Skip to content

Commit

Permalink
Collect DerivationInfos marked with @derivation during registry initi…
Browse files Browse the repository at this point in the history
…alization

This is done using the classgraph library and some plain reflection. No
scala reflection APIs are used (see javadoc of
DerivationInfoRegistry.findAndRegisterDerivations).

This is the first step towards replacing all existing reflection-based
derivation info registry initialization code, as well as the dependency
on the no-longer-maintained reflections library.
  • Loading branch information
Joscha Mennicken authored and EnguerrandPrebet committed Sep 12, 2024
1 parent 470b9bf commit 41b81fa
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 19 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ lazy val core = project
libraryDependencies += "cc.redberry" %% "rings.scaladsl" % "2.5.8",
libraryDependencies += "com.github.scopt" %% "scopt" % "4.1.0",
libraryDependencies += "com.lihaoyi" %% "fastparse" % "3.1.0",
libraryDependencies += "io.github.classgraph" % "classgraph" % "4.8.174",
libraryDependencies += "io.spray" %% "spray-json" % "1.3.6",
libraryDependencies += "org.apache.commons" % "commons-configuration2" % "2.10.1",
libraryDependencies += "org.apache.commons" % "commons-lang3" % "3.14.0",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright (c) Carnegie Mellon University, Karlsruhe Institute of Technology.
* See LICENSE.txt for the conditions of this license.
*/

package org.keymaerax.core.btactics.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface Derivation {
}
Original file line number Diff line number Diff line change
Expand Up @@ -1564,9 +1564,8 @@ object Ax extends Logging {
Sequent(immutable.IndexedSeq("[a_;]p_(||)".asFormula), immutable.IndexedSeq("[a_;]q_(||)".asFormula)),
useAt(box, PosInExpr(1 :: Nil))(-1) & useAt(box, PosInExpr(1 :: Nil))(1) &
notL(-1) & notR(1) &
// @todo use [[DerivedAxioms.mondrule]]
by(
ProvableInfo("<> monotone"),
Ax.mondrule,
USubst(
SubstitutionPair(UnitPredicational("p_", AnyArg), Not(UnitPredicational("q_", AnyArg))) ::
SubstitutionPair(UnitPredicational("q_", AnyArg), Not(UnitPredicational("p_", AnyArg))) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@

package org.keymaerax.btactics

import io.github.classgraph.ClassGraph
import org.keymaerax.Logging
import org.keymaerax.bellerophon._
import org.keymaerax.btactics.macros._
import org.keymaerax.core._
import org.keymaerax.core.btactics.annotations.Derivation
import org.keymaerax.infrastruct._
import org.reflections.Reflections
import org.reflections.scanners.Scanners

import java.lang.reflect.{Field, InvocationTargetException}
import scala.annotation.{nowarn, tailrec}
import scala.collection.mutable
import scala.jdk.CollectionConverters.CollectionHasAsScala
Expand Down Expand Up @@ -182,6 +185,65 @@ object DerivationInfoRegistry extends Logging {
)
}

/** Evaluate the field of a Scala object and register it with the global [[DerivationInfo]] object. */
private def registerDerivationFromField(clazz: Class[_], instance: AnyRef, field: Field): Unit = {
// Val fields are private but have public getter functions of the same name.
val getter = clazz.getMethod(field.getName)

// Exceptions that occur during reflection are wrapped in an InvocationTargetException.
// We unwrap and rethrow so from the outside it looks like we just evaluated the derivation normally.
val valueAny =
try { getter.invoke(instance) }
catch { case e: InvocationTargetException => throw e.getCause }

val value = valueAny.asInstanceOf[DerivationInfo]
DerivationInfo.register(value)
}

/**
* Find and register all fields of a Scala object which are marked with the [[Derivation]] annotation.
*
* @return
* A list of fields that failed to initialize and register, along with the cause.
*/
private def registerDerivationsFromObject(clazz: Class[_]): Seq[(Field, Throwable)] = {
// An object's instance can be located through its public static final MODULE$ field.
// TODO Fail gracefully if this is not a Scala object
val instance = clazz.getField("MODULE$").get(null)

clazz
.getDeclaredFields
.toSeq
.filter(_.isAnnotationPresent(classOf[Derivation]))
.flatMap(field =>
try {
registerDerivationFromField(clazz, instance, field)
None
} catch { case t: Throwable => Some(field -> t) }
)
}

/**
* Find and register all fields marked with the [[Derivation]] annotation globally. The fields must be Scala object
* fields, and they must be of type [[DerivationInfo]].
*
* This function is implemented without using the Scala 2 reflection API. Hopefully this will make it easier to
* upgrade to Scala 3 later.
*
* @return
* A list of fields that failed to initialize and register, along with the cause.
*/
private def registerDerivationsGlobally(): Seq[(Field, Throwable)] = {
import scala.jdk.CollectionConverters._
new ClassGraph()
.enableAllInfo()
.scan()
.getClassesWithFieldAnnotation(classOf[Derivation])
.asScala
.toSeq
.flatMap(info => registerDerivationsFromObject(info.loadClass()))
}

/** Has the global DerivationInfo list been initialized? */
def isInit: Boolean = DerivationInfo._allInfo.nonEmpty

Expand All @@ -202,6 +264,17 @@ object DerivationInfoRegistry extends Logging {
// Remember that initialization is in progress,
DerivationInfo._initStatus = DerivationInfo.InitInProgress
if (!initLibrary) return // Advanced use - user takes over in-progress initialization

val derivationErrors = registerDerivationsGlobally()
if (derivationErrors.nonEmpty) {
println()
println("Failed to initialize derivations:")
derivationErrors.foreach { case (field, _) =>
println(s"- (in ${field.getDeclaringClass.getName}) ${field.getName}")
}
println()
}

// Initialize derived axioms and rules, which automatically initializes their AxiomInfo and RuleInfo too
// To allow working with restricted functionality: continue initialization despite potential errors in
// deriving axioms, throw exception at end of initialization
Expand Down Expand Up @@ -247,6 +320,7 @@ object DerivationInfoRegistry extends Logging {
s"@Tactic init failed: Following DerivationInfo never implemented as @Tactic: " + unimplemented.mkString(", "),
)
DerivationInfo._initStatus = DerivationInfo.InitComplete
if (derivationErrors.nonEmpty) { throw derivationErrors.head._2 }
deriveErrors match {
case Left(t) => throw t
case _ => // nothing to do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ object DerivationInfoAugmentors {

implicit class ProvableInfoAugmentor(val pi: ProvableInfo) {
val derivedAxiomDB: LemmaDB = LemmaDBFactory.lemmaDB
def derivedAxiomOrRule(name: String): ProvableSig = {
val lemmaName = DerivationInfo(name) match {
def derivedAxiomOrRule(info: DerivationInfo): ProvableSig = {
val lemmaName = info match {
case si: StorableInfo => si.storedName
case _ => throw new IllegalArgumentException(s"Axiom or rule $name is not storable")
case _ => throw new IllegalArgumentException(s"Axiom or rule ${info.canonicalName} is not storable")
}
require(
derivedAxiomDB.contains(lemmaName),
Expand All @@ -40,7 +40,7 @@ object DerivationInfoAugmentors {
.get(lemmaName)
.getOrElse(
throw new IllegalArgumentException(
"Lemma " + lemmaName + " for derived axiom/rule " + name + " should have been added already"
s"Lemma $lemmaName for derived axiom/rule ${info.canonicalName} should have been added already"
)
)
.fact
Expand All @@ -54,8 +54,8 @@ object DerivationInfoAugmentors {
val provable = pi match {
case cai: CoreAxiomInfo => ProvableSig.axioms(cai.canonicalName)
case cari: AxiomaticRuleInfo => ProvableSig.rules(cari.canonicalName)
case dai: DerivedAxiomInfo => derivedAxiomOrRule(dai.canonicalName)
case dari: DerivedRuleInfo => derivedAxiomOrRule(dari.canonicalName)
case dai: DerivedAxiomInfo => derivedAxiomOrRule(dai)
case dari: DerivedRuleInfo => derivedAxiomOrRule(dari)
}
pi.theProvable = Some(provable)
provable
Expand All @@ -69,7 +69,7 @@ object DerivationInfoAugmentors {
case Some(formula) => formula.asInstanceOf[Formula]
case None =>
val formula = pi match {
case dai: DerivedAxiomInfo => derivedAxiomOrRule(dai.canonicalName).conclusion.succ.head
case dai: DerivedAxiomInfo => derivedAxiomOrRule(dai).conclusion.succ.head
case _: CoreAxiomInfo => ProvableSig.axiom.get(pi.canonicalName) match {
case Some(fml) => fml
case None => throw AxiomNotFoundException("No formula for core axiom " + pi.canonicalName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,27 +128,34 @@ object DerivationInfo {
}
}

private def addInfoWithChecks(info: DerivationInfo): Unit = {
_allInfo.get(info.codeName) match {
case None => addInfo(info)
case Some(knownInfo) if knownInfo == info => // Nothing to do here
case Some(knownInfo) => throw new IllegalArgumentException(
s"Duplicate name registration attempt: ${info.codeName} already registered as $knownInfo of ${knownInfo.getClass.getSimpleName}"
)
}
}

def register(info: DerivationInfo): Unit = {
requireInitInProgress(info)
addInfoWithChecks(info)
}

// Hack: derivedAxiom function expects its own derivedaxiominfo to be present during evaluation so that
// it can look up a stored name rather than computing it. The actual solution is a simple refactor but it touches lots
// of code so just delay [[value == derivedAxiom(...)]] execution till after info/
def registerR[T, R <: DerivationInfo](value: => T, p: R): R = {
// Note: We don't require DerivationInfo initialization to be "in progress" for registerR because it used for axioms rather than tactics.
if (!_allInfo.contains(p.codeName)) addInfo(p)
else if (_allInfo(p.codeName) != p) throw new IllegalArgumentException(
"Duplicate name registration attempt: axiom " + p.codeName + " already registered as " + _allInfo(p.codeName) +
" of " + _allInfo(p.codeName).getClass.getSimpleName
)
addInfoWithChecks(p)
val _ = value
p
}

def registerL[T, R <: DerivationInfo](value: => T, p: R): T = {
requireInitInProgress(p)
if (!_allInfo.contains(p.codeName)) addInfo(p)
else if (_allInfo(p.codeName) != p) throw new IllegalArgumentException(
"Duplicate name registration attempt: tactic " + p.codeName + " already registered as " + _allInfo(p.codeName) +
" of " + _allInfo(p.codeName).getClass.getSimpleName
)
addInfoWithChecks(p)
val _ = value
value
}
Expand Down

0 comments on commit 41b81fa

Please sign in to comment.