Skip to content

Use an IndexVec to cache queries with index-like key #103808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion compiler/rustc_hir/src/hir_id.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::def_id::{DefId, LocalDefId, CRATE_DEF_ID};
use crate::def_id::{DefId, DefIndex, LocalDefId, CRATE_DEF_ID};
use rustc_data_structures::stable_hasher::{HashStable, StableHasher, ToStableHashKey};
use rustc_span::{def_id::DefPathHash, HashStableContext};
use std::fmt;
Expand All @@ -22,6 +22,18 @@ impl OwnerId {
}
}

impl rustc_index::vec::Idx for OwnerId {
#[inline]
fn new(idx: usize) -> Self {
OwnerId { def_id: LocalDefId { local_def_index: DefIndex::from_usize(idx) } }
}

#[inline]
fn index(self) -> usize {
self.def_id.local_def_index.as_usize()
}
}

impl<CTX: HashStableContext> HashStable<CTX> for OwnerId {
#[inline]
fn hash_stable(&self, hcx: &mut CTX, hasher: &mut StableHasher) {
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_index/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ pub trait Idx: Copy + 'static + Eq + PartialEq + Debug + Hash {

fn index(self) -> usize;

#[inline]
fn increment_by(&mut self, amount: usize) {
*self = self.plus(amount);
}

#[inline]
fn plus(self, amount: usize) -> Self {
Self::new(self.index() + amount)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
//! Defines the set of legal keys that can be used in queries.

use crate::infer::canonical::Canonical;
use crate::mir;
use crate::traits;
use crate::ty::fast_reject::SimplifiedType;
use crate::ty::subst::{GenericArg, SubstsRef};
use crate::ty::{self, layout::TyAndLayout, Ty, TyCtxt};
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
use rustc_hir::hir_id::{HirId, OwnerId};
use rustc_middle::infer::canonical::Canonical;
use rustc_middle::mir;
use rustc_middle::traits;
use rustc_middle::ty::fast_reject::SimplifiedType;
use rustc_middle::ty::subst::{GenericArg, SubstsRef};
use rustc_middle::ty::{self, layout::TyAndLayout, Ty, TyCtxt};
use rustc_query_system::query::{DefaultCacheSelector, VecCacheSelector};
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};

/// The `Key` trait controls what types can legally be used as the key
/// for a query.
pub trait Key {
pub trait Key: Sized {
type CacheSelector = DefaultCacheSelector<Self>;

/// Given an instance of this key, what crate is it referring to?
/// This is used to find the provider.
fn query_crate_is_local(&self) -> bool;
Expand Down Expand Up @@ -100,6 +103,8 @@ impl<'tcx> Key for mir::interpret::LitToConstInput<'tcx> {
}

impl Key for CrateNum {
type CacheSelector = VecCacheSelector<Self>;

#[inline(always)]
fn query_crate_is_local(&self) -> bool {
*self == LOCAL_CRATE
Expand All @@ -110,6 +115,8 @@ impl Key for CrateNum {
}

impl Key for OwnerId {
type CacheSelector = VecCacheSelector<Self>;

#[inline(always)]
fn query_crate_is_local(&self) -> bool {
true
Expand All @@ -123,6 +130,8 @@ impl Key for OwnerId {
}

impl Key for LocalDefId {
type CacheSelector = VecCacheSelector<Self>;

#[inline(always)]
fn query_crate_is_local(&self) -> bool {
true
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
use crate::ty::{self, print::describe_as_module, TyCtxt};
use rustc_span::def_id::LOCAL_CRATE;

mod keys;
pub use keys::Key;

// Each of these queries corresponds to a function pointer field in the
// `Providers` struct for requesting a value of that type, and a method
// on `tcx: TyCtxt` (and `tcx.at(span)`) for doing that request in a way
Expand Down
5 changes: 3 additions & 2 deletions compiler/rustc_middle/src/ty/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::mir::interpret::{
};
use crate::mir::interpret::{LitToConstError, LitToConstInput};
use crate::mir::mono::CodegenUnit;
use crate::query::Key;
use crate::thir;
use crate::traits::query::{
CanonicalPredicateGoal, CanonicalProjectionGoal, CanonicalTyGoal,
Expand Down Expand Up @@ -121,10 +122,10 @@ macro_rules! query_helper_param_ty {

macro_rules! query_storage {
([][$K:ty, $V:ty]) => {
<DefaultCacheSelector as CacheSelector<$K, $V>>::Cache
<<$K as Key>::CacheSelector as CacheSelector<'tcx, $V>>::Cache
};
([(arena_cache) $($rest:tt)*][$K:ty, $V:ty]) => {
<ArenaCacheSelector<'tcx> as CacheSelector<$K, $V>>::Cache
<<$K as Key>::CacheSelector as CacheSelector<'tcx, $V>>::ArenaCache
};
([$other:tt $($modifiers:tt)*][$($args:tt)*]) => {
query_storage!([$($modifiers)*][$($args)*])
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_query_impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ use rustc_query_system::query::*;
#[cfg(parallel_compiler)]
pub use rustc_query_system::query::{deadlock, QueryContext};

mod keys;
use keys::Key;
use rustc_middle::query::Key;

pub use rustc_query_system::query::QueryConfig;
pub(crate) use rustc_query_system::query::{QueryDescription, QueryVTable};
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_query_impl/src/plumbing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
//! generate the actual methods on tcx which find and execute the provider,
//! manage the caches, and so forth.

use crate::keys::Key;
use crate::on_disk_cache::{CacheDecoder, CacheEncoder, EncodedDepNodeIndex};
use crate::profiling_support::QueryKeyStringCache;
use crate::{on_disk_cache, Queries};
Expand All @@ -12,6 +11,7 @@ use rustc_errors::{Diagnostic, Handler};
use rustc_middle::dep_graph::{
self, DepKind, DepKindStruct, DepNode, DepNodeIndex, SerializedDepNodeIndex,
};
use rustc_middle::query::Key;
use rustc_middle::ty::tls::{self, ImplicitCtxt};
use rustc_middle::ty::{self, TyCtxt};
use rustc_query_system::dep_graph::{DepNodeParams, HasDepContext};
Expand Down
203 changes: 192 additions & 11 deletions compiler/rustc_query_system/src/query/caches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@ use rustc_data_structures::sharded::Sharded;
#[cfg(not(parallel_compiler))]
use rustc_data_structures::sync::Lock;
use rustc_data_structures::sync::WorkerLocal;
use rustc_index::vec::{Idx, IndexVec};
use std::default::Default;
use std::fmt::Debug;
use std::hash::Hash;
use std::marker::PhantomData;

pub trait CacheSelector<K, V> {
type Cache;
pub trait CacheSelector<'tcx, V> {
type Cache
where
V: Clone;
type ArenaCache;
}

pub trait QueryStorage {
Expand Down Expand Up @@ -47,10 +51,13 @@ pub trait QueryCache: QueryStorage + Sized {
fn iter(&self, f: &mut dyn FnMut(&Self::Key, &Self::Value, DepNodeIndex));
}

pub struct DefaultCacheSelector;
pub struct DefaultCacheSelector<K>(PhantomData<K>);

impl<K: Eq + Hash, V: Clone> CacheSelector<K, V> for DefaultCacheSelector {
type Cache = DefaultCache<K, V>;
impl<'tcx, K: Eq + Hash, V: 'tcx> CacheSelector<'tcx, V> for DefaultCacheSelector<K> {
type Cache = DefaultCache<K, V>
where
V: Clone;
type ArenaCache = ArenaCache<'tcx, K, V>;
}

pub struct DefaultCache<K, V> {
Expand Down Expand Up @@ -134,12 +141,6 @@ where
}
}

pub struct ArenaCacheSelector<'tcx>(PhantomData<&'tcx ()>);

impl<'tcx, K: Eq + Hash, V: 'tcx> CacheSelector<K, V> for ArenaCacheSelector<'tcx> {
type Cache = ArenaCache<'tcx, K, V>;
}

pub struct ArenaCache<'tcx, K, V> {
arena: WorkerLocal<TypedArena<(V, DepNodeIndex)>>,
#[cfg(parallel_compiler)]
Expand Down Expand Up @@ -224,3 +225,183 @@ where
}
}
}

pub struct VecCacheSelector<K>(PhantomData<K>);

impl<'tcx, K: Idx, V: 'tcx> CacheSelector<'tcx, V> for VecCacheSelector<K> {
type Cache = VecCache<K, V>
where
V: Clone;
type ArenaCache = VecArenaCache<'tcx, K, V>;
}

pub struct VecCache<K: Idx, V> {
#[cfg(parallel_compiler)]
cache: Sharded<IndexVec<K, Option<(V, DepNodeIndex)>>>,
#[cfg(not(parallel_compiler))]
cache: Lock<IndexVec<K, Option<(V, DepNodeIndex)>>>,
}

impl<K: Idx, V> Default for VecCache<K, V> {
fn default() -> Self {
VecCache { cache: Default::default() }
}
}

impl<K: Eq + Idx, V: Clone + Debug> QueryStorage for VecCache<K, V> {
type Value = V;
type Stored = V;

#[inline]
fn store_nocache(&self, value: Self::Value) -> Self::Stored {
// We have no dedicated storage
value
}
}

impl<K, V> QueryCache for VecCache<K, V>
where
K: Eq + Idx + Clone + Debug,
V: Clone + Debug,
{
type Key = K;

#[inline(always)]
fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
where
OnHit: FnOnce(&V, DepNodeIndex) -> R,
{
#[cfg(parallel_compiler)]
let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
#[cfg(not(parallel_compiler))]
let lock = self.cache.lock();
if let Some(Some(value)) = lock.get(*key) {
let hit_result = on_hit(&value.0, value.1);
Ok(hit_result)
} else {
Err(())
}
}

#[inline]
fn complete(&self, key: K, value: V, index: DepNodeIndex) -> Self::Stored {
#[cfg(parallel_compiler)]
let mut lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
#[cfg(not(parallel_compiler))]
let mut lock = self.cache.lock();
lock.insert(key, (value.clone(), index));
value
}

fn iter(&self, f: &mut dyn FnMut(&Self::Key, &Self::Value, DepNodeIndex)) {
#[cfg(parallel_compiler)]
{
let shards = self.cache.lock_shards();
for shard in shards.iter() {
for (k, v) in shard.iter_enumerated() {
if let Some(v) = v {
f(&k, &v.0, v.1);
}
}
}
}
#[cfg(not(parallel_compiler))]
{
let map = self.cache.lock();
for (k, v) in map.iter_enumerated() {
if let Some(v) = v {
f(&k, &v.0, v.1);
}
}
}
}
}

pub struct VecArenaCache<'tcx, K: Idx, V> {
arena: WorkerLocal<TypedArena<(V, DepNodeIndex)>>,
#[cfg(parallel_compiler)]
cache: Sharded<IndexVec<K, Option<&'tcx (V, DepNodeIndex)>>>,
#[cfg(not(parallel_compiler))]
cache: Lock<IndexVec<K, Option<&'tcx (V, DepNodeIndex)>>>,
}

impl<'tcx, K: Idx, V> Default for VecArenaCache<'tcx, K, V> {
fn default() -> Self {
VecArenaCache {
arena: WorkerLocal::new(|_| TypedArena::default()),
cache: Default::default(),
}
}
}

impl<'tcx, K: Eq + Idx, V: Debug + 'tcx> QueryStorage for VecArenaCache<'tcx, K, V> {
type Value = V;
type Stored = &'tcx V;

#[inline]
fn store_nocache(&self, value: Self::Value) -> Self::Stored {
let value = self.arena.alloc((value, DepNodeIndex::INVALID));
let value = unsafe { &*(&value.0 as *const _) };
&value
}
}

impl<'tcx, K, V: 'tcx> QueryCache for VecArenaCache<'tcx, K, V>
where
K: Eq + Idx + Clone + Debug,
V: Debug,
{
type Key = K;

#[inline(always)]
fn lookup<R, OnHit>(&self, key: &K, on_hit: OnHit) -> Result<R, ()>
where
OnHit: FnOnce(&&'tcx V, DepNodeIndex) -> R,
{
#[cfg(parallel_compiler)]
let lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
#[cfg(not(parallel_compiler))]
let lock = self.cache.lock();
if let Some(Some(value)) = lock.get(*key) {
let hit_result = on_hit(&&value.0, value.1);
Ok(hit_result)
} else {
Err(())
}
}

#[inline]
fn complete(&self, key: K, value: V, index: DepNodeIndex) -> Self::Stored {
let value = self.arena.alloc((value, index));
let value = unsafe { &*(value as *const _) };
#[cfg(parallel_compiler)]
let mut lock = self.cache.get_shard_by_hash(key.index() as u64).lock();
#[cfg(not(parallel_compiler))]
let mut lock = self.cache.lock();
lock.insert(key, value);
&value.0
}

fn iter(&self, f: &mut dyn FnMut(&Self::Key, &Self::Value, DepNodeIndex)) {
#[cfg(parallel_compiler)]
{
let shards = self.cache.lock_shards();
for shard in shards.iter() {
for (k, v) in shard.iter_enumerated() {
if let Some(v) = v {
f(&k, &v.0, v.1);
}
}
}
}
#[cfg(not(parallel_compiler))]
{
let map = self.cache.lock();
for (k, v) in map.iter_enumerated() {
if let Some(v) = v {
f(&k, &v.0, v.1);
}
}
}
}
}
2 changes: 1 addition & 1 deletion compiler/rustc_query_system/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub use self::job::{print_query_stack, QueryInfo, QueryJob, QueryJobId, QueryJob

mod caches;
pub use self::caches::{
ArenaCacheSelector, CacheSelector, DefaultCacheSelector, QueryCache, QueryStorage,
CacheSelector, DefaultCacheSelector, QueryCache, QueryStorage, VecCacheSelector,
};

mod config;
Expand Down