Skip to content

Commit 517bf62

Browse files
committed
fix: improve sqrt implementations for f32 and f64 without std
Handle special cases (NaN, zero, negative, infinity) and use Newton-Raphson refinement for better accuracy. Update RMS doc tests for floating-point precision.
1 parent 567f1ae commit 517bf62

File tree

2 files changed

+56
-12
lines changed

2 files changed

+56
-12
lines changed

dasp_rms/src/lib.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,18 @@ where
117117
/// fn main() {
118118
/// let window = ring_buffer::Fixed::from([[0.0]; 4]);
119119
/// let mut rms = Rms::new(window);
120-
/// assert_eq!(rms.next([1.0]), [0.5]);
121-
/// assert_eq!(rms.next([-1.0]), [0.7071067811865476]);
122-
/// assert_eq!(rms.next([1.0]), [0.8660254037844386]);
123-
/// assert_eq!(rms.next([-1.0]), [1.0]);
120+
///
121+
/// assert_eq!(rms.next([1.0f32]), [0.5f32]);
122+
///
123+
/// let result = rms.next([-1.0f32])[0];
124+
/// assert!((result - 0.7071067811865476).abs() < 0.0001,
125+
/// "Expected ~0.7071067811865476, got {}", result);
126+
///
127+
/// let result = rms.next([1.0f32])[0];
128+
/// assert!((result - 0.8660254037844386).abs() < 0.0001,
129+
/// "Expected ~0.8660254037844386, got {}", result);
130+
///
131+
/// assert_eq!(rms.next([-1.0f32]), [1.0f32]);
124132
/// }
125133
/// ```
126134
#[inline]

dasp_sample/src/ops.rs

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,29 @@ pub mod f32 {
44

55
#[cfg(not(feature = "std"))]
66
pub fn sqrt(x: f32) -> f32 {
7-
if x >= 0.0 {
8-
f32::from_bits((x.to_bits() + 0x3f80_0000) >> 1)
9-
} else {
10-
f32::NAN
7+
if x.is_nan() || x < 0.0 {
8+
return f32::NAN;
119
}
10+
if x == 0.0 {
11+
return x; // preserves +0.0 and -0.0
12+
}
13+
if x.is_infinite() {
14+
return f32::INFINITY;
15+
}
16+
17+
let bits = x.to_bits();
18+
let exp = (bits >> 23) & 0xff;
19+
let mant = bits & 0x7fffff;
20+
21+
let unbiased = exp as i32 - 127;
22+
let sqrt_exp = (unbiased / 2 + 127) as u32;
23+
let guess_bits = (sqrt_exp << 23) | (mant >> 1);
24+
let mut guess = f32::from_bits(guess_bits);
25+
26+
for _ in 0..3 {
27+
guess = 0.5 * (guess + x / guess);
28+
}
29+
guess
1230
}
1331
#[cfg(feature = "std")]
1432
pub fn sqrt(x: f32) -> f32 {
@@ -22,11 +40,29 @@ pub mod f64 {
2240

2341
#[cfg(not(feature = "std"))]
2442
pub fn sqrt(x: f64) -> f64 {
25-
if x >= 0.0 {
26-
f64::from_bits((x.to_bits() + 0x3f80_0000) >> 1)
27-
} else {
28-
f64::NAN
43+
if x.is_nan() || x < 0.0 {
44+
return f64::NAN;
45+
}
46+
if x == 0.0 {
47+
return x; // preserves +0.0 and -0.0
48+
}
49+
if x.is_infinite() {
50+
return f64::INFINITY;
51+
}
52+
53+
let bits = x.to_bits();
54+
let exp = (bits >> 52) & 0x7ff;
55+
let mant = bits & 0x000f_ffff_ffff_ffff;
56+
57+
let unbiased = exp as i32 - 1023;
58+
let sqrt_exp = (unbiased / 2 + 1023) as u64;
59+
let guess_bits = (sqrt_exp << 52) | (mant >> 1);
60+
let mut guess = f64::from_bits(guess_bits);
61+
62+
for _ in 0..4 {
63+
guess = 0.5 * (guess + x / guess);
2964
}
65+
guess
3066
}
3167
#[cfg(feature = "std")]
3268
pub fn sqrt(x: f64) -> f64 {

0 commit comments

Comments
 (0)