Skip to content

Fix desugaring of context bounds in extensions #11892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 75 additions & 65 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,32 @@ object desugar {
* def f$default$2[T](x: Int) = x + "m"
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(using Context): Tree =
addDefaultGetters(elimContextBounds(meth, isPrimaryConstructor))
addDefaultGetters(elimContextBounds(Nil, meth, isPrimaryConstructor))

private def elimContextBounds(meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
private def defDef(extParamss: List[ParamClause], meth: DefDef)(using Context): Tree =
addDefaultGetters(elimContextBounds(extParamss, meth, false))

private def elimContextBounds(extParamss: List[ParamClause], meth: DefDef, isPrimaryConstructor: Boolean)(using Context): DefDef =
val DefDef(_, paramss, tpt, rhs) = meth

rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss =
elimContextBounds(extParamss, isPrimaryConstructor, true) ++
elimContextBounds(paramss, isPrimaryConstructor, false)
)
end elimContextBounds

private def elimContextBounds(paramss: List[ParamClause], isPrimaryConstructor: Boolean, ext: Boolean)(using Context): List[ParamClause] =
val evidenceParamBuf = ListBuffer[ValDef]()

def desugarContextBounds(rhs: Tree): Tree = rhs match
case ContextBounds(tbounds, cxbounds) =>
val iflag = if sourceVersion.isAtLeast(`future`) then Given else Implicit
val iflag = if ext || sourceVersion.isAtLeast(`future`) then Given else Implicit
evidenceParamBuf ++= makeImplicitParameters(
cxbounds, iflag, forPrimaryConstructor = isPrimaryConstructor)
tbounds
Expand All @@ -240,15 +257,7 @@ object desugar {
tparam => cpy.TypeDef(tparam)(rhs = desugarContextBounds(tparam.rhs))
}(identity)

rhs match
case MacroTree(call) =>
cpy.DefDef(meth)(rhs = call).withMods(meth.mods | Macro | Erased)
case _ =>
addEvidenceParams(
cpy.DefDef(meth)(
name = normalizeName(meth, tpt).asTermName,
paramss = paramssNoContextBounds),
evidenceParamBuf.toList)
addEvidenceParams(paramssNoContextBounds, evidenceParamBuf.toList)
end elimContextBounds

def addDefaultGetters(meth: DefDef)(using Context): Tree =
Expand Down Expand Up @@ -348,22 +357,22 @@ object desugar {
adaptToExpectedTpt(tree)
}

/** Add all evidence parameters in `params` as implicit parameters to `meth`.
* If the parameters of `meth` end in an implicit parameter list or using clause,
/** Add all evidence parameters in `params` as implicit parameters to `paramss`.
* If the parameters of `paramss` end in an implicit parameter list or using clause,
* evidence parameters are added in front of that list. Otherwise they are added
* as a separate parameter clause.
*/
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(using Context): DefDef =
params match
case Nil =>
meth
case evidenceParams =>
val paramss1 = meth.paramss.reverse match
case ValDefs(vparams @ (vparam :: _)) :: rparamss if vparam.mods.isOneOf(GivenOrImplicit) =>
((evidenceParams ++ vparams) :: rparamss).reverse
case _ =>
meth.paramss :+ evidenceParams
cpy.DefDef(meth)(paramss = paramss1)

private def addEvidenceParams(paramss: List[ParamClause], params: List[ValDef])(using Context): List[ParamClause] =
paramss.reverse match
case ValDefs(vparams @ (vparam :: _)) :: rparamss if vparam.mods.isOneOf(GivenOrImplicit) =>
((params ++ vparams) :: rparamss).reverse
case _ =>
params match
case Nil =>
paramss
case evidenceParams =>
paramss :+ evidenceParams

