Skip to content

Commit

Permalink
Check separation of different parts of a declared type.
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Jan 24, 2025
1 parent 35cb53c commit ab4d96a
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 48 deletions.
154 changes: 137 additions & 17 deletions compiler/src/dotty/tools/dotc/cc/SepCheck.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,28 @@ import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
import CaptureSet.{Refs, emptySet, HiddenSet}
import config.Printers.capt
import StdNames.nme
import util.{SimpleIdentitySet, EqHashMap}
import util.{SimpleIdentitySet, EqHashMap, SrcPos}

object SepChecker:

/** Enumerates kinds of captures encountered so far */
enum Captures:
case None
case Explicit // one or more explicitly declared captures
case Hidden // exacttly one hidden captures
case NeedsCheck // one hidden capture and one other capture (hidden or declared)

def add(that: Captures): Captures =
if this == None then that
else if that == None then this
else if this == Explicit && that == Explicit then Explicit
else NeedsCheck
end Captures

class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
import tpd.*
import checker.*
import SepChecker.*

/** The set of capabilities that are hidden by a polymorphic result type
* of some previous definition.
Expand Down Expand Up @@ -52,21 +69,17 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:

private def hidden(using Context): Refs =
val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet

def hiddenByElem(elem: CaptureRef): Refs =
if seen.add(elem) then elem match
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs.elems)
case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly)
case _ => emptySet
else emptySet

def recur(cs: Refs): Refs =
(emptySet /: cs): (elems, elem) =>
elems ++ hiddenByElem(elem)

if seen.add(elem) then elems ++ hiddenByElem(elem, recur)
else elems
recur(refs)
end hidden

private def containsHidden(using Context): Boolean =
refs.exists: ref =>
!hiddenByElem(ref, _ => emptySet).isEmpty

/** Deduct the footprint of `sym` and `sym*` from `refs` */
private def deductSym(sym: Symbol)(using Context) =
val ref = sym.termRef
Expand All @@ -79,6 +92,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
refs -- captures(dep).footprint
end extension

private def hiddenByElem(ref: CaptureRef, recur: Refs => Refs)(using Context): Refs = ref match
case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs.elems)
case ReadOnlyCapability(ref1) => hiddenByElem(ref1, recur).map(_.readOnly)
case _ => emptySet

/** The captures of an argument or prefix widened to the formal parameter, if
* the latter contains a cap.
*/
Expand Down Expand Up @@ -186,6 +204,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
for (arg, idx) <- indexedArgs do
if arg.needsSepCheck then
val ac = formalCaptures(arg)
checkType(arg.formalType, arg.srcPos, NoSymbol, " the argument's adapted type")
val hiddenInArg = ac.hidden.footprint
//println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
Expand All @@ -212,6 +231,105 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
if !overlap.isEmpty then
sepUseError(tree, usedFootprint, overlap)

def checkType(tpt: Tree, sym: Symbol)(using Context): Unit =
checkType(tpt.nuType, tpt.srcPos, sym, "")

/** Check that all parts of type `tpe` are separated.
* @param tpe the type to check
* @param pos position for error reporting
* @param sym if `tpe` is the (result-) type of a val or def, the symbol of
* this definition, otherwise NoSymbol. If `sym` exists we
* deduct its associated direct and reach capabilities everywhere
* from the capture sets we check.
* @param what a string describing what kind of type it is
*/
def checkType(tpe: Type, pos: SrcPos, sym: Symbol, what: String)(using Context): Unit =

def checkParts(parts: List[Type]): Unit =
var footprint: Refs = emptySet
var hiddenSet: Refs = emptySet
var checked = 0
for part <- parts do

