Skip to content

search graph: improve rebasing and add forced ambiguity support #143054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion compiler/rustc_next_trait_solver/src/solve/search_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ where
) -> QueryResult<I> {
match kind {
PathKind::Coinductive => response_no_constraints(cx, input, Certainty::Yes),
PathKind::Unknown => response_no_constraints(cx, input, Certainty::overflow(false)),
PathKind::Unknown | PathKind::ForcedAmbiguity => {
response_no_constraints(cx, input, Certainty::overflow(false))
}
// Even though we know these cycles to be unproductive, we still return
// overflow during coherence. This is both as we are not 100% confident in
// the implementation yet and any incorrect errors would be unsound there.
Expand Down
166 changes: 83 additions & 83 deletions compiler/rustc_type_ir/src/search_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ use std::marker::PhantomData;
use derive_where::derive_where;
#[cfg(feature = "nightly")]
use rustc_macros::{Decodable_NoContext, Encodable_NoContext, HashStable_NoContext};
use rustc_type_ir::data_structures::HashMap;
use tracing::{debug, instrument};

use crate::data_structures::HashMap;

mod stack;
use stack::{Stack, StackDepth, StackEntry};
mod global_cache;
Expand Down Expand Up @@ -137,6 +136,12 @@ pub enum PathKind {
Unknown,
/// A path with at least one coinductive step. Such cycles hold.
Coinductive,
/// A path which is treated as ambiguous. Once a path has this path kind
/// any other segment does not change its kind.
///
/// This is currently only used when fuzzing to support negative reasoning.
/// For more details, see #143054.
ForcedAmbiguity,
}

impl PathKind {
Expand All @@ -149,6 +154,9 @@ impl PathKind {
/// to `max(self, rest)`.
fn extend(self, rest: PathKind) -> PathKind {
match (self, rest) {
(PathKind::ForcedAmbiguity, _) | (_, PathKind::ForcedAmbiguity) => {
PathKind::ForcedAmbiguity
}
(PathKind::Coinductive, _) | (_, PathKind::Coinductive) => PathKind::Coinductive,
(PathKind::Unknown, _) | (_, PathKind::Unknown) => PathKind::Unknown,
(PathKind::Inductive, PathKind::Inductive) => PathKind::Inductive,
Expand Down Expand Up @@ -187,41 +195,6 @@ impl UsageKind {
}
}

/// For each goal we track whether the paths from this goal
/// to its cycle heads are coinductive.
///
/// This is a necessary condition to rebase provisional cache
/// entries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllPathsToHeadCoinductive {
Yes,
No,
}
impl From<PathKind> for AllPathsToHeadCoinductive {
fn from(path: PathKind) -> AllPathsToHeadCoinductive {
match path {
PathKind::Coinductive => AllPathsToHeadCoinductive::Yes,
_ => AllPathsToHeadCoinductive::No,
}
}
}
impl AllPathsToHeadCoinductive {
#[must_use]
fn merge(self, other: impl Into<Self>) -> Self {
match (self, other.into()) {
(AllPathsToHeadCoinductive::Yes, AllPathsToHeadCoinductive::Yes) => {
AllPathsToHeadCoinductive::Yes
}
(AllPathsToHeadCoinductive::No, _) | (_, AllPathsToHeadCoinductive::No) => {
AllPathsToHeadCoinductive::No
}
}
}
fn and_merge(&mut self, other: impl Into<Self>) {
*self = self.merge(other);
}
}

#[derive(Debug, Clone, Copy)]
struct AvailableDepth(usize);
impl AvailableDepth {
Expand Down Expand Up @@ -261,9 +234,9 @@ impl AvailableDepth {
///
/// We also track all paths from this goal to that head. This is necessary
/// when rebasing provisional cache results.
#[derive(Clone, Debug, PartialEq, Eq, Default)]
#[derive(Clone, Debug, Default)]
struct CycleHeads {
heads: BTreeMap<StackDepth, AllPathsToHeadCoinductive>,
heads: BTreeMap<StackDepth, PathsToNested>,
}

impl CycleHeads {
Expand All @@ -283,27 +256,16 @@ impl CycleHeads {
self.heads.first_key_value().map(|(k, _)| *k)
}

fn remove_highest_cycle_head(&mut self) {
fn remove_highest_cycle_head(&mut self) -> PathsToNested {
let last = self.heads.pop_last();
debug_assert_ne!(last, None);
last.unwrap().1
}

fn insert(
&mut self,
head: StackDepth,
path_from_entry: impl Into<AllPathsToHeadCoinductive> + Copy,
) {
self.heads.entry(head).or_insert(path_from_entry.into()).and_merge(path_from_entry);
fn insert(&mut self, head: StackDepth, path_from_entry: impl Into<PathsToNested> + Copy) {
*self.heads.entry(head).or_insert(path_from_entry.into()) |= path_from_entry.into();
}

fn merge(&mut self, heads: &CycleHeads) {
for (&head, &path_from_entry) in heads.heads.iter() {
self.insert(head, path_from_entry);
debug_assert!(matches!(self.heads[&head], AllPathsToHeadCoinductive::Yes));
}
}

fn iter(&self) -> impl Iterator<Item = (StackDepth, AllPathsToHeadCoinductive)> + '_ {
fn iter(&self) -> impl Iterator<Item = (StackDepth, PathsToNested)> + '_ {
self.heads.iter().map(|(k, v)| (*k, *v))
}

Expand All @@ -317,13 +279,7 @@ impl CycleHeads {
Ordering::Equal => continue,
Ordering::Greater => unreachable!(),
}

let path_from_entry = match step_kind {
PathKind::Coinductive => AllPathsToHeadCoinductive::Yes,
PathKind::Unknown | PathKind::Inductive => path_from_entry,
};

self.insert(head, path_from_entry);
self.insert(head, path_from_entry.extend_with(step_kind));
}
}
}
Expand All @@ -332,13 +288,14 @@ bitflags::bitflags! {
/// Tracks how nested goals have been accessed. This is necessary to disable
/// global cache entries if computing them would otherwise result in a cycle or
/// access a provisional cache entry.
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PathsToNested: u8 {
/// The initial value when adding a goal to its own nested goals.
const EMPTY = 1 << 0;
const INDUCTIVE = 1 << 1;
const UNKNOWN = 1 << 2;
const COINDUCTIVE = 1 << 3;
const FORCED_AMBIGUITY = 1 << 4;
}
}
impl From<PathKind> for PathsToNested {
Expand All @@ -347,6 +304,7 @@ impl From<PathKind> for PathsToNested {
PathKind::Inductive => PathsToNested::INDUCTIVE,
PathKind::Unknown => PathsToNested::UNKNOWN,
PathKind::Coinductive => PathsToNested::COINDUCTIVE,
PathKind::ForcedAmbiguity => PathsToNested::FORCED_AMBIGUITY,
}
}
}
Expand Down Expand Up @@ -379,10 +337,45 @@ impl PathsToNested {
self.insert(PathsToNested::COINDUCTIVE);
}
}
PathKind::ForcedAmbiguity => {
if self.intersects(
PathsToNested::EMPTY
| PathsToNested::INDUCTIVE
| PathsToNested::UNKNOWN
| PathsToNested::COINDUCTIVE,
) {
self.remove(
PathsToNested::EMPTY
| PathsToNested::INDUCTIVE
| PathsToNested::UNKNOWN
| PathsToNested::COINDUCTIVE,
);
self.insert(PathsToNested::FORCED_AMBIGUITY);
}
}
}

self
}

#[must_use]
fn extend_with_paths(self, path: PathsToNested) -> Self {
let mut new = PathsToNested::empty();
for p in path.iter_paths() {
new |= self.extend_with(p);
}
new
}

fn iter_paths(self) -> impl Iterator<Item = PathKind> {
let (PathKind::Inductive
| PathKind::Unknown
| PathKind::Coinductive
| PathKind::ForcedAmbiguity);
[PathKind::Inductive, PathKind::Unknown, PathKind::Coinductive, PathKind::ForcedAmbiguity]
.into_iter()
.filter(move |&p| self.contains(p.into()))
}
}

