Skip to content

Add ParamClause to allow multiple type param clauses #11074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -256,21 +256,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end DefDefTypeTest

object DefDef extends DefDefModule:
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss => {
val (tparams, vparamss) = tpd.splitArgs(prefss)
yCheckedOwners(rhsFn(tparams.map(_.tpe))(vparamss), symbol).getOrElse(tpd.EmptyTree)
}))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
def apply(symbol: Symbol, rhsFn: List[List[Tree]] => Option[Term]): DefDef =
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss =>
yCheckedOwners(rhsFn(prefss), symbol).getOrElse(tpd.EmptyTree)
))
def copy(original: Tree)(name: String, paramss: List[ParamClause], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, paramss, tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[ParamClause], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.paramss, ddef.tpt, optional(ddef.rhs))
end DefDef

given DefDefMethods: DefDefMethods with
extension (self: DefDef)
def typeParams: List[TypeDef] = self.leadingTypeParams // TODO: adapt to multiple type parameter clauses
def paramss: List[List[ValDef]] = self.termParamss
def paramss: List[ParamClause] = self.paramss
def leadingTypeParams: List[TypeDef] = self.leadingTypeParams
def trailingParamss: List[ParamClause] = self.trailingParamss
def termParamss: List[TermParamClause] = self.termParamss
def returnTpt: TypeTree = self.tpt
def rhs: Option[Term] = optional(self.rhs)
end extension
Expand Down Expand Up @@ -750,7 +751,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head.map(withDefaultPos)), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
case Block((ddef @ DefDef(_, TermParamClause(params) :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
Expand Down Expand Up @@ -1481,6 +1482,59 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end extension
end AlternativesMethods

type ParamClause = tpd.ParamClause

object ParamClause extends ParamClauseModule

given ParamClauseMethods: ParamClauseMethods with
extension (self: ParamClause)
def params: List[ValDef] | List[TypeDef] = self
end ParamClauseMethods

type TermParamClause = List[tpd.ValDef]

given TermParamClauseTypeTest: TypeTest[ParamClause, TermParamClause] with
def unapply(x: ParamClause): Option[TermParamClause & x.type] = x match
case tpd.ValDefs(_) => Some(x.asInstanceOf[TermParamClause & x.type])
case _ => None
end TermParamClauseTypeTest

object TermParamClause extends TermParamClauseModule:
def apply(params: List[ValDef]): TermParamClause =
if yCheck then
val implicitParams = params.count(_.symbol.is(dotc.core.Flags.Implicit))
assert(implicitParams == 0 || implicitParams == params.size, "Expected all or non of parameters to be implicit")
params
def unapply(x: TermParamClause): Some[List[ValDef]] = Some(x)
end TermParamClause

given TermParamClauseMethods: TermParamClauseMethods with
extension (self: TermParamClause)
def params: List[ValDef] = self
def isImplicit: Boolean =
self.nonEmpty && self.head.symbol.is(dotc.core.Flags.Implicit)
end TermParamClauseMethods

type TypeParamClause = List[tpd.TypeDef]

given TypeParamClauseTypeTest: TypeTest[ParamClause, TypeParamClause] with
def unapply(x: ParamClause): Option[TypeParamClause & x.type] = x match
case tpd.TypeDefs(_) => Some(x.asInstanceOf[TypeParamClause & x.type])
case _ => None
end TypeParamClauseTypeTest

object TypeParamClause extends TypeParamClauseModule:
def apply(params: List[TypeDef]): TypeParamClause =
if params.isEmpty then throw IllegalArgumentException("Empty type parameters")
params
def unapply(x: TypeParamClause): Some[List[TypeDef]] = Some(x)
end TypeParamClause

given TypeParamClauseMethods: TypeParamClauseMethods with
extension (self: TypeParamClause)
def params: List[TypeDef] = self
end TypeParamClauseMethods

type Selector = untpd.ImportSelector

object Selector extends SelectorModule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,14 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
}))
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
(ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
// def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =
// (ddef.name.toString, ddef.typeParams, ddef.termParamss, ddef.tpt, optional(ddef.rhs))
end DefDef

given DefDefMethods: DefDefMethods with
extension (self: DefDef)
def typeParams: List[TypeDef] = self.leadingTypeParams // TODO: adapt to multiple type parameter clauses
def paramss: List[List[ValDef]] = self.termParamss
// def paramss: List[List[ValDef]] = self.termParamss
def returnTpt: TypeTree = self.tpt
def rhs: Option[Term] = optional(self.rhs)
end extension
Expand Down Expand Up @@ -747,12 +747,7 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
val meth = dotc.core.Symbols.newSymbol(owner, nme.ANON_FUN, Synthetic | Method, tpe)
tpd.Closure(meth, tss => yCheckedOwners(rhsFn(meth, tss.head), meth))

def unapply(tree: Block): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some((params, body))
case _ => None
}
def unapply(tree: Block): Option[(List[ValDef], Term)] = ???
end Lambda

