Skip to content

Commit e672e2c

Browse files
committed
Unify type parameters on quote pattern matching type lambda
1 parent 0064c93 commit e672e2c

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

+10
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,16 @@ sealed trait GadtState {
268268
finally if !result then restore(saved)
269269
result
270270

271+
def unifySyms(params1: List[Symbol], params2: List[Symbol])(using Context) =
272+
addToConstraint(params1)
273+
addToConstraint(params2)
274+
val paramrefs1 = params1 map (gadt.tvarOrError(_))
275+
val paramrefs2 = params2 map (gadt.tvarOrError(_))
276+
for ((p1, p2) <- paramrefs1.zip(paramrefs2))
277+
do
278+
addLess(p1.origin, p2.origin)
279+
addLess(p2.origin, p1.origin)
280+
271281
// ---- Protected/internal -----------------------------------------------
272282

273283
override protected def constraint = gadt.constraint

compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala

+16-4
Original file line numberDiff line numberDiff line change
@@ -427,13 +427,25 @@ class QuoteMatcher(debug: Boolean) {
427427
notMatched
428428
case _ => matched
429429

430+
def matchTypeParams(ptparams: List[TypeDef], scparams: List[TypeDef]): optional[MatchingExprs] =
431+
// TODO-18271: Compare type bounds
432+
val ptsyms = ptparams.map(_.symbol)
433+
val scsyms = scparams.map(_.symbol)
434+
ctx.gadtState.unifySyms(ptsyms, scsyms)
435+
matched
436+
430437
def matchParamss(scparamss: List[ParamClause], ptparamss: List[ParamClause])(using Env): optional[(Env, MatchingExprs)] =
431438
(scparamss, ptparamss) match {
432439
case (scparams :: screst, ptparams :: ptrest) =>
433-
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
434-
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
435-
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
436-
(resEnv, mr1 &&& mrrest)
440+
(scparams, ptparams) match
441+
case (TypeDefs(scparams), TypeDefs(ptparams)) =>
442+
(summon[Env], matchTypeParams(scparams, ptparams))
443+
case (ValDefs(scparams), ValDefs(ptparams)) =>
444+
val mr1 = matchLists(scparams, ptparams)(_ =?= _)
445+
val newEnv = summon[Env] ++ scparams.map(_.symbol).zip(ptparams.map(_.symbol))
446+
val (resEnv, mrrest) = withEnv(newEnv)(matchParamss(screst, ptrest))
447+
(resEnv, mr1 &&& mrrest)
448+
case _ => notMatched
437449
case (Nil, Nil) => (summon[Env], matched)
438450
case _ => notMatched
439451
}

0 commit comments

Comments
 (0)