/// The nested goals of each stack entry and the path from the
Expand Down Expand Up @@ -693,7 +686,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
if let Some((_scope, expected)) = validate_cache {
// Do not try to move a goal into the cache again if we're testing
// the global cache.
assert_eq!(evaluation_result.result, expected, "input={input:?}");
assert_eq!(expected, evaluation_result.result, "input={input:?}");
} else if D::inspect_is_noop(inspect) {
self.insert_global_cache(cx, input, evaluation_result, dep_node)
}
Expand Down Expand Up @@ -782,7 +775,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
stack_entry: &StackEntry<X>,
mut mutate_result: impl FnMut(X::Input, X::Result) -> X::Result,
) {
let head = self.stack.next_index();
let popped_head = self.stack.next_index();
#[allow(rustc::potential_query_instability)]
self.provisional_cache.retain(|&input, entries| {
entries.retain_mut(|entry| {
Expand All @@ -792,30 +785,37 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
path_from_head,
result,
} = entry;
if heads.highest_cycle_head() == head {
let ep = if heads.highest_cycle_head() == popped_head {
heads.remove_highest_cycle_head()
} else {
return true;
}

// We only try to rebase if all paths from the cache entry
// to its heads are coinductive. In this case these cycle
// kinds won't change, no matter the goals between these
// heads and the provisional cache entry.
if heads.iter().any(|(_, p)| matches!(p, AllPathsToHeadCoinductive::No)) {
return false;
}
};

// The same for nested goals of the cycle head.
if stack_entry.heads.iter().any(|(_, p)| matches!(p, AllPathsToHeadCoinductive::No))
{
return false;
// We're rebasing an entry `e` over a head `p`. This head
// has a number of own heads `h` it depends on. We need to
// make sure that the path kind of all paths `hph` remain the
// same after rebasing.
//
// After rebasing the cycles `hph` will go through `e`. We need
// to make sure that forall possible paths `hep` and `heph`
// is equal to `hph.`
for (h, ph) in stack_entry.heads.iter() {
let hp =
Self::cycle_path_kind(&self.stack, stack_entry.step_kind_from_parent, h);
let hph = ph.extend_with(hp);
let he = hp.extend(*path_from_head);
let hep = ep.extend_with(he);
for hep in hep.iter_paths() {
let heph = ph.extend_with(hep);
if hph != heph {
return false;
}
}

let eph = ep.extend_with_paths(ph);
heads.insert(h, eph);
}

// Merge the cycle heads of the provisional cache entry and the
// popped head. If the popped cycle head was a root, discard all
// provisional cache entries which depend on it.
heads.merge(&stack_entry.heads);
let Some(head) = heads.opt_highest_cycle_head() else {
return false;
};
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_type_ir/src/search_graph/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::ops::{Index, IndexMut};
use derive_where::derive_where;
use rustc_index::IndexVec;

use super::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind};
use crate::search_graph::{AvailableDepth, Cx, CycleHeads, NestedGoals, PathKind, UsageKind};

rustc_index::newtype_index! {
#[orderable]
Expand Down Expand Up @@ -79,6 +79,9 @@ impl<X: Cx> Stack<X> {
}

pub(super) fn push(&mut self, entry: StackEntry<X>) -> StackDepth {
if cfg!(debug_assertions) && self.entries.iter().any(|e| e.input == entry.input) {
panic!("pushing duplicate entry on stack: {entry:?} {:?}", self.entries);
}
self.entries.push(entry)
}

Expand Down
Loading