1
1
use std:: ops:: ControlFlow ;
2
2
3
- use rustc_data_structures:: fx:: FxIndexMap ;
3
+ use rustc_data_structures:: fx:: { FxIndexMap , FxIndexSet } ;
4
4
use rustc_hir as hir;
5
5
use rustc_hir:: def_id:: DefId ;
6
+ use rustc_infer:: infer:: outlives:: env:: OutlivesEnvironment ;
6
7
use rustc_infer:: infer:: TyCtxtInferExt ;
7
8
use rustc_infer:: traits:: Obligation ;
8
- use rustc_middle:: traits:: ObligationCause ;
9
+ use rustc_middle:: traits:: { ObligationCause , Reveal } ;
9
10
use rustc_middle:: ty:: {
10
11
self , Ty , TyCtxt , TypeFolder , TypeSuperFoldable , TypeSuperVisitable , TypeVisitable , TypeVisitor ,
11
12
} ;
12
13
use rustc_span:: ErrorGuaranteed ;
13
14
use rustc_span:: { sym, Span } ;
14
- use rustc_trait_selection:: traits:: ObligationCtxt ;
15
+ use rustc_trait_selection:: traits:: outlives_bounds:: InferCtxtExt ;
16
+ use rustc_trait_selection:: traits:: { normalize_param_env_or_error, ObligationCtxt } ;
15
17
use rustc_type_ir:: fold:: TypeFoldable ;
16
18
17
19
/// Check that an implementation does not refine an RPITIT from a trait method signature.
@@ -30,24 +32,48 @@ pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
30
32
let hidden_tys = tcx. collect_return_position_impl_trait_in_trait_tys ( impl_m. def_id ) ?;
31
33
32
34
let impl_def_id = impl_m. container_id ( tcx) ;
33
- // let trait_def_id = trait_m.container_id(tcx);
34
- let trait_m_to_impl_m_substs = ty:: InternalSubsts :: identity_for_item ( tcx, impl_m. def_id )
35
- . rebase_onto ( tcx, impl_def_id, impl_trait_ref. substs ) ;
35
+ let trait_def_id = trait_m. container_id ( tcx) ;
36
+ let trait_m_to_impl_m_args = ty:: GenericArgs :: identity_for_item ( tcx, impl_m. def_id )
37
+ . rebase_onto ( tcx, impl_def_id, impl_trait_ref. args ) ;
36
38
37
- let bound_trait_m_sig = tcx. fn_sig ( trait_m. def_id ) . subst ( tcx, trait_m_to_impl_m_substs) ;
38
- let trait_m_sig = tcx. liberate_late_bound_regions ( impl_m. def_id , bound_trait_m_sig) ;
39
+ let infcx = tcx. infer_ctxt ( ) . build ( ) ;
40
+ let ocx = ObligationCtxt :: new ( & infcx) ;
41
+
42
+ let mut hybrid_preds = tcx. predicates_of ( impl_def_id) . instantiate_identity ( tcx) . predicates ;
43
+ hybrid_preds. extend (
44
+ tcx. predicates_of ( trait_m. def_id )
45
+ . instantiate_own ( tcx, trait_m_to_impl_m_args)
46
+ . map ( |( pred, _) | pred) ,
47
+ ) ;
48
+ let normalize_cause =
49
+ ObligationCause :: misc ( tcx. def_span ( impl_m. def_id ) , impl_m. def_id . expect_local ( ) ) ;
50
+ let unnormalized_param_env = ty:: ParamEnv :: new (
51
+ tcx. mk_clauses ( & hybrid_preds) ,
52
+ Reveal :: HideReturnPositionImplTraitInTrait ,
53
+ ) ;
54
+ let param_env = normalize_param_env_or_error ( tcx, unnormalized_param_env, normalize_cause) ;
55
+
56
+ let bound_trait_m_sig = tcx. fn_sig ( trait_m. def_id ) . instantiate ( tcx, trait_m_to_impl_m_args) ;
57
+ let unnormalized_trait_m_sig =
58
+ tcx. liberate_late_bound_regions ( impl_m. def_id , bound_trait_m_sig) ;
59
+ let trait_m_sig = ocx. normalize ( & ObligationCause :: dummy ( ) , param_env, unnormalized_trait_m_sig) ;
39
60
40
61
let mut visitor = ImplTraitInTraitCollector { tcx, types : FxIndexMap :: default ( ) } ;
41
62
trait_m_sig. visit_with ( & mut visitor) ;
42
63
43
64
let mut reverse_mapping = FxIndexMap :: default ( ) ;
44
65
let mut bounds_to_prove = vec ! [ ] ;
45
- for ( rpitit_def_id, rpitit_substs) in visitor. types {
46
- let hidden_ty = hidden_tys
47
- . get ( & rpitit_def_id)
48
- . expect ( "expected hidden type for RPITIT" )
49
- . subst_identity ( ) ;
50
- reverse_mapping. insert ( hidden_ty, tcx. mk_projection ( rpitit_def_id, rpitit_substs) ) ;
66
+ for ( rpitit_def_id, rpitit_args) in visitor. types {
67
+ let hidden_ty =
68
+ hidden_tys. get ( & rpitit_def_id) . expect ( "expected hidden type for RPITIT" ) . instantiate (
69
+ tcx,
70
+ rpitit_args. rebase_onto (
71
+ tcx,
72
+ trait_def_id,
73
+ ty:: GenericArgs :: identity_for_item ( tcx, impl_def_id) ,
74
+ ) ,
75
+ ) ;
76
+ reverse_mapping. insert ( hidden_ty, Ty :: new_projection ( tcx, rpitit_def_id, rpitit_args) ) ;
51
77
52
78
let ty:: Alias ( ty:: Opaque , opaque_ty) = * hidden_ty. kind ( ) else {
53
79
return Err ( report_mismatched_rpitit_signature (
@@ -82,11 +108,6 @@ pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
82
108
) ;
83
109
}
84
110
85
- let infcx = tcx. infer_ctxt ( ) . build ( ) ;
86
- let ocx = ObligationCtxt :: new ( & infcx) ;
87
- let param_env =
88
- tcx. param_env ( impl_m. def_id ) . with_hidden_return_position_impl_trait_in_trait_tys ( ) ;
89
-
90
111
ocx. register_obligations (
91
112
bounds_to_prove. fold_with ( & mut ReverseMapper { tcx, reverse_mapping } ) . into_iter ( ) . map (
92
113
|( pred, span) | {
@@ -107,6 +128,24 @@ pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
107
128
) ) ;
108
129
}
109
130
131
+ let mut wf_tys = FxIndexSet :: default ( ) ;
132
+ wf_tys. extend ( unnormalized_trait_m_sig. inputs_and_output ) ;
133
+ wf_tys. extend ( trait_m_sig. inputs_and_output ) ;
134
+ let outlives_env = OutlivesEnvironment :: with_bounds (
135
+ param_env,
136
+ ocx. infcx . implied_bounds_tys ( param_env, impl_m. def_id . expect_local ( ) , wf_tys. clone ( ) ) ,
137
+ ) ;
138
+ let errors = ocx. infcx . resolve_regions ( & outlives_env) ;
139
+ if !errors. is_empty ( ) {
140
+ return Err ( report_mismatched_rpitit_signature (
141
+ tcx,
142
+ trait_m_sig,
143
+ trait_m. def_id ,
144
+ impl_m. def_id ,
145
+ None ,
146
+ ) ) ;
147
+ }
148
+
110
149
Ok ( ( ) )
111
150
}
112
151
0 commit comments