/** Report an error if `current` and `next` overlap.
* @param current the footprint or hidden set seen so far
* @param next the footprint or hidden set of the next part
* @param mapRefs a function over the capture set elements of the next part
* that returns the references of the same kind as `current`
* (i.e. the part's footprint or hidden set)
* @param prevRel a verbal description of current ("references or "hides")
* @param nextRel a verbal descriiption of next
*/
def checkSep(current: Refs, next: Refs, mapRefs: Refs => Refs, prevRel: String, nextRel: String): Unit =
val globalOverlap = current.overlapWith(next)
if !globalOverlap.isEmpty then
val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
.map: prev =>
val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym)
(i", $prev , ", prevRefs, prevRefs.overlapWith(next))
.dropWhile(_._3.isEmpty)
.nextOption
.getOrElse(("", current, globalOverlap))
report.error(
em"""Separation failure in$what type $tpe.
|One part, $part , $nextRel ${CaptureSet(next)}.
|A previous part$prevStr $prevRel ${CaptureSet(prevRefs)}.
|The two sets overlap at ${CaptureSet(overlap)}.""",
pos)

val partRefs = part.deepCaptureSet.elems
val partFootprint = partRefs.footprint.deductSym(sym)
val partHidden = partRefs.hidden.footprint.deductSym(sym) -- partFootprint

checkSep(footprint, partHidden, identity, "references", "hides")
checkSep(hiddenSet, partHidden, _.hidden, "also hides", "hides")
checkSep(hiddenSet, partFootprint, _.hidden, "hides", "references")

footprint ++= partFootprint
hiddenSet ++= partHidden
checked += 1
end for
end checkParts

object traverse extends TypeAccumulator[Captures]:

/** A stack of part lists to check. We maintain this since immediately
* checking parts when traversing the type would check innermost to oputermost.
* But we want to check outermost parts first since this prioritized errors
* that are more obvious.
*/
var toCheck: List[List[Type]] = Nil

private val seen = util.HashSet[Symbol]()

def apply(c: Captures, t: Type) =
if variance < 0 then c
else
val t1 = t.dealias
t1 match
case t @ AppliedType(tycon, args) =>
val c1 = foldOver(Captures.None, t)
if c1 == Captures.NeedsCheck then
toCheck = (tycon :: args) :: toCheck
c.add(c1)
case t @ CapturingType(parent, cs) =>
val c1 = this(c, parent)
if cs.elems.containsHidden then c1.add(Captures.Hidden)
else if !cs.elems.isEmpty then c1.add(Captures.Explicit)
else c1
case t: TypeRef if t.symbol.isAbstractOrParamType =>
if seen.contains(t.symbol) then c
else
seen += t.symbol
apply(apply(c, t.prefix), t.info.bounds.hi)
case t =>
foldOver(c, t)

if !tpe.hasAnnotation(defn.UntrackedCapturesAnnot) then
traverse(Captures.None, tpe)
traverse.toCheck.foreach(checkParts)
end checkType

