Skip to content

Commit ba9806b

Browse files
committed
Add improvements to for comprehensions
- Allow `for`-comprehensions to start with aliases desugaring them into valdefs in a new block - Desugar aliases into simple valdefs, instead of patterns when they are not followed by a guard - Add an experimental language flag that enables the new desugaring method
1 parent 05bde2a commit ba9806b

File tree

8 files changed

+264
-45
lines changed

8 files changed

+264
-45
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+117-44
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import NameKinds.{UniqueName, ContextBoundParamName, ContextFunctionParamName, D
1111
import typer.{Namer, Checking}
1212
import util.{Property, SourceFile, SourcePosition, SrcPos, Chars}
1313
import config.{Feature, Config}
14+
import config.Feature.{sourceVersion, migrateTo3, enabled, betterForsEnabled}
1415
import config.SourceVersion.*
1516
import collection.mutable
1617
import reporting.*
@@ -1807,7 +1808,7 @@ object desugar {
18071808
*
18081809
* 1.
18091810
*
1810-
* for (P <- G) E ==> G.foreach (P => E)
1811+
* for (P <- G) do E ==> G.foreach (P => E)
18111812
*
18121813
* Here and in the following (P => E) is interpreted as the function (P => E)
18131814
* if P is a variable pattern and as the partial function { case P => E } otherwise.
@@ -1816,11 +1817,11 @@ object desugar {
18161817
*
18171818
* for (P <- G) yield P ==> G
18181819
*
1819-
* if P is a variable or a tuple of variables and G is not a withFilter.
1820+
* If P is a variable or a tuple of variables and G is not a withFilter.
18201821
*
18211822
* for (P <- G) yield E ==> G.map (P => E)
18221823
*
1823-
* otherwise
1824+
* Otherwise
18241825
*
18251826
* 3.
18261827
*
@@ -1830,25 +1831,48 @@ object desugar {
18301831
*
18311832
* 4.
18321833
*
1833-
* for (P <- G; E; ...) ...
1834-
* =>
1835-
* for (P <- G.filter (P => E); ...) ...
1834+
* for (P <- G; if E; ...) ...
1835+
* ==>
1836+
* for (P <- G.withFilter (P => E); ...) ...
18361837
*
18371838
* 5. For any N:
18381839
*
1839-
* for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...)
1840+
* for (P <- G; P_1 = E_1; ... P_N = E_N; rest)
18401841
* ==>
1841-
* for (TupleN(P_1, P_2, ... P_N) <-
1842-
* for (x_1 @ P_1 <- G) yield {
1843-
* val x_2 @ P_2 = E_2
1842+
* G.flatMap (P => for (P_1 = E_1; ... P_N = E_N; ...)) if rest contains (<-)
1843+
* G.map (P => for (P_1 = E_1; ... P_N = E_N; ...)) otherwise
1844+
*
1845+
* 6. For any N:
1846+
*
1847+
* for (P <- G; P_1 = E_1; ... P_N = E_N; if E; ...)
1848+
* ==>
1849+
* for (TupleN(P, P_1, ... P_N) <-
1850+
* for (x @ P <- G) yield {
1851+
* val x_1 @ P_1 = E_2
18441852
* ...
1845-
* val x_N & P_N = E_N
1846-
* TupleN(x_1, ..., x_N)
1847-
* } ...)
1853+
* val x_N @ P_N = E_N
1854+
* TupleN(x, x_1, ..., x_N)
1855+
* }; if E; ...)
18481856
*
18491857
* If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated
18501858
* and the variable constituting P_i is used instead of x_i
18511859
*
1860+
* 7. For any N:
1861+
*
1862+
* for (P_1 = E_1; ... P_N = E_N; ...)
1863+
* ==>
1864+
* {
1865+
* val x_N @ P_N = E_N
1866+
* for (...)
1867+
* }
1868+
*
1869+
* 8.
1870+
* for () yield E ==> E
1871+
*
1872+
* (Where empty for-comprehensions are excluded by the parser)
1873+
*
1874+
* If the aliases are not followed by a guard, otherwise an error.
1875+
*
18521876
* @param mapName The name to be used for maps (either map or foreach)
18531877
* @param flatMapName The name to be used for flatMaps (either flatMap or foreach)
18541878
* @param enums The enumerators in the for expression
@@ -1973,37 +1997,86 @@ object desugar {
19731997
case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals)
19741998
case _ => false
19751999

1976-
enums match {
1977-
case (gen: GenFrom) :: Nil =>
1978-
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
1979-
&& deepEquals(gen.pat, body)
1980-
then gen.expr // avoid a redundant map with identity
1981-
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
1982-
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
1983-
val cont = makeFor(mapName, flatMapName, rest, body)
1984-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
1985-
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
1986-
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
1987-
val pats = valeqs map { case GenAlias(pat, _) => pat }
1988-
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
1989-
val (defpat0, id0) = makeIdPat(gen.pat)
1990-
val (defpats, ids) = (pats map makeIdPat).unzip
1991-
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
1992-
val mods = defpat match
1993-
case defTree: DefTree => defTree.mods
1994-
case _ => Modifiers()
1995-
makePatDef(valeq, mods, defpat, rhs)
1996-
}
1997-
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1998-
val allpats = gen.pat :: pats
1999-
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
2000-
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
2001-
case (gen: GenFrom) :: test :: rest =>
2002-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
2003-
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered)
2004-
makeFor(mapName, flatMapName, genFrom :: rest, body)
2005-
case _ =>
2006-
EmptyTree //may happen for erroneous input
2000+
if betterForsEnabled then
2001+
enums match {
2002+
case Nil => body
2003+
case (gen: GenFrom) :: Nil =>
2004+
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2005+
&& deepEquals(gen.pat, body)
2006+
then gen.expr // avoid a redundant map with identity
2007+
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2008+
case (gen: GenFrom) :: rest
2009+
if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) =>
2010+
val cont = makeFor(mapName, flatMapName, rest, body)
2011+
val selectName =
2012+
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
2013+
else mapName
2014+
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2015+
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
2016+
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
2017+
val pats = valeqs map { case GenAlias(pat, _) => pat }
2018+
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
2019+
val (defpat0, id0) = makeIdPat(gen.pat)
2020+
val (defpats, ids) = (pats map makeIdPat).unzip
2021+
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
2022+
val mods = defpat match
2023+
case defTree: DefTree => defTree.mods
2024+
case _ => Modifiers()
2025+
makePatDef(valeq, mods, defpat, rhs)
2026+
}
2027+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
2028+
val allpats = gen.pat :: pats
2029+
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
2030+
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
2031+
case (gen: GenFrom) :: test :: rest =>
2032+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
2033+
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered)
2034+
makeFor(mapName, flatMapName, genFrom :: rest, body)
2035+
case GenAlias(_, _) :: _ =>
2036+
val (valeqs, rest) = enums.span(_.isInstanceOf[GenAlias])
2037+
val pats = valeqs.map { case GenAlias(pat, _) => pat }
2038+
val rhss = valeqs.map { case GenAlias(_, rhs) => rhs }
2039+
val (defpats, ids) = pats.map(makeIdPat).unzip
2040+
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
2041+
val mods = defpat match
2042+
case defTree: DefTree => defTree.mods
2043+
case _ => Modifiers()
2044+
makePatDef(valeq, mods, defpat, rhs)
2045+
}
2046+
Block(pdefs, makeFor(mapName, flatMapName, rest, body))
2047+
case _ =>
2048+
EmptyTree //may happen for erroneous input
2049+
}
2050+
else {
2051+
enums match {
2052+
case (gen: GenFrom) :: Nil =>
2053+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2054+
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
2055+
val cont = makeFor(mapName, flatMapName, rest, body)
2056+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
2057+
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
2058+
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
2059+
val pats = valeqs map { case GenAlias(pat, _) => pat }
2060+
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
2061+
val (defpat0, id0) = makeIdPat(gen.pat)
2062+
val (defpats, ids) = (pats map makeIdPat).unzip
2063+
val pdefs = valeqs.lazyZip(defpats).lazyZip(rhss).map { (valeq, defpat, rhs) =>
2064+
val mods = defpat match
2065+
case defTree: DefTree => defTree.mods
2066+
case _ => Modifiers()
2067+
makePatDef(valeq, mods, defpat, rhs)
2068+
}
2069+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
2070+
val allpats = gen.pat :: pats
2071+
val vfrom1 = GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
2072+
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
2073+
case (gen: GenFrom) :: test :: rest =>
2074+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
2075+
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
2076+
makeFor(mapName, flatMapName, genFrom :: rest, body)
2077+
case _ =>
2078+
EmptyTree //may happen for erroneous input
2079+
}
20072080
}
20082081
}
20092082

