diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index c360712999e2..b892e963ea51 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D import typer.{Namer, Checking} import util.{Property, SourceFile, SourcePosition, SrcPos, Chars} import config.{Feature, Config} +import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled} import config.SourceVersion.* import collection.mutable import reporting.* @@ -1803,46 +1804,81 @@ object desugar { /** Create tree for for-comprehension `` or * `` where mapName and flatMapName are chosen * corresponding to whether this is a for-do or a for-yield. - * The creation performs the following rewrite rules: + * If betterFors are enabled, the creation performs the following rewrite rules: * - * 1. + * 1. if betterFors is enabled: * - * for (P <- G) E ==> G.foreach (P => E) + * for () do E ==> E + * or + * for () yield E ==> E * - * Here and in the following (P => E) is interpreted as the function (P => E) - * if P is a variable pattern and as the partial function { case P => E } otherwise. + * (Where empty for-comprehensions are excluded by the parser) * * 2. * - * for (P <- G) yield E ==> G.map (P => E) + * for (P <- G) do E ==> G.foreach (P => E) + * + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a variable pattern and as the partial function { case P => E } otherwise. * * 3. * + * for (P <- G) yield P ==> G + * + * If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter. + * + * for (P <- G) yield E ==> G.map (P => E) + * + * Otherwise + * + * 4. + * * for (P_1 <- G_1; P_2 <- G_2; ...) ... * ==> * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) * - * 4. + * 5. * - * for (P <- G; E; ...) ... - * => - * for (P <- G.filter (P => E); ...) ... + * for (P <- G; if E; ...) ... + * ==> + * for (P <- G.withFilter (P => E); ...) ... * - * 5. For any N: + * 6. For any N, if betterFors is enabled: * - * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * for (P <- G; P_1 = E_1; ... P_N = E_N; P1 <- G1; ...) ... * ==> - * for (TupleN(P_1, P_2, ... P_N) <- - * for (x_1 @ P_1 <- G) yield { - * val x_2 @ P_2 = E_2 + * G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) + * + * 7. For any N, if betterFors is enabled: + * + * for (P <- G; P_1 = E_1; ... P_N = E_N) ... + * ==> + * G.map (P => for (P_1 = E_1; ... P_N = E_N) ...) + * + * 8. For any N: + * + * for (P <- G; P_1 = E_1; ... P_N = E_N; ...) + * ==> + * for (TupleN(P, P_1, ... P_N) <- + * for (x @ P <- G) yield { + * val x_1 @ P_1 = E_2 * ... - * val x_N & P_N = E_N - * TupleN(x_1, ..., x_N) - * } ...) + * val x_N @ P_N = E_N + * TupleN(x, x_1, ..., x_N) + * }; if E; ...) * * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated * and the variable constituting P_i is used instead of x_i * + * 9. For any N, if betterFors is enabled: + * + * for (P_1 = E_1; ... P_N = E_N; ...) + * ==> + * { + * val x_N @ P_N = E_N + * for (...) + * } + * * @param mapName The name to be used for maps (either map or foreach) * @param flatMapName The name to be used for flatMaps (either flatMap or foreach) * @param enums The enumerators in the for expression @@ -1951,7 +1987,7 @@ object desugar { case GenCheckMode.FilterAlways => false // pattern was prefixed by `case` case GenCheckMode.FilterNow | GenCheckMode.CheckAndFilter => isVarBinding(gen.pat) || isIrrefutable(gen.pat, gen.expr) case GenCheckMode.Check => true - case GenCheckMode.Ignore => true + case GenCheckMode.Ignore | GenCheckMode.Filtered => true /** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when * matched against `rhs`. @@ -1961,12 +1997,31 @@ object desugar { Select(rhs, name) } + def deepEquals(t1: Tree, t2: Tree): Boolean = + (unsplice(t1), unsplice(t2)) match + case (Ident(n1), Ident(n2)) => n1 == n2 + case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) + case _ => false + enums match { + case Nil if betterForsEnabled => body case (gen: GenFrom) :: Nil => - Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + if betterForsEnabled + && gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && deepEquals(gen.pat, body) + then gen.expr // avoid a redundant map with identity + else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) + case (gen: GenFrom) :: rest + if betterForsEnabled + && rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) => // possible aliases followed by a generator or end of for + val cont = makeFor(mapName, flatMapName, rest, body) + val selectName = + if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName + else mapName + Apply(rhsSelect(gen, selectName), makeLambda(gen, cont)) case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) => val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias]) val pats = valeqs map { case GenAlias(pat, _) => pat } @@ -1985,8 +2040,20 @@ object desugar { makeFor(mapName, flatMapName, vfrom1 :: rest1, body) case (gen: GenFrom) :: test :: rest => val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) + val genFrom = GenFrom(gen.pat, filtered, if betterForsEnabled then GenCheckMode.Filtered else GenCheckMode.Ignore) makeFor(mapName, flatMapName, genFrom :: rest, body) + case GenAlias(_, _) :: _ if betterForsEnabled => + val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias]) + val pats = valeqs.map { case GenAlias(pat, _) => pat } + val rhss = valeqs.map { case GenAlias(_, rhs) => rhs } + val (defpats, ids) = pats.map(makeIdPat).unzip + val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) => + val mods = defpat match + case defTree: DefTree => defTree.mods + case _ => Modifiers() + makePatDef(valeq, mods, defpat, rhs) + } + Block(pdefs, makeFor(mapName, flatMapName, rest, body)) case _ => EmptyTree //may happen for erroneous input } diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 81228b1588d0..60309d4d83bd 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -183,7 +183,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { /** An enum to control checking or filtering of patterns in GenFrom trees */ enum GenCheckMode { - case Ignore // neither filter nor check since filtering was done before + case Ignore // neither filter nor check since pattern is trivially irrefutable + case Filtered // neither filter nor check since filtering was done before case Check // check that pattern is irrefutable case CheckAndFilter // both check and filter (transitional period starting with 3.2) case FilterNow // filter out non-matching elements if we are not in 3.2 or later diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 8c1021e91e38..fa82f14a81fe 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -38,6 +38,7 @@ object Feature: val modularity = experimental("modularity") val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors") val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions") + val betterFors = experimental("betterFors") def experimentalAutoEnableFeatures(using Context): List[TermName] = defn.languageExperimentalFeatures @@ -67,7 +68,8 @@ object Feature: (into, "Allow into modifier on parameter types"), (namedTuples, "Allow named tuples"), (modularity, "Enable experimental modularity features"), - (betterMatchTypeExtractors, "Enable better match type extractors") + (betterMatchTypeExtractors, "Enable better match type extractors"), + (betterFors, "Enable improvements in `for` comprehensions") ) // legacy language features from Scala 2 that are no longer supported. @@ -125,6 +127,8 @@ object Feature: def clauseInterleavingEnabled(using Context) = sourceVersion.isAtLeast(`3.6`) || enabled(clauseInterleaving) + def betterForsEnabled(using Context) = enabled(betterFors) + def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals) def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros) diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index 37587868da58..f4a6b5b76aa0 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -2891,7 +2891,11 @@ object Parsers { /** Enumerators ::= Generator {semi Enumerator | Guard} */ - def enumerators(): List[Tree] = generator() :: enumeratorsRest() + def enumerators(): List[Tree] = + if in.featureEnabled(Feature.betterFors) then + aliasesUntilGenerator() ++ enumeratorsRest() + else + generator() :: enumeratorsRest() def enumeratorsRest(): List[Tree] = if (isStatSep) { @@ -2933,6 +2937,18 @@ object Parsers { GenFrom(pat, subExpr(), checkMode) } + def aliasesUntilGenerator(): List[Tree] = + if in.token == CASE then generator() :: Nil + else { + val pat = pattern1() + if in.token == EQUALS then + atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: { + if (isStatSep) in.nextToken() + aliasesUntilGenerator() + } + else generatorRest(pat, casePat = false) :: Nil + } + /** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr * | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr * | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 7db326350fa1..3d71c0da1481 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -133,6 +133,12 @@ object language: @compileTimeOnly("`quotedPatternsWithPolymorphicFunctions` can only be used at compile time in import statements") object quotedPatternsWithPolymorphicFunctions + /** Experimental support for improvements in `for` comprehensions + * + * @see [[https://github.com/scala/improvement-proposals/pull/79]] + */ + @compileTimeOnly("`betterFors` can only be used at compile time in import statements") + object betterFors end experimental /** The deprecated object contains features that are no longer officially suypported in Scala. diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index bf652cb0ee33..88e3f2b27a84 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -8,7 +8,8 @@ object MiMaFilters { val ForwardsBreakingChanges: Map[String, Seq[ProblemFilter]] = Map( // Additions that require a new minor version of the library Build.mimaPreviousDottyVersion -> Seq( - + ProblemFilters.exclude[MissingFieldProblem]("scala.runtime.stdLibPatches.language#experimental.betterFors"), + ProblemFilters.exclude[MissingClassProblem]("scala.runtime.stdLibPatches.language$experimental$betterFors$"), ), // Additions since last LTS diff --git a/tests/run/better-fors.check b/tests/run/better-fors.check new file mode 100644 index 000000000000..8b75db2f56ad --- /dev/null +++ b/tests/run/better-fors.check @@ -0,0 +1,12 @@ +List((1,3), (1,4), (2,3), (2,4)) +List((1,2,3), (1,2,4)) +List((1,3), (1,4), (2,3), (2,4)) +List((2,3), (2,4)) +List((2,3), (2,4)) +List((1,2), (2,4)) +List(1, 2, 3) +List((2,3,6)) +List(6) +List(3, 6) +List(6) +List(2) diff --git a/tests/run/better-fors.scala b/tests/run/better-fors.scala new file mode 100644 index 000000000000..8c0bff230632 --- /dev/null +++ b/tests/run/better-fors.scala @@ -0,0 +1,105 @@ +import scala.language.experimental.betterFors + +def for1 = + for { + a = 1 + b <- List(a, 2) + c <- List(3, 4) + } yield (b, c) + +def for2 = + for + a = 1 + b = 2 + c <- List(3, 4) + yield (a, b, c) + +def for3 = + for { + a = 1 + b <- List(a, 2) + c = 3 + d <- List(c, 4) + } yield (b, d) + +def for4 = + for { + a = 1 + b <- List(a, 2) + if b > 1 + c <- List(3, 4) + } yield (b, c) + +def for5 = + for { + a = 1 + b <- List(a, 2) + c = 3 + if b > 1 + d <- List(c, 4) + } yield (b, d) + +def for6 = + for { + a = 1 + b = 2 + c <- for { + x <- List(a, b) + y = x * 2 + } yield (x, y) + } yield c + +def for7 = + for { + a <- List(1, 2, 3) + } yield a + +def for8 = + for { + a <- List(1, 2) + b = a + 1 + if b > 2 + c = b * 2 + if c < 8 + } yield (a, b, c) + +def for9 = + for { + a <- List(1, 2) + b = a * 2 + if b > 2 + } yield a + b + +def for10 = + for { + a <- List(1, 2) + b = a * 2 + } yield a + b + +def for11 = + for { + a <- List(1, 2) + b = a * 2 + if b > 2 && b % 2 == 0 + } yield a + b + +def for12 = + for { + a <- List(1, 2) + if a > 1 + } yield a + +object Test extends App { + println(for1) + println(for2) + println(for3) + println(for4) + println(for5) + println(for6) + println(for7) + println(for8) + println(for9) + println(for10) + println(for11) + println(for12) +} diff --git a/tests/run/fors.check b/tests/run/fors.check index 50f6385e5845..7b7e8d076108 100644 --- a/tests/run/fors.check +++ b/tests/run/fors.check @@ -45,6 +45,9 @@ hello world hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4 (2,1) (4,3) +testTailrec +List((4,Symbol(a)), (5,Symbol(b)), (6,Symbol(c))) + testGivens 123 456 diff --git a/tests/run/fors.scala b/tests/run/fors.scala index 682978b5b3d8..a12d0e977157 100644 --- a/tests/run/fors.scala +++ b/tests/run/fors.scala @@ -4,6 +4,9 @@ //############################################################################ +import annotation.tailrec + +@scala.annotation.experimental object Test extends App { val xs = List(1, 2, 3) val ys = List(Symbol("a"), Symbol("b"), Symbol("c")) @@ -108,6 +111,19 @@ object Test extends App { for case (x, y) <- xs do print(s"${(y, x)} "); println() } + /////////////////// elimination of map /////////////////// + + import scala.language.experimental.betterFors + + @tailrec + def pair[B](xs: List[Int], ys: List[B], n: Int): List[(Int, B)] = + if n == 0 then xs.zip(ys) + else for (x, y) <- pair(xs.map(_ + 1), ys, n - 1) yield (x, y) + + def testTailrec() = + println("\ntestTailrec") + println(pair(xs, ys, 3)) + def testGivens(): Unit = { println("\ntestGivens") @@ -141,5 +157,6 @@ object Test extends App { testOld() testNew() testFiltering() + testTailrec() testGivens() }