Skip to content

Commit 5342c6a

Browse files
committed
Add tag_for_variant query
This query allows for sharing code between `rustc_const_eval` and `rustc_transmutability`.
1 parent 9023f90 commit 5342c6a

File tree

7 files changed

+145
-93
lines changed

7 files changed

+145
-93
lines changed

compiler/rustc_const_eval/src/const_eval/eval_queries.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use either::{Left, Right};
33
use rustc_hir::def::DefKind;
44
use rustc_middle::mir::interpret::{AllocId, ErrorHandled, InterpErrorInfo};
55
use rustc_middle::mir::{self, ConstAlloc, ConstValue};
6-
use rustc_middle::query::TyCtxtAt;
6+
use rustc_middle::query::{Key, TyCtxtAt};
77
use rustc_middle::traits::Reveal;
88
use rustc_middle::ty::layout::LayoutOf;
99
use rustc_middle::ty::print::with_no_trimmed_paths;
@@ -243,6 +243,24 @@ pub(crate) fn turn_into_const_value<'tcx>(
243243
op_to_const(&ecx, &mplace.into(), /* for diagnostics */ false)
244244
}
245245

246+
/// Computes the tag (if any) for a given type and variant.
247+
#[instrument(skip(tcx), level = "debug")]
248+
pub fn tag_for_variant_provider<'tcx>(
249+
tcx: TyCtxt<'tcx>,
250+
(ty, variant_index): (Ty<'tcx>, abi::VariantIdx),
251+
) -> Option<ty::ScalarInt> {
252+
assert!(ty.is_enum());
253+
254+
let ecx = InterpCx::new(
255+
tcx,
256+
ty.default_span(tcx),
257+
ty::ParamEnv::reveal_all(),
258+
CompileTimeInterpreter::new(CanAccessMutGlobal::No, CheckAlignment::Error),
259+
);
260+
261+
ecx.tag_for_variant(ty, variant_index).unwrap().map(|(tag, _tag_field)| tag)
262+
}
263+
246264
#[instrument(skip(tcx), level = "debug")]
247265
pub fn eval_to_const_value_raw_provider<'tcx>(
248266
tcx: TyCtxt<'tcx>,

compiler/rustc_const_eval/src/interpret/discriminant.rs

+89-67
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use rustc_middle::mir;
44
use rustc_middle::ty::layout::{LayoutOf, PrimitiveExt};
5-
use rustc_middle::ty::{self, Ty};
5+
use rustc_middle::ty::{self, ScalarInt, Ty};
66
use rustc_target::abi::{self, TagEncoding};
77
use rustc_target::abi::{VariantIdx, Variants};
88

@@ -28,78 +28,27 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
2828
throw_ub!(UninhabitedEnumVariantWritten(variant_index))
2929
}
3030