compiler/src/dotty/tools/dotc/config/Feature.scala

+3
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ object Feature:
3737
val namedTuples = experimental("namedTuples")
3838
val modularity = experimental("modularity")
3939
val betterMatchTypeExtractors = experimental("betterMatchTypeExtractors")
40+
val betterFors = experimental("betterFors")
4041

4142
def experimentalAutoEnableFeatures(using Context): List[TermName] =
4243
defn.languageExperimentalFeatures
@@ -123,6 +124,8 @@ object Feature:
123124

124125
def clauseInterleavingEnabled(using Context) = enabled(clauseInterleaving)
125126

127+
def betterForsEnabled(using Context) = enabled(betterFors)
128+
126129
def genericNumberLiteralsEnabled(using Context) = enabled(genericNumberLiterals)
127130

128131
def scala2ExperimentalMacroEnabled(using Context) = enabled(scala2macros)

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ object StdNames {
433433
val asInstanceOfPM: N = "$asInstanceOf$"
434434
val assert_ : N = "assert"
435435
val assume_ : N = "assume"
436+
val betterFors: N = "betterFors"
436437
val box: N = "box"
437438
val break: N = "break"
438439
val build : N = "build"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+17-1
Original file line numberDiff line numberDiff line change
@@ -2881,7 +2881,11 @@ object Parsers {
28812881

28822882
/** Enumerators ::= Generator {semi Enumerator | Guard}
28832883
*/
2884-
def enumerators(): List[Tree] = generator() :: enumeratorsRest()
2884+
def enumerators(): List[Tree] =
2885+
if in.featureEnabled(Feature.betterFors) then
2886+
aliasesUntilGenerator() ++ enumeratorsRest()
2887+
else
2888+
generator() :: enumeratorsRest()
28852889

28862890
def enumeratorsRest(): List[Tree] =
28872891
if (isStatSep) {
@@ -2923,6 +2927,18 @@ object Parsers {
29232927
GenFrom(pat, subExpr(), checkMode)
29242928
}
29252929

2930+
def aliasesUntilGenerator(): List[Tree] =
2931+
if in.token == CASE then generator() :: Nil
2932+
else {
2933+
val pat = pattern1()
2934+
if in.token == EQUALS then
2935+
atSpan(startOffset(pat), in.skipToken()) { GenAlias(pat, subExpr()) } :: {
2936+
if (isStatSep) in.nextToken()
2937+
aliasesUntilGenerator()
2938+
}
2939+
else generatorRest(pat, casePat = false) :: Nil
2940+
}
2941+
29262942
/** ForExpr ::= ‘for’ ‘(’ Enumerators ‘)’ {nl} [‘do‘ | ‘yield’] Expr
29272943
* | ‘for’ ‘{’ Enumerators ‘}’ {nl} [‘do‘ | ‘yield’] Expr
29282944
* | ‘for’ Enumerators (‘do‘ | ‘yield’) Expr

library/src/scala/runtime/stdLibPatches/language.scala

+7
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,13 @@ object language:
124124
*/
125125
@compileTimeOnly("`betterMatchTypeExtractors` can only be used at compile time in import statements")
126126
object betterMatchTypeExtractors
127+
128+
/** Experimental support for improvements in `for` comprehensions
129+
*
130+
* @see [[https://dotty.epfl.ch/docs/reference/experimental/better-fors]]
131+
*/
132+
@compileTimeOnly("`betterFors` can only be used at compile time in import statements")
133+
object betterFors
127134
end experimental
128135

129136
/** The deprecated object contains features that are no longer officially suypported in Scala.

tests/run/better-fors.check

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
List((1,3), (1,4), (2,3), (2,4))
2+
List((1,2,3), (1,2,4))
3+
List((1,3), (1,4), (2,3), (2,4))
4+
List((2,3), (2,4))
5+
List((2,3), (2,4))
6+
List((1,2), (2,4))
7+
List(1, 2, 3)
8+
List((2,3,6))
9+
List(6)
10+
List(3, 6)
11+
List(6)
12+
List(2)

tests/run/better-fors.scala

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import scala.language.experimental.betterFors
2+
3+
def for1 =
4+
for {
5+
a = 1
6+
b <- List(a, 2)
7+
c <- List(3, 4)
8+
} yield (b, c)
9+
10+
def for2 =
11+
for
12+
a = 1
13+
b = 2
14+
c <- List(3, 4)
15+
yield (a, b, c)
16+
17+
def for3 =
18+
for {
19+
a = 1
20+
b <- List(a, 2)
21+
c = 3
22+
d <- List(c, 4)
23+
} yield (b, d)
24+
25+
def for4 =
26+
for {
27+
a = 1
28+
b <- List(a, 2)
29+
if b > 1
30+
c <- List(3, 4)
31+
} yield (b, c)
32+
33+
def for5 =
34+
for {
35+
a = 1
36+
b <- List(a, 2)
37+
c = 3
38+
if b > 1
39+
d <- List(c, 4)
40+
} yield (b, d)
41+
42+
def for6 =
43+
for {
44+
a = 1
45+
b = 2
46+
c <- for {
47+
x <- List(a, b)
48+
y = x * 2
49+
} yield (x, y)
50+
} yield c
51+
52+
def for7 =
53+
for {
54+
a <- List(1, 2, 3)
55+
} yield a
56+
57+
def for8 =
58+
for {
59+
a <- List(1, 2)
60+
b = a + 1
61+
if b > 2
62+
c = b * 2
63+
if c < 8
64+
} yield (a, b, c)
65+
66+
def for9 =
67+
for {
68+
a <- List(1, 2)
69+
b = a * 2
70+
if b > 2
71+
} yield a + b
72+
73+
def for10 =
74+
for {
75+
a <- List(1, 2)
76+
b = a * 2
77+
} yield a + b
78+
79+
def for11 =
80+
for {
81+
a <- List(1, 2)
82+
b = a * 2
83+
if b > 2 && b % 2 == 0
84+
} yield a + b
85+
86+
def for12 =
87+
for {
88+
a <- List(1, 2)
89+
if a > 1
90+
} yield a
91+
92+
object Test extends App {
93+
println(for1)
94+
println(for2)
95+
println(for3)
96+
println(for4)
97+
println(for5)
98+
println(for6)
99+
println(for7)
100+
println(for8)
101+
println(for9)
102+
println(for10)
103+
println(for11)
104+
println(for12)
105+
}

0 commit comments

Comments
 (0)