private def collectMethodTypes(tp: Type): List[TermLambda] = tp match
case tp: MethodType => tp :: collectMethodTypes(tp.resType)
case tp: PolyType => collectMethodTypes(tp.resType)
Expand All @@ -231,7 +349,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
(formal, arg) <- mt.paramInfos.zip(args)
dep <- formal.captureSet.elems.toList
do
val referred = dep match
val referred = dep.stripReach match
case dep: TermParamRef =>
argMap(dep.binder)(dep.paramNum) :: Nil
case dep: ThisType if dep.cls == fn.symbol.owner =>
Expand Down Expand Up @@ -269,11 +387,13 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
defsShadow = saved
case tree: ValOrDefDef =>
traverseChildren(tree)
if previousDefs.nonEmpty && !tree.symbol.isOneOf(TermParamOrAccessor) then
capt.println(i"sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}")
defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
resultType(tree.symbol) = tree.tpt.nuType
previousDefs.head += tree
if !tree.symbol.isOneOf(TermParamOrAccessor) then
checkType(tree.tpt, tree.symbol)
if previousDefs.nonEmpty then
capt.println(i"sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}")
defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
resultType(tree.symbol) = tree.tpt.nuType
previousDefs.head += tree
case _ =>
traverseChildren(tree)
end SepChecker
Expand Down
8 changes: 1 addition & 7 deletions tests/neg-custom-args/captures/capt-depfun.check
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt-depfun.scala:10:43 ----------------------------------
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
| ^^^^^^^
| Found: Str^{} ->{ac, y, z} Str^{y, z}
| Required: Str^{y, z} ->{fresh} Str^{y, z}
|
| longer explanation available when compiling with `-explain`
-- Error: tests/neg-custom-args/captures/capt-depfun.scala:10:24 -------------------------------------------------------
10 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
| Separation failure: Str^{y, z} => Str^{y, z} captures a root element hiding {ac, y, z}
| and also refers to {y, z}.
| The two sets overlap at {y, z}
2 changes: 1 addition & 1 deletion tests/neg-custom-args/captures/capt-depfun.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ class Str
def f(y: Cap, z: Cap) =
def g(): C @retains(y, z) = ???
val ac: ((x: Cap) => Str @retains(x) => Str @retains(x)) = ???
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error sepcheck
val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error
35 changes: 25 additions & 10 deletions tests/neg-custom-args/captures/reaches2.check
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
-- Error: tests/neg-custom-args/captures/reaches2.scala:8:10 -----------------------------------------------------------
8 | ps.map((x, y) => compose1(x, y)) // error // error
| ^
|reference ps* is not included in the allowed capture set {}
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
-- Error: tests/neg-custom-args/captures/reaches2.scala:8:13 -----------------------------------------------------------
8 | ps.map((x, y) => compose1(x, y)) // error // error
| ^
|reference ps* is not included in the allowed capture set {}
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
-- Error: tests/neg-custom-args/captures/reaches2.scala:10:10 ----------------------------------------------------------
10 | ps.map((x, y) => compose1(x, y)) // error // error // error
| ^
|reference ps* is not included in the allowed capture set {}
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
-- Error: tests/neg-custom-args/captures/reaches2.scala:10:13 ----------------------------------------------------------
10 | ps.map((x, y) => compose1(x, y)) // error // error // error
| ^
|reference ps* is not included in the allowed capture set {}
|of an enclosing function literal with expected type ((box A ->{ps*} A, box A ->{ps*} A)) -> box (x$0: A^?) ->? A^?
-- Error: tests/neg-custom-args/captures/reaches2.scala:10:31 ----------------------------------------------------------
10 | ps.map((x, y) => compose1(x, y)) // error // error // error
| ^
| Separation failure: argument of type (x$0: A) ->{y} box A^?
| to method compose1: [A, B, C](f: A => B, g: B => C): A ->{f, g} C
| corresponds to capture-polymorphic formal parameter g of type box A^? => box A^?
| and captures {ps*}, but this capability is also passed separately
| in the first argument with type (x$0: A) ->{x} box A^?.
|
| Capture set of first argument : {x}
| Hidden set of current argument : {y}
| Footprint of first argument : {x, ps*}
| Hidden footprint of current argument : {y, ps*}
| Declared footprint of current argument: {}
| Undeclared overlap of footprints : {ps*}
4 changes: 3 additions & 1 deletion tests/neg-custom-args/captures/reaches2.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import language.`3.8` // sepchecks on

class List[+A]:
def map[B](f: A -> B): List[B] = ???

def compose1[A, B, C](f: A => B, g: B => C): A ->{f, g} C =
z => g(f(z))

def mapCompose[A](ps: List[(A => A, A => A)]): List[A ->{ps*} A] =
ps.map((x, y) => compose1(x, y)) // error // error
ps.map((x, y) => compose1(x, y)) // error // error // error

