@@ -26,7 +26,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2626 /** Handlers to synthesize implicits for special types */
2727 type SpecialHandler = (Type , Span ) => Context ?=> TreeWithErrors
2828 private type SpecialHandlers = List [(ClassSymbol , SpecialHandler )]
29-
29+
3030 val synthesizedClassTag : SpecialHandler = (formal, span) =>
3131 formal.argInfos match
3232 case arg :: Nil =>
@@ -240,11 +240,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
240240 * <parent> {
241241 * MirroredMonoType = <monoType>
242242 * MirroredType = <mirroredType>
243- * MirroredLabel = <label> }
243+ * MirroredLabel = <label>
244244 * }
245245 */
246- private def mirrorCore (parentClass : ClassSymbol , monoType : Type , mirroredType : Type , label : Name , formal : Type )(using Context ) =
247- formal & parentClass.typeRef
246+ private def mirrorCore (parentClass : ClassSymbol , monoType : Type , mirroredType : Type , label : Name )(using Context ) =
247+ parentClass.typeRef
248248 .refinedWith(tpnme.MirroredMonoType , TypeAlias (monoType))
249249 .refinedWith(tpnme.MirroredType , TypeAlias (mirroredType))
250250 .refinedWith(tpnme.MirroredLabel , TypeAlias (ConstantType (Constant (label.toString))))
@@ -269,6 +269,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
269269 then report.error(
270270 em " $name mismatch, expected: $expected, found: $actual. " , ctx.source.atSpan(span))
271271
272+ extension (formal : Type )
273+ /** `tp := op; tp <:< formal; formal & tp` */
274+ private def constrained_& (op : Context ?=> Type )(using Context ): Type =
275+ val tp = op
276+ tp <:< formal
277+ formal & tp
278+
272279 private def mkMirroredMonoType (mirroredType : HKTypeLambda )(using Context ): Type =
273280 val monoMap = new TypeMap :
274281 def apply (t : Type ) = t match
@@ -313,22 +320,23 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
313320 val elemsLabels = TypeOps .nestedPairs(elemLabels)
314321 checkRefinement(formal, tpnme.MirroredElemTypes , elemsType, span)
315322 checkRefinement(formal, tpnme.MirroredElemLabels , elemsLabels, span)
316- val mirrorType =
317- mirrorCore(defn.Mirror_ProductClass , monoType, mirroredType, cls.name, formal )
323+ val mirrorType = formal.constrained_& {
324+ mirrorCore(defn.Mirror_ProductClass , monoType, mirroredType, cls.name)
318325 .refinedWith(tpnme.MirroredElemTypes , TypeAlias (elemsType))
319326 .refinedWith(tpnme.MirroredElemLabels , TypeAlias (elemsLabels))
327+ }
320328 val mirrorRef =
321329 if (genAnonyousMirror(cls)) anonymousMirror(monoType, ExtendsProductMirror , span)
322330 else companionPath(mirroredType, span)
323331 withNoErrors(mirrorRef.cast(mirrorType))
324332 end makeProductMirror
325333
326- def getError (cls : Symbol ): String =
334+ def getError (cls : Symbol ): String =
327335 val reason = if ! cls.isGenericProduct then
328336 i " because ${cls.whyNotGenericProduct}"
329337 else if ! canAccessCtor(cls) then
330338 i " because the constructor of $cls is innaccessible from the calling scope. "
331- else
339+ else
332340 " "
333341 i " $cls is not a generic product $reason"
334342 end getError
@@ -341,11 +349,15 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
341349 val module = mirroredType.termSymbol
342350 val modulePath = pathFor(mirroredType).withSpan(span)
343351 if module.info.classSymbol.is(Scala2x ) then
344- val mirrorType = mirrorCore(defn.Mirror_SingletonProxyClass , mirroredType, mirroredType, module.name, formal)
352+ val mirrorType = formal.constrained_& {
353+ mirrorCore(defn.Mirror_SingletonProxyClass , mirroredType, mirroredType, module.name)
354+ }
345355 val mirrorRef = New (defn.Mirror_SingletonProxyClass .typeRef, modulePath :: Nil )
346356 withNoErrors(mirrorRef.cast(mirrorType))
347357 else
348- val mirrorType = mirrorCore(defn.Mirror_SingletonClass , mirroredType, mirroredType, module.name, formal)
358+ val mirrorType = formal.constrained_& {
359+ mirrorCore(defn.Mirror_SingletonClass , mirroredType, mirroredType, module.name)
360+ }
349361 withNoErrors(modulePath.cast(mirrorType))
350362 else
351363 val cls = mirroredType.classSymbol
@@ -419,16 +431,19 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
419431 (mirroredType, elems)
420432
421433 val mirrorType =
422- mirrorCore(defn.Mirror_SumClass , monoType, mirroredType, cls.name, formal)
434+ val labels = TypeOps .nestedPairs(elemLabels)
435+ formal.constrained_& {
436+ mirrorCore(defn.Mirror_SumClass , monoType, mirroredType, cls.name)
423437 .refinedWith(tpnme.MirroredElemTypes , TypeAlias (elemsType))
424- .refinedWith(tpnme.MirroredElemLabels , TypeAlias (TypeOps .nestedPairs(elemLabels)))
438+ .refinedWith(tpnme.MirroredElemLabels , TypeAlias (labels))
439+ }
425440 val mirrorRef =
426441 if useCompanion then companionPath(mirroredType, span)
427442 else anonymousMirror(monoType, ExtendsSumMirror , span)
428443 withNoErrors(mirrorRef.cast(mirrorType))
429444 else if ! clsIsGenericSum then
430445 (EmptyTree , List (i " $cls is not a generic sum because ${cls.whyNotGenericSum(declScope)}" ))
431- else
446+ else
432447 EmptyTreeNoError
433448 end sumMirror
434449
@@ -595,7 +610,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
595610 tp.baseType(cls)
596611 val base = baseWithRefinements(formal)
597612 val result =
598- if (base <:< formal.widenExpr)
613+ if (base <:< formal.widenExpr)
599614 // With the subtype test we enforce that the searched type `formal` is of the right form
600615 handler(base, span)
601616 else EmptyTreeNoError
@@ -609,19 +624,19 @@ end Synthesizer
609624
610625object Synthesizer :
611626
612- /** Tuple used to store the synthesis result with a list of errors. */
627+ /** Tuple used to store the synthesis result with a list of errors. */
613628 type TreeWithErrors = (Tree , List [String ])
614629 private def withNoErrors (tree : Tree ): TreeWithErrors = (tree, List .empty)
615630
616631 private val EmptyTreeNoError : TreeWithErrors = withNoErrors(EmptyTree )
617632
618633 private def orElse (treeWithErrors1 : TreeWithErrors , treeWithErrors2 : => TreeWithErrors ): TreeWithErrors = treeWithErrors1 match
619- case (tree, errors) if tree eq genericEmptyTree =>
634+ case (tree, errors) if tree eq genericEmptyTree =>
620635 val (tree2, errors2) = treeWithErrors2
621636 (tree2, errors ::: errors2)
622637 case _ => treeWithErrors1
623638
624- private def clearErrorsIfNotEmpty (treeWithErrors : TreeWithErrors ) = treeWithErrors match
639+ private def clearErrorsIfNotEmpty (treeWithErrors : TreeWithErrors ) = treeWithErrors match
625640 case (tree, _) if tree eq genericEmptyTree => treeWithErrors
626641 case (tree, _) => withNoErrors(tree)
627642
0 commit comments