31-
match dest.layout().variants {
32-
abi::Variants::Single { index } => {
33-
assert_eq!(index, variant_index);
34-
}
35-
abi::Variants::Multiple {
36-
tag_encoding: TagEncoding::Direct,
37-
tag: tag_layout,
38-
tag_field,
39-
..
40-
} => {
31+
match self.tag_for_variant(dest.layout().ty, variant_index)? {
32+
Some((tag, tag_field)) => {
4133
// No need to validate that the discriminant here because the
42-
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
43-
44-
let discr_val = dest
45-
.layout()
46-
.ty
47-
.discriminant_for_variant(*self.tcx, variant_index)
48-
.unwrap()
49-
.val;
50-
51-
// raw discriminants for enums are isize or bigger during
52-
// their computation, but the in-memory tag is the smallest possible
53-
// representation
54-
let size = tag_layout.size(self);
55-
let tag_val = size.truncate(discr_val);
56-
34+
// `TyAndLayout::for_variant()` call earlier already checks the
35+
// variant is valid.
5736
let tag_dest = self.project_field(dest, tag_field)?;
58-
self.write_scalar(Scalar::from_uint(tag_val, size), &tag_dest)?;
37+
self.write_scalar(tag, &tag_dest)
5938
}
60-
abi::Variants::Multiple {
61-
tag_encoding:
62-
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
63-
tag: tag_layout,
64-
tag_field,
65-
..
66-
} => {
67-
// No need to validate that the discriminant here because the
68-
// `TyAndLayout::for_variant()` call earlier already checks the variant is valid.
69-
70-
if variant_index != untagged_variant {
71-
let variants_start = niche_variants.start().as_u32();
72-
let variant_index_relative = variant_index
73-
.as_u32()
74-
.checked_sub(variants_start)
75-
.expect("overflow computing relative variant idx");
76-
// We need to use machine arithmetic when taking into account `niche_start`:
77-
// tag_val = variant_index_relative + niche_start_val
78-
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
79-
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
80-
let variant_index_relative_val =
81-
ImmTy::from_uint(variant_index_relative, tag_layout);
82-
let tag_val = self.wrapping_binary_op(
83-
mir::BinOp::Add,
84-
&variant_index_relative_val,
85-
&niche_start_val,
86-
)?;
87-
// Write result.
88-
let niche_dest = self.project_field(dest, tag_field)?;
89-
self.write_immediate(*tag_val, &niche_dest)?;
90-
} else {
91-
// The untagged variant is implicitly encoded simply by having a value that is
92-
// outside the niche variants. But what if the data stored here does not
93-
// actually encode this variant? That would be bad! So let's double-check...
94-
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
95-
if actual_variant != variant_index {
96-
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
97-
}
39+
None => {
40+
// No need to write the tag here, because an untagged variant is
41+
// implicitly encoded. For `Niche`-optimized enums, it's by
42+
// simply by having a value that is outside the niche variants.
43+
// But what if the data stored here does not actually encode
44+
// this variant? That would be bad! So let's double-check...
45+
let actual_variant = self.read_discriminant(&dest.to_op(self)?)?;
46+
if actual_variant != variant_index {
47+
throw_ub!(InvalidNichedEnumVariantWritten { enum_ty: dest.layout().ty });
9848
}
49+
Ok(())
9950
}
10051
}
101-
102-
Ok(())
10352
}
10453

10554
/// Read discriminant, return the runtime value as well as the variant index.
@@ -277,4 +226,77 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
277226
};
278227
Ok(ImmTy::from_scalar(discr_value, discr_layout))
279228
}
229+
230+
/// Computes the tag value and its field number (if any) of a given variant
231+
/// of type `ty`.
232+
pub(crate) fn tag_for_variant(
233+
&self,
234+
ty: Ty<'tcx>,
235+
variant_index: VariantIdx,
236+
) -> InterpResult<'tcx, Option<(ScalarInt, usize)>> {
237+
match self.layout_of(ty)?.variants {
238+
abi::Variants::Single { index } => {
239+
assert_eq!(index, variant_index);
240+
Ok(None)
241+
}
242+
243+
abi::Variants::Multiple {
244+
tag_encoding: TagEncoding::Direct,
245+
tag: tag_layout,
246+
tag_field,
247+
..
248+
} => {
249+
// raw discriminants for enums are isize or bigger during
250+
// their computation, but the in-memory tag is the smallest possible
251+
// representation
252+
let discr = self.discriminant_for_variant(ty, variant_index)?;
253+
let discr_size = discr.layout.size;
254+
let discr_val = discr.to_scalar().to_bits(discr_size)?;
255+
let tag_size = tag_layout.size(self);
256+
let tag_val = tag_size.truncate(discr_val);
257+
let tag = ScalarInt::try_from_uint(tag_val, tag_size).unwrap();
258+
Ok(Some((tag, tag_field)))
259+
}
260+
261+
abi::Variants::Multiple {
262+
tag_encoding: TagEncoding::Niche { untagged_variant, .. },
263+
..
264+
} if untagged_variant == variant_index => {
265+
// The untagged variant is implicitly encoded simply by having a
266+
// value that is outside the niche variants.
267+
Ok(None)
268+
}
269+
270+
abi::Variants::Multiple {
271+
tag_encoding:
272+
TagEncoding::Niche { untagged_variant, ref niche_variants, niche_start },
273+
tag: tag_layout,
274+
tag_field,
275+
..
276+
} => {
277+
assert!(variant_index != untagged_variant);
278+
let variants_start = niche_variants.start().as_u32();
279+
let variant_index_relative = variant_index
280+
.as_u32()
281+
.checked_sub(variants_start)
282+
.expect("overflow computing relative variant idx");
283+
// We need to use machine arithmetic when taking into account `niche_start`:
284+
// tag_val = variant_index_relative + niche_start_val
285+
let tag_layout = self.layout_of(tag_layout.primitive().to_int_ty(*self.tcx))?;
286+
let niche_start_val = ImmTy::from_uint(niche_start, tag_layout);
287+
let variant_index_relative_val =
288+
ImmTy::from_uint(variant_index_relative, tag_layout);
289+
let tag = self
290+
.wrapping_binary_op(
291+
mir::BinOp::Add,
292+
&variant_index_relative_val,
293+
&niche_start_val,
294+
)?
295+
.to_scalar()
296+
.try_to_int()
297+
.unwrap();
298+
Ok(Some((tag, tag_field)))
299+
}
300+
}
301+
}
280302
}

