From b0323181b6a9c20fcd994adb80bbc76bd5e79d4c Mon Sep 17 00:00:00 2001 From: Markus Alexander Kuppe Date: Fri, 16 Aug 2024 10:43:55 +0200 Subject: [PATCH] Rewrite for set membership of the powerset of a record where the record has infinite co-domains. Simplified real-world scenario: ```tla EXTENDS Integers VARIABLE \* @type: Set({ p: (Int) }); v TypeOK == v \in SUBSET [ p: Int ] Init == v = { [p |-> 42] } Next == UNCHANGED v ``` Apalache Error: ```sh $ apalache-mc check --inv=TypeOK APARecSub.tla [...] Input error (see the manual): Found a set map over an infinite set of CellTFrom(Int). Not supported. ``` Rewrite: ```tla S \in SUBSET [a : T] ~~> \A r \in S: DOMAIN r = { "a" } /\ r.a \in T ``` Related commits, issues, PRs: * 625a1645e75a910d43759c36ad8a06291ebc55b3 * 785e26925a45b077e14cf79f31f834e0c0919639 * https://github.com/apalache-mc/apalache/issues/723 * https://github.com/apalache-mc/apalache/issues/1627 * https://github.com/apalache-mc/apalache/issues/2762 * https://github.com/apalache-mc/apalache/pull/1453 * https://github.com/apalache-mc/apalache/pull/1629 Signed-off-by: Markus Alexander Kuppe --- .../apalache/tla/pp/ExprOptimizer.scala | 25 ++++++ .../apalache/tla/pp/TestExprOptimizer.scala | 80 ++++++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala index e85d215f85..30381dc0c1 100644 --- a/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala +++ b/tla-pp/src/main/scala/at/forsyte/apalache/tla/pp/ExprOptimizer.scala @@ -97,6 +97,31 @@ class ExprOptimizer(nameGen: UniqueNameGenerator, tracker: TransformationTracker } apply(tla.and(domEq +: fieldsEq: _*).as(b)) } + + // S ∈ SUBSET { ["a" ↦ x] : x ∈ T } + case memEx @ OperEx(TlaSetOper.in, setRec, + OperEx(TlaSetOper.powerset, + OperEx(TlaSetOper.map, OperEx(TlaFunOper.rec, fieldsAndValues @ _*), varsAndSets @ _*))) + if fieldsAndValues.length == varsAndSets.length => + val (fields, values) = TlaOper.deinterleave(fieldsAndValues) + val (vars, sets) = TlaOper.deinterleave(varsAndSets) + assert(fields.length == vars.length) + if (values.zip(vars).exists(p => p._1 != p._2)) { + memEx + } else { + val strSetT = SetT1(StrT1) + val b = BoolT1 + + val domType = getElemType(setRec) + val r = tla.name(nameGen.newName()).as(domType) + + val domEq = tla.eql(tla.dom(r).as(SetT1(domType)), tla.enumSet(fields: _*).as(strSetT)).as(b) + + val fieldsEq = fields.zip(values.zip(sets)).map { case (key, (value, set)) => + tla.in(tla.appFun(r, key).as(value.typeTag.asTlaType1()), set).as(b) + } + apply(tla.forall(r, setRec, tla.and(domEq +: fieldsEq: _*).as(b)).as(b)) + } } /** diff --git a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala index 5f13923734..a38e77fefc 100644 --- a/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala +++ b/tla-pp/src/test/scala/at/forsyte/apalache/tla/pp/TestExprOptimizer.scala @@ -87,22 +87,100 @@ class TestExprOptimizer extends AnyFunSuite with BeforeAndAfterEach { // An optimization for set membership over sets of records. Note that this is the standard form produced by Keramelizer. test("""r \in { [a |-> x, b |-> y]: x \in S, y \in T } becomes DOMAIN r = { "a", "b" } /\ r.a \in S /\ r.b \in T""") { + // ... [a |-> x, b |-> y] ... val recT = RecT1("a" -> IntT1, "b" -> IntT1) - val recSetT = SetT1(recT) + // ... x \in S, y \in T ... val record = enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT) + // ... S ... val S = name("S").as(intSetT) + // ... T ... val T = name("T").as(intSetT) + // { ... } + val recSetT = SetT1(recT) val recordSet = map(record, name("x").as(intT), S, name("y").as(intT), T).as(recSetT) + // r ... val r = name("r").as(recT) + // ... \in ... val input = in(r, recordSet).as(boolT) + // ~~> + + // DOMAIN r = { "a", "b" } val strSetT = SetT1(StrT1) val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT) + // r.a \in S val memA = in(appFun(r, str("a")).as(intT), S).as(boolT) + // r.b \in T val memB = in(appFun(r, str("b")).as(intT), T).as(boolT) + // ... /\ ... /\ ... val expected = and(domEq, memA, memB).as(boolT) val output = optimizer.apply(input) + + assert(expected == output) + } + + // An optimization for set membership of the powerset of a record where the record has infinite co-domains. + test("""S \in SUBSET [a : T] ~~> \A r \in S: DOMAIN r = { "a" } /\ r.a \in T""") { + + // ... { [a |-> x] : x \in T } ... + val recT = RecT1("a" -> IntT1) + val record = + enumFun(str("a"), name("x").as(intT)).as(recT) + val T = name("T").as(intSetT) + val recSetT = SetT1(recT) + val recordSet = map(record, name("x").as(intT), T).as(recSetT) + + // ... SUBSET ... + val powSetT = powSet(recordSet).as(recSetT) + + // S ... + val s = name("S").as(recSetT) + + // ... \in ... + val input = in(s, powSetT).as(boolT) + val output = optimizer.apply(input) + + // ~~> + + // DOMAIN r = { "a" } + val r = name("t_1").as(recT) + val strSetT = SetT1(StrT1) + val domEq = eql(dom(r).as(strSetT), enumSet(str("a")).as(strSetT)).as(boolT) + + // r.a \in T + val memA = in(appFun(r, str("a")).as(intT), T).as(boolT) + + // ... /\ ... + val conjunct = and(domEq, memA).as(boolT) + + // \A ... + val expected = forall(r, s, conjunct).as(boolT) + + assert(expected == output) + } + + test("""S \in SUBSET [a : T, b : U] ~~> \A r \in S: DOMAIN r = { "a", "b" } /\ r.a \in T /\ r.b \in U""") { + val recT = RecT1("a" -> IntT1, "b" -> IntT1) + val record = + enumFun(str("a"), name("x").as(intT), str("b"), name("y").as(intT)).as(recT) + val T = name("T").as(intSetT) + val U = name("U").as(intSetT) + val recSetT = SetT1(recT) + val recordSet = map(record, name("x").as(intT), T, name("y").as(intT), U).as(recSetT) + val powSetT = powSet(recordSet).as(recSetT) + val s = name("S").as(recSetT) + val input = in(s, powSetT).as(boolT) + val output = optimizer.apply(input) + + val r = name("t_1").as(recT) + val strSetT = SetT1(StrT1) + val domEq = eql(dom(r).as(strSetT), enumSet(str("a"), str("b")).as(strSetT)).as(boolT) + val memA = in(appFun(r, str("a")).as(intT), T).as(boolT) + val memB = in(appFun(r, str("b")).as(intT), U).as(boolT) + val conjunct = and(domEq, memA, memB).as(boolT) + val expected = forall(r, s, conjunct).as(boolT) + assert(expected == output) }