Skip to content

Adding Interweaved Type and Term Clauses to function definitions #13836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
13 changes: 13 additions & 0 deletions TypesEverywhere.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.annotation.targetName
class TypesEverywhere{
def f1[T](x: T): T = x
def f2[T][U](x: T, y: U): (T, U) = (x, y)
def f3[T](x: T)[U <: x.type](y: U): (T, U) = (x, y)



def f4[T](x: T)(y: T) = (x,y)
@targetName("f5") def f4[T](x: T, y: T) = (x,y)
}


67 changes: 59 additions & 8 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2890,6 +2890,55 @@ object Parsers {

/* -------- PARAMETERS ------------------------------------------- */

def typeOrTermParamClause(nparams: Int, // number of parameters preceding this clause
ofClass: Boolean = false, // owner is a class
ofCaseClass: Boolean = false, // owner is a case class
prefix: Boolean = false, // clause precedes name of an extension method
givenOnly: Boolean = false, // only given parameters allowed
firstClause: Boolean = false, // clause is the first in regular list of clauses
ownerKind: ParamOwner.Value
): List[TypeDef] | List[ValDef] =
if (in.token == LPAREN)
paramClause(nparams, ofClass, ofCaseClass, prefix, givenOnly, firstClause)
else if (in.token == LBRACKET)
typeParamClause(ownerKind)
else
Nil

end typeOrTermParamClause

def typeOrTermParamClauses(
ownerKind: ParamOwner.Value,
ofClass: Boolean = false,
ofCaseClass: Boolean = false,
givenOnly: Boolean = false,
numLeadParams: Int = 0
): List[List[TypeDef] | List[ValDef]] =

def recur(firstClause: Boolean, nparams: Int): List[List[TypeDef] | List[ValDef]] =
newLineOptWhenFollowedBy(LPAREN)
newLineOptWhenFollowedBy(LBRACKET) //I have doubts this works //TODO: test this
if in.token == LPAREN then
val paramsStart = in.offset
val params = paramClause(
nparams,
ofClass = ofClass,
ofCaseClass = ofCaseClass,
givenOnly = givenOnly,
firstClause = firstClause)
val lastClause = params.nonEmpty && params.head.mods.flags.is(Implicit)
params :: (
if lastClause then Nil
else recur(firstClause = false, nparams + params.length))
else if in.token == LBRACKET then
typeParamClause(ownerKind) :: recur(firstClause, nparams)
else Nil
end recur

recur(firstClause = true, nparams = numLeadParams)
end typeOrTermParamClauses


/** ClsTypeParamClause::= ‘[’ ClsTypeParam {‘,’ ClsTypeParam} ‘]’
* ClsTypeParam ::= {Annotation} [‘+’ | ‘-’]
* id [HkTypeParamClause] TypeParamBounds
Expand Down Expand Up @@ -3348,8 +3397,9 @@ object Parsers {
val mods1 = addFlag(mods, Method)
val ident = termIdent()
var name = ident.name.asTermName
val tparams = typeParamClauseOpt(ParamOwner.Def)
val vparamss = paramClauses(numLeadParams = numLeadParams)
val paramss = typeOrTermParamClauses(ParamOwner.Def, numLeadParams = numLeadParams)
//val tparams = typeParamClauseOpt(ParamOwner.Def)
//val vparamss = paramClauses(numLeadParams = numLeadParams)
var tpt = fromWithinReturnType { typedOpt() }
if (migrateTo3) newLineOptWhenFollowedBy(LBRACE)
val rhs =
Expand All @@ -3367,7 +3417,8 @@ object Parsers {
accept(EQUALS)
expr()

val ddef = DefDef(name, joinParams(tparams, vparamss), tpt, rhs)
//val ddef = DefDef(name, joinParams(tparams, vparamss), tpt, rhs)
val ddef = DefDef(name, paramss, tpt, rhs)
if (isBackquoted(ident)) ddef.pushAttachment(Backquoted, ())
finalizeDef(ddef, mods1, start)
}
Expand Down Expand Up @@ -3396,7 +3447,7 @@ object Parsers {
argumentExprss(mkApply(Ident(nme.CONSTRUCTOR), argumentExprs()))
}

/** TypeDcl ::= id [TypeParamClause] {FunParamClause} TypeBounds [‘=’ Type]
/** TypeDcl ::= id [TypeParamClause] {FunParamClause} TypeBounds [‘=’ Type] //TODO: change to {ParamClauses} ?
*/
def typeDefOrDcl(start: Offset, mods: Modifiers): Tree = {
newLinesOpt()
Expand Down Expand Up @@ -3489,7 +3540,7 @@ object Parsers {
val tparams = typeParamClauseOpt(ParamOwner.Class)
val cmods = fromWithinClassConstr(constrModsOpt())
val vparamss = paramClauses(ofClass = true, ofCaseClass = isCaseClass)
makeConstructor(tparams, vparamss).withMods(cmods)
makeConstructor(tparams, vparamss).withMods(cmods) //TODO: Reafactor to take List[List[ValDef] | List[TypeDef]]
}

/** ConstrMods ::= {Annotation} [AccessModifier]
Expand Down Expand Up @@ -3578,7 +3629,7 @@ object Parsers {
}

/** GivenDef ::= [GivenSig] (AnnotType [‘=’ Expr] | StructuralInstance)
* GivenSig ::= [id] [DefTypeParamClause] {UsingParamClauses} ‘:’
* GivenSig ::= [id] [DefTypeParamClause] {UsingParamClauses} ‘:’ //TODO: Change to {Params}
*/
def givenDef(start: Offset, mods: Modifiers, givenMod: Mod) = atSpan(start, nameStart) {
var mods1 = addMod(mods, givenMod)
Expand All @@ -3590,7 +3641,7 @@ object Parsers {
newLineOpt()
val vparamss =
if in.token == LPAREN && in.lookahead.isIdent(nme.using)
then paramClauses(givenOnly = true)
then paramClauses(givenOnly = true)
else Nil
newLinesOpt()
val noParams = tparams.isEmpty && vparamss.isEmpty
Expand Down Expand Up @@ -3626,7 +3677,7 @@ object Parsers {
}

/** Extension ::= ‘extension’ [DefTypeParamClause] {UsingParamClause} ‘(’ DefParam ‘)’
* {UsingParamClause} ExtMethods
* {UsingParamClause} ExtMethods //TODO: Change to {Params} ?
*/
def extension(): ExtMethods =
val start = in.skipToken()
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,8 @@ trait Applications extends Compatibility {
val typedArgs = if (isNamed) typedNamedArgs(tree.args) else tree.args.mapconserve(typedType(_))
record("typedTypeApply")
typedExpr(tree.fun, PolyProto(typedArgs, pt)) match {
case _: TypeApply if !ctx.isAfterTyper =>
errorTree(tree, "illegal repeated type application")
/* case _: TypeApply if !ctx.isAfterTyper => //TODO: assess removing this is okay
errorTree(tree, "illegal repeated type application") */
case typedFn =>
typedFn.tpe.widen match {
case pt: PolyType =>
Expand Down
8 changes: 8 additions & 0 deletions tests/neg/_TypeInterweaving/ab.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

given String = ""
given Double = 0

def ab[A][B](x: A)(using B): B = summon[B]

def test =
ab[Int](0: Int) // error
2 changes: 2 additions & 0 deletions tests/neg/_TypeInterweaving/nameCollision.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def f[T](x: T)[U](y: U) = (x,y)
def f[T](x: T, y: T) = (x,y) // error
5 changes: 5 additions & 0 deletions tests/neg/_TypeInterweaving/params.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
class Params{
def bar[T](x: T)[T]: String = ??? // error
def zoo(x: Int)[T, U](x: U): T = ??? // error
def bbb[T <: U](x: U)[U]: U = ??? // error // error
}
2 changes: 2 additions & 0 deletions tests/neg/_TypeInterweaving/unmatched1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

def f1[T (x: T)] = ??? // error
5 changes: 5 additions & 0 deletions tests/neg/_TypeInterweaving/unmatched2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@




def f2[T(x: T) = ??? // error
3 changes: 3 additions & 0 deletions tests/neg/_TypeInterweaving/unmatched3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@


def f3(x: Any[)T] = ??? // error
8 changes: 8 additions & 0 deletions tests/pos/_typeInterweaving/ba.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

given String = ""
given Double = 0

def ba[B][A](x: A)(using B): B = summon[B]

def test =
ba[String](0)
6 changes: 6 additions & 0 deletions tests/pos/_typeInterweaving/chainedParams.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

class Chain{
type Tail <: Chain
}

def f[C1 <: Chain](c1: C1)[C2 <: c1.Tail](c2: C2)[C3 <: c2.Tail](c3: C3): c3.Tail = ???
6 changes: 6 additions & 0 deletions tests/pos/_typeInterweaving/class.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

class C[T](x: T)[U](y: U){
def pair: (T,U) = (x,y)
def first = x
def second = y
}
4 changes: 4 additions & 0 deletions tests/pos/_typeInterweaving/classless.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def f1[T][U](x: T, y: U): (T, U) = (x, y)
def f2[T](x: T)[U](y: U): (T, U) = (x, y)

@main def test = f2(0)[String]("Hello")
16 changes: 16 additions & 0 deletions tests/pos/_typeInterweaving/functorCurrying.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//taken from https://dotty.epfl.ch/docs/reference/contextual/type-classes.html
//at version 3.1.1-RC1-bin-20210930-01f040b-NIGHTLY
//modified to have type currying
trait Functor[F[_]]:
def map[A][B](x: F[A], f: A => B): F[B]


given Functor[List] with
def map[A][B](x: List[A], f: A => B): List[B] =
x.map(f)

def assertTransformation[F[_]: Functor][A][B](expected: F[B], original: F[A], mapping: A => B): Unit =
assert(expected == summon[Functor[F]].map(original, mapping))

@main def test =
assertTransformation(List("a1", "b1"), List("a", "b"), elt => s"${elt}1")
16 changes: 16 additions & 0 deletions tests/pos/_typeInterweaving/functorInterweaving.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
//taken from https://dotty.epfl.ch/docs/reference/contextual/type-classes.html
//at version 3.1.1-RC1-bin-20210930-01f040b-NIGHTLY
//modified to have type interveawing
trait Functor[F[_]]:
def map[A](x: F[A])[B](f: A => B): F[B]


given Functor[List] with
def map[A](x: List[A])[B](f: A => B): List[B] =
x.map(f)

def assertTransformation[F[_]: Functor][A](original: F[A])[B](expected: F[B])(mapping: A => B): Unit =
assert(expected == summon[Functor[F]].map(original)(mapping))

@main def test =
assertTransformation(List("a", "b"))(List("a1", "b1")){elt => s"${elt}1"}
21 changes: 21 additions & 0 deletions tests/pos/_typeInterweaving/higherKindedReturn.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def f_4[T]: [U] => T => T = ???
def f_3[T][U]: T => T = ???
def f_2[T][U](): T => T = ???
def f_1[T <: Int][U <: String](): T => T = ???
def f0[T <: Int][U <: String]: T => T = ???
def f1[T <: Int][U <: String]: [X <: Unit] => X => X = ???
def f2[T <: Int][U <: String](): [X <: Unit] => X => X = ???
def f3[T <: Int][U <: String]()[X <: Unit]: X => X = ???

@main def test = {
f_4[Int][String] //only one that works when lines 1088 to 1089 of Applications.scala are uncommented
f_3[Int][String]
f_2[Int][String]()
f_1[Int][String]()
f0[Int][String]
f1[Int][String]
f1[Int][Unit]
f1[Int][String][Unit]
f2[Int]()[Unit]
f3[Int]()[Unit]
}
3 changes: 3 additions & 0 deletions tests/pos/_typeInterweaving/nameCollision.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import scala.annotation.targetName
def f[T](x: T)[U](y: U) = (x,y)
@targetName("g") def f[T](x: T, y: T) = (x,y)
9 changes: 9 additions & 0 deletions tests/pos/_typeInterweaving/newline.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Newline {
def multipleLines
[T]
(x: T)
[U]
(using (T,U))
(y: U)
= ???
}
25 changes: 25 additions & 0 deletions tests/pos/_typeInterweaving/overload.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

class A{
/*
def f0[T](x: Any) = ???
def f0[T](x: Int) = ???
*/

def f1[T][U](x: Any) = ???
def f1[T][U](x: Int) = ???

//f0(1)
f1(1)
f1("hello")

case class B[U](x: Int)
def b[U](x: Int) = B[U](x)

def f2[T]: [U] => Int => B[U] = [U] => (x: Int) => b[U](x)


f2(1)
f2[Any](1)
f2[Any][Any](1)

}
6 changes: 6 additions & 0 deletions tests/pos/_typeInterweaving/params.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class Params{
type U
def foo[T](x: T)[U >: x.type <: T][L <: List[U]](l: L): L = ???
def aaa(x: U): U = ???
def bbb[T <: U](x: U)[U]: U = ???
}
Loading