Skip to content

Commit 8f108a7

Browse files
calebzulawskiworkingjubilee
authored andcommitted
Generically implement ToBitMaskArray
1 parent 5f49d4c commit 8f108a7

File tree

5 files changed

+141
-3
lines changed

5 files changed

+141
-3
lines changed

crates/core_simd/src/masks.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
mod mask_impl;
1414

1515
mod to_bitmask;
16-
pub use to_bitmask::ToBitMask;
16+
pub use to_bitmask::{ToBitMask, ToBitMaskArray};
17+
18+
#[cfg(feature = "generic_const_exprs")]
19+
pub use to_bitmask::bitmask_len;
1720

1821
use crate::simd::{intrinsics, LaneCount, Simd, SimdElement, SupportedLaneCount};
1922
use core::cmp::Ordering;

crates/core_simd/src/masks/bitmask.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#![allow(unused_imports)]
22
use super::MaskElement;
33
use crate::simd::intrinsics;
4-
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
4+
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask, ToBitMaskArray};
55
use core::marker::PhantomData;
66

77
/// A mask where each lane is represented by a single bit.
@@ -115,6 +115,24 @@ where
115115
unsafe { Self(intrinsics::simd_bitmask(value), PhantomData) }
116116
}
117117

118+
#[inline]
119+
#[must_use = "method returns a new array and does not mutate the original value"]
120+
pub fn to_bitmask_array<const N: usize>(self) -> [u8; N] {
121+
assert!(core::mem::size_of::<Self>() == N);
122+
123+
// Safety: converting an integer to an array of bytes of the same size is safe
124+
unsafe { core::mem::transmute_copy(&self.0) }
125+
}
126+
127+
#[inline]
128+
#[must_use = "method returns a new mask and does not mutate the original value"]
129+
pub fn from_bitmask_array<const N: usize>(bitmask: [u8; N]) -> Self {
130+
assert!(core::mem::size_of::<Self>() == N);
131+
132+
// Safety: converting an array of bytes to an integer of the same size is safe
133+
Self(unsafe { core::mem::transmute_copy(&bitmask) }, PhantomData)
134+
}
135+
118136
#[inline]
119137
pub fn to_bitmask_integer<U>(self) -> U
120138
where

crates/core_simd/src/masks/full_masks.rs

+67-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use super::MaskElement;
44
use crate::simd::intrinsics;
5-
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask};
5+
use crate::simd::{LaneCount, Simd, SupportedLaneCount, ToBitMask, ToBitMaskArray};
66

77
#[repr(transparent)]
88
pub struct Mask<T, const LANES: usize>(Simd<T, LANES>)
@@ -127,6 +127,72 @@ where
127127
unsafe { Mask(intrinsics::simd_cast(self.0)) }
128128
}
129129

130+
#[inline]
131+
#[must_use = "method returns a new array and does not mutate the original value"]
132+
pub fn to_bitmask_array<const N: usize>(self) -> [u8; N]
133+
where
134+
super::Mask<T, LANES>: ToBitMaskArray,
135+
[(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,
136+
{
137+
assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N);
138+
139+
// Safety: N is the correct bitmask size
140+
//
141+
// The transmute below allows this function to be marked safe, since it will prevent
142+
// monomorphization errors in the case of an incorrect size.
143+
unsafe {
144+
// Compute the bitmask
145+
let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] =
146+
intrinsics::simd_bitmask(self.0);
147+
148+
// Transmute to the return type, previously asserted to be the same size
149+
let mut bitmask: [u8; N] = core::mem::transmute_copy(&bitmask);
150+
151+
// LLVM assumes bit order should match endianness
152+
if cfg!(target_endian = "big") {
153+
for x in bitmask.as_mut() {
154+
*x = x.reverse_bits();
155+
}
156+
};
157+
158+
bitmask
159+
}
160+
}
161+
162+
#[inline]
163+
#[must_use = "method returns a new mask and does not mutate the original value"]
164+
pub fn from_bitmask_array<const N: usize>(mut bitmask: [u8; N]) -> Self
165+
where
166+
super::Mask<T, LANES>: ToBitMaskArray,
167+
[(); <super::Mask<T, LANES> as ToBitMaskArray>::BYTES]: Sized,
168+
{
169+
assert_eq!(<super::Mask<T, LANES> as ToBitMaskArray>::BYTES, N);
170+
171+
// Safety: N is the correct bitmask size
172+
//
173+
// The transmute below allows this function to be marked safe, since it will prevent
174+
// monomorphization errors in the case of an incorrect size.
175+
unsafe {
176+
// LLVM assumes bit order should match endianness
177+
if cfg!(target_endian = "big") {
178+
for x in bitmask.as_mut() {
179+
*x = x.reverse_bits();
180+
}
181+
}
182+
183+
// Transmute to the bitmask type, previously asserted to be the same size
184+
let bitmask: [u8; <super::Mask<T, LANES> as ToBitMaskArray>::BYTES] =
185+
core::mem::transmute_copy(&bitmask);
186+
187+
// Compute the regular mask
188+
Self::from_int_unchecked(intrinsics::simd_select_bitmask(
189+
bitmask,
190+
Self::splat(true).to_int(),
191+
Self::splat(false).to_int(),
192+
))
193+
}
194+
}
195+
130196
#[inline]
131197
pub(crate) fn to_bitmask_integer<U: ReverseBits>(self) -> U
132198
where