compiler/rustc_const_eval/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ rustc_fluent_macro::fluent_messages! { "../messages.ftl" }
4040

4141
pub fn provide(providers: &mut Providers) {
4242
const_eval::provide(providers);
43+
providers.tag_for_variant = const_eval::tag_for_variant_provider;
4344
providers.eval_to_const_value_raw = const_eval::eval_to_const_value_raw_provider;
4445
providers.eval_to_allocation_raw = const_eval::eval_to_allocation_raw_provider;
4546
providers.eval_static_initializer = const_eval::eval_static_initializer_provider;

compiler/rustc_middle/src/query/erase.rs

+1
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ trivial! {
234234
Option<rustc_middle::middle::stability::DeprecationEntry>,
235235
Option<rustc_middle::ty::Destructor>,
236236
Option<rustc_middle::ty::ImplTraitInTraitData>,
237+
Option<rustc_middle::ty::ScalarInt>,
237238
Option<rustc_span::def_id::CrateNum>,
238239
Option<rustc_span::def_id::DefId>,
239240
Option<rustc_span::def_id::LocalDefId>,

compiler/rustc_middle/src/query/keys.rs

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use rustc_query_system::query::DefIdCacheSelector;
1313
use rustc_query_system::query::{DefaultCacheSelector, SingleCacheSelector, VecCacheSelector};
1414
use rustc_span::symbol::{Ident, Symbol};
1515
use rustc_span::{Span, DUMMY_SP};
16+
use rustc_target::abi;
1617

1718
/// Placeholder for `CrateNum`'s "local" counterpart
1819
#[derive(Copy, Clone, Debug)]
@@ -502,6 +503,14 @@ impl<'tcx> Key for (DefId, Ty<'tcx>, GenericArgsRef<'tcx>, ty::ParamEnv<'tcx>) {
502503
}
503504
}
504505

506+
impl<'tcx> Key for (Ty<'tcx>, abi::VariantIdx) {
507+
type CacheSelector = DefaultCacheSelector<Self>;
508+
509+
fn default_span(&self, _tcx: TyCtxt<'_>) -> Span {
510+
DUMMY_SP
511+
}
512+
}
513+
505514
impl<'tcx> Key for (ty::Predicate<'tcx>, traits::WellFormedLoc) {
506515
type CacheSelector = DefaultCacheSelector<Self>;
507516

compiler/rustc_middle/src/query/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,13 @@ rustc_queries! {
10421042
}
10431043
}
10441044

