Skip to content

Commit d31d92c

Browse files
authored
Extend #[derive(TransparentWrapper)] (#147)
* Extend #[derive(TransparentWrapper)] to allow types other than single idents in #[transparent(Type)]. * Update TransparentWrapper derive macro documentation. * Suggest type alias workaround in TransparentWrapper error message
1 parent 3e18072 commit d31d92c

File tree

3 files changed

+85
-31
lines changed

3 files changed

+85
-31
lines changed

derive/src/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,10 @@ pub fn derive_maybe_pod(
318318
///
319319
/// If the struct only contains a single field, the `Wrapped` type will
320320
/// automatically be determined. If there is more then one field in the struct,
321-
/// you need to specify the `Wrapped` type using `#[transparent(T)]`
321+
/// you need to specify the `Wrapped` type using `#[transparent(T)]`. Due to
322+
/// technical limitations, the type in the `#[transparent(Type)]` needs to be
323+
/// the exact same token sequence as the corresponding type in the struct
324+
/// definition.
322325
///
323326
/// ## Examples
324327
///

derive/src/traits.rs

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -423,19 +423,29 @@ impl Derivable for CheckedBitPattern {
423423

424424
pub struct TransparentWrapper;
425425

426+
struct WrappedType {
427+
wrapped_type: syn::Type,
428+
/// Was the type given with a #[transparent(Type)] attribute.
429+
explicit: bool,
430+
}
431+
426432
impl TransparentWrapper {
427-
fn get_wrapper_type(
433+
fn get_wrapped_type(
428434
attributes: &[Attribute], fields: &Fields,
429-
) -> Option<TokenStream> {
430-
let transparent_param = get_simple_attr(attributes, "transparent");
431-
transparent_param.map(|ident| ident.to_token_stream()).or_else(|| {
435+
) -> Option<WrappedType> {
436+
let transparent_param =
437+
get_type_from_simple_attr(attributes, "transparent")
438+
.map(|wrapped_type| WrappedType { wrapped_type, explicit: true });
439+
transparent_param.or_else(|| {
432440
let mut types = get_field_types(&fields);
433441
let first_type = types.next();
434442
if let Some(_) = types.next() {
435443
// can't guess param type if there is more than one field
436444
return None;
437445
} else {
438-
first_type.map(|ty| ty.to_token_stream())
446+
first_type
447+
.cloned()
448+
.map(|wrapped_type| WrappedType { wrapped_type, explicit: false })
439449
}
440450
})
441451
}
@@ -445,15 +455,13 @@ impl Derivable for TransparentWrapper {
445455
fn ident(input: &DeriveInput, crate_name: &TokenStream) -> Result<syn::Path> {
446456
let fields = get_struct_fields(input)?;
447457

448-
let ty = match Self::get_wrapper_type(&input.attrs, &fields) {
449-
Some(ty) => ty,
450-
None => bail!(
451-
"\
452-
when deriving TransparentWrapper for a struct with more than one field \
453-
you need to specify the transparent field using #[transparent(T)]\
454-
"
455-
),
456-
};
458+
let WrappedType { wrapped_type: ty, .. } =
459+
match Self::get_wrapped_type(&input.attrs, &fields) {
460+
Some(ty) => ty,
461+
None => bail!("when deriving TransparentWrapper for a struct with more \
462+
than one field, you need to specify the transparent field \
463+
using #[transparent(T)]"),
464+
};
457465

458466
Ok(syn::parse_quote!(#crate_name::TransparentWrapper<#ty>))
459467
}
@@ -464,19 +472,27 @@ impl Derivable for TransparentWrapper {
464472
let (impl_generics, _ty_generics, where_clause) =
465473
input.generics.split_for_impl();
466474
let fields = get_struct_fields(input)?;
467-
let wrapped_type = match Self::get_wrapper_type(&input.attrs, &fields) {
468-
Some(wrapped_type) => wrapped_type.to_string(),
469-
None => unreachable!(), /* other code will already reject this derive */
470-
};
475+
let (wrapped_type, explicit) =
476+
match Self::get_wrapped_type(&input.attrs, &fields) {
477+
Some(WrappedType { wrapped_type, explicit }) => {
478+
(wrapped_type.to_token_stream().to_string(), explicit)
479+
}
480+
None => unreachable!(), /* other code will already reject this derive */
481+
};
471482
let mut wrapped_field_ty = None;
472483
let mut nonwrapped_field_tys = vec![];
473484
for field in fields.iter() {
474485
let field_ty = &field.ty;
475486
if field_ty.to_token_stream().to_string() == wrapped_type {
476487
if wrapped_field_ty.is_some() {
477-
bail!(
478-
"TransparentWrapper can only have one field of the wrapped type"
479-
);
488+
if explicit {
489+
bail!("TransparentWrapper must have one field of the wrapped type. \
490+
The type given in `#[transparent(Type)]` must match tokenwise \
491+
with the type in the struct definition, not just be the same type. \
492+
You may be able to use a type alias to work around this limitation.");
493+
} else {
494+
bail!("TransparentWrapper must have one field of the wrapped type");
495+
}
480496
}
481497
wrapped_field_ty = Some(field_ty);
482498
} else {
@@ -1182,21 +1198,31 @@ fn generate_enum_discriminant(input: &DeriveInput) -> Result<TokenStream> {
11821198
})
11831199
}
11841200

1185-
fn get_ident_from_stream(tokens: TokenStream) -> Option<Ident> {
1186-
match tokens.into_iter().next() {
1187-
Some(TokenTree::Group(group)) => get_ident_from_stream(group.stream()),
1188-
Some(TokenTree::Ident(ident)) => Some(ident),
1189-
_ => None,
1201+
fn get_wrapped_type_from_stream(tokens: TokenStream) -> Option<syn::Type> {
1202+
let mut tokens = tokens.into_iter().peekable();
1203+
match tokens.peek() {
1204+
Some(TokenTree::Group(group)) => {
1205+
let res = get_wrapped_type_from_stream(group.stream());
1206+
tokens.next(); // remove the peeked token tree
1207+
match tokens.next() {
1208+
// If there were more tokens, the input was invalid
1209+
Some(_) => None,
1210+
None => res,
1211+
}
1212+
}
1213+
_ => syn::parse2(tokens.collect()).ok(),
11901214
}
11911215
}
11921216

1193-
/// get a simple #[foo(bar)] attribute, returning "bar"
1194-
fn get_simple_attr(attributes: &[Attribute], attr_name: &str) -> Option<Ident> {
1217+
/// get a simple `#[foo(bar)]` attribute, returning `bar`
1218+
fn get_type_from_simple_attr(
1219+
attributes: &[Attribute], attr_name: &str,
1220+
) -> Option<syn::Type> {
11951221
for attr in attributes {
11961222
if let (AttrStyle::Outer, Meta::List(list)) = (&attr.style, &attr.meta) {
11971223
if list.path.is_ident(attr_name) {
1198-
if let Some(ident) = get_ident_from_stream(list.tokens.clone()) {
1199-
return Some(ident);
1224+
if let Some(ty) = get_wrapped_type_from_stream(list.tokens.clone()) {
1225+
return Some(ty);
12001226
}
12011227
}
12021228
}

derive/tests/basic.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,31 @@ enum CheckedBitPatternEnumNested {
312312
#[repr(transparent)]
313313
struct NewtypeWrapperTest<T>(T);
314314

315+
#[derive(Debug, Clone, PartialEq, Eq, TransparentWrapper)]
316+
#[repr(transparent)]
317+
struct AlgebraicNewtypeWrapperTest<T>(Vec<T>);
318+
319+
#[test]
320+
fn algebraic_newtype_corect() {
321+
let x: Vec<u32> = vec![1, 2, 3, 4];
322+
let y: AlgebraicNewtypeWrapperTest<u32> =
323+
AlgebraicNewtypeWrapperTest::wrap(x.clone());
324+
assert_eq!(y.0, x);
325+
}
326+
327+
#[derive(Debug, Clone, PartialEq, Eq, TransparentWrapper)]
328+
#[repr(transparent)]
329+
#[transparent(Vec<T>)]
330+
struct AlgebraicNewtypeWrapperTestWithFields<T, U>(Vec<T>, PhantomData<U>);
331+
332+
#[test]
333+
fn algebraic_newtype_fields_corect() {
334+
let x: Vec<u32> = vec![1, 2, 3, 4];
335+
let y: AlgebraicNewtypeWrapperTestWithFields<u32, f32> =
336+
AlgebraicNewtypeWrapperTestWithFields::wrap(x.clone());
337+
assert_eq!(y.0, x);
338+
}
339+
315340
#[test]
316341
fn fails_cast_contiguous() {
317342
let can_cast = CheckedBitPatternEnumWithValues::is_valid_bit_pattern(&5);

0 commit comments

Comments
 (0)