type If = tpd.If
Expand Down
27 changes: 19 additions & 8 deletions compiler/src/scala/quoted/runtime/impl/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ object Matcher {
}.transformTree(scrutinee)(Symbol.spliceOwner)
}
val names = args.map {
case Block(List(DefDef("$anonfun", _, _, _, Some(Apply(Ident(name), _)))), _) => name
case Block(List(DefDef("$anonfun", _, _, Some(Apply(Ident(name), _)))), _) => name
case arg => arg.symbol.name
}
val argTypes = args.map(x => x.tpe.widenTermRefByName)
Expand Down Expand Up @@ -302,16 +302,19 @@ object Matcher {
tpt1 =?= tpt2 &&& treeOptMatches(rhs1, rhs2)(using rhsEnv)

/* Match def */
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
case (DefDef(_, paramss1, tpt1, Some(rhs1)), DefDef(_, paramss2, tpt2, Some(rhs2))) =>
def rhsEnv =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.params.zip(clause2.params)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] =
(scrutinee.symbol -> pattern.symbol) :: typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) :::
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv

typeParams1 =?= typeParams2
&&& matchLists(paramss1, paramss2)(_ =?= _)
matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(rhs1 =?= rhs2)

Expand Down Expand Up @@ -343,6 +346,14 @@ object Matcher {
}
end extension

extension (scrutinee: ParamClause)
/** Check that all parameters in the clauses clauses match with =?= and concatenate the results with &&& */
private def =?= (pattern: ParamClause)(using Env)(using DummyImplicit): Matching =
(scrutinee, pattern) match
case (TermParamClause(params1), TermParamClause(params2)) => matchLists(params1, params2)(_ =?= _)
case (TypeParamClause(params1), TypeParamClause(params2)) => matchLists(params1, params2)(_ =?= _)
case _ => notMatched

