Skip to content

Commit ac15b82

Browse files
committed
Auto merge of #141451 - lcnr:canonicalize-env-cache, r=<try>
[perf] next-solver canonicalization + eager-resolve kinda hacky
2 parents 95a2212 + c3eaa13 commit ac15b82

File tree

18 files changed

+289
-44
lines changed

18 files changed

+289
-44
lines changed

compiler/rustc_infer/src/infer/canonical/canonicalizer.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,14 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Canonicalizer<'cx, 'tcx> {
493493
ct
494494
}
495495
}
496+
497+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
498+
if p.flags().intersects(self.needs_canonical_flags) { p.super_fold_with(self) } else { p }
499+
}
500+
501+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
502+
if c.flags().intersects(self.needs_canonical_flags) { c.super_fold_with(self) } else { c }
503+
}
496504
}
497505

498506
impl<'cx, 'tcx> Canonicalizer<'cx, 'tcx> {

compiler/rustc_infer/src/infer/resolve.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ impl<'a, 'tcx> TypeFolder<TyCtxt<'tcx>> for OpportunisticVarResolver<'a, 'tcx> {
5555
ct.super_fold_with(self)
5656
}
5757
}
58+
59+
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
60+
if !p.has_non_region_infer() { p } else { p.super_fold_with(self) }
61+
}
62+
63+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
64+
if !c.has_non_region_infer() { c } else { c.super_fold_with(self) }
65+
}
5866
}
5967

6068
/// The opportunistic region resolver opportunistically resolves regions

compiler/rustc_middle/src/ty/context.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,17 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
179179
f(&mut *self.new_solver_evaluation_cache.lock())
180180
}
181181

182+
fn canonical_param_env_cache_get_or_insert<R>(
183+
self,
184+
param_env: ty::ParamEnv<'tcx>,
185+
f: impl FnOnce() -> ty::CanonicalParamEnvCacheEntry<Self>,
186+
from_entry: impl FnOnce(&ty::CanonicalParamEnvCacheEntry<Self>) -> R,
187+
) -> R {
188+
let mut cache = self.new_solver_canonical_param_env_cache.lock();
189+
let entry = cache.entry(param_env).or_insert_with(f);
190+
from_entry(entry)
191+
}
192+
182193
fn evaluation_is_concurrent(&self) -> bool {
183194
self.sess.threads() > 1
184195
}
@@ -1444,6 +1455,8 @@ pub struct GlobalCtxt<'tcx> {
14441455

14451456
/// Caches the results of goal evaluation in the new solver.
14461457
pub new_solver_evaluation_cache: Lock<search_graph::GlobalCache<TyCtxt<'tcx>>>,
1458+
pub new_solver_canonical_param_env_cache:
1459+
Lock<FxHashMap<ty::ParamEnv<'tcx>, ty::CanonicalParamEnvCacheEntry<TyCtxt<'tcx>>>>,
14471460

14481461
pub canonical_param_env_cache: CanonicalParamEnvCache<'tcx>,
14491462

@@ -1692,6 +1705,7 @@ impl<'tcx> TyCtxt<'tcx> {
16921705
selection_cache: Default::default(),
16931706
evaluation_cache: Default::default(),
16941707
new_solver_evaluation_cache: Default::default(),
1708+
new_solver_canonical_param_env_cache: Default::default(),
16951709
canonical_param_env_cache: Default::default(),
16961710
data_layout,
16971711
alloc_map: interpret::AllocMap::new(),

compiler/rustc_middle/src/ty/erase_regions.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,12 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for RegionEraserVisitor<'tcx> {
8686
p
8787
}
8888
}
89+
90+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
91+
if c.has_type_flags(TypeFlags::HAS_BINDER_VARS | TypeFlags::HAS_FREE_REGIONS) {
92+
c.super_fold_with(self)
93+
} else {
94+
c
95+
}
96+
}
8997
}

compiler/rustc_middle/src/ty/fold.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ where
177177
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
178178
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
179179
}
180+
181+
fn fold_clauses(&mut self, c: ty::Clauses<'tcx>) -> ty::Clauses<'tcx> {
182+
if c.has_vars_bound_at_or_above(self.current_index) { c.super_fold_with(self) } else { c }
183+
}
180184
}
181185

