@@ -2,7 +2,7 @@ package dotty.tools
22package  dotc 
33package  typer 
44
5- import  dotty . tools . dotc . ast .{ Trees ,  untpd ,  tpd ,  TreeTypeMap } 
5+ import  ast ._ 
66import  Trees ._ 
77import  core ._ 
88import  Flags ._ 
@@ -14,16 +14,16 @@ import StdNames._
1414import  transform .SymUtils ._ 
1515import  Contexts .Context 
1616import  Names .{Name , TermName }
17- import  NameKinds .{InlineAccessorName , InlineScrutineeName ,  InlineBinderName }
17+ import  NameKinds .{InlineAccessorName , InlineBinderName ,  InlineScrutineeName }
1818import  ProtoTypes .selectionProto 
1919import  SymDenotations .SymDenotation 
2020import  Inferencing .fullyDefinedType 
2121import  config .Printers .inlining 
2222import  ErrorReporting .errorTree 
23+ 
2324import  collection .mutable 
2425import  reporting .trace 
2526import  util .Positions .Position 
26- import  ast .TreeInfo 
2727
2828object  Inliner  {
2929  import  tpd ._ 
@@ -692,7 +692,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
692692          bindingsBuf +=  ValDef (sym, constToLiteral(rhs))
693693        }
694694
695-         def  searchImplicit (sym : TermSymbol , tpt : Tree ) =  {
695+         def  searchImplicit (sym : TermSymbol , tpt : Tree )( implicit   ctx :  Context )  =  {
696696          val  evTyper  =  new  Typer 
697697          val  evidence  =  evTyper.inferImplicitArg(tpt.tpe, tpt.pos)(ctx.fresh.setTyper(evTyper))
698698          evidence.tpe match  {
@@ -707,48 +707,70 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
707707          }
708708        }
709709
710-         pat match  {
711-           case  Typed (pat1, tpt) => 
712-             val  getBoundVars  =  new  TreeAccumulator [List [TypeSymbol ]] {
713-               def  apply (syms : List [TypeSymbol ], t : Tree )(implicit  ctx : Context ) =  {
714-                 val  syms1  =  t match  {
715-                   case  t : Bind  if  t.symbol.isType &&  t.name !=  tpnme.WILDCARD  => 
716-                     t.symbol.asType ::  syms
717-                   case  _ => 
718-                     syms
719-                 }
720-                 foldOver(syms1, t)
710+         import  java .{lang  =>  jl }
711+         def  getBoundVarsMap (pat : Tree , tpt : Tree ) =  {
712+           //  UnApply nodes with pattern bound variables translate to something like this
713+           //    UnApply[t @ t](pats)(implicits): T[t]
714+           //  Need to traverse any binds in type arguments of the UnAppyl to get the set of
715+           //  all instantiable type variables. Test case is pos/inline-caseclass.scala.
716+           val  allTpts  =  tpt ::  (pat match  {
717+             case  UnApply (TypeApply (_, tpts), _, _) =>  tpts
718+             case  _ =>  Nil 
719+           })
720+ 
721+           val  getBinds  =  new  TreeAccumulator [Set [TypeSymbol ]] {
722+             def  apply (syms : Set [TypeSymbol ], t : tpd.Tree )(implicit  ctx : Context ):  Set [TypeSymbol ] =  {
723+               val  syms1  =  t match  {
724+                 case  t : Bind  if  t.symbol.isType &&  t.name !=  tpnme.WILDCARD  => 
725+                   syms +  t.symbol.asType
726+                 case  _ =>  syms
721727              }
728+               foldOver(syms1, t)
722729            }
723-             var  boundVars  =  getBoundVars(Nil , tpt)
724-             //  UnApply nodes with pattern bound variables translate to something like this
725-             //    UnApply[t @ t](pats)(implicits): T[t]
726-             //  Need to traverse any binds in type arguments of the UnAppyl to get the set of
727-             //  all instantiable type variables. Test case is pos/inline-caseclass.scala.
728-             pat1 match  {
729-               case  UnApply (TypeApply (_, tpts), _, _) => 
730-                 for  (tpt <-  tpts) boundVars =  getBoundVars(boundVars, tpt)
731-               case  _ => 
732-             }
733-             for  (bv <-  boundVars) {
734-               val  TypeBounds (lo, hi) =  bv.info.bounds
735-               ctx.gadt.addBound(bv, lo, isUpper =  false )
736-               ctx.gadt.addBound(bv, hi, isUpper =  true )
730+           }
731+           val  binds  =  allTpts.foldLeft[Set [TypeSymbol ]](Set .empty)(getBinds.apply(_, _))
732+           val  getBoundVars  =  new  TypeAccumulator [util.SimpleIdentityMap [TypeSymbol , jl.Boolean ]] {
733+             def  apply (syms : util.SimpleIdentityMap [TypeSymbol , jl.Boolean ], t : Type ) =  {
734+               val  syms1  =  t match  {
735+                  case  tr : TypeRef  if  tr.symbol.is(Case ) &&  binds.contains(tr.symbol.asType) => 
736+                   syms.updated[jl.Boolean ](tr.typeSymbol.asType, variance >=  0 )
737+                 case  _ => 
738+                   syms
739+               }
740+               foldOver(syms1, t)
737741            }
742+           }
743+           getBoundVars(util.SimpleIdentityMap .Empty , tpt.tpe)
744+         }
745+ 
746+         def  registerAsGadtSyms (map : util.SimpleIdentityMap [TypeSymbol , jl.Boolean ])(implicit  ctx : Context ):  Unit  = 
747+           map.foreachBinding { case  (sym, _) => 
748+             val  TypeBounds (lo, hi) =  sym.info.bounds
749+             ctx.gadt.addBound(sym, lo, isUpper =  false )
750+             ctx.gadt.addBound(sym, hi, isUpper =  true )
751+           }
752+ 
753+         def  addTypeBindings (map : util.SimpleIdentityMap [TypeSymbol , jl.Boolean ])(implicit  ctx : Context ):  Unit  = 
754+           map.foreachBinding { case  (sym, fromBelow) => 
755+             sym.info =  TypeAlias (ctx.gadt.approximation(sym, fromBelow =  fromBelow))
756+             bindingsBuf +=  TypeDef (sym)
757+           }
758+ 
759+         pat match  {
760+           case  Typed (pat1, tpt) => 
761+             val  boundVarsMap  =  getBoundVarsMap(pat1, tpt)
762+             registerAsGadtSyms(boundVarsMap)
738763            scrut <:<  tpt.tpe &&  {
739-               for  (bv <-  boundVars) {
740-                 bv.info =  TypeAlias (ctx.gadt.bounds(bv).lo)
741-                   //  FIXME: This is very crude. We should approximate with lower or higher bound depending
742-                   //  on variance, and we should also take care of recursive bounds. Basically what
743-                   //  ConstraintHandler#approximation does. However, this only works for constrained paramrefs
744-                   //  not GADT-bound variables. Hopefully we will get some way to improve this when we
745-                   //  re-implement GADTs in terms of constraints.
746-                 bindingsBuf +=  TypeDef (bv)
747-               }
764+               addTypeBindings(boundVarsMap)
748765              reducePattern(bindingsBuf, scrut, pat1)
749766            }
750767          case  pat @  Bind (name : TermName , Typed (_, tpt)) if  isImplicit => 
751-             searchImplicit(pat.symbol.asTerm, tpt)
768+             val  boundVarsMap  =  getBoundVarsMap(tpt, tpt)
769+             registerAsGadtSyms(boundVarsMap)
770+             searchImplicit(pat.symbol.asTerm, tpt) &&  {
771+               addTypeBindings(boundVarsMap)
772+               true 
773+             }
752774          case  pat @  Bind (name : TermName , body) => 
753775            reducePattern(bindingsBuf, scrut, body) &&  {
754776              if  (name !=  nme.WILDCARD ) newBinding(pat.symbol.asTerm, ref(scrut))
@@ -906,7 +928,23 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
906928              matchBindingsBuf +=  binding
907929              rhsCtx.enter(binding.symbol)
908930            }
909-             typedExpr(rhs, pt)(rhsCtx)
931+             val  rrhs  =  rhs match  {
932+               case  Block (stats, t) if  t.pos.isSynthetic => 
933+                 t match  {
934+                   case  Typed (expr, _) => 
935+                     untpd.Block (stats, expr)
936+                   case  TypeApply (Select (expr, n), _) if  n ==  defn.Any_asInstanceOf .name => 
937+                     untpd.Block (stats, expr)
938+                   case  _ => 
939+                     rhs
940+                 }
941+               case  _ =>  rhs
942+             }
943+ //             println(i"""pre=$rhs
944+ //                        |post=$rrhs""".stripMargin)
945+             //  put caseBindings in a block to let typedExpr see them
946+             val  Block (_, res) =  typedExpr(untpd.Block (caseBindings, rrhs), pt)(rhsCtx)
947+             res
910948          case  None  => 
911949            def  guardStr (guard : untpd.Tree ) =  if  (guard.isEmpty) " " else  i "  if  $guard" 
912950            def  patStr (cdef : untpd.CaseDef ) =  i " case  ${cdef.pat}${guardStr(cdef.guard)}" 
0 commit comments