diff --git a/src/lib.rs b/src/lib.rs index 6447ca36..f27e3b6e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -926,8 +926,8 @@ impl Extend for ArrayVec { let take = self.capacity() - self.len(); unsafe { let len = self.len(); - let mut ptr = self.as_mut_ptr().offset(len as isize); - let end_ptr = ptr.offset(take as isize); + let mut ptr = raw_ptr_add(self.as_mut_ptr(), len); + let end_ptr = raw_ptr_add(ptr, take); // Keep the length in a separate variable, write it back on scope // exit. To help the compiler with alias analysis and stuff. // We update the length to handle panic in the iteration of the @@ -943,8 +943,8 @@ impl Extend for ArrayVec { loop { if ptr == end_ptr { break; } if let Some(elt) = iter.next() { - ptr::write(ptr, elt); - ptr = ptr.offset(1); + raw_ptr_write(ptr, elt); + ptr = raw_ptr_add(ptr, 1); guard.data += 1; } else { break; @@ -954,6 +954,24 @@ impl Extend for ArrayVec { } } +/// Rawptr add but uses arithmetic distance for ZST +unsafe fn raw_ptr_add(ptr: *mut T, offset: usize) -> *mut T { + if mem::size_of::() == 0 { + // Special case for ZST + (ptr as usize).wrapping_add(offset) as _ + } else { + ptr.offset(offset as isize) + } +} + +unsafe fn raw_ptr_write(ptr: *mut T, value: T) { + if mem::size_of::() == 0 { + /* nothing */ + } else { + ptr::write(ptr, value) + } +} + /// Create an `ArrayVec` from an iterator. /// /// Does not extract more items than there is space for. No error diff --git a/tests/tests.rs b/tests/tests.rs index 8f07ef76..306689c3 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -640,3 +640,23 @@ fn test_newish_stable_uses_maybe_uninit() { assert!(cfg!(has_stable_maybe_uninit)); } } + +#[test] +fn test_extend_zst() { + let mut range = 0..10; + #[derive(Copy, Clone, PartialEq, Debug)] + struct Z; // Zero sized type + + let mut array: ArrayVec<[_; 5]> = range.by_ref().map(|_| Z).collect(); + assert_eq!(&array[..], &[Z; 5]); + assert_eq!(range.next(), Some(5)); + + array.extend(range.by_ref().map(|_| Z)); + assert_eq!(range.next(), Some(6)); + + let mut array: ArrayVec<[_; 10]> = (0..3).map(|_| Z).collect(); + assert_eq!(&array[..], &[Z; 3]); + array.extend((3..5).map(|_| Z)); + assert_eq!(&array[..], &[Z; 5]); + assert_eq!(array.len(), 5); +}