182186
impl<'tcx> TyCtxt<'tcx> {

compiler/rustc_middle/src/ty/predicate.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ impl<'tcx> Clause<'tcx> {
238238
}
239239
}
240240

241+
impl<'tcx> rustc_type_ir::inherent::Clauses<TyCtxt<'tcx>> for ty::Clauses<'tcx> {}
242+
241243
#[extension(pub trait ExistentialPredicateStableCmpExt<'tcx>)]
242244
impl<'tcx> ExistentialPredicate<'tcx> {
243245
/// Compares via an ordering that will not change if modules are reordered or other changes are

compiler/rustc_middle/src/ty/structural_impls.rs

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,19 @@ impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clause<'tcx> {
570570
}
571571
}
572572

573+
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
574+
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
575+
self,
576+
folder: &mut F,
577+
) -> Result<Self, F::Error> {
578+
folder.try_fold_clauses(self)
579+
}
580+
581+
fn fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
582+
folder.fold_clauses(self)
583+
}
584+
}
585+
573586
impl<'tcx> TypeVisitable<TyCtxt<'tcx>> for ty::Predicate<'tcx> {
574587
fn visit_with<V: TypeVisitor<TyCtxt<'tcx>>>(&self, visitor: &mut V) -> V::Result {
575588
visitor.visit_predicate(*self)
@@ -615,6 +628,19 @@ impl<'tcx> TypeSuperVisitable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
615628
}
616629
}
617630

631+
impl<'tcx> TypeSuperFoldable<TyCtxt<'tcx>> for ty::Clauses<'tcx> {
632+
fn try_super_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
633+
self,
634+
folder: &mut F,
635+
) -> Result<Self, F::Error> {
636+
ty::util::try_fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
637+
}
638+
639+
fn super_fold_with<F: TypeFolder<TyCtxt<'tcx>>>(self, folder: &mut F) -> Self {
640+
ty::util::fold_list(self, folder, |tcx, v| tcx.mk_clauses(v))
641+
}
642+
}
643+
618644
impl<'tcx> TypeFoldable<TyCtxt<'tcx>> for ty::Const<'tcx> {
619645
fn try_fold_with<F: FallibleTypeFolder<TyCtxt<'tcx>>>(
620646
self,
@@ -775,7 +801,6 @@ macro_rules! list_fold {
775801
}
776802

777803
list_fold! {
778-
ty::Clauses<'tcx> : mk_clauses,
779804
&'tcx ty::List<ty::PolyExistentialPredicate<'tcx>> : mk_poly_existential_predicates,
780805
&'tcx ty::List<PlaceElem<'tcx>> : mk_place_elems,
781806
&'tcx ty::List<ty::Pattern<'tcx>> : mk_patterns,

compiler/rustc_next_trait_solver/src/canonicalizer.rs

Lines changed: 123 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,23 @@ use rustc_type_ir::data_structures::{HashMap, ensure_sufficient_stack};
44
use rustc_type_ir::inherent::*;
55
use rustc_type_ir::solve::{Goal, QueryInput};
66
use rustc_type_ir::{
7-
self as ty, Canonical, CanonicalTyVarKind, CanonicalVarKind, InferCtxtLike, Interner,
8-
TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
7+
self as ty, Canonical, CanonicalParamEnvCacheEntry, CanonicalTyVarKind, CanonicalVarKind,
8+
Flags, InferCtxtLike, Interner, TypeFlags, TypeFoldable, TypeFolder, TypeSuperFoldable,
9+
TypeVisitableExt,
910
};
1011

1112
use crate::delegate::SolverDelegate;
1213

14+
/// Does this have infer/placeholder/param, free regions or ReErased?
15+
const NEEDS_CANONICAL: TypeFlags = TypeFlags::from_bits(
16+
TypeFlags::HAS_INFER.bits()
17+
| TypeFlags::HAS_PLACEHOLDER.bits()
18+
| TypeFlags::HAS_PARAM.bits()
19+
| TypeFlags::HAS_FREE_REGIONS.bits()
20+
| TypeFlags::HAS_RE_ERASED.bits(),
21+
)
22+
.unwrap();
23+
1324
/// Whether we're canonicalizing a query input or the query response.
1425
///
1526
/// When canonicalizing an input we're in the context of the caller
@@ -79,13 +90,80 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
7990
cache: Default::default(),
8091
};
8192

