Skip to content

Commit 5b05b05

Browse files
Implement refine check for RPITITs
1 parent 0feab53 commit 5b05b05

File tree

29 files changed

+444
-31
lines changed

29 files changed

+444
-31
lines changed

compiler/rustc_const_eval/src/transform/validate.rs

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ impl<'tcx> MirPass<'tcx> for Validator {
5050
let param_env = match mir_phase.reveal() {
5151
Reveal::UserFacing => tcx.param_env(def_id),
5252
Reveal::All => tcx.param_env_reveal_all_normalized(def_id),
53+
Reveal::HideReturnPositionImplTraitInTrait => {
54+
unreachable!("only used during refinement checks")
55+
}
5356
};
5457

5558
let always_live_locals = always_storage_live_locals(body);

compiler/rustc_hir_analysis/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,10 @@ hir_analysis_return_type_notation_on_non_rpitit =
219219
.note = function returns `{$ty}`, which is not compatible with associated type return bounds
220220
.label = this function must be `async` or return `impl Trait`
221221
222+
hir_analysis_rpitit_refined = impl method signature does not match trait method signature
223+
.suggestion = replace the return type so that it matches the trait
224+
.label = return type from trait method defined here
225+
.unmatched_bound_label = this bound is stronger than that defined on the trait
222226
hir_analysis_self_in_impl_self =
223227
`Self` is not valid in the self type of an impl block
224228
.note = replace `Self` with a different type

compiler/rustc_hir_analysis/src/check/compare_impl_item.rs

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ use rustc_trait_selection::traits::{
2828
use std::borrow::Cow;
2929
use std::iter;
3030

31+
mod refine;
32+
3133
/// Checks that a method from an impl conforms to the signature of
3234
/// the same method as declared in the trait.
3335
///
@@ -53,6 +55,12 @@ pub(super) fn compare_impl_method<'tcx>(
5355
impl_trait_ref,
5456
CheckImpliedWfMode::Check,
5557
)?;
58+
refine::compare_impl_trait_in_trait_predicate_entailment(
59+
tcx,
60+
impl_m,
61+
trait_m,
62+
impl_trait_ref,
63+
)?;
5664
};
5765
}
5866

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
use std::ops::ControlFlow;
2+
3+
use rustc_data_structures::fx::FxIndexMap;
4+
use rustc_hir as hir;
5+
use rustc_hir::def_id::DefId;
6+
use rustc_infer::infer::TyCtxtInferExt;
7+
use rustc_infer::traits::Obligation;
8+
use rustc_middle::traits::ObligationCause;
9+
use rustc_middle::ty::{
10+
self, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitor,
11+
};
12+
use rustc_span::ErrorGuaranteed;
13+
use rustc_span::{sym, Span};
14+
use rustc_trait_selection::traits::ObligationCtxt;
15+
use rustc_type_ir::fold::TypeFoldable;
16+
17+
/// Check that an implementation does not refine an RPITIT from a trait method signature.
18+
pub(super) fn compare_impl_trait_in_trait_predicate_entailment<'tcx>(
19+
tcx: TyCtxt<'tcx>,
20+
impl_m: ty::AssocItem,
21+
trait_m: ty::AssocItem,
22+
impl_trait_ref: ty::TraitRef<'tcx>,
23+
) -> Result<(), ErrorGuaranteed> {
24+
if !tcx.impl_method_has_trait_impl_trait_tys(impl_m.def_id)
25+
|| tcx.has_attr(impl_m.def_id, sym::refine)
26+
{
27+
return Ok(());
28+
}
29+
30+
let hidden_tys = tcx.collect_return_position_impl_trait_in_trait_tys(impl_m.def_id)?;
31+
32+
let impl_def_id = impl_m.container_id(tcx);
33+
//let trait_def_id = trait_m.container_id(tcx);
34+
let trait_m_to_impl_m_substs = ty::InternalSubsts::identity_for_item(tcx, impl_m.def_id)
35+
.rebase_onto(tcx, impl_def_id, impl_trait_ref.substs);
36+
37+
let bound_trait_m_sig = tcx.fn_sig(trait_m.def_id).subst(tcx, trait_m_to_impl_m_substs);
38+
let trait_m_sig = tcx.liberate_late_bound_regions(impl_m.def_id, bound_trait_m_sig);
39+
40+
let mut visitor = ImplTraitInTraitCollector { tcx, types: FxIndexMap::default() };
41+
trait_m_sig.visit_with(&mut visitor);
42+
43+
let mut reverse_mapping = FxIndexMap::default();
44+
let mut bounds_to_prove = vec![];
45+
for (rpitit_def_id, rpitit_substs) in visitor.types {
46+
let hidden_ty = hidden_tys
47+
.get(&rpitit_def_id)
48+
.expect("expected hidden type for RPITIT")
49+
.subst_identity();
50+
reverse_mapping.insert(hidden_ty, tcx.mk_projection(rpitit_def_id, rpitit_substs));
51+
52+
let ty::Alias(ty::Opaque, opaque_ty) = *hidden_ty.kind() else {
53+
return Err(report_mismatched_rpitit_signature(
54+
tcx,
55+
trait_m_sig,
56+
trait_m.def_id,
57+
impl_m.def_id,
58+
None,
59+
));
60+
};
61+
62+
// Check that this is an opaque that comes from our impl fn
63+
if !tcx.hir().get_if_local(opaque_ty.def_id).map_or(false, |node| {
64+
matches!(
65+
node.expect_item().expect_opaque_ty().origin,
66+
hir::OpaqueTyOrigin::AsyncFn(def_id) | hir::OpaqueTyOrigin::FnReturn(def_id)
67+
if def_id == impl_m.def_id.expect_local()
68+
)
69+
}) {
70+
return Err(report_mismatched_rpitit_signature(
71+
tcx,
72+
trait_m_sig,
73+
trait_m.def_id,
74+
impl_m.def_id,
75+
None,
76+
));
77+
}
78+
79+
bounds_to_prove.extend(
80+
tcx.explicit_item_bounds(opaque_ty.def_id)
81+
.iter_instantiated_copied(tcx, opaque_ty.args),
82+
);
83+
}
84+
85+
let infcx = tcx.infer_ctxt().build();
86+
let ocx = ObligationCtxt::new(&infcx);
87+
let param_env =
88+
tcx.param_env(impl_m.def_id).with_hidden_return_position_impl_trait_in_trait_tys();
89+
90+
ocx.register_obligations(
91+
bounds_to_prove.fold_with(&mut ReverseMapper { tcx, reverse_mapping }).into_iter().map(
92+
|(pred, span)| {
93+
Obligation::new(tcx, ObligationCause::dummy_with_span(span), param_env, pred)
94+
},
95+
),
96+
);
97+
98+
let errors = ocx.select_all_or_error();
99+
if !errors.is_empty() {
100+
let span = errors.first().unwrap().obligation.cause.span;
101+
return Err(report_mismatched_rpitit_signature(
102+
tcx,
103+
trait_m_sig,
104+
trait_m.def_id,
105+
impl_m.def_id,
106+
Some(span),
107+
));
108+
}
109+
110+
Ok(())
111+
}
112+
113+
struct ImplTraitInTraitCollector<'tcx> {
114+
tcx: TyCtxt<'tcx>,
115+
types: FxIndexMap<DefId, ty::GenericArgsRef<'tcx>>,
116+
}
117+
118+
impl<'tcx> TypeVisitor<TyCtxt<'tcx>> for ImplTraitInTraitCollector<'tcx> {
119+
type BreakTy = !;
120+
121+
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
122+
if let ty::Alias(ty::Projection, proj) = *ty.kind()
123+
&& self.tcx.is_impl_trait_in_trait(proj.def_id)
124+
{
125+
if self.types.insert(proj.def_id, proj.args).is_none() {
126+
for (pred, _) in self
127+
.tcx
128+
.explicit_item_bounds(proj.def_id)
129+
.iter_instantiated_copied(self.tcx, proj.args)
130+
{
131+
pred.visit_with(self)?;
132+
}
133+
}
134+
ControlFlow::Continue(())
135+
} else {
136+
ty.super_visit_with(self)
137+
}
138+
}
139+
}
140+
141+
struct ReverseMapper<'tcx> {
142+
tcx: TyCtxt<'tcx>,
143+
reverse_mapping: FxIndexMap<Ty<'tcx>, Ty<'tcx>>,
144+
}
145+
146+
impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ReverseMapper<'tcx> {
147+
fn interner(&self) -> TyCtxt<'tcx> {
148+
self.tcx
149+
}
150+
151+
fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
152+
if let Some(ty) = self.reverse_mapping.get(&ty) { *ty } else { ty.super_fold_with(self) }
153+
}
154+
}
155+
156+
fn report_mismatched_rpitit_signature<'tcx>(
157+
tcx: TyCtxt<'tcx>,
158+
trait_m_sig: ty::FnSig<'tcx>,
159+
trait_m_def_id: DefId,
160+
impl_m_def_id: DefId,
161+
unmatched_bound: Option<Span>,
162+
) -> ErrorGuaranteed {
163+
let mapping = std::iter::zip(
164+
tcx.fn_sig(trait_m_def_id).skip_binder().bound_vars(),
165+
tcx.fn_sig(impl_m_def_id).skip_binder().bound_vars(),
166+
)
167+
.filter_map(|(impl_bv, trait_bv)| {
168+
if let ty::BoundVariableKind::Region(impl_bv) = impl_bv
169+
&& let ty::BoundVariableKind::Region(trait_bv) = trait_bv
170+
{
171+
Some((impl_bv, trait_bv))
172+
} else {
173+
None
174+
}
175+
})
176+
.collect();
177+
178+
let return_ty =
179+
trait_m_sig.output().fold_with(&mut super::RemapLateBound { tcx, mapping: &mapping });
180+
181+
let (span, impl_return_span, sugg) =
182+
match tcx.hir().get_by_def_id(impl_m_def_id.expect_local()).fn_decl().unwrap().output {
183+
hir::FnRetTy::DefaultReturn(span) => {
184+
(tcx.def_span(impl_m_def_id), span, format!("-> {return_ty} "))
185+
}
186+
hir::FnRetTy::Return(ty) => (ty.span, ty.span, format!("{return_ty}")),
187+
};
188+
let trait_return_span =
189+
tcx.hir().get_if_local(trait_m_def_id).map(|node| match node.fn_decl().unwrap().output {
190+
hir::FnRetTy::DefaultReturn(_) => tcx.def_span(trait_m_def_id),
191+
hir::FnRetTy::Return(ty) => ty.span,
192+
});
193+
194+
tcx.sess.emit_err(crate::errors::ReturnPositionImplTraitInTraitRefined {
195+
span,
196+
impl_return_span,
197+
trait_return_span,
198+
sugg,
199+
unmatched_bound,
200+
})
201+
}