/** The implicit evidence parameters of `meth`, as generated by `desugar.defDef` */
private def evidenceParams(meth: DefDef)(using Context): List[ValDef] =
Expand Down Expand Up @@ -487,9 +496,8 @@ object desugar {
case ddef: DefDef if ddef.name.isConstructorName =>
decompose(
defDef(
addEvidenceParams(
cpy.DefDef(ddef)(paramss = joinParams(constrTparams, ddef.paramss)),
evidenceParams(constr1).map(toDefParam(_, keepAnnotations = false, keepDefault = false)))))
cpy.DefDef(ddef)(paramss = addEvidenceParams(joinParams(constrTparams, ddef.paramss),
evidenceParams(constr1).map(toDefParam(_, keepAnnotations = false, keepDefault = false))))))
case stat =>
stat
}
Expand Down Expand Up @@ -899,43 +907,45 @@ object desugar {
/** Transform extension construct to list of extension methods */
def extMethods(ext: ExtMethods)(using Context): Tree = flatTree {
for mdef <- ext.methods yield
defDef(
cpy.DefDef(mdef)(
name = normalizeName(mdef, ext).asTermName,
paramss =
if mdef.name.isRightAssocOperatorName then
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters

paramss match
case params :: paramss1 => // `params` must have a single parameter and without `given` flag

def badRightAssoc(problem: String) =
report.error(i"right-associative extension method $problem", mdef.srcPos)
ext.paramss ++ mdef.paramss

params match
case ValDefs(vparam :: Nil) =>
if !vparam.mods.is(Given) then
// we merge the extension parameters with the method parameters,
// swapping the operator arguments:
// e.g.
// extension [A](using B)(c: C)(using D)
// def %:[E](f: F)(g: G)(using H): Res = ???
// will be encoded as
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
val (leadingUsing, otherExtParamss) = ext.paramss.span(isUsingOrTypeParamClause)
leadingUsing ::: typaramss ::: params :: otherExtParamss ::: paramss1
else
badRightAssoc("cannot start with using clause")
case _ =>
badRightAssoc("must start with a single parameter")
case _ =>
// no value parameters, so not an infix operator.
ext.paramss ++ mdef.paramss
else
ext.paramss ++ mdef.paramss
).withMods(mdef.mods | ExtensionMethod)
)
def ret(ess: List[ParamClause], mss: List[ParamClause]) =
defDef(
ess,
cpy.DefDef(mdef)(
name = normalizeName(mdef, ext).asTermName,
paramss = mss
).withMods(mdef.mods | ExtensionMethod)
)
if mdef.name.isRightAssocOperatorName then
val (typaramss, paramss) = mdef.paramss.span(isTypeParamClause) // first extract type parameters

paramss match
case params :: paramss1 => // `params` must have a single parameter and without `given` flag

def badRightAssoc(problem: String) =
report.error(i"right-associative extension method $problem", mdef.srcPos)
ret(ext.paramss, mdef.paramss)

params match
case ValDefs(vparam :: Nil) =>
if !vparam.mods.is(Given) then
// we merge the extension parameters with the method parameters,
// swapping the operator arguments:
// e.g.
// extension [A](using B)(c: C)(using D)
// def %:[E](f: F)(g: G)(using H): Res = ???
// will be encoded as
// def %:[A](using B)[E](f: F)(c: C)(using D)(g: G)(using H): Res = ???
val (leadingUsing, otherExtParamss) = ext.paramss.span(isUsingOrTypeParamClause)
ret(leadingUsing ::: typaramss ::: params :: Nil, otherExtParamss ::: paramss1)
else
badRightAssoc("cannot start with using clause")
case _ =>
badRightAssoc("must start with a single parameter")
case _ =>
// no value parameters, so not an infix operator.
ret(ext.paramss, mdef.paramss)
else
ret(ext.paramss, mdef.paramss)
}

/** Transforms
Expand Down
28 changes: 12 additions & 16 deletions tests/neg/i10901.check
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
-- [E008] Not Found Error: tests/neg/i10901.scala:45:38 ----------------------------------------------------------------
45 | val pos1: Point2D[Int,Double] = x º y // error
| ^^^
| value º is not a member of object BugExp4Point2D.IntT.
| An extension method was tried, but could not be fully constructed:
|value º is not a member of object BugExp4Point2D.IntT.
|An extension method was tried, but could not be fully constructed:
|
| º(x) failed with
| º(x) failed with
|
| Ambiguous overload. The overloaded alternatives of method º in object dsl with types
| [T1, T2]
| (x: BugExp4Point2D.ColumnType[T1])
| (y: BugExp4Point2D.ColumnType[T2])
| (implicit evidence$7: Numeric[T1], evidence$8: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| [T1, T2]
| (x: T1)
| (y: BugExp4Point2D.ColumnType[T2])
| (implicit evidence$5: Numeric[T1], evidence$6: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| both match arguments ((x : BugExp4Point2D.IntT.type))
| Ambiguous overload. The overloaded alternatives of method º in object dsl with types
| [T1, T2]
| (x: BugExp4Point2D.ColumnType[T1])
| (y: BugExp4Point2D.ColumnType[T2])(using x$3: Numeric[T1], x$4: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| [T1, T2]
| (x: T1)(y: BugExp4Point2D.ColumnType[T2])(using x$3: Numeric[T1], x$4: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| both match arguments ((x : BugExp4Point2D.IntT.type))
-- [E008] Not Found Error: tests/neg/i10901.scala:48:38 ----------------------------------------------------------------
48 | val pos4: Point2D[Int,Double] = x º 201.1 // error
| ^^^
Expand All @@ -26,9 +23,8 @@
|
| Ambiguous overload. The overloaded alternatives of method º in object dsl with types
| [T1, T2]
| (x: BugExp4Point2D.ColumnType[T1])
| (y: T2)(implicit evidence$9: Numeric[T1], evidence$10: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| [T1, T2](x: T1)(y: T2)(implicit evidence$3: Numeric[T1], evidence$4: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| (x: BugExp4Point2D.ColumnType[T1])(y: T2)(using x$3: Numeric[T1], x$4: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| [T1, T2](x: T1)(y: T2)(using x$3: Numeric[T1], x$4: Numeric[T2]): BugExp4Point2D.Point2D[T1, T2]
| both match arguments ((x : BugExp4Point2D.IntT.type))
-- [E008] Not Found Error: tests/neg/i10901.scala:62:16 ----------------------------------------------------------------
62 | val y = "abc".foo // error
Expand Down
12 changes: 6 additions & 6 deletions tests/neg/i10901.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@ object BugExp4Point2D {
object dsl {


extension [T1:Numeric, T2:Numeric](x: T1)
extension [T1, T2](x: T1)

// N - N
@targetName("point2DConstant")
def º(y: T2): Point2D[T1,T2] = ???
def º(y: T2)(using Numeric[T1], Numeric[T2]): Point2D[T1,T2] = ???


// N - C
@targetName("point2DConstantData")
def º(y: ColumnType[T2]): Point2D[T1,T2] = ???
def º(y: ColumnType[T2])(using Numeric[T1], Numeric[T2]): Point2D[T1,T2] = ???



extension [T1:Numeric, T2:Numeric](x: ColumnType[T1])
extension [T1, T2](x: ColumnType[T1])
// C - C
@targetName("point2DData")
def º(y: ColumnType[T2]): Point2D[T1,T2] = ???
def º(y: ColumnType[T2])(using Numeric[T1], Numeric[T2]): Point2D[T1,T2] = ???

// C - N
@targetName("point2DDataConstant")
def º(y: T2): Point2D[T1,T2] = ???
def º(y: T2)(using Numeric[T1], Numeric[T2]): Point2D[T1,T2] = ???


}
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/missing-implicit1.check
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
-- Error: tests/neg/missing-implicit1.scala:23:42 ----------------------------------------------------------------------
23 | List(1, 2, 3).traverse(x => Option(x)) // error
| ^
|no implicit argument of type testObjectInstance.Zip[Option] was found for an implicit parameter of method traverse in trait Traverse
|no implicit argument of type testObjectInstance.Zip[Option] was found for parameter x$3 of method traverse in trait Traverse
|
|The following import might fix the problem:
|
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/missing-implicit1.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
object testObjectInstance:
trait Zip[F[_]]
trait Traverse[F[_]] {
extension [A, B, G[_] : Zip](fa: F[A]) def traverse(f: A => G[B]): G[F[B]]
extension [A, B, G[_]](fa: F[A]) def traverse(f: A => G[B])(using Zip[G]): G[F[B]]
}

object instances {
Expand Down
6 changes: 3 additions & 3 deletions tests/neg/missing-implicit4.check
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
-- Error: tests/neg/missing-implicit4.scala:20:42 ----------------------------------------------------------------------
20 | List(1, 2, 3).traverse(x => Option(x)) // error
| ^
|no implicit argument of type Zip[Option] was found for an implicit parameter of method traverse in trait Traverse
| no implicit argument of type Zip[Option] was found for parameter x$3 of method traverse in trait Traverse
|
|The following import might fix the problem:
| The following import might fix the problem:
|
| import instances.zipOption
| import instances.zipOption
|
2 changes: 1 addition & 1 deletion tests/neg/missing-implicit4.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
def testLocalInstance =
trait Zip[F[_]]
trait Traverse[F[_]] {
extension [A, B, G[_] : Zip](fa: F[A]) def traverse(f: A => G[B]): G[F[B]]
extension [A, B, G[_]](fa: F[A]) def traverse(f: A => G[B])(using Zip[G]): G[F[B]]
}

object instances {
Expand Down
6 changes: 3 additions & 3 deletions tests/pos/i11358.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ object Test:
def test7 = +++(IArray(1, 2))[Int](IArray(2, 3))
def test8 = +++(IArray(1, 2))[Int](List(2, 3))

extension [A: reflect.ClassTag](arr: IArray[A])
def +++[B >: A: reflect.ClassTag](suffix: IArray[B]): IArray[B] = ???
def +++[B >: A: reflect.ClassTag](suffix: IterableOnce[B]): IArray[B] = ???
extension [A](arr: IArray[A])
def +++[B >: A](suffix: IArray[B])(using reflect.ClassTag[A], reflect.ClassTag[B]): IArray[B] = ???
def +++[B >: A](suffix: IterableOnce[B])(using reflect.ClassTag[A], reflect.ClassTag[B]): IArray[B] = ???
10 changes: 10 additions & 0 deletions tests/pos/i11586.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
type Conv[T] = [X] =>> X => T

trait SemiGroup[T]:
extension [U: Conv[T]](x: U)
def combine(y: T): T
extension (x: T)
def combine[U: Conv[T]](y: U): T

trait Q[T, R: SemiGroup] extends SemiGroup[T]:
def res(x: R, y: R) = x.combine(y)