1045+
/// Computes the tag (if any) for a given type and variant.
1046+
query tag_for_variant(
1047+
key: (Ty<'tcx>, abi::VariantIdx)
1048+
) -> Option<ty::ScalarInt> {
1049+
desc { "computing variant tag for enum" }
1050+
}
1051+
10451052
/// Evaluates a constant and returns the computed allocation.
10461053
///
10471054
/// **Do not use this** directly, use the `eval_to_const_value` or `eval_to_valtree` instead.

compiler/rustc_transmute/src/layout/tree.rs

+19-25
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ pub(crate) mod rustc {
174174
use crate::layout::rustc::{Def, Ref};
175175

176176
use rustc_middle::ty::layout::LayoutError;
177-
use rustc_middle::ty::util::Discr;
178177
use rustc_middle::ty::AdtDef;
179178
use rustc_middle::ty::GenericArgsRef;
180179
use rustc_middle::ty::ParamEnv;
180+
use rustc_middle::ty::ScalarInt;
181181
use rustc_middle::ty::VariantDef;
182182
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
183183
use rustc_span::ErrorGuaranteed;
@@ -331,14 +331,15 @@ pub(crate) mod rustc {
331331
trace!(?adt_def, "treeifying enum");
332332
let mut tree = Tree::uninhabited();
333333

334-
for (idx, discr) in adt_def.discriminants(tcx) {
334+
for (idx, variant) in adt_def.variants().iter_enumerated() {
335+
let tag = tcx.tag_for_variant((ty, idx));
335336
tree = tree.or(Self::from_repr_c_variant(
336337
ty,
337338
*adt_def,
338339
args_ref,
339340
&layout_summary,
340-
Some(discr),
341-
adt_def.variant(idx),
341+
tag,
342+
variant,
342343
tcx,
343344
)?);
344345
}
@@ -393,7 +394,7 @@ pub(crate) mod rustc {
393394
adt_def: AdtDef<'tcx>,
394395
args_ref: GenericArgsRef<'tcx>,
395396
layout_summary: &LayoutSummary,
396-
discr: Option<Discr<'tcx>>,
397+
tag: Option<ScalarInt>,
397398
variant_def: &'tcx VariantDef,
398399
tcx: TyCtxt<'tcx>,
399400
) -> Result<Self, Err> {
@@ -403,9 +404,6 @@ pub(crate) mod rustc {
403404
let min_align = repr.align.unwrap_or(Align::ONE);
404405
let max_align = repr.pack.unwrap_or(Align::MAX);
405406

406-
let clamp =
407-
|align: Align| align.clamp(min_align, max_align).bytes().try_into().unwrap();
408-
409407
let variant_span = trace_span!(
410408
"treeifying variant",
411409
min_align = ?min_align,
@@ -419,17 +417,12 @@ pub(crate) mod rustc {
419417
)
420418
.unwrap();
421419

422-
// The layout of the variant is prefixed by the discriminant, if any.
423-
if let Some(discr) = discr {
424-
trace!(?discr, "treeifying discriminant");
425-
let discr_layout = alloc::Layout::from_size_align(
426-
layout_summary.discriminant_size,
427-
clamp(layout_summary.discriminant_align),
428-
)
429-
.unwrap();
430-
trace!(?discr_layout, "computed discriminant layout");
431-
variant_layout = variant_layout.extend(discr_layout).unwrap().0;
432-
tree = tree.then(Self::from_discr(discr, tcx, layout_summary.discriminant_size));
420+
// The layout of the variant is prefixed by the tag, if any.
421+
if let Some(tag) = tag {
422+
let tag_layout =
423+
alloc::Layout::from_size_align(tag.size().bytes_usize(), 1).unwrap();
424+
tree = tree.then(Self::from_tag(tag, tcx));
425+
variant_layout = variant_layout.extend(tag_layout).unwrap().0;
433426
}
434427

435428
// Next come fields.
@@ -469,18 +462,19 @@ pub(crate) mod rustc {
469462
Ok(tree)
470463
}
471464

472-
pub fn from_discr(discr: Discr<'tcx>, tcx: TyCtxt<'tcx>, size: usize) -> Self {
465+
pub fn from_tag(tag: ScalarInt, tcx: TyCtxt<'tcx>) -> Self {
473466
use rustc_target::abi::Endian;
474-
467+
let size = tag.size();
468+
let bits = tag.to_bits(size).unwrap();
475469
let bytes: [u8; 16];
476470
let bytes = match tcx.data_layout.endian {
477471
Endian::Little => {
478-
bytes = discr.val.to_le_bytes();
479-
&bytes[..size]
472+
bytes = bits.to_le_bytes();
473+
&bytes[..size.bytes_usize()]
480474
}
481475
Endian::Big => {
482-
bytes = discr.val.to_be_bytes();
483-
&bytes[bytes.len() - size..]
476+
bytes = bits.to_be_bytes();
477+
&bytes[bytes.len() - size.bytes_usize()..]
484478
}
485479
};
486480
Self::Seq(bytes.iter().map(|&b| Self::from_bits(b)).collect())

0 commit comments

Comments
 (0)