Skip to content

add additional TypeFlags fast paths #141581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/rustc_infer/src/infer/canonical/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
}

fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c }
}
}

impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_infer/src/infer/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
ct.super_fold_with(self)
}
}

fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if !p.has_non_region_infer() { p } else { p.super_fold_with(self) }
}

fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if !c.has_non_region_infer() { c } else { c.super_fold_with(self) }
}
}

/// The opportunistic region resolver opportunistically resolves regions
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_middle/src/ty/erase_regions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,12 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for RegionEraserVisitor<'tcx> {
p
}
}

fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) {
c.super_fold_with(self)
} else {
c
}
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ where
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
}

fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c }
}
}

impl<'tcx> TyCtxt<'tcx> {
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_middle/src/ty/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ impl<'tcx> Clause<'tcx> {
}
}

impl<'tcx> rustc_type_ir::inherent::Clauses<TyCtxt<'tcx>> for ty::Clauses<'tcx> {}

#[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)]
impl<'tcx> ExistentialPredicate<'tcx> {
/// Compares via an ordering that will not change if modules are reordered or other changes are
Expand Down
27 changes: 26 additions & 1 deletion compiler/rustc_middle/src/ty/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,19 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clause<'tcx> {
}
}

impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
folder.try_fold_clauses(self)
}

fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
folder.fold_clauses(self)
}
}

impl<'tcx> TypeVisitable<TyCtxt<'tcx>> for ty::Predicate<'tcx> {
fn visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::Result {
visitor.visit_predicate(*self)
Expand Down Expand Up @@ -615,6 +628,19 @@ impl<'tcx> TypeSuperVisitable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
}
}

impl<'tcx> TypeSuperFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
fn try_super_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
folder: &mut F,
) -> Result<Self, F::Error> {
ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
}

fn super_fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
}
}

impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Const<'tcx> {
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
self,
Expand Down Expand Up @@ -775,7 +801,6 @@ macro_rules! list_fold {
}

list_fold! {
ty::Clauses<'tcx> : mk_clauses,
&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
&'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
&'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_next_trait_solver/src/canonicalizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,4 +572,15 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
}

fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
match self.canonicalize_mode {
CanonicalizeMode::Input { keep_static: true }
| CanonicalizeMode::Response { max_input_universe: _ } => {}
CanonicalizeMode::Input { keep_static: false } => {
panic!("erasing 'static in env")
}
}
if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }
}
}
20 changes: 18 additions & 2 deletions compiler/rustc_next_trait_solver/src/resolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::delegate::SolverDelegate;
// EAGER RESOLUTION

/// Resolves ty, region, and const vars to their inferred values or their root vars.
pub struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
struct EagerResolver<'a, D, I = <D as SolverDelegate>::Interner>
where
D: SolverDelegate<Interner = I>,
I: Interner,
Expand All @@ -22,8 +22,20 @@ where
cache: DelayedMap<I::Ty, I::Ty>,
}

pub fn eager_resolve_vars<D: SolverDelegate, T: TypeFoldable<D::Interner>>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we actually check for has_infer in fold_clauses, this change may be unnecessary again 🤔

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷 idc

delegate: &D,
value: T,
) -> T {
if value.has_infer() {
let mut folder = EagerResolver::new(delegate);
value.fold_with(&mut folder)
} else {
value
}
}

impl<'a, D: SolverDelegate> EagerResolver<'a, D> {
pub fn new(delegate: &'a D) -> Self {
fn new(delegate: &'a D) -> Self {
EagerResolver { delegate, cache: Default::default() }
}
}
Expand Down Expand Up @@ -90,4 +102,8 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for EagerResolv
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_infer() { p.super_fold_with(self) } else { p }
}

fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
if c.has_infer() { c.super_fold_with(self) } else { c }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use tracing::{debug, instrument, trace};

use crate::canonicalizer::Canonicalizer;
use crate::delegate::SolverDelegate;
use crate::resolve::EagerResolver;
use crate::resolve::eager_resolve_vars;
use crate::solve::eval_ctxt::CurrentGoalKind;
use crate::solve::{
CanonicalInput, CanonicalResponse, Certainty, EvalCtxt, ExternalConstraintsData, Goal,
Expand Down Expand Up @@ -61,8 +61,7 @@ where
// so we only canonicalize the lookup table and ignore
// duplicate entries.
let opaque_types = self.delegate.clone_opaque_types_lookup_table();
let (goal, opaque_types) =
(goal, opaque_types).fold_with(&mut EagerResolver::new(self.delegate));
let (goal, opaque_types) = eager_resolve_vars(self.delegate, (goal, opaque_types));

let mut orig_values = Default::default();
let canonical = Canonicalizer::canonicalize_input(
Expand Down Expand Up @@ -157,8 +156,8 @@ where

let external_constraints =
self.compute_external_query_constraints(certainty, normalization_nested_goals);
let (var_values, mut external_constraints) = (self.var_values, external_constraints)
.fold_with(&mut EagerResolver::new(self.delegate));
let (var_values, mut external_constraints) =
eager_resolve_vars(self.delegate, (self.var_values, external_constraints));

// Remove any trivial or duplicated region constraints once we've resolved regions
let mut unique = HashSet::default();
Expand Down Expand Up @@ -469,7 +468,7 @@ where
{
let var_values = CanonicalVarValues { var_values: delegate.cx().mk_args(var_values) };
let state = inspect::State { var_values, data };
let state = state.fold_with(&mut EagerResolver::new(delegate));
let state = eager_resolve_vars(delegate, state);
Canonicalizer::canonicalize_response(delegate, max_input_universe, &mut vec![], state)
}

Expand Down
16 changes: 16 additions & 0 deletions compiler/rustc_next_trait_solver/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,22 @@ where
}
}
}

fn visit_predicate(&mut self, p: I::Predicate) -> Self::Result {
if p.has_non_region_infer() || p.has_placeholders() {
p.super_visit_with(self)
} else {
ControlFlow::Continue(())
}
}

fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
if c.has_non_region_infer() || c.has_placeholders() {
c.super_visit_with(self)
} else {
ControlFlow::Continue(())
}
}
}

let mut visitor = ContainsTermOrNotNameable {
Expand Down
9 changes: 4 additions & 5 deletions compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
use rustc_macros::extension;
use rustc_middle::traits::ObligationCause;
use rustc_middle::traits::solve::{Certainty, Goal, GoalSource, NoSolution, QueryResult};
use rustc_middle::ty::{TyCtxt, TypeFoldable, VisitorResult, try_visit};
use rustc_middle::ty::{TyCtxt, VisitorResult, try_visit};
use rustc_middle::{bug, ty};
use rustc_next_trait_solver::resolve::EagerResolver;
use rustc_next_trait_solver::resolve::eager_resolve_vars;
use rustc_next_trait_solver::solve::inspect::{self, instantiate_canonical_state};
use rustc_next_trait_solver::solve::{GenerateProofTree, MaybeCause, SolverDelegateEvalExt as _};
use rustc_span::{DUMMY_SP, Span};
Expand Down Expand Up @@ -187,8 +187,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
let _ = term_hack.constrain(infcx, span, param_env);
}

let opt_impl_args =
opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx)));
let opt_impl_args = opt_impl_args.map(|impl_args| eager_resolve_vars(infcx, impl_args));

let goals = instantiated_goals
.into_iter()
Expand Down Expand Up @@ -392,7 +391,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
infcx,
depth,
orig_values,
goal: uncanonicalized_goal.fold_with(&mut EagerResolver::new(infcx)),
goal: eager_resolve_vars(infcx, uncanonicalized_goal),
result,
evaluation_kind: evaluation.kind,
normalizes_to_term_hack,
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_type_ir/src/binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,14 @@ impl<'a, I: Interner> TypeFolder<I> for ArgFolder<'a, I> {
c.super_fold_with(self)
}
}

fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_param() { p.super_fold_with(self) } else { p }
}

fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
if c.has_param() { c.super_fold_with(self) } else { c }
}
}

impl<'a, I: Interner> ArgFolder<'a, I> {
Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_type_ir/src/fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ pub trait TypeFolder<I: Interner>: Sized {
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
p.super_fold_with(self)
}

fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
c.super_fold_with(self)
}
}

/// This trait is implemented for every folding traversal. There is a fold
Expand Down Expand Up @@ -190,6 +194,10 @@ pub trait FallibleTypeFolder<I: Interner>: Sized {
fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
p.try_super_fold_with(self)
}

fn try_fold_clauses(&mut self, c: I::Clauses) -> Result<I::Clauses, Self::Error> {
c.try_super_fold_with(self)
}
}

///////////////////////////////////////////////////////////////////////////
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_type_ir/src/inherent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,18 @@ pub trait Clause<I: Interner<Clause = Self>>:
fn instantiate_supertrait(self, cx: I, trait_ref: ty::Binder<I, ty::TraitRef<I>>) -> Self;
}

pub trait Clauses<I: Interner<Clauses = Self>>:
Copy
+ Debug
+ Hash
+ Eq
+ TypeSuperVisitable<I>
+ TypeSuperFoldable<I>
+ Flags
+ SliceLike<Item = I::Clause>
{
}

/// Common capabilities of placeholder kinds
pub trait PlaceholderLike: Copy + Debug + Hash + Eq {
fn universe(self) -> ty::UniverseIndex;
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_type_ir/src/interner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::ir_print::IrPrint;
use crate::lang_items::TraitSolverLangItem;
use crate::relate::Relate;
use crate::solve::{CanonicalInput, ExternalConstraintsData, PredefinedOpaquesData, QueryResult};
use crate::visit::{Flags, TypeSuperVisitable, TypeVisitable};
use crate::visit::{Flags, TypeVisitable};
use crate::{self as ty, search_graph};

#[cfg_attr(feature = "nightly", rustc_diagnostic_item = "type_ir_interner")]
Expand Down Expand Up @@ -146,7 +146,7 @@ pub trait Interner:
type ParamEnv: ParamEnv<Self>;
type Predicate: Predicate<Self>;
type Clause: Clause<Self>;
type Clauses: Copy + Debug + Hash + Eq + TypeSuperVisitable<Self> + Flags;
type Clauses: Clauses<Self>;

fn with_global_cache<R>(self, f: impl FnOnce(&mut search_graph::GlobalCache<Self>) -> R) -> R;

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_type_ir/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ pub trait TypeVisitor<I: Interner>: Sized {
p.super_visit_with(self)
}

fn visit_clauses(&mut self, p: I::Clauses) -> Self::Result {
p.super_visit_with(self)
fn visit_clauses(&mut self, c: I::Clauses) -> Self::Result {
c.super_visit_with(self)
}

fn visit_error(&mut self, _guar: I::ErrorGuaranteed) -> Self::Result {
Expand Down
Loading