37 changes: 29 additions & 8 deletions tests/neg-custom-args/captures/sepchecks2.check
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:7:10 ---------------------------------------------------------
7 | println(c) // error
| ^
| Separation failure: Illegal access to {c} which is hidden by the previous definition
| of value xs with type List[box () => Unit].
| This type hides capabilities {xs*, c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:10:33 --------------------------------------------------------
10 | foo((() => println(c)) :: Nil, c) // error
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:10:10 --------------------------------------------------------
10 | println(c) // error
| ^
| Separation failure: Illegal access to {c} which is hidden by the previous definition
| of value xs with type List[box () => Unit].
| This type hides capabilities {xs*, c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:13:33 --------------------------------------------------------
13 | foo((() => println(c)) :: Nil, c) // error
| ^
| Separation failure: argument of type (c : Object^)
| to method foo: (xs: List[box () => Unit], y: Object^): Nothing
Expand All @@ -19,3 +19,24 @@
| Hidden footprint of current argument : {c}
| Declared footprint of current argument: {}
| Undeclared overlap of footprints : {c}
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:14:10 --------------------------------------------------------
14 | val x1: (Object^, Object^) = (c, c) // error
| ^^^^^^^^^^^^^^^^^^
| Separation failure in type (box Object^, box Object^).
| One part, box Object^ , hides {c}.
| A previous part, box Object^ , also hides {c}.
| The two sets overlap at {c}.
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:15:10 --------------------------------------------------------
15 | val x2: (Object^, Object^{d}) = (d, d) // error
| ^^^^^^^^^^^^^^^^^^^^^
| Separation failure in type (box Object^, box Object^{d}).
| One part, box Object^{d} , references {d}.
| A previous part, box Object^ , hides {d}.
| The two sets overlap at {d}.
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:27:6 ---------------------------------------------------------
27 | bar((c, c)) // error
| ^^^^^^
| Separation failure in the argument's adapted type type (box Object^, box Object^).
| One part, box Object^ , hides {c}.
| A previous part, box Object^ , also hides {c}.
| The two sets overlap at {c}.
20 changes: 19 additions & 1 deletion tests/neg-custom-args/captures/sepchecks2.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
import language.future // sepchecks on


def foo(xs: List[() => Unit], y: Object^) = ???

def bar(x: (Object^, Object^)): Unit = ???

def Test(c: Object^) =
val xs: List[() => Unit] = (() => println(c)) :: Nil
println(c) // error

def Test2(c: Object^) =
def Test2(c: Object^, d: Object^): Unit =
foo((() => println(c)) :: Nil, c) // error
val x1: (Object^, Object^) = (c, c) // error
val x2: (Object^, Object^{d}) = (d, d) // error

def Test3(c: Object^, d: Object^) =
val x: (Object^, Object^) = (c, d) // ok

def Test4(c: Object^, d: Object^) =
val x: (Object^, Object^{c}) = (d, c) // ok

def Test5(c: Object^, d: Object^): Unit =
bar((c, d)) // ok

def Test6(c: Object^, d: Object^): Unit =
bar((c, c)) // error

3 changes: 2 additions & 1 deletion tests/pos-custom-args/captures/i15749a.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import caps.cap
import caps.use
import language.`3.7` // sepchecks on

class Unit
object u extends Unit
Expand All @@ -13,7 +14,7 @@ def test =
def wrapper[T](x: T): Wrapper[T] = Wrapper:
[X] => (op: T ->{cap} X) => op(x)

def strictMap[A <: Top, B <: Top](mx: Wrapper[A])(f: A ->{cap} B): Wrapper[B] =
def strictMap[A <: Top, B <: Top](mx: Wrapper[A])(f: A ->{cap, mx*} B): Wrapper[B] =
mx.value((x: A) => wrapper(f(x)))

def force[A](thunk: Unit ->{cap} A): A = thunk(u)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import scala.reflect.ClassTag
import annotation.unchecked.{uncheckedVariance, uncheckedCaptures}
import annotation.tailrec
import caps.cap
import caps.untrackedCaptures
import language.`3.7` // sepchecks on

/** A strawman architecture for new collections. It contains some
Expand Down Expand Up @@ -68,11 +69,13 @@ object CollectionStrawMan5 {
/** Base trait for strict collections */
trait Buildable[+A] extends Iterable[A] {
protected def newBuilder: Builder[A, Repr] @uncheckedVariance
override def partition(p: A => Boolean): (Repr, Repr) = {
override def partition(p: A => Boolean): (Repr, Repr) @untrackedCaptures =
// Without untrackedCaptures this fails SepChecks.checkType.
// But this is probably an error in the hiding logic.
// TODO remove @untrackedCaptures and investigate
val l, r = newBuilder
iterator.foreach(x => (if (p(x)) l else r) += x)
(l.result, r.result)
}
// one might also override other transforms here to avoid generating
// iterators if it helps efficiency.
}
Expand Down

0 comments on commit ab4d96a

Please sign in to comment.