crates/core_simd/src/masks/to_bitmask.rs

+38
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,24 @@ pub unsafe trait ToBitMask: Sealed {
3131
fn from_bitmask(bitmask: Self::BitMask) -> Self;
3232
}
3333

34+
/// Converts masks to and from byte array bitmasks.
35+
///
36+
/// Each bit of the bitmask corresponds to a mask lane, starting with the LSB of the first byte.
37+
///
38+
/// # Safety
39+
/// This trait is `unsafe` and sealed, since the `BYTES` value must match the number of lanes in
40+
/// the mask.
41+
pub unsafe trait ToBitMaskArray: Sealed {
42+
/// The length of the bitmask array.
43+
const BYTES: usize;
44+
45+
/// Converts a mask to a bitmask.
46+
fn to_bitmask_array(self) -> [u8; Self::BYTES];
47+
48+
/// Converts a bitmask to a mask.
49+
fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self;
50+
}
51+
3452
macro_rules! impl_integer_intrinsic {
3553
{ $(unsafe impl ToBitMask<BitMask=$int:ty> for Mask<_, $lanes:literal>)* } => {
3654
$(
@@ -55,3 +73,23 @@ impl_integer_intrinsic! {
5573
unsafe impl ToBitMask<BitMask=u32> for Mask<_, 32>
5674
unsafe impl ToBitMask<BitMask=u64> for Mask<_, 64>
5775
}
76+
77+
/// Returns the minimum numnber of bytes in a bitmask with `lanes` lanes.
78+
pub const fn bitmask_len(lanes: usize) -> usize {
79+
(lanes + 7) / 8
80+
}
81+
82+
unsafe impl<T: MaskElement, const LANES: usize> ToBitMaskArray for Mask<T, LANES>
83+
where
84+
LaneCount<LANES>: SupportedLaneCount,
85+
{
86+
const BYTES: usize = bitmask_len(LANES);
87+
88+
fn to_bitmask_array(self) -> [u8; Self::BYTES] {
89+
self.0.to_bitmask_array()
90+
}
91+
92+
fn from_bitmask_array(bitmask: [u8; Self::BYTES]) -> Self {
93+
Mask(mask_impl::Mask::from_bitmask_array(bitmask))
94+
}
95+
}

crates/core_simd/tests/masks.rs

+13
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ macro_rules! test_mask_api {
8080
assert_eq!(bitmask, 0b1000001101001001);
8181
assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask(bitmask), mask);
8282
}
83+
84+
#[test]
85+
fn roundtrip_bitmask_array_conversion() {
86+
use core_simd::ToBitMaskArray;
87+
let values = [
88+
true, false, false, true, false, false, true, false,
89+
true, true, false, false, false, false, false, true,
90+
];
91+
let mask = core_simd::Mask::<$type, 16>::from_array(values);
92+
let bitmask = mask.to_bitmask_array();
93+
assert_eq!(bitmask, [0b01001001, 0b10000011]);
94+
assert_eq!(core_simd::Mask::<$type, 16>::from_bitmask_array(bitmask), mask);
95+
}
8396
}
8497
}
8598
}

0 commit comments

Comments
 (0)