Skip to content

Commit fad3636

Browse files
committed
build a search tree during trait solving
1 parent ecd65f8 commit fad3636

File tree

4 files changed

+198
-16
lines changed

4 files changed

+198
-16
lines changed

compiler/rustc_type_ir/src/search_graph/global_cache.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,14 @@ impl<X: Cx> GlobalCache<X> {
4747
evaluation_result: EvaluationResult<X>,
4848
dep_node: X::DepNodeIndex,
4949
) {
50-
let EvaluationResult { encountered_overflow, required_depth, heads, nested_goals, result } =
51-
evaluation_result;
50+
let EvaluationResult {
51+
node_id: _,
52+
encountered_overflow,
53+
required_depth,
54+
heads,
55+
nested_goals,
56+
result,
57+
} = evaluation_result;
5258
debug_assert!(heads.is_empty());
5359
let result = cx.mk_tracked(result, dep_node);
5460
let entry = self.map.entry(input).or_default();

compiler/rustc_type_ir/src/search_graph/mod.rs

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ use crate::data_structures::HashMap;
2828
mod stack;
2929
use stack::{Stack, StackDepth, StackEntry};
3030
mod global_cache;
31+
mod tree;
3132
use global_cache::CacheData;
3233
pub use global_cache::GlobalCache;
34+
use tree::SearchTree;
3335

3436
/// The search graph does not simply use `Interner` directly
3537
/// to enable its fuzzing without having to stub the rest of
@@ -436,6 +438,7 @@ impl<X: Cx> NestedGoals<X> {
436438
/// goals still on the stack.
437439
#[derive_where(Debug; X: Cx)]
438440
struct ProvisionalCacheEntry<X: Cx> {
441+
entry_node_id: tree::NodeId,
439442
/// Whether evaluating the goal encountered overflow. This is used to
440443
/// disable the cache entry except if the last goal on the stack is
441444
/// already involved in this cycle.
@@ -459,6 +462,7 @@ struct ProvisionalCacheEntry<X: Cx> {
459462
/// evaluation.
460463
#[derive_where(Debug; X: Cx)]
461464
struct EvaluationResult<X: Cx> {
465+
node_id: tree::NodeId,
462466
encountered_overflow: bool,
463467
required_depth: usize,
464468
heads: CycleHeads,
@@ -479,7 +483,8 @@ impl<X: Cx> EvaluationResult<X> {
479483
required_depth: final_entry.required_depth,
480484
heads: final_entry.heads,
481485
nested_goals: final_entry.nested_goals,
482-
// We only care about the final result.
486+
// We only care about the result and the `node_id` of the final iteration.
487+
node_id: final_entry.node_id,
483488
result,
484489
}
485490
}
@@ -497,6 +502,8 @@ pub struct SearchGraph<D: Delegate<Cx = X>, X: Cx = <D as Delegate>::Cx> {
497502
/// is only valid until the result of one of its cycle heads changes.
498503
provisional_cache: HashMap<X::Input, Vec<ProvisionalCacheEntry<X>>>,
499504

505+
tree: SearchTree<X>,
506+
500507
_marker: PhantomData<D>,
501508
}
502509

@@ -520,6 +527,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
520527
root_depth: AvailableDepth(root_depth),
521528
stack: Default::default(),
522529
provisional_cache: Default::default(),
530+
tree: Default::default(),
523531
_marker: PhantomData,
524532
}
525533
}
@@ -605,6 +613,9 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
605613
return self.handle_overflow(cx, input, inspect);
606614
};
607615

616+
let node_id =
617+
self.tree.create_node(&self.stack, input, step_kind_from_parent, available_depth);
618+
608619
// We check the provisional cache before checking the global cache. This simplifies
609620
// the implementation as we can avoid worrying about cases where both the global and
610621
// provisional cache may apply, e.g. consider the following example
@@ -613,7 +624,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
613624
// - A
614625
// - BA cycle
615626
// - CB :x:
616-
if let Some(result) = self.lookup_provisional_cache(input, step_kind_from_parent) {
627+
if let Some(result) = self.lookup_provisional_cache(node_id, input, step_kind_from_parent) {
617628
return result;
618629
}
619630

@@ -630,7 +641,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
630641
.inspect(|expected| debug!(?expected, "validate cache entry"))
631642
.map(|r| (scope, r))
632643
} else if let Some(result) =
633-
self.lookup_global_cache(cx, input, step_kind_from_parent, available_depth)
644+
self.lookup_global_cache(cx, node_id, input, step_kind_from_parent, available_depth)
634645
{
635646
return result;
636647
} else {
@@ -641,13 +652,14 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
641652
// avoid iterating over the stack in case a goal has already been computed.
642653
// This may not have an actual performance impact and we could reorder them
643654
// as it may reduce the number of `nested_goals` we need to track.
644-
if let Some(result) = self.check_cycle_on_stack(cx, input, step_kind_from_parent) {
655+
if let Some(result) = self.check_cycle_on_stack(cx, node_id, input, step_kind_from_parent) {
645656
debug_assert!(validate_cache.is_none(), "global cache and cycle on stack: {input:?}");
646657
return result;
647658
}
648659

649660
// Unfortunate, it looks like we actually have to compute this goal.
650661
self.stack.push(StackEntry {
662+
node_id,
651663
input,
652664
step_kind_from_parent,
653665
available_depth,
@@ -694,6 +706,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
694706
debug_assert!(validate_cache.is_none(), "unexpected non-root: {input:?}");
695707
let entry = self.provisional_cache.entry(input).or_default();
696708
let EvaluationResult {
709+
node_id,
697710
encountered_overflow,
698711
required_depth: _,
699712
heads,
@@ -705,8 +718,13 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
705718
step_kind_from_parent,
706719
heads.highest_cycle_head(),
707720
);
708-
let provisional_cache_entry =
709-
ProvisionalCacheEntry { encountered_overflow, heads, path_from_head, result };
721+
let provisional_cache_entry = ProvisionalCacheEntry {
722+
entry_node_id: node_id,
723+
encountered_overflow,
724+
heads,
725+
path_from_head,
726+
result,
727+
};
710728
debug!(?provisional_cache_entry);
711729
entry.push(provisional_cache_entry);
712730
} else {
@@ -780,6 +798,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
780798
self.provisional_cache.retain(|&input, entries| {
781799
entries.retain_mut(|entry| {
782800
let ProvisionalCacheEntry {
801+
entry_node_id: _,
783802
encountered_overflow: _,
784803
heads,
785804
path_from_head,
@@ -831,6 +850,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
831850

832851
fn lookup_provisional_cache(
833852
&mut self,
853+
node_id: tree::NodeId,
834854
input: X::Input,
835855
step_kind_from_parent: PathKind,
836856
) -> Option<X::Result> {
@@ -839,8 +859,13 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
839859
}
840860

841861
let entries = self.provisional_cache.get(&input)?;
842-
for &ProvisionalCacheEntry { encountered_overflow, ref heads, path_from_head, result } in
843-
entries
862+
for &ProvisionalCacheEntry {
863+
entry_node_id,
864+
encountered_overflow,
865+
ref heads,
866+
path_from_head,
867+
result,
868+
} in entries
844869
{
845870
let head = heads.highest_cycle_head();
846871
if encountered_overflow {
@@ -872,6 +897,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
872897
);
873898
debug_assert!(self.stack[head].has_been_used.is_some());
874899
debug!(?head, ?path_from_head, "provisional cache hit");
900+
self.tree.provisional_cache_hit(node_id, entry_node_id);
875901
return Some(result);
876902
}
877903
}
@@ -912,6 +938,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
912938
// A provisional cache entry is applicable if the path to
913939
// its highest cycle head is equal to the expected path.
914940
for &ProvisionalCacheEntry {
941+
entry_node_id: _,
915942
encountered_overflow,
916943
ref heads,
917944
path_from_head: head_to_provisional,
@@ -970,6 +997,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
970997
fn lookup_global_cache(
971998
&mut self,
972999
cx: X,
1000+
node_id: tree::NodeId,
9731001
input: X::Input,
9741002
step_kind_from_parent: PathKind,
9751003
available_depth: AvailableDepth,
@@ -993,13 +1021,15 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
9931021
);
9941022

9951023
debug!(?required_depth, "global cache hit");
1024+
self.tree.global_cache_hit(node_id);
9961025
Some(result)
9971026
})
9981027
}
9991028

10001029
fn check_cycle_on_stack(
10011030
&mut self,
10021031
cx: X,
1032+
node_id: tree::NodeId,
10031033
input: X::Input,
10041034
step_kind_from_parent: PathKind,
10051035
) -> Option<X::Result> {
@@ -1030,11 +1060,11 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
10301060

10311061
// Return the provisional result or, if we're in the first iteration,
10321062
// start with no constraints.
1033-
if let Some(result) = self.stack[head].provisional_result {
1034-
Some(result)
1035-
} else {
1036-
Some(D::initial_provisional_result(cx, path_kind, input))
1037-
}
1063+
let result = self.stack[head]
1064+
.provisional_result
1065+
.unwrap_or_else(|| D::initial_provisional_result(cx, path_kind, input));
1066+
self.tree.cycle_on_stack(node_id, self.stack[head].node_id, result);
1067+
Some(result)
10381068
}
10391069

10401070
/// Whether we've reached a fixpoint when evaluating a cycle head.
@@ -1077,6 +1107,15 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
10771107
let stack_entry = self.stack.pop();
10781108
encountered_overflow |= stack_entry.encountered_overflow;
10791109
debug_assert_eq!(stack_entry.input, input);
1110+
// FIXME: Cloning the cycle heads here is quite ass. We should make cycle heads
1111+
// CoW and use reference counting.
1112+
self.tree.finish_evaluate(
1113+
stack_entry.node_id,
1114+
stack_entry.provisional_result,
1115+
stack_entry.encountered_overflow,
1116+
stack_entry.heads.clone(),
1117+
result,
1118+
);
10801119

10811120
// If the current goal is not the root of a cycle, we are done.
10821121
//
@@ -1137,7 +1176,14 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
11371176
self.clear_dependent_provisional_results();
11381177

11391178
debug!(?result, "fixpoint changed provisional results");
1179+
let node_id = self.tree.create_node(
1180+
&self.stack,
1181+
stack_entry.input,
1182+
stack_entry.step_kind_from_parent,
1183+
stack_entry.available_depth,
1184+
);
11401185
self.stack.push(StackEntry {
1186+
node_id,
11411187
input,
11421188
step_kind_from_parent: stack_entry.step_kind_from_parent,
11431189
available_depth: stack_entry.available_depth,

compiler/rustc_type_ir/src/search_graph/stack.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::ops::{Index, IndexMut};
33
use derive_where::derive_where;
44
use rustc_index::IndexVec;
55

6-
use super::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind};
6+
use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind, tree};
77

88
rustc_index::newtype_index! {
99
#[orderable]
@@ -15,6 +15,8 @@ rustc_index::newtype_index! {
1515
/// when popping a child goal or completely immutable.
1616
#[derive_where(Debug; X: Cx)]
1717
pub(super) struct StackEntry<X: Cx> {
18+
pub node_id: tree::NodeId,
19+
1820
pub input: X::Input,
1921

2022
/// Whether proving this goal is a coinductive step.

0 commit comments

Comments
 (0)