|
| 1 | +package x |
| 2 | + |
| 3 | +import scala.annotation._ |
| 4 | +import scala.quoted._ |
| 5 | + |
| 6 | +trait CB[+T] |
| 7 | + |
| 8 | +object CBM: |
| 9 | + def pure[T](t:T):CB[T] = ??? |
| 10 | + def map[A,B](fa:CB[A])(f: A=>B):CB[B] = ??? |
| 11 | + def flatMap[A,B](fa:CB[A])(f: A=>CB[B]):CB[B] = ??? |
| 12 | + def spawn[A](op: =>CB[A]): CB[A] = ??? |
| 13 | + |
| 14 | + |
| 15 | +@compileTimeOnly("await should be inside async block") |
| 16 | +def await[T](f: CB[T]): T = ??? |
| 17 | + |
| 18 | + |
| 19 | +trait CpsExpr[T:Type](prev: Seq[Expr[?]]): |
| 20 | + |
| 21 | + def fLast(using Quotes): Expr[CB[T]] |
| 22 | + def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] |
| 23 | + def append[A:Type](chunk: CpsExpr[A])(using Quotes): CpsExpr[A] |
| 24 | + def syncOrigin(using Quotes): Option[Expr[T]] |
| 25 | + def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] = |
| 26 | + MappedCpsExpr[T,A](Seq(),this,f) |
| 27 | + def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] = |
| 28 | + FlatMappedCpsExpr[T,A](Seq(),this,f) |
| 29 | + |
| 30 | + def transformed(using Quotes): Expr[CB[T]] = |
| 31 | + import quotes.reflect._ |
| 32 | + Block(prev.toList.map(_.asTerm), fLast.asTerm).asExprOf[CB[T]] |
| 33 | + |
| 34 | + |
| 35 | +case class GenericSyncCpsExpr[T:Type](prev: Seq[Expr[?]],last: Expr[T]) extends CpsExpr[T](prev): |
| 36 | + |
| 37 | + override def fLast(using Quotes): Expr[CB[T]] = |
| 38 | + '{ CBM.pure(${last}:T) } |
| 39 | + |
| 40 | + override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] = |
| 41 | + copy(prev = exprs ++: prev) |
| 42 | + |
| 43 | + override def syncOrigin(using Quotes): Option[Expr[T]] = |
| 44 | + import quotes.reflect._ |
| 45 | + Some(Block(prev.toList.map(_.asTerm), last.asTerm).asExprOf[T]) |
| 46 | + |
| 47 | + override def append[A:Type](e: CpsExpr[A])(using Quotes) = |
| 48 | + e.prependExprs(Seq(last)).prependExprs(prev) |
| 49 | + |
| 50 | + override def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] = |
| 51 | + copy(last = '{ $f($last) }) |
| 52 | + |
| 53 | + override def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] = |
| 54 | + GenericAsyncCpsExpr[A](prev, '{ CBM.flatMap(CBM.pure($last))($f) } ) |
| 55 | + |
| 56 | + |
| 57 | +abstract class AsyncCpsExpr[T:Type]( |
| 58 | + prev: Seq[Expr[?]] |
| 59 | + ) extends CpsExpr[T](prev): |
| 60 | + |
| 61 | + override def append[A:Type](e: CpsExpr[A])(using Quotes): CpsExpr[A] = |
| 62 | + flatMap( '{ (x:T) => ${e.transformed} }) |
| 63 | + |
| 64 | + override def syncOrigin(using Quotes): Option[Expr[T]] = None |
| 65 | + |
| 66 | + |
| 67 | + |
| 68 | +case class GenericAsyncCpsExpr[T:Type]( |
| 69 | + prev: Seq[Expr[?]], |
| 70 | + fLastExpr: Expr[CB[T]] |
| 71 | + ) extends AsyncCpsExpr[T](prev): |
| 72 | + |
| 73 | + override def fLast(using Quotes): Expr[CB[T]] = fLastExpr |
| 74 | + |
| 75 | + override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] = |
| 76 | + copy(prev = exprs ++: prev) |
| 77 | + |
| 78 | + override def map[A:Type](f: Expr[T => A])(using Quotes): CpsExpr[A] = |
| 79 | + MappedCpsExpr(Seq(),this,f) |
| 80 | + |
| 81 | + override def flatMap[A:Type](f: Expr[T => CB[A]])(using Quotes): CpsExpr[A] = |
| 82 | + FlatMappedCpsExpr(Seq(),this,f) |
| 83 | + |
| 84 | + |
| 85 | + |
| 86 | +case class MappedCpsExpr[S:Type, T:Type]( |
| 87 | + prev: Seq[Expr[?]], |
| 88 | + point: CpsExpr[S], |
| 89 | + mapping: Expr[S=>T] |
| 90 | + ) extends AsyncCpsExpr[T](prev): |
| 91 | + |
| 92 | + override def fLast(using Quotes): Expr[CB[T]] = |
| 93 | + '{ CBM.map(${point.transformed})($mapping) } |
| 94 | + |
| 95 | + override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] = |
| 96 | + copy(prev = exprs ++: prev) |
| 97 | + |
| 98 | + |
| 99 | + |
| 100 | +case class FlatMappedCpsExpr[S:Type, T:Type]( |
| 101 | + prev: Seq[Expr[?]], |
| 102 | + point: CpsExpr[S], |
| 103 | + mapping: Expr[S => CB[T]] |
| 104 | + ) extends AsyncCpsExpr[T](prev): |
| 105 | + |
| 106 | + override def fLast(using Quotes): Expr[CB[T]] = |
| 107 | + '{ CBM.flatMap(${point.transformed})($mapping) } |
| 108 | + |
| 109 | + override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] = |
| 110 | + copy(prev = exprs ++: prev) |
| 111 | + |
| 112 | + |
| 113 | +class ValRhsFlatMappedCpsExpr[T:Type, V:Type](using thisQuotes: Quotes) |
| 114 | + ( |
| 115 | + prev: Seq[Expr[?]], |
| 116 | + oldValDef: quotes.reflect.ValDef, |
| 117 | + cpsRhs: CpsExpr[V], |
| 118 | + next: CpsExpr[T] |
| 119 | + ) |
| 120 | + extends AsyncCpsExpr[T](prev) { |
| 121 | + |
| 122 | + override def fLast(using Quotes):Expr[CB[T]] = |
| 123 | + import quotes.reflect._ |
| 124 | + next.syncOrigin match |
| 125 | + case Some(nextOrigin) => |
| 126 | + // owner of this block is incorrect |
| 127 | + '{ |
| 128 | + CBM.map(${cpsRhs.transformed})((vx:V) => |
| 129 | + ${buildAppendBlockExpr('vx, nextOrigin)}) |
| 130 | + } |
| 131 | + case None => |
| 132 | + '{ |
| 133 | + CBM.flatMap(${cpsRhs.transformed})((v:V)=> |
| 134 | + ${buildAppendBlockExpr('v, next.transformed)}) |
| 135 | + } |
| 136 | + |
| 137 | + |
| 138 | + override def prependExprs(exprs: Seq[Expr[?]]): CpsExpr[T] = |
| 139 | + ValRhsFlatMappedCpsExpr(using thisQuotes)(exprs ++: prev,oldValDef,cpsRhs,next) |
| 140 | + |
| 141 | + override def append[A:quoted.Type](e: CpsExpr[A])(using Quotes) = |
| 142 | + ValRhsFlatMappedCpsExpr(using thisQuotes)(prev,oldValDef,cpsRhs,next.append(e)) |
| 143 | + |
| 144 | + |
| 145 | + private def buildAppendBlock(using Quotes)(rhs:quotes.reflect.Term, |
| 146 | + exprTerm:quotes.reflect.Term): quotes.reflect.Term = |
| 147 | + import quotes.reflect._ |
| 148 | + import scala.quoted.Expr |
| 149 | + |
| 150 | + val castedOldValDef = oldValDef.asInstanceOf[quotes.reflect.ValDef] |
| 151 | + val valDef = ValDef(castedOldValDef.symbol, Some(rhs.changeOwner(castedOldValDef.symbol))) |
| 152 | + exprTerm match |
| 153 | + case Block(stats,last) => |
| 154 | + Block(valDef::stats, last) |
| 155 | + case other => |
| 156 | + Block(valDef::Nil,other) |
| 157 | + |
| 158 | + private def buildAppendBlockExpr[A:Type](using Quotes)(rhs: Expr[V], expr:Expr[A]):Expr[A] = |
| 159 | + import quotes.reflect._ |
| 160 | + buildAppendBlock(rhs.asTerm,expr.asTerm).asExprOf[A] |
| 161 | + |
| 162 | +} |
| 163 | + |
| 164 | + |
| 165 | +object CpsExpr: |
| 166 | + |
| 167 | + def sync[T:Type](f: Expr[T]): CpsExpr[T] = |
| 168 | + GenericSyncCpsExpr[T](Seq(), f) |
| 169 | + |
| 170 | + def async[T:Type](f: Expr[CB[T]]): CpsExpr[T] = |
| 171 | + GenericAsyncCpsExpr[T](Seq(), f) |
| 172 | + |
| 173 | + |
| 174 | +object Async: |
| 175 | + |
| 176 | + transparent inline def transform[T](inline expr: T) = ${ |
| 177 | + Async.transformImpl[T]('expr) |
| 178 | + } |
| 179 | + |
| 180 | + def transformImpl[T:Type](f: Expr[T])(using Quotes): Expr[CB[T]] = |
| 181 | + import quotes.reflect._ |
| 182 | + // println(s"before transformed: ${f.show}") |
| 183 | + val cpsExpr = rootTransform[T](f) |
| 184 | + val r = '{ CBM.spawn(${cpsExpr.transformed}) } |
| 185 | + // println(s"transformed value: ${r.show}") |
| 186 | + r |
| 187 | + |
| 188 | + def rootTransform[T:Type](f: Expr[T])(using Quotes): CpsExpr[T] = { |
| 189 | + import quotes.reflect._ |
| 190 | + f match |
| 191 | + case '{ while ($cond) { $repeat } } => |
| 192 | + val cpsRepeat = rootTransform(repeat.asExprOf[Unit]) |
| 193 | + CpsExpr.async('{ |
| 194 | + def _whilefun():CB[Unit] = |
| 195 | + if ($cond) { |
| 196 | + ${cpsRepeat.flatMap('{(x:Unit) => _whilefun()}).transformed} |
| 197 | + } else { |
| 198 | + CBM.pure(()) |
| 199 | + } |
| 200 | + _whilefun() |
| 201 | + }.asExprOf[CB[T]]) |
| 202 | + case _ => |
| 203 | + val fTree = f.asTerm |
| 204 | + fTree match { |
| 205 | + case fun@Apply(fun1@TypeApply(obj2,targs2), args1) => |
| 206 | + if (obj2.symbol.name == "await") { |
| 207 | + val awaitArg = args1.head |
| 208 | + CpsExpr.async(awaitArg.asExprOf[CB[T]]) |
| 209 | + } else { |
| 210 | + ??? |
| 211 | + } |
| 212 | + case Assign(left,right) => |
| 213 | + left match |
| 214 | + case id@Ident(x) => |
| 215 | + right.tpe.widen.asType match |
| 216 | + case '[r] => |
| 217 | + val cpsRight = rootTransform(right.asExprOf[r]) |
| 218 | + CpsExpr.async( |
| 219 | + cpsRight.map[T]( |
| 220 | + '{ (x:r) => ${Assign(left,'x.asTerm).asExprOf[T] } |
| 221 | + }).transformed ) |
| 222 | + case _ => ??? |
| 223 | + case Block(prevs,last) => |
| 224 | + val rPrevs = prevs.map{ p => |
| 225 | + p match |
| 226 | + case v@ValDef(vName,vtt,optRhs) => |
| 227 | + optRhs.get.tpe.widen.asType match |
| 228 | + case '[l] => |
| 229 | + val cpsRight = rootTransform(optRhs.get.asExprOf[l]) |
| 230 | + ValRhsFlatMappedCpsExpr(using quotes)(Seq(), v, cpsRight, CpsExpr.sync('{})) |
| 231 | + case t: Term => |
| 232 | + // TODO: rootTransform |
| 233 | + t.asExpr match |
| 234 | + case '{ $p: tp } => |
| 235 | + rootTransform(p) |
| 236 | + case other => |
| 237 | + printf(other.show) |
| 238 | + throw RuntimeException(s"can't handle term in block: $other") |
| 239 | + case other => |
| 240 | + printf(other.show) |
| 241 | + throw RuntimeException(s"unknown tree type in block: $other") |
| 242 | + } |
| 243 | + val rLast = rootTransform(last.asExprOf[T]) |
| 244 | + val blockResult = rPrevs.foldRight(rLast)((e,s) => e.append(s)) |
| 245 | + val retval = CpsExpr.async(blockResult.transformed) |
| 246 | + retval |
| 247 | + //BlockTransform(cpsCtx).run(prevs,last) |
| 248 | + case id@Ident(name) => |
| 249 | + CpsExpr.sync(id.asExprOf[T]) |
| 250 | + case tid@Typed(Ident(name), tp) => |
| 251 | + CpsExpr.sync(tid.asExprOf[T]) |
| 252 | + case matchTerm@Match(scrutinee, caseDefs) => |
| 253 | + val nCases = caseDefs.map{ old => |
| 254 | + CaseDef.copy(old)(old.pattern, old.guard, rootTransform(old.rhs.asExprOf[T]).transformed.asTerm) |
| 255 | + } |
| 256 | + CpsExpr.async(Match(scrutinee, nCases).asExprOf[CB[T]]) |
| 257 | + case inlinedTerm@ Inlined(call,List(),body) => |
| 258 | + rootTransform(body.asExprOf[T]) |
| 259 | + case constTerm@Literal(_)=> |
| 260 | + CpsExpr.sync(constTerm.asExprOf[T]) |
| 261 | + case _ => |
| 262 | + throw RuntimeException(s"language construction is not supported: ${fTree}") |
| 263 | + } |
| 264 | + } |
| 265 | + |
0 commit comments