/** Does the scrutenne symbol match the pattern symbol? It matches if:
* - They are the same symbol
* - The scrutinee has is in the environment and they are equivalent
Expand Down Expand Up @@ -382,7 +393,7 @@ object Matcher {
def unapply(args: List[Term]): Option[List[Ident]] =
args.foldRight(Option(List.empty[Ident])) {
case (id: Ident, Some(acc)) => Some(id :: acc)
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
case (Block(List(DefDef("$anonfun", TermParamClause(params) :: Nil, Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
if params.zip(args).forall(_.symbol == _.symbol) =>
Some(id :: acc)
case _ => None
Expand Down
13 changes: 11 additions & 2 deletions compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ object Extractors {
this += ", " ++= bindings += ", " += expansion += ")"
case ValDef(name, tpt, rhs) =>
this += "ValDef(\"" += name += "\", " += tpt += ", " += rhs += ")"
case DefDef(name, typeParams, paramss, returnTpt, rhs) =>
this += "DefDef(\"" += name += "\", " ++= typeParams += ", " +++= paramss += ", " += returnTpt += ", " += rhs += ")"
case DefDef(name, paramsClauses, returnTpt, rhs) =>
this += "DefDef(\"" += name += "\", " ++= paramsClauses += ", " += returnTpt += ", " += rhs += ")"
case TypeDef(name, rhs) =>
this += "TypeDef(\"" += name += "\", " += rhs += ")"
case ClassDef(name, constr, parents, derived, self, body) =>
Expand Down Expand Up @@ -256,6 +256,11 @@ object Extractors {
else if x.isTypeDef then this += "IsTypeDefSymbol(<" += x.fullName += ">)"
else { assert(x.isNoSymbol); this += "NoSymbol()" }

def visitParamClause(x: ParamClause): this.type =
x match
case TermParamClause(params) => this += "TermParamClause(" ++= params += ")"
case TypeParamClause(params) => this += "TypeParamClause(" ++= params += ")"

def +=(x: Boolean): this.type = { sb.append(x); this }
def +=(x: Byte): this.type = { sb.append(x); this }
def +=(x: Short): this.type = { sb.append(x); this }
Expand Down Expand Up @@ -301,6 +306,10 @@ object Extractors {
def +=(x: Symbol): self.type = { visitSymbol(x); buff }
}

private implicit class ParamClauseOps(buff: self.type) {
def ++=(x: List[ParamClause]): self.type = { visitList(x, visitParamClause); buff }
}

private def visitOption[U](opt: Option[U], visit: U => this.type): this.type = opt match {
case Some(x) =>
this += "Some("
Expand Down
29 changes: 15 additions & 14 deletions compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ object SourceCode {
this += "."
printSelectors(selectors)

case cdef @ ClassDef(name, DefDef(_, targs, argss, _, _), parents, derived, self, stats) =>
case cdef @ ClassDef(name, DefDef(_, paramss, _, _), parents, derived, self, stats) =>
printDefAnnotations(cdef)

val flags = cdef.symbol.flags
Expand All @@ -155,12 +155,13 @@ object SourceCode {
else if (flags.is(Flags.Abstract)) this += highlightKeyword("abstract class ") += highlightTypeDef(name)
else this += highlightKeyword("class ") += highlightTypeDef(name)

val typeParams = stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(targs)
if (!flags.is(Flags.Module)) {
printTargsDefs(typeParams)
val it = argss.iterator
while (it.hasNext)
printArgsDefs(it.next())
for paramClause <- paramss do
paramClause match
case TermParamClause(params) =>
printArgsDefs(params)
case TypeParamClause(params) =>
printTargsDefs(stats.collect { case targ: TypeDef => targ }.filter(_.symbol.isTypeParam).zip(params))
}

val parents1 = parents.filter {
Expand Down Expand Up @@ -212,8 +213,8 @@ object SourceCode {
// Currently the compiler does not allow overriding some of the methods generated for case classes
d.symbol.flags.is(Flags.Synthetic) &&
(d match {
case DefDef("apply" | "unapply" | "writeReplace", _, _, _, _) if d.symbol.owner.flags.is(Flags.Module) => true
case DefDef(n, _, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
case DefDef("apply" | "unapply" | "writeReplace", _, _, _) if d.symbol.owner.flags.is(Flags.Module) => true
case DefDef(n, _, _, _) if d.symbol.owner.flags.is(Flags.Case) =>
n == "copy" ||
n.matches("copy\\$default\\$[1-9][0-9]*") || // default parameters for the copy method
n.matches("_[1-9][0-9]*") || // Getters from Product
Expand Down Expand Up @@ -301,7 +302,7 @@ object SourceCode {
printTree(body)
}

case ddef @ DefDef(name, targs, argss, tpt, rhs) =>
case ddef @ DefDef(name, paramss, tpt, rhs) =>
printDefAnnotations(ddef)

val isConstructor = name == "<init>"
Expand All @@ -316,10 +317,10 @@ object SourceCode {

val name1: String = if (isConstructor) "this" else splicedName(ddef.symbol).getOrElse(name)
this += highlightKeyword("def ") += highlightValDef(name1)
printTargsDefs(targs.zip(targs))
val it = argss.iterator
while (it.hasNext)
printArgsDefs(it.next())
for clause <- paramss do
clause match
case TermParamClause(params) => printArgsDefs(params)
case TypeParamClause(params) => printTargsDefs(params.zip(params))
if (!isConstructor) {
this += ": "
printTypeTree(tpt)
Expand Down Expand Up @@ -1251,7 +1252,7 @@ object SourceCode {

private def printDefinitionName(tree: Definition): this.type = tree match {
case ValDef(name, _, _) => this += highlightValDef(name)
case DefDef(name, _, _, _, _) => this += highlightValDef(name)
case DefDef(name, _, _, _) => this += highlightValDef(name)
case ClassDef(name, _, _, _, _, _) => this += highlightTypeDef(name.stripSuffix("$"))
case TypeDef(name, _) => this += highlightTypeDef(name)
}
Expand Down
Loading