diff --git a/compiler/rustc_transmute/Cargo.toml b/compiler/rustc_transmute/Cargo.toml index 0250cc0ea0788..246b66d3d0307 100644 --- a/compiler/rustc_transmute/Cargo.toml +++ b/compiler/rustc_transmute/Cargo.toml @@ -5,7 +5,6 @@ edition = "2024" [dependencies] # tidy-alphabetical-start -itertools = "0.12" rustc_abi = { path = "../rustc_abi", optional = true } rustc_data_structures = { path = "../rustc_data_structures" } rustc_hir = { path = "../rustc_hir", optional = true } @@ -15,6 +14,11 @@ smallvec = "1.8.1" tracing = "0.1" # tidy-alphabetical-end +[dev-dependencies] +# tidy-alphabetical-start +itertools = "0.12" +# tidy-alphabetical-end + [features] rustc = [ "dep:rustc_abi", diff --git a/compiler/rustc_transmute/src/layout/dfa.rs b/compiler/rustc_transmute/src/layout/dfa.rs index d1f58157b696b..05afa28db31a9 100644 --- a/compiler/rustc_transmute/src/layout/dfa.rs +++ b/compiler/rustc_transmute/src/layout/dfa.rs @@ -1,5 +1,5 @@ use std::fmt; -use std::ops::RangeInclusive; +use std::iter::Peekable; use std::sync::atomic::{AtomicU32, Ordering}; use super::{Byte, Ref, Tree, Uninhabited}; @@ -211,15 +211,15 @@ where let b_transitions = b_src.and_then(|b_src| b.transitions.get(&b_src)).unwrap_or(&empty_transitions); - let byte_transitions = - a_transitions.byte_transitions.union(&b_transitions.byte_transitions); - - let byte_transitions = byte_transitions.map_states(|(a_dst, b_dst)| { - assert!(a_dst.is_some() || b_dst.is_some()); + let byte_transitions = a_transitions.byte_transitions.union( + &b_transitions.byte_transitions, + |a_dst, b_dst| { + assert!(a_dst.is_some() || b_dst.is_some()); - queue.enqueue(a_dst, b_dst); - mapped((a_dst, b_dst)) - }); + queue.enqueue(a_dst, b_dst); + mapped((a_dst, b_dst)) + }, + ); let ref_transitions = a_transitions.ref_transitions.keys().chain(b_transitions.ref_transitions.keys()); @@ -245,18 +245,6 @@ where Self { transitions, start, accept } } - pub(crate) fn states_from( - &self, - state: State, - src_validity: RangeInclusive, - ) -> impl Iterator { - self.transitions - .get(&state) - .map(move |t| t.byte_transitions.states_from(src_validity)) - .into_iter() - .flatten() - } - pub(crate) fn get_uninit_edge_dst(&self, state: State) -> Option { let transitions = self.transitions.get(&state)?; transitions.byte_transitions.get_uninit_edge_dst() @@ -334,95 +322,31 @@ where use edge_set::EdgeSet; mod edge_set { - use std::cmp; - - use run::*; - use smallvec::{SmallVec, smallvec}; + use smallvec::SmallVec; use super::*; - mod run { - use std::ops::{Range, RangeInclusive}; - - use super::*; - use crate::layout::Byte; - - /// A logical set of edges. - /// - /// A `Run` encodes one edge for every byte value in `start..=end` - /// pointing to `dst`. - #[derive(Eq, PartialEq, Copy, Clone, Debug)] - pub(super) struct Run { - // `start` and `end` are both inclusive (ie, closed) bounds, as this - // is required in order to be able to store 0..=255. We provide - // setters and getters which operate on closed/open ranges, which - // are more intuitive and easier for performing offset math. - start: u8, - end: u8, - pub(super) dst: S, - } - - impl Run { - pub(super) fn new(range: RangeInclusive, dst: S) -> Self { - Self { start: *range.start(), end: *range.end(), dst } - } - - pub(super) fn from_inclusive_exclusive(range: Range, dst: S) -> Self { - Self { - start: range.start.try_into().unwrap(), - end: (range.end - 1).try_into().unwrap(), - dst, - } - } - - pub(super) fn contains(&self, idx: u16) -> bool { - idx >= u16::from(self.start) && idx <= u16::from(self.end) - } - - pub(super) fn as_inclusive_exclusive(&self) -> (u16, u16) { - (u16::from(self.start), u16::from(self.end) + 1) - } - - pub(super) fn as_byte(&self) -> Byte { - Byte::new(self.start..=self.end) - } - pub(super) fn map_state(self, f: impl FnOnce(S) -> SS) -> Run { - let Run { start, end, dst } = self; - Run { start, end, dst: f(dst) } - } - - /// Produces a new `Run` whose lower bound is the greater of - /// `self`'s existing lower bound and `lower_bound`. - pub(super) fn clamp_lower(self, lower_bound: u8) -> Self { - let Run { start, end, dst } = self; - Run { start: cmp::max(start, lower_bound), end, dst } - } - } - } - - /// The set of outbound byte edges associated with a DFA node (not including - /// reference edges). + /// The set of outbound byte edges associated with a DFA node. #[derive(Eq, PartialEq, Clone, Debug)] pub(super) struct EdgeSet { - // A sequence of runs stored in ascending order. Since the graph is a - // DFA, these must be non-overlapping with one another. - runs: SmallVec<[Run; 1]>, - // The edge labeled with the uninit byte, if any. + // A sequence of byte edges with contiguous byte values and a common + // destination is stored as a single run. // - // FIXME(@joshlf): Make `State` a `NonZero` so that this is NPO'd. - uninit: Option, + // Runs are non-empty, non-overlapping, and stored in ascending order. + runs: SmallVec<[(Byte, S); 1]>, } impl EdgeSet { - pub(crate) fn new(byte: Byte, dst: S) -> Self { - match byte.range() { - Some(range) => Self { runs: smallvec![Run::new(range, dst)], uninit: None }, - None => Self { runs: SmallVec::new(), uninit: Some(dst) }, + pub(crate) fn new(range: Byte, dst: S) -> Self { + let mut this = Self { runs: SmallVec::new() }; + if !range.is_empty() { + this.runs.push((range, dst)); } + this } pub(crate) fn empty() -> Self { - Self { runs: SmallVec::new(), uninit: None } + Self { runs: SmallVec::new() } } #[cfg(test)] @@ -431,43 +355,23 @@ mod edge_set { S: Ord, { edges.sort(); - Self { - runs: edges - .into_iter() - .map(|(byte, state)| Run::new(byte.range().unwrap(), state)) - .collect(), - uninit: None, - } + Self { runs: edges.into() } } pub(crate) fn iter(&self) -> impl Iterator where S: Copy, { - self.uninit - .map(|dst| (Byte::uninit(), dst)) - .into_iter() - .chain(self.runs.iter().map(|run| (run.as_byte(), run.dst))) - } - - pub(crate) fn states_from( - &self, - byte: RangeInclusive, - ) -> impl Iterator - where - S: Copy, - { - // FIXME(@joshlf): Optimize this. A manual scan over `self.runs` may - // permit us to more efficiently discard runs which will not be - // produced by this iterator. - self.iter().filter(move |(o, _)| Byte::new(byte.clone()).transmutable_into(&o)) + self.runs.iter().copied() } pub(crate) fn get_uninit_edge_dst(&self) -> Option where S: Copy, { - self.uninit + // Uninit is ordered last. + let &(range, dst) = self.runs.last()?; + if range.contains_uninit() { Some(dst) } else { None } } pub(crate) fn map_states(self, mut f: impl FnMut(S) -> SS) -> EdgeSet { @@ -478,95 +382,106 @@ mod edge_set { // allocates the correct number of elements once up-front [1]. // // [1] https://doc.rust-lang.org/1.85.0/src/alloc/vec/spec_from_iter_nested.rs.html#47 - runs: self.runs.into_iter().map(|run| run.map_state(&mut f)).collect(), - uninit: self.uninit.map(f), + runs: self.runs.into_iter().map(|(b, s)| (b, f(s))).collect(), } } /// Unions two edge sets together. /// /// If `u = a.union(b)`, then for each byte value, `u` will have an edge - /// with that byte value and with the destination `(Some(_), None)`, - /// `(None, Some(_))`, or `(Some(_), Some(_))` depending on whether `a`, + /// with that byte value and with the destination `join(Some(_), None)`, + /// `join(None, Some(_))`, or `join(Some(_), Some(_))` depending on whether `a`, /// `b`, or both have an edge with that byte value. /// /// If neither `a` nor `b` have an edge with a particular byte value, /// then no edge with that value will be present in `u`. - pub(crate) fn union(&self, other: &Self) -> EdgeSet<(Option, Option)> + pub(crate) fn union( + &self, + other: &Self, + mut join: impl FnMut(Option, Option) -> S, + ) -> EdgeSet where S: Copy, { - let uninit = match (self.uninit, other.uninit) { - (None, None) => None, - (s, o) => Some((s, o)), - }; - - let mut runs = SmallVec::new(); - - // Iterate over `self.runs` and `other.runs` simultaneously, - // advancing `idx` as we go. At each step, we advance `idx` as far - // as we can without crossing a run boundary in either `self.runs` - // or `other.runs`. - - // INVARIANT: `idx < s[0].end && idx < o[0].end`. - let (mut s, mut o) = (self.runs.as_slice(), other.runs.as_slice()); - let mut idx = 0u16; - while let (Some((s_run, s_rest)), Some((o_run, o_rest))) = - (s.split_first(), o.split_first()) - { - let (s_start, s_end) = s_run.as_inclusive_exclusive(); - let (o_start, o_end) = o_run.as_inclusive_exclusive(); - - // Compute `end` as the end of the current run (which starts - // with `idx`). - let (end, dst) = match (s_run.contains(idx), o_run.contains(idx)) { - // `idx` is in an existing run in both `s` and `o`, so `end` - // is equal to the smallest of the two ends of those runs. - (true, true) => (cmp::min(s_end, o_end), (Some(s_run.dst), Some(o_run.dst))), - // `idx` is in an existing run in `s`, but not in any run in - // `o`. `end` is either the end of the `s` run or the - // beginning of the next `o` run, whichever comes first. - (true, false) => (cmp::min(s_end, o_start), (Some(s_run.dst), None)), - // The inverse of the previous case. - (false, true) => (cmp::min(s_start, o_end), (None, Some(o_run.dst))), - // `idx` is not in a run in either `s` or `o`, so advance it - // to the beginning of the next run. - (false, false) => { - idx = cmp::min(s_start, o_start); - continue; - } - }; + let xs = self.runs.iter().copied(); + let ys = other.runs.iter().copied(); + // FIXME(@joshlf): Merge contiguous runs with common destination. + EdgeSet { runs: union(xs, ys).map(|(range, (x, y))| (range, join(x, y))).collect() } + } + } +} + +/// Merges two sorted sequences into one sorted sequence. +pub(crate) fn union, Y: Iterator>( + xs: X, + ys: Y, +) -> UnionIter { + UnionIter { xs: xs.peekable(), ys: ys.peekable() } +} + +pub(crate) struct UnionIter { + xs: Peekable, + ys: Peekable, +} + +// FIXME(jswrenn) we'd likely benefit from specializing try_fold here. +impl, Y: Iterator> Iterator + for UnionIter +{ + type Item = (Byte, (Option, Option)); - // FIXME(@joshlf): If this is contiguous with the previous run - // and has the same `dst`, just merge it into that run rather - // than adding a new one. - runs.push(Run::from_inclusive_exclusive(idx..end, dst)); - idx = end; + fn next(&mut self) -> Option { + use std::cmp::{self, Ordering}; - if idx >= s_end { - s = s_rest; + let ret; + match (self.xs.peek_mut(), self.ys.peek_mut()) { + (None, None) => { + ret = None; + } + (Some(x), None) => { + ret = Some((x.0, (Some(x.1), None))); + self.xs.next(); + } + (None, Some(y)) => { + ret = Some((y.0, (None, Some(y.1)))); + self.ys.next(); + } + (Some(x), Some(y)) => { + let start; + let end; + let dst; + match x.0.start.cmp(&y.0.start) { + Ordering::Less => { + start = x.0.start; + end = cmp::min(x.0.end, y.0.start); + dst = (Some(x.1), None); + } + Ordering::Greater => { + start = y.0.start; + end = cmp::min(x.0.start, y.0.end); + dst = (None, Some(y.1)); + } + Ordering::Equal => { + start = x.0.start; + end = cmp::min(x.0.end, y.0.end); + dst = (Some(x.1), Some(y.1)); + } } - if idx >= o_end { - o = o_rest; + ret = Some((Byte { start, end }, dst)); + if start == x.0.start { + x.0.start = end; + } + if start == y.0.start { + y.0.start = end; + } + if x.0.is_empty() { + self.xs.next(); + } + if y.0.is_empty() { + self.ys.next(); } } - - // At this point, either `s` or `o` have been exhausted, so the - // remaining elements in the other slice are guaranteed to be - // non-overlapping. We can add all remaining runs to `runs` with no - // further processing. - if let Ok(idx) = u8::try_from(idx) { - let (slc, map) = if !s.is_empty() { - let map: fn(_) -> _ = |st| (Some(st), None); - (s, map) - } else { - let map: fn(_) -> _ = |st| (None, Some(st)); - (o, map) - }; - runs.extend(slc.iter().map(|run| run.clamp_lower(idx).map_state(map))); - } - - EdgeSet { runs, uninit } } + ret } } diff --git a/compiler/rustc_transmute/src/layout/mod.rs b/compiler/rustc_transmute/src/layout/mod.rs index 4d5f630ae229e..c08bf440734e2 100644 --- a/compiler/rustc_transmute/src/layout/mod.rs +++ b/compiler/rustc_transmute/src/layout/mod.rs @@ -6,61 +6,61 @@ pub(crate) mod tree; pub(crate) use tree::Tree; pub(crate) mod dfa; -pub(crate) use dfa::Dfa; +pub(crate) use dfa::{Dfa, union}; #[derive(Debug)] pub(crate) struct Uninhabited; -/// A range of byte values, or the uninit byte. +/// A range of byte values (including an uninit byte value). #[derive(Hash, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] pub(crate) struct Byte { - // An inclusive-inclusive range. We use this instead of `RangeInclusive` - // because `RangeInclusive: !Copy`. + // An inclusive-exclusive range. We use this instead of `Range` because `Range: !Copy`. // - // `None` means uninit. - // - // FIXME(@joshlf): Optimize this representation. Some pairs of values (where - // `lo > hi`) are illegal, and we could use these to represent `None`. - range: Option<(u8, u8)>, + // Uninit byte value is represented by 256. + pub(crate) start: u16, + pub(crate) end: u16, } impl Byte { + const UNINIT: u16 = 256; + + #[inline] fn new(range: RangeInclusive) -> Self { - Self { range: Some((*range.start(), *range.end())) } + let start: u16 = (*range.start()).into(); + let end: u16 = (*range.end()).into(); + Byte { start, end: end + 1 } } + #[inline] fn from_val(val: u8) -> Self { - Self { range: Some((val, val)) } + let val: u16 = val.into(); + Byte { start: val, end: val + 1 } } - pub(crate) fn uninit() -> Byte { - Byte { range: None } + #[inline] + fn uninit() -> Byte { + Byte { start: 0, end: Self::UNINIT + 1 } } - /// Returns `None` if `self` is the uninit byte. - pub(crate) fn range(&self) -> Option> { - self.range.map(|(lo, hi)| lo..=hi) + #[inline] + fn is_empty(&self) -> bool { + self.start == self.end } - /// Are any of the values in `self` transmutable into `other`? - /// - /// Note two special cases: An uninit byte is only transmutable into another - /// uninit byte. Any byte is transmutable into an uninit byte. - pub(crate) fn transmutable_into(&self, other: &Byte) -> bool { - match (self.range, other.range) { - (None, None) => true, - (None, Some(_)) => false, - (Some(_), None) => true, - (Some((slo, shi)), Some((olo, ohi))) => slo <= ohi && olo <= shi, - } + #[inline] + fn contains_uninit(&self) -> bool { + self.start <= Self::UNINIT && Self::UNINIT < self.end } } impl fmt::Debug for Byte { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.range { - None => write!(f, "uninit"), - Some((lo, hi)) => write!(f, "{lo}..={hi}"), + if self.start == Self::UNINIT && self.end == Self::UNINIT + 1 { + write!(f, "uninit") + } else if self.start <= Self::UNINIT && self.end == Self::UNINIT + 1 { + write!(f, "{}..{}|uninit", self.start, self.end - 1) + } else { + write!(f, "{}..{}", self.start, self.end) } } } @@ -72,6 +72,7 @@ impl From> for Byte { } impl From for Byte { + #[inline] fn from(src: u8) -> Self { Self::from_val(src) } diff --git a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs index 0a19cccc2ed03..f76abe50ed343 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/mod.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/mod.rs @@ -1,14 +1,11 @@ -use std::rc::Rc; -use std::{cmp, iter}; - -use itertools::Either; +use rustc_data_structures::stack::ensure_sufficient_stack; use tracing::{debug, instrument, trace}; pub(crate) mod query_context; #[cfg(test)] mod tests; -use crate::layout::{self, Byte, Def, Dfa, Ref, Tree, dfa}; +use crate::layout::{self, Def, Dfa, Ref, Tree, dfa, union}; use crate::maybe_transmutable::query_context::QueryContext; use crate::{Answer, Condition, Map, Reason}; @@ -153,230 +150,135 @@ where if let Some(answer) = cache.get(&(src_state, dst_state)) { answer.clone() } else { - debug!(?src_state, ?dst_state); - debug!(src = ?self.src); - debug!(dst = ?self.dst); - debug!( - src_transitions_len = self.src.transitions.len(), - dst_transitions_len = self.dst.transitions.len() - ); - let answer = if dst_state == self.dst.accept { - // truncation: `size_of(Src) >= size_of(Dst)` - // - // Why is truncation OK to do? Because even though the Src is bigger, all we care about - // is whether we have enough data for the Dst to be valid in accordance with what its - // type dictates. - // For example, in a u8 to `()` transmutation, we have enough data available from the u8 - // to transmute it to a `()` (though in this case does `()` really need any data to - // begin with? It doesn't). Same thing with u8 to fieldless struct. - // Now then, why is something like u8 to bool not allowed? That is not because the bool - // is smaller in size, but rather because those 2 bits that we are re-interpreting from - // the u8 could introduce invalid states for the bool type. - // - // So, if it's possible to transmute to a smaller Dst by truncating, and we can guarantee - // that none of the actually-used data can introduce an invalid state for Dst's type, we - // are able to safely transmute, even with truncation. - Answer::Yes - } else if src_state == self.src.accept { - // extension: `size_of(Src) <= size_of(Dst)` - if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) { - self.answer_memo(cache, src_state, dst_state_prime) - } else { - Answer::No(Reason::DstIsTooBig) - } + let answer = ensure_sufficient_stack(|| self.answer_impl(cache, src_state, dst_state)); + if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) { + panic!("failed to correctly cache transmutability") + } + answer + } + } + + fn answer_impl( + &self, + cache: &mut Map<(dfa::State, dfa::State), Answer<::Ref>>, + src_state: dfa::State, + dst_state: dfa::State, + ) -> Answer<::Ref> { + debug!(?src_state, ?dst_state); + debug!(src = ?self.src); + debug!(dst = ?self.dst); + debug!( + src_transitions_len = self.src.transitions.len(), + dst_transitions_len = self.dst.transitions.len() + ); + if dst_state == self.dst.accept { + // truncation: `size_of(Src) >= size_of(Dst)` + // + // Why is truncation OK to do? Because even though the Src is bigger, all we care about + // is whether we have enough data for the Dst to be valid in accordance with what its + // type dictates. + // For example, in a u8 to `()` transmutation, we have enough data available from the u8 + // to transmute it to a `()` (though in this case does `()` really need any data to + // begin with? It doesn't). Same thing with u8 to fieldless struct. + // Now then, why is something like u8 to bool not allowed? That is not because the bool + // is smaller in size, but rather because those 2 bits that we are re-interpreting from + // the u8 could introduce invalid states for the bool type. + // + // So, if it's possible to transmute to a smaller Dst by truncating, and we can guarantee + // that none of the actually-used data can introduce an invalid state for Dst's type, we + // are able to safely transmute, even with truncation. + Answer::Yes + } else if src_state == self.src.accept { + // extension: `size_of(Src) <= size_of(Dst)` + if let Some(dst_state_prime) = self.dst.get_uninit_edge_dst(dst_state) { + self.answer_memo(cache, src_state, dst_state_prime) + } else { + Answer::No(Reason::DstIsTooBig) + } + } else { + let src_quantifier = if self.assume.validity { + // if the compiler may assume that the programmer is doing additional validity checks, + // (e.g.: that `src != 3u8` when the destination type is `bool`) + // then there must exist at least one transition out of `src_state` such that the transmute is viable... + Quantifier::ThereExists } else { - let src_quantifier = if self.assume.validity { - // if the compiler may assume that the programmer is doing additional validity checks, - // (e.g.: that `src != 3u8` when the destination type is `bool`) - // then there must exist at least one transition out of `src_state` such that the transmute is viable... - Quantifier::ThereExists - } else { - // if the compiler cannot assume that the programmer is doing additional validity checks, - // then for all transitions out of `src_state`, such that the transmute is viable... - // then there must exist at least one transition out of `dst_state` such that the transmute is viable... - Quantifier::ForAll - }; - - let c = &core::cell::RefCell::new(&mut *cache); - let bytes_answer = src_quantifier.apply( - // for each of the byte set transitions out of the `src_state`... - self.src.bytes_from(src_state).flat_map( - move |(src_validity, src_state_prime)| { - // ...find all matching transitions out of `dst_state`. - - let Some(src_validity) = src_validity.range() else { - // NOTE: We construct an iterator here rather - // than just computing the value directly (via - // `self.answer_memo`) so that, if the iterator - // we produce from this branch is - // short-circuited, we don't waste time - // computing `self.answer_memo` unnecessarily. - // That will specifically happen if - // `src_quantifier == Quantifier::ThereExists`, - // since we emit `Answer::Yes` first (before - // chaining `answer_iter`). - let answer_iter = if let Some(dst_state_prime) = - self.dst.get_uninit_edge_dst(dst_state) - { - Either::Left(iter::once_with(move || { - let mut c = c.borrow_mut(); - self.answer_memo(&mut *c, src_state_prime, dst_state_prime) - })) - } else { - Either::Right(iter::once(Answer::No( - Reason::DstIsBitIncompatible, - ))) - }; - - // When `answer == Answer::No(...)`, there are - // two cases to consider: - // - If `assume.validity`, then we should - // succeed because the user is responsible for - // ensuring that the *specific* byte value - // appearing at runtime is valid for the - // destination type. When `assume.validity`, - // `src_quantifier == - // Quantifier::ThereExists`, so adding an - // `Answer::Yes` has the effect of ensuring - // that the "there exists" is always - // satisfied. - // - If `!assume.validity`, then we should fail. - // In this case, `src_quantifier == - // Quantifier::ForAll`, so adding an - // `Answer::Yes` has no effect. - return Either::Left(iter::once(Answer::Yes).chain(answer_iter)); - }; - - #[derive(Copy, Clone, Debug)] - struct Accum { - // The number of matching byte edges that we - // have found in the destination so far. - sum: usize, - found_uninit: bool, + // if the compiler cannot assume that the programmer is doing additional validity checks, + // then for all transitions out of `src_state`, such that the transmute is viable... + // then there must exist at least one transition out of `dst_state` such that the transmute is viable... + Quantifier::ForAll + }; + + let bytes_answer = src_quantifier.apply( + union(self.src.bytes_from(src_state), self.dst.bytes_from(dst_state)).filter_map( + |(_range, (src_state_prime, dst_state_prime))| { + match (src_state_prime, dst_state_prime) { + // No matching transitions in `src`. Skip. + (None, _) => None, + // No matching transitions in `dst`. Fail. + (Some(_), None) => Some(Answer::No(Reason::DstIsBitIncompatible)), + // Matching transitions. Continue with successor states. + (Some(src_state_prime), Some(dst_state_prime)) => { + Some(self.answer_memo(cache, src_state_prime, dst_state_prime)) } + } + }, + ), + ); - let accum1 = Rc::new(std::cell::Cell::new(Accum { - sum: 0, - found_uninit: false, - })); - let accum2 = Rc::clone(&accum1); - let sv = src_validity.clone(); - let update_accum = move |mut accum: Accum, dst_validity: Byte| { - if let Some(dst_validity) = dst_validity.range() { - // Only add the part of `dst_validity` that - // overlaps with `src_validity`. - let start = cmp::max(*sv.start(), *dst_validity.start()); - let end = cmp::min(*sv.end(), *dst_validity.end()); - - // We add 1 here to account for the fact - // that `end` is an inclusive bound. - accum.sum += 1 + usize::from(end.saturating_sub(start)); - } else { - accum.found_uninit = true; - } - accum - }; - - let answers = self - .dst - .states_from(dst_state, src_validity.clone()) - .map(move |(dst_validity, dst_state_prime)| { - let mut c = c.borrow_mut(); - accum1.set(update_accum(accum1.get(), dst_validity)); - let answer = - self.answer_memo(&mut *c, src_state_prime, dst_state_prime); - answer + // The below early returns reflect how this code would behave: + // if self.assume.validity { + // or(bytes_answer, refs_answer) + // } else { + // and(bytes_answer, refs_answer) + // } + // ...if `refs_answer` was computed lazily. The below early + // returns can be deleted without impacting the correctness of + // the algorithm; only its performance. + debug!(?bytes_answer); + match bytes_answer { + Answer::No(_) if !self.assume.validity => return bytes_answer, + Answer::Yes if self.assume.validity => return bytes_answer, + _ => {} + }; + + let refs_answer = src_quantifier.apply( + // for each reference transition out of `src_state`... + self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| { + // ...there exists a reference transition out of `dst_state`... + Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map( + |(dst_ref, dst_state_prime)| { + if !src_ref.is_mutable() && dst_ref.is_mutable() { + Answer::No(Reason::DstIsMoreUnique) + } else if !self.assume.alignment + && src_ref.min_align() < dst_ref.min_align() + { + Answer::No(Reason::DstHasStricterAlignment { + src_min_align: src_ref.min_align(), + dst_min_align: dst_ref.min_align(), }) - .chain( - iter::once_with(move || { - let src_validity_len = usize::from(*src_validity.end()) - - usize::from(*src_validity.start()) - + 1; - let accum = accum2.get(); - - // If this condition is false, then - // there are some byte values in the - // source which have no corresponding - // transition in the destination DFA. In - // that case, we add a `No` to our list - // of answers. When - // `!self.assume.validity`, this will - // cause the query to fail. - if accum.found_uninit || accum.sum == src_validity_len { - None - } else { - Some(Answer::No(Reason::DstIsBitIncompatible)) - } - }) - .flatten(), - ); - Either::Right(answers) - }, - ), - ); - - // The below early returns reflect how this code would behave: - // if self.assume.validity { - // or(bytes_answer, refs_answer) - // } else { - // and(bytes_answer, refs_answer) - // } - // ...if `refs_answer` was computed lazily. The below early - // returns can be deleted without impacting the correctness of - // the algorithm; only its performance. - debug!(?bytes_answer); - match bytes_answer { - Answer::No(_) if !self.assume.validity => return bytes_answer, - Answer::Yes if self.assume.validity => return bytes_answer, - _ => {} - }; - - let refs_answer = src_quantifier.apply( - // for each reference transition out of `src_state`... - self.src.refs_from(src_state).map(|(src_ref, src_state_prime)| { - // ...there exists a reference transition out of `dst_state`... - Quantifier::ThereExists.apply(self.dst.refs_from(dst_state).map( - |(dst_ref, dst_state_prime)| { - if !src_ref.is_mutable() && dst_ref.is_mutable() { - Answer::No(Reason::DstIsMoreUnique) - } else if !self.assume.alignment - && src_ref.min_align() < dst_ref.min_align() - { - Answer::No(Reason::DstHasStricterAlignment { - src_min_align: src_ref.min_align(), - dst_min_align: dst_ref.min_align(), - }) - } else if dst_ref.size() > src_ref.size() { - Answer::No(Reason::DstRefIsTooBig { + } else if dst_ref.size() > src_ref.size() { + Answer::No(Reason::DstRefIsTooBig { src: src_ref, dst: dst_ref }) + } else { + // ...such that `src` is transmutable into `dst`, if + // `src_ref` is transmutability into `dst_ref`. + and( + Answer::If(Condition::IfTransmutable { src: src_ref, dst: dst_ref, - }) - } else { - // ...such that `src` is transmutable into `dst`, if - // `src_ref` is transmutability into `dst_ref`. - and( - Answer::If(Condition::IfTransmutable { - src: src_ref, - dst: dst_ref, - }), - self.answer_memo(cache, src_state_prime, dst_state_prime), - ) - } - }, - )) - }), - ); - - if self.assume.validity { - or(bytes_answer, refs_answer) - } else { - and(bytes_answer, refs_answer) - } - }; - if let Some(..) = cache.insert((src_state, dst_state), answer.clone()) { - panic!("failed to correctly cache transmutability") + }), + self.answer_memo(cache, src_state_prime, dst_state_prime), + ) + } + }, + )) + }), + ); + + if self.assume.validity { + or(bytes_answer, refs_answer) + } else { + and(bytes_answer, refs_answer) } - answer } } } diff --git a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs index 992fcb7cc4c81..fbb4639dbd630 100644 --- a/compiler/rustc_transmute/src/maybe_transmutable/tests.rs +++ b/compiler/rustc_transmute/src/maybe_transmutable/tests.rs @@ -400,16 +400,23 @@ mod r#ref { fn should_permit_identity_transmutation() { type Tree = crate::layout::Tree; - let layout = Tree::Seq(vec![Tree::byte(0x00), Tree::Ref([()])]); + for validity in [false, true] { + let layout = Tree::Seq(vec![Tree::byte(0x00), Tree::Ref([()])]); - let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( - layout.clone(), - layout, - Assume::default(), - UltraMinimal::default(), - ) - .answer(); - assert_eq!(answer, Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] })); + let assume = Assume { validity, ..Assume::default() }; + + let answer = crate::maybe_transmutable::MaybeTransmutableQuery::new( + layout.clone(), + layout, + assume, + UltraMinimal::default(), + ) + .answer(); + assert_eq!( + answer, + Answer::If(crate::Condition::IfTransmutable { src: [()], dst: [()] }) + ); + } } } diff --git a/tests/ui/transmutability/unions/extension.rs b/tests/ui/transmutability/unions/extension.rs new file mode 100644 index 0000000000000..eb4dcd4dff3df --- /dev/null +++ b/tests/ui/transmutability/unions/extension.rs @@ -0,0 +1,12 @@ +#![crate_type = "lib"] +#![feature(transmutability)] +use std::mem::{Assume, MaybeUninit, TransmuteFrom}; + +pub fn is_maybe_transmutable() + where Dst: TransmuteFrom +{} + +fn extension() { + is_maybe_transmutable::<(), MaybeUninit>(); + is_maybe_transmutable::, [u8; 2]>(); //~ ERROR `MaybeUninit` cannot be safely transmuted into `[u8; 2]` +} diff --git a/tests/ui/transmutability/unions/extension.stderr b/tests/ui/transmutability/unions/extension.stderr new file mode 100644 index 0000000000000..c99e46f3d12b7 --- /dev/null +++ b/tests/ui/transmutability/unions/extension.stderr @@ -0,0 +1,17 @@ +error[E0277]: `MaybeUninit` cannot be safely transmuted into `[u8; 2]` + --> $DIR/extension.rs:11:46 + | +LL | is_maybe_transmutable::, [u8; 2]>(); + | ^^^^^^^ the size of `MaybeUninit` is smaller than the size of `[u8; 2]` + | +note: required by a bound in `is_maybe_transmutable` + --> $DIR/extension.rs:6:16 + | +LL | pub fn is_maybe_transmutable() + | --------------------- required by a bound in this function +LL | where Dst: TransmuteFrom + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `is_maybe_transmutable` + +error: aborting due to 1 previous error + +For more information about this error, try `rustc --explain E0277`. diff --git a/tests/ui/transmutability/unions/init_as_uninit.rs b/tests/ui/transmutability/unions/init_as_uninit.rs new file mode 100644 index 0000000000000..d14eca800ef5d --- /dev/null +++ b/tests/ui/transmutability/unions/init_as_uninit.rs @@ -0,0 +1,26 @@ +//@ check-pass +// Regression test for issue #140337. +#![crate_type = "lib"] +#![feature(transmutability)] +#![allow(dead_code)] +use std::mem::{Assume, MaybeUninit, TransmuteFrom}; + +pub fn is_transmutable() +where + Dst: TransmuteFrom +{} + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum B0 { Value = 0 } + +#[derive(Copy, Clone)] +#[repr(u8)] +pub enum B1 { Value = 1 } + +fn main() { + is_transmutable::<(B0, B0), MaybeUninit<(B0, B0)>>(); + is_transmutable::<(B0, B0), MaybeUninit<(B0, B1)>>(); + is_transmutable::<(B0, B0), MaybeUninit<(B1, B0)>>(); + is_transmutable::<(B0, B0), MaybeUninit<(B1, B1)>>(); +} diff --git a/tests/ui/transmutability/unions/should_permit_intersecting_if_validity_is_assumed.rs b/tests/ui/transmutability/unions/should_permit_intersecting_if_validity_is_assumed.rs index 359ba51543981..24c6fa2e6ac0f 100644 --- a/tests/ui/transmutability/unions/should_permit_intersecting_if_validity_is_assumed.rs +++ b/tests/ui/transmutability/unions/should_permit_intersecting_if_validity_is_assumed.rs @@ -34,4 +34,19 @@ fn test() { assert::is_maybe_transmutable::(); assert::is_maybe_transmutable::(); + + #[repr(C)] + struct C { + a: Ox00, + b: Ox00, + } + + #[repr(C, align(2))] + struct D { + a: Ox00, + } + + assert::is_maybe_transmutable::(); + // With Assume::VALIDITY a padding byte can hold any value. + assert::is_maybe_transmutable::(); }