compiler/rustc_hir_analysis/src/errors.rs

+14
Original file line numberDiff line numberDiff line change
@@ -918,3 +918,17 @@ pub struct UnusedAssociatedTypeBounds {
918918
#[suggestion(code = "")]
919919
pub span: Span,
920920
}
921+
922+
#[derive(Diagnostic)]
923+
#[diag(hir_analysis_rpitit_refined)]
924+
pub(crate) struct ReturnPositionImplTraitInTraitRefined {
925+
#[primary_span]
926+
pub span: Span,
927+
#[suggestion(applicability = "maybe-incorrect", code = "{sugg}")]
928+
pub impl_return_span: Span,
929+
#[label]
930+
pub trait_return_span: Option<Span>,
931+
pub sugg: String,
932+
#[label(hir_analysis_unmatched_bound_label)]
933+
pub unmatched_bound: Option<Span>,
934+
}

compiler/rustc_middle/src/traits/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ pub enum Reveal {
6464
/// type-checking.
6565
UserFacing,
6666

67+
// Same as user-facing reveal, but do not project ("reveal") return-position
68+
// impl trait in traits. This is only used for checking that an RPITIT is not
69+
// refined by an implementation.
70+
HideReturnPositionImplTraitInTrait,
71+
6772
/// At codegen time, all monomorphic projections will succeed.
6873
/// Also, `impl Trait` is normalized to the concrete type,
6974
/// which has to be already collected by type-checking.

compiler/rustc_middle/src/ty/mod.rs

+21-3
Original file line numberDiff line numberDiff line change
@@ -1662,8 +1662,15 @@ struct ParamTag {
16621662

16631663
impl_tag! {
16641664
impl Tag for ParamTag;
1665-
ParamTag { reveal: traits::Reveal::UserFacing },
1666-
ParamTag { reveal: traits::Reveal::All },
1665+
ParamTag {
1666+
reveal: traits::Reveal::UserFacing,
1667+
},
1668+
ParamTag {
1669+
reveal: traits::Reveal::All,
1670+
},
1671+
ParamTag {
1672+
reveal: traits::Reveal::HideReturnPositionImplTraitInTrait,
1673+
},
16671674
}
16681675

16691676
impl<'tcx> fmt::Debug for ParamEnv<'tcx> {
@@ -1767,6 +1774,15 @@ impl<'tcx> ParamEnv<'tcx> {
17671774
Self::new(List::empty(), self.reveal())
17681775
}
17691776

1777+
#[inline]
1778+
pub fn with_hidden_return_position_impl_trait_in_trait_tys(self) -> Self {
1779+
Self::new(
1780+
self.caller_bounds(),
1781+
Reveal::HideReturnPositionImplTraitInTrait,
1782+
self.constness(),
1783+
)
1784+
}
1785+
17701786
/// Creates a suitable environment in which to perform trait
17711787
/// queries on the given value. When type-checking, this is simply
17721788
/// the pair of the environment plus value. But when reveal is set to
@@ -1781,7 +1797,9 @@ impl<'tcx> ParamEnv<'tcx> {
17811797
/// although the surrounding function is never reachable.
17821798
pub fn and<T: TypeVisitable<TyCtxt<'tcx>>>(self, value: T) -> ParamEnvAnd<'tcx, T> {
17831799
match self.reveal() {
1784-
Reveal::UserFacing => ParamEnvAnd { param_env: self, value },
1800+
Reveal::UserFacing | Reveal::HideReturnPositionImplTraitInTrait => {
1801+
ParamEnvAnd { param_env: self, value }
1802+
}
17851803

17861804
Reveal::All => {
17871805
if value.is_global() {

compiler/rustc_trait_selection/src/solve/assembly/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
741741
self.merge_candidates(param_env_candidates)
742742
}
743743
ty::Alias(ty::Opaque, _opaque_ty) => match goal.param_env.reveal() {
744-
Reveal::UserFacing => {
744+
Reveal::UserFacing | Reveal::HideReturnPositionImplTraitInTrait => {
745745
self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
746746
}
747747
Reveal::All => return Err(NoSolution),

compiler/rustc_trait_selection/src/solve/opaques.rs

+2
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
8080
self.eq(goal.param_env, expected, actual)?;
8181
self.evaluate_added_goals_and_make_canonical_response(Certainty::Yes)
8282
}
83+
(Reveal::HideReturnPositionImplTraitInTrait, SolverMode::Normal) => todo!(),
84+
(Reveal::HideReturnPositionImplTraitInTrait, SolverMode::Coherence) => todo!(),
8385
}
8486
}
8587
}

compiler/rustc_trait_selection/src/traits/project.rs

+11-2
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ pub(crate) fn needs_normalization<'tcx, T: TypeVisitable<TyCtxt<'tcx>>>(
412412
reveal: Reveal,
413413
) -> bool {
414414
match reveal {
415-
Reveal::UserFacing => value.has_type_flags(
415+
Reveal::UserFacing | Reveal::HideReturnPositionImplTraitInTrait => value.has_type_flags(
416416
ty::TypeFlags::HAS_TY_PROJECTION
417417
| ty::TypeFlags::HAS_TY_INHERENT
418418
| ty::TypeFlags::HAS_CT_PROJECTION,
@@ -546,7 +546,9 @@ impl<'a, 'b, 'tcx> TypeFolder<TyCtxt<'tcx>> for AssocTypeNormalizer<'a, 'b, 'tcx
546546
ty::Opaque => {
547547
// Only normalize `impl Trait` outside of type inference, usually in codegen.
548548
match self.param_env.reveal() {
549-
Reveal::UserFacing => ty.super_fold_with(self),
549+
Reveal::UserFacing | Reveal::HideReturnPositionImplTraitInTrait => {
550+
ty.super_fold_with(self)
551+
}
550552

551553
Reveal::All => {
552554
let recursion_limit = self.interner().recursion_limit();
@@ -1699,6 +1701,13 @@ fn assemble_candidates_from_impls<'cx, 'tcx>(
16991701
obligation: &ProjectionTyObligation<'tcx>,
17001702
candidate_set: &mut ProjectionCandidateSet<'tcx>,
17011703
) {
1704+
// Don't reveal RPITIT if we are checking RPITIT refines.
1705+
if selcx.tcx().is_impl_trait_in_trait(obligation.predicate.def_id)
1706+
&& obligation.param_env.reveal() == Reveal::HideReturnPositionImplTraitInTrait
1707+
{
1708+
return;
1709+
}
1710+
17021711
// If we are resolving `<T as TraitRef<...>>::Item == Type`,
17031712
// start out by selecting the predicate `T as TraitRef<...>`:
17041713
let trait_ref = obligation.predicate.trait_ref(selcx.tcx());

compiler/rustc_trait_selection/src/traits/query/normalize.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ impl<'cx, 'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for QueryNormalizer<'cx, 'tcx>
225225
ty::Opaque => {
226226
// Only normalize `impl Trait` outside of type inference, usually in codegen.
227227
match self.param_env.reveal() {
228-
Reveal::UserFacing => ty.try_super_fold_with(self)?,
228+
Reveal::UserFacing | Reveal::HideReturnPositionImplTraitInTrait => {
229+
ty.try_super_fold_with(self)?
230+
}
229231

230232
Reveal::All => {
231233
let args = data.args.try_fold_with(self)?;

0 commit comments

Comments
 (0)