82-
let value = value.fold_with(&mut canonicalizer);
93+
let value = if value.has_type_flags(NEEDS_CANONICAL) {
94+
value.fold_with(&mut canonicalizer)
95+
} else {
96+
value
97+
};
8398
assert!(!value.has_infer(), "unexpected infer in {value:?}");
8499
assert!(!value.has_placeholders(), "unexpected placeholders in {value:?}");
85100
let (max_universe, variables) = canonicalizer.finalize();
86101
Canonical { max_universe, variables, value }
87102
}
88103

104+
fn canonicalize_param_env(
105+
delegate: &'a D,
106+
variables: &'a mut Vec<I::GenericArg>,
107+
param_env: I::ParamEnv,
108+
) -> (I::ParamEnv, HashMap<I::GenericArg, usize>, Vec<CanonicalVarKind<I>>) {
109+
if !param_env.has_type_flags(NEEDS_CANONICAL) {
110+
return (param_env, Default::default(), Vec::new());
111+
}
112+
113+
if !param_env.has_non_region_infer() {
114+
delegate.cx().canonical_param_env_cache_get_or_insert(
115+
param_env,
116+
|| {
117+
let mut variables = Vec::new();
118+
let mut env_canonicalizer = Canonicalizer {
119+
delegate,
120+
canonicalize_mode: CanonicalizeMode::Input { keep_static: true },
121+
122+
variables: &mut variables,
123+
variable_lookup_table: Default::default(),
124+
var_kinds: Vec::new(),
125+
binder_index: ty::INNERMOST,
126+
127+
cache: Default::default(),
128+
};
129+
let param_env = param_env.fold_with(&mut env_canonicalizer);
130+
debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST);
131+
CanonicalParamEnvCacheEntry {
132+
param_env,
133+
variable_lookup_table: env_canonicalizer.variable_lookup_table,
134+
var_kinds: env_canonicalizer.var_kinds,
135+
variables,
136+
}
137+
},
138+
|&CanonicalParamEnvCacheEntry {
139+
param_env,
140+
variables: ref cache_variables,
141+
ref variable_lookup_table,
142+
ref var_kinds,
143+
}| {
144+
debug_assert!(variables.is_empty());
145+
variables.extend(cache_variables.iter().copied());
146+
(param_env, variable_lookup_table.clone(), var_kinds.clone())
147+
},
148+
)
149+
} else {
150+
let mut env_canonicalizer = Canonicalizer {
151+
delegate,
152+
canonicalize_mode: CanonicalizeMode::Input { keep_static: true },
153+
154+
variables,
155+
variable_lookup_table: Default::default(),
156+
var_kinds: Vec::new(),
157+
binder_index: ty::INNERMOST,
158+
159+
cache: Default::default(),
160+
};
161+
let param_env = param_env.fold_with(&mut env_canonicalizer);
162+
debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST);
163+
(param_env, env_canonicalizer.variable_lookup_table, env_canonicalizer.var_kinds)
164+
}
165+
}
166+
89167
/// When canonicalizing query inputs, we keep `'static` in the `param_env`
90168
/// but erase it everywhere else. We generally don't want to depend on region
91169
/// identity, so while it should not matter whether `'static` is kept in the
@@ -100,30 +178,17 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
100178
input: QueryInput<I, P>,
101179
) -> ty::Canonical<I, QueryInput<I, P>> {
102180
// First canonicalize the `param_env` while keeping `'static`
103-
let mut env_canonicalizer = Canonicalizer {
104-
delegate,
105-
canonicalize_mode: CanonicalizeMode::Input { keep_static: true },
106-
107-
variables,
108-
variable_lookup_table: Default::default(),
109-
var_kinds: Vec::new(),
110-
binder_index: ty::INNERMOST,
111-
112-
cache: Default::default(),
113-
};
114-
let param_env = input.goal.param_env.fold_with(&mut env_canonicalizer);
115-
debug_assert_eq!(env_canonicalizer.binder_index, ty::INNERMOST);
181+
let (param_env, variable_lookup_table, var_kinds) =
182+
Canonicalizer::canonicalize_param_env(delegate, variables, input.goal.param_env);
116183
// Then canonicalize the rest of the input without keeping `'static`
117184
// while *mostly* reusing the canonicalizer from above.
118185
let mut rest_canonicalizer = Canonicalizer {
119186
delegate,
120187
canonicalize_mode: CanonicalizeMode::Input { keep_static: false },
121188

122-
variables: env_canonicalizer.variables,
123-
// We're able to reuse the `variable_lookup_table` as whether or not
124-
// it already contains an entry for `'static` does not matter.
125-
variable_lookup_table: env_canonicalizer.variable_lookup_table,
126-
var_kinds: env_canonicalizer.var_kinds,
189+
variables,
190+
variable_lookup_table,
191+
var_kinds,
127192
binder_index: ty::INNERMOST,
128193

129194
// We do not reuse the cache as it may contain entries whose canonicalized
@@ -134,10 +199,22 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
134199
cache: Default::default(),
135200
};
136201

137-
let predicate = input.goal.predicate.fold_with(&mut rest_canonicalizer);
202+
let predicate = input.goal.predicate;
203+
let predicate = if predicate.has_type_flags(NEEDS_CANONICAL) {
204+
predicate.fold_with(&mut rest_canonicalizer)
205+
} else {
206+
predicate
207+
};
138208
let goal = Goal { param_env, predicate };
209+
210+
let predefined_opaques_in_body = input.predefined_opaques_in_body;
139211
let predefined_opaques_in_body =
140-
input.predefined_opaques_in_body.fold_with(&mut rest_canonicalizer);
212+
if input.predefined_opaques_in_body.has_type_flags(NEEDS_CANONICAL) {
213+
predefined_opaques_in_body.fold_with(&mut rest_canonicalizer)
214+
} else {
215+
predefined_opaques_in_body
216+
};
217+
141218
let value = QueryInput { goal, predefined_opaques_in_body };
142219

143220
assert!(!value.has_infer(), "unexpected infer in {value:?}");
@@ -387,7 +464,11 @@ impl<'a, D: SolverDelegate<Interner = I>, I: Interner> Canonicalizer<'a, D, I> {
387464
| ty::Alias(_, _)
388465
| ty::Bound(_, _)
389466
| ty::Error(_) => {
390-
return ensure_sufficient_stack(|| t.super_fold_with(self));
467+
return if t.has_type_flags(NEEDS_CANONICAL) {
468+
ensure_sufficient_stack(|| t.super_fold_with(self))
469+
} else {
470+
t
471+
};
391472
}
392473
};
393474

@@ -522,11 +603,28 @@ impl<D: SolverDelegate<Interner = I>, I: Interner> TypeFolder<I> for Canonicaliz
522603
| ty::ConstKind::Unevaluated(_)
523604
| ty::ConstKind::Value(_)
524605
| ty::ConstKind::Error(_)
525-
| ty::ConstKind::Expr(_) => return c.super_fold_with(self),
606+
| ty::ConstKind::Expr(_) => {
607+
return if c.has_type_flags(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c };
608+
}
526609
};
527610

528611
let var = self.get_or_insert_bound_var(c, kind);
529612

530613
Const::new_anon_bound(self.cx(), self.binder_index, var)
531614
}
615+
616+
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
617+
if p.flags().intersects(NEEDS_CANONICAL) { p.super_fold_with(self) } else { p }
618+
}
619+
620+
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
621+
match self.canonicalize_mode {
622+
CanonicalizeMode::Input { keep_static: true }
623+
| CanonicalizeMode::Response { max_input_universe: _ } => {}
624+
CanonicalizeMode::Input { keep_static: false } => {
625+
panic!("erasing 'static in env")
626+
}
627+
}
628+
if c.flags().intersects(NEEDS_CANONICAL) { c.super_fold_with(self) } else { c }
629+
}
532630
}

0 commit comments

Comments
 (0)