|
| 1 | +package dotty.tools |
| 2 | +package dotc |
| 3 | +package cc |
| 4 | +import ast.tpd |
| 5 | +import collection.mutable |
| 6 | + |
| 7 | +import core.* |
| 8 | +import Symbols.*, Types.* |
| 9 | +import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.* |
| 10 | +import CaptureSet.{Refs, emptySet} |
| 11 | +import config.Printers.capt |
| 12 | +import StdNames.nme |
| 13 | + |
| 14 | +class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser: |
| 15 | + import tpd.* |
| 16 | + import checker.* |
| 17 | + |
| 18 | + extension (refs: Refs) |
| 19 | + private def footprint(using Context): Refs = |
| 20 | + def recur(elems: Refs, newElems: List[CaptureRef]): Refs = newElems match |
| 21 | + case newElem :: newElems1 => |
| 22 | + val superElems = newElem.captureSetOfInfo.elems.filter: superElem => |
| 23 | + !superElem.isMaxCapability && !elems.contains(superElem) |
| 24 | + recur(elems ++ superElems, newElems1 ++ superElems.toList) |
| 25 | + case Nil => elems |
| 26 | + val elems: Refs = refs.filter(!_.isMaxCapability) |
| 27 | + recur(elems, elems.toList) |
| 28 | + |
| 29 | + private def overlapWith(other: Refs)(using Context): Refs = |
| 30 | + val refs1 = refs |
| 31 | + val refs2 = other |
| 32 | + def common(refs1: Refs, refs2: Refs) = |
| 33 | + refs1.filter: ref => |
| 34 | + ref.isExclusive && refs2.exists(_.stripReadOnly eq ref) |
| 35 | + common(refs, other) ++ common(other, refs) |
| 36 | + |
| 37 | + private def hidden(refs: Refs)(using Context): Refs = |
| 38 | + val seen: util.EqHashSet[CaptureRef] = new util.EqHashSet |
| 39 | + |
| 40 | + def hiddenByElem(elem: CaptureRef): Refs = |
| 41 | + if seen.add(elem) then elem match |
| 42 | + case Fresh.Cap(hcs) => hcs.elems.filter(!_.isRootCapability) ++ recur(hcs.elems) |
| 43 | + case ReadOnlyCapability(ref) => hiddenByElem(ref).map(_.readOnly) |
| 44 | + case _ => emptySet |
| 45 | + else emptySet |
| 46 | + |
| 47 | + def recur(cs: Refs): Refs = |
| 48 | + (emptySet /: cs): (elems, elem) => |
| 49 | + elems ++ hiddenByElem(elem) |
| 50 | + |
| 51 | + recur(refs) |
| 52 | + end hidden |
| 53 | + |
| 54 | + /** The captures of an argument or prefix widened to the formal parameter, if |
| 55 | + * the latter contains a cap. |
| 56 | + */ |
| 57 | + private def formalCaptures(arg: Tree)(using Context): Refs = |
| 58 | + val argType = arg.formalType.orElse(arg.nuType) |
| 59 | + (if arg.nuType.hasUseAnnot then argType.deepCaptureSet else argType.captureSet) |
| 60 | + .elems |
| 61 | + |
| 62 | + /** The captures of an argument of prefix. No widening takes place */ |
| 63 | + private def actualCaptures(arg: Tree)(using Context): Refs = |
| 64 | + val argType = arg.nuType |
| 65 | + (if argType.hasUseAnnot then argType.deepCaptureSet else argType.captureSet) |
| 66 | + .elems |
| 67 | + |
| 68 | + private def sepError(fn: Tree, args: List[Tree], argIdx: Int, |
| 69 | + overlap: Refs, hiddenInArg: Refs, footprints: List[(Refs, Int)], |
| 70 | + deps: collection.Map[Tree, List[Tree]])(using Context): Unit = |
| 71 | + val arg = args(argIdx) |
| 72 | + def paramName(mt: Type, idx: Int): Option[Name] = mt match |
| 73 | + case mt @ MethodType(pnames) => |
| 74 | + if idx < pnames.length then Some(pnames(idx)) else paramName(mt.resType, idx - pnames.length) |
| 75 | + case mt: PolyType => paramName(mt.resType, idx) |
| 76 | + case _ => None |
| 77 | + def formalName = paramName(fn.nuType.widen, argIdx) match |
| 78 | + case Some(pname) => i"$pname " |
| 79 | + case _ => "" |
| 80 | + def whatStr = if overlap.size == 1 then "this capability is" else "these capabilities are" |
| 81 | + def funStr = |
| 82 | + if fn.symbol.exists then i"${fn.symbol}: ${fn.symbol.info}" |
| 83 | + else i"a function of type ${fn.nuType.widen}" |
| 84 | + val clashIdx = footprints |
| 85 | + .collect: |
| 86 | + case (fp, idx) if !hiddenInArg.overlapWith(fp).isEmpty => idx |
| 87 | + .head |
| 88 | + def whereStr = clashIdx match |
| 89 | + case 0 => "function prefix" |
| 90 | + case 1 => "first argument " |
| 91 | + case 2 => "second argument" |
| 92 | + case 3 => "third argument " |
| 93 | + case n => s"${n}th argument " |
| 94 | + def clashTree = |
| 95 | + if clashIdx == 0 then methPart(fn).asInstanceOf[Select].qualifier |
| 96 | + else args(clashIdx - 1) |
| 97 | + def clashType = clashTree.nuType |
| 98 | + def clashCaptures = actualCaptures(clashTree) |
| 99 | + def hiddenCaptures = hidden(formalCaptures(arg)) |
| 100 | + def clashFootprint = clashCaptures.footprint |
| 101 | + def hiddenFootprint = hiddenCaptures.footprint |
| 102 | + def declaredFootprint = deps(arg).map(actualCaptures(_)).foldLeft(emptySet)(_ ++ _).footprint |
| 103 | + def footprintOverlap = hiddenFootprint.overlapWith(clashFootprint) -- declaredFootprint |
| 104 | + report.error( |
| 105 | + em"""Separation failure: argument of type ${arg.nuType} |
| 106 | + |to $funStr |
| 107 | + |corresponds to capture-polymorphic formal parameter ${formalName}of type ${arg.formalType} |
| 108 | + |and captures ${CaptureSet(overlap)}, but $whatStr also passed separately |
| 109 | + |in the ${whereStr.trim} with type $clashType. |
| 110 | + | |
| 111 | + | Capture set of $whereStr : ${CaptureSet(clashCaptures)} |
| 112 | + | Hidden set of current argument : ${CaptureSet(hiddenCaptures)} |
| 113 | + | Footprint of $whereStr : ${CaptureSet(clashFootprint)} |
| 114 | + | Hidden footprint of current argument : ${CaptureSet(hiddenFootprint)} |
| 115 | + | Declared footprint of current argument: ${CaptureSet(declaredFootprint)} |
| 116 | + | Undeclared overlap of footprints : ${CaptureSet(footprintOverlap)}""", |
| 117 | + arg.srcPos) |
| 118 | + end sepError |
| 119 | + |
| 120 | + private def checkApply(fn: Tree, args: List[Tree], deps: collection.Map[Tree, List[Tree]])(using Context): Unit = |
| 121 | + val fnCaptures = methPart(fn) match |
| 122 | + case Select(qual, _) => qual.nuType.captureSet |
| 123 | + case _ => CaptureSet.empty |
| 124 | + capt.println(i"check separate $fn($args), fnCaptures = $fnCaptures, argCaptures = ${args.map(arg => CaptureSet(formalCaptures(arg)))}, deps = ${deps.toList}") |
| 125 | + var footprint = fnCaptures.elems.footprint |
| 126 | + val footprints = mutable.ListBuffer[(Refs, Int)]((footprint, 0)) |
| 127 | + val indexedArgs = args.zipWithIndex |
| 128 | + |
| 129 | + def subtractDeps(elems: Refs, arg: Tree): Refs = |
| 130 | + deps(arg).foldLeft(elems): (elems, dep) => |
| 131 | + elems -- actualCaptures(dep).footprint |
| 132 | + |
| 133 | + for (arg, idx) <- indexedArgs do |
| 134 | + if !arg.needsSepCheck then |
| 135 | + footprint = footprint ++ subtractDeps(actualCaptures(arg).footprint, arg) |
| 136 | + footprints += ((footprint, idx + 1)) |
| 137 | + for (arg, idx) <- indexedArgs do |
| 138 | + if arg.needsSepCheck then |
| 139 | + val ac = formalCaptures(arg) |
| 140 | + val hiddenInArg = hidden(ac).footprint |
| 141 | + //println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg") |
| 142 | + val overlap = subtractDeps(hiddenInArg.overlapWith(footprint), arg) |
| 143 | + if !overlap.isEmpty then |
| 144 | + sepError(fn, args, idx, overlap, hiddenInArg, footprints.toList, deps) |
| 145 | + footprint ++= actualCaptures(arg).footprint |
| 146 | + footprints += ((footprint, idx + 1)) |
| 147 | + end checkApply |
| 148 | + |
| 149 | + private def collectMethodTypes(tp: Type): List[TermLambda] = tp match |
| 150 | + case tp: MethodType => tp :: collectMethodTypes(tp.resType) |
| 151 | + case tp: PolyType => collectMethodTypes(tp.resType) |
| 152 | + case _ => Nil |
| 153 | + |
| 154 | + private def dependencies(fn: Tree, argss: List[List[Tree]])(using Context): collection.Map[Tree, List[Tree]] = |
| 155 | + val mtpe = |
| 156 | + if fn.symbol.exists then fn.symbol.info |
| 157 | + else fn.tpe.widen // happens for PolyFunction applies |
| 158 | + val mtps = collectMethodTypes(mtpe) |
| 159 | + assert(mtps.hasSameLengthAs(argss), i"diff for $fn: ${fn.symbol} /// $mtps /// $argss") |
| 160 | + val mtpsWithArgs = mtps.zip(argss) |
| 161 | + val argMap = mtpsWithArgs.toMap |
| 162 | + val deps = mutable.HashMap[Tree, List[Tree]]().withDefaultValue(Nil) |
| 163 | + for |
| 164 | + (mt, args) <- mtpsWithArgs |
| 165 | + (formal, arg) <- mt.paramInfos.zip(args) |
| 166 | + dep <- formal.captureSet.elems.toList |
| 167 | + do |
| 168 | + val referred = dep match |
| 169 | + case dep: TermParamRef => |
| 170 | + argMap(dep.binder)(dep.paramNum) :: Nil |
| 171 | + case dep: ThisType if dep.cls == fn.symbol.owner => |
| 172 | + val Select(qual, _) = fn: @unchecked |
| 173 | + qual :: Nil |
| 174 | + case _ => |
| 175 | + Nil |
| 176 | + deps(arg) ++= referred |
| 177 | + deps |
| 178 | + |
| 179 | + private def traverseApply(tree: Tree, argss: List[List[Tree]])(using Context): Unit = tree match |
| 180 | + case Apply(fn, args) => traverseApply(fn, args :: argss) |
| 181 | + case TypeApply(fn, args) => traverseApply(fn, argss) // skip type arguments |
| 182 | + case _ => |
| 183 | + if argss.nestedExists(_.needsSepCheck) then |
| 184 | + checkApply(tree, argss.flatten, dependencies(tree, argss)) |
| 185 | + |
| 186 | + def traverse(tree: Tree)(using Context): Unit = |
| 187 | + tree match |
| 188 | + case tree: GenericApply => |
| 189 | + if tree.symbol != defn.Caps_unsafeAssumeSeparate then |
| 190 | + tree.tpe match |
| 191 | + case _: MethodOrPoly => |
| 192 | + case _ => traverseApply(tree, Nil) |
| 193 | + traverseChildren(tree) |
| 194 | + case _ => |
| 195 | + traverseChildren(tree) |
| 196 | +end SepChecker |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | + |
| 201 | + |
| 202 | + |
0 commit comments