libm/math/generic/
sqrt.rs

1/* SPDX-License-Identifier: MIT */
2/* origin: musl src/math/sqrt.c. Ported to generic Rust algorithm in 2025, TG. */
3
4//! Generic square root algorithm.
5//!
6//! This routine operates around `m_u2`, a U.2 (fixed point with two integral bits) mantissa
7//! within the range [1, 4). A table lookup provides an initial estimate, then goldschmidt
8//! iterations at various widths are used to approach the real values.
9//!
10//! For the iterations, `r` is a U0 number that approaches `1/sqrt(m_u2)`, and `s` is a U2 number
11//! that approaches `sqrt(m_u2)`. Recall that m_u2 ∈ [1, 4).
12//!
13//! With Newton-Raphson iterations, this would be:
14//!
15//! - `w = r * r           w ~ 1 / m`
16//! - `u = 3 - m * w       u ~ 3 - m * w = 3 - m / m = 2`
17//! - `r = r * u / 2       r ~ r`
18//!
19//! (Note that the righthand column does not show anything analytically meaningful (i.e. r ~ r),
20//! since the value of performing one iteration is in reducing the error representable by `~`).
21//!
22//! Instead of Newton-Raphson iterations, Goldschmidt iterations are used to calculate
23//! `s = m * r`:
24//!
25//! - `s = m * r           s ~ m / sqrt(m)`
26//! - `u = 3 - s * r       u ~ 3 - (m / sqrt(m)) * (1 / sqrt(m)) = 3 - m / m = 2`
27//! - `r = r * u / 2       r ~ r`
28//! - `s = s * u / 2       s ~ s`
29//!
30//! The above is precise because it uses the original value `m`. There is also a faster version
31//! that performs fewer steps but does not use `m`:
32//!
33//! - `u = 3 - s * r       u ~ 3 - 1`
34//! - `r = r * u / 2       r ~ r`
35//! - `s = s * u / 2       s ~ s`
36//!
37//! Rounding errors accumulate faster with the second version, so it is only used for subsequent
38//! iterations within the same width integer. The first version is always used for the first
39//! iteration at a new width in order to avoid this accumulation.
40//!
41//! Goldschmidt has the advantage over Newton-Raphson that `sqrt(x)` and `1/sqrt(x)` are
42//! computed at the same time, i.e. there is no need to calculate `1/sqrt(x)` and invert it.
43
44use crate::support::{
45    CastFrom, CastInto, DInt, Float, FpResult, HInt, Int, IntTy, MinInt, Round, Status, cold_path,
46};
47
48#[inline]
49pub fn sqrt<F>(x: F) -> F
50where
51    F: Float + SqrtHelper,
52    F::Int: HInt,
53    F::Int: From<u8>,
54    F::Int: From<F::ISet2>,
55    F::Int: CastInto<F::ISet1>,
56    F::Int: CastInto<F::ISet2>,
57    u32: CastInto<F::Int>,
58{
59    sqrt_round(x, Round::Nearest).val
60}
61
62#[inline]
63pub fn sqrt_round<F>(x: F, _round: Round) -> FpResult<F>
64where
65    F: Float + SqrtHelper,
66    F::Int: HInt,
67    F::Int: From<u8>,
68    F::Int: From<F::ISet2>,
69    F::Int: CastInto<F::ISet1>,
70    F::Int: CastInto<F::ISet2>,
71    u32: CastInto<F::Int>,
72{
73    let zero = IntTy::<F>::ZERO;
74    let one = IntTy::<F>::ONE;
75
76    let mut ix = x.to_bits();
77
78    // Top is the exponent and sign, which may or may not be shifted. If the float fits into a
79    // `u32`, we can get by without paying shifting costs.
80    let noshift = F::BITS <= u32::BITS;
81    let (mut top, special_case) = if noshift {
82        let exp_lsb = one << F::SIG_BITS;
83        let special_case = ix.wrapping_sub(exp_lsb) >= F::EXP_MASK - exp_lsb;
84        (Exp::NoShift(()), special_case)
85    } else {
86        let top = u32::cast_from(ix >> F::SIG_BITS);
87        let special_case = top.wrapping_sub(1) >= F::EXP_SAT - 1;
88        (Exp::Shifted(top), special_case)
89    };
90
91    // Handle NaN, zero, and out of domain (<= 0)
92    if special_case {
93        cold_path();
94
95        // +/-0
96        if ix << 1 == zero {
97            return FpResult::ok(x);
98        }
99
100        // Positive infinity
101        if ix == F::EXP_MASK {
102            return FpResult::ok(x);
103        }
104
105        // NaN or negative
106        if ix > F::EXP_MASK {
107            return FpResult::new(F::NAN, Status::INVALID);
108        }
109
110        // Normalize subnormals by multiplying by 1.0 << SIG_BITS (e.g. 0x1p52 for doubles).
111        let scaled = x * F::from_parts(false, F::SIG_BITS + F::EXP_BIAS, zero);
112        ix = scaled.to_bits();
113        match top {
114            Exp::Shifted(ref mut v) => {
115                *v = scaled.ex();
116                *v = (*v).wrapping_sub(F::SIG_BITS);
117            }
118            Exp::NoShift(()) => {
119                ix = ix.wrapping_sub((F::SIG_BITS << F::SIG_BITS).cast());
120            }
121        }
122    }
123
124    // Reduce arguments such that `x = 4^e * m`:
125    //
126    // - m_u2 ∈ [1, 4), a fixed point U2.BITS number
127    // - 2^e is the exponent part of the result
128    let (m_u2, exp) = match top {
129        Exp::Shifted(top) => {
130            // We now know `x` is positive, so `top` is just its (biased) exponent
131            let mut e = top;
132            // Construct a fixed point representation of the mantissa.
133            let mut m_u2 = (ix | F::IMPLICIT_BIT) << F::EXP_BITS;
134            let even = (e & 1) != 0;
135            if even {
136                m_u2 >>= 1;
137            }
138            e = (e.wrapping_add(F::EXP_SAT >> 1)) >> 1;
139            (m_u2, Exp::Shifted(e))
140        }
141        Exp::NoShift(()) => {
142            let even = ix & (one << F::SIG_BITS) != zero;
143
144            // Exponent part of the return value
145            let mut e_noshift = ix >> 1;
146            // ey &= (F::EXP_MASK << 2) >> 2; // clear the top exponent bit (result = 1.0)
147            e_noshift += (F::EXP_MASK ^ (F::SIGN_MASK >> 1)) >> 1;
148            e_noshift &= F::EXP_MASK;
149
150            let m1 = (ix << F::EXP_BITS) | F::SIGN_MASK;
151            let m0 = (ix << (F::EXP_BITS - 1)) & !F::SIGN_MASK;
152            let m_u2 = if even { m0 } else { m1 };
153
154            (m_u2, Exp::NoShift(e_noshift))
155        }
156    };
157
158    // Extract the top 6 bits of the significand with the lowest bit of the exponent.
159    let i = usize::cast_from(ix >> (F::SIG_BITS - 6)) & 0b1111111;
160
161    // Start with an initial guess for `r = 1 / sqrt(m)` from the table, and shift `m` as an
162    // initial value for `s = sqrt(m)`. See the module documentation for details.
163    let r1_u0: F::ISet1 = F::ISet1::cast_from(RSQRT_TAB[i]) << (F::ISet1::BITS - 16);
164    let s1_u2: F::ISet1 = ((m_u2) >> (F::BITS - F::ISet1::BITS)).cast();
165
166    // Perform iterations, if any, at quarter width (used for `f128`).
167    let (r1_u0, _s1_u2) = goldschmidt::<F, F::ISet1>(r1_u0, s1_u2, F::SET1_ROUNDS, false);
168
169    // Widen values and perform iterations at half width (used for `f64` and `f128`).
170    let r2_u0: F::ISet2 = F::ISet2::from(r1_u0) << (F::ISet2::BITS - F::ISet1::BITS);
171    let s2_u2: F::ISet2 = ((m_u2) >> (F::BITS - F::ISet2::BITS)).cast();
172    let (r2_u0, _s2_u2) = goldschmidt::<F, F::ISet2>(r2_u0, s2_u2, F::SET2_ROUNDS, false);
173
174    // Perform final iterations at full width (used for all float types).
175    let r_u0: F::Int = F::Int::from(r2_u0) << (F::BITS - F::ISet2::BITS);
176    let s_u2: F::Int = m_u2;
177    let (_r_u0, s_u2) = goldschmidt::<F, F::Int>(r_u0, s_u2, F::FINAL_ROUNDS, true);
178
179    // Shift back to mantissa position.
180    let mut m = s_u2 >> (F::EXP_BITS - 2);
181
182    // The musl source includes the following comment (with literals replaced):
183    //
184    // > s < sqrt(m) < s + 0x1.09p-SIG_BITS
185    // > compute nearest rounded result: the nearest result to SIG_BITS bits is either s or
186    // > s+0x1p-SIG_BITS, we can decide by comparing (2^SIG_BITS s + 0.5)^2 to 2^(2*SIG_BITS) m.
187    //
188    // Expanding this with , with `SIG_BITS = p` and adjusting based on the operations done to
189    // `d0` and `d1`:
190    //
191    // - `2^(2p)m ≟ ((2^p)m + 0.5)^2`
192    // - `2^(2p)m ≟ 2^(2p)m^2 + (2^p)m + 0.25`
193    // - `2^(2p)m - m^2 ≟ (2^(2p) - 1)m^2 + (2^p)m + 0.25`
194    // - `(1 - 2^(2p))m + m^2 ≟ (1 - 2^(2p))m^2 + (1 - 2^p)m + 0.25` (?)
195    //
196    // I do not follow how the rounding bit is extracted from this comparison with the below
197    // operations. In any case, the algorithm is well tested.
198
199    // The value needed to shift `m_u2` by to create `m*2^(2p)`. `2p = 2 * F::SIG_BITS`,
200    // `F::BITS - 2` accounts for the offset that `m_u2` already has.
201    let shift = 2 * F::SIG_BITS - (F::BITS - 2);
202
203    // `2^(2p)m - m^2`
204    let d0 = (m_u2 << shift).wrapping_sub(m.wrapping_mul(m));
205    // `m - 2^(2p)m + m^2`
206    let d1 = m.wrapping_sub(d0);
207    m += d1 >> (F::BITS - 1);
208    m &= F::SIG_MASK;
209
210    match exp {
211        Exp::Shifted(e) => m |= IntTy::<F>::cast_from(e) << F::SIG_BITS,
212        Exp::NoShift(e) => m |= e,
213    };
214
215    let mut y = F::from_bits(m);
216
217    // FIXME(f16): the fenv math does not work for `f16`
218    if F::BITS > 16 {
219        // Handle rounding and inexact. `(m + 1)^2 == 2^shift m` is exact; for all other cases, add
220        // a tiny value to cause fenv effects.
221        let d2 = d1.wrapping_add(m).wrapping_add(one);
222        let mut tiny = if d2 == zero {
223            cold_path();
224            zero
225        } else {
226            F::IMPLICIT_BIT
227        };
228
229        tiny |= (d1 ^ d2) & F::SIGN_MASK;
230        let t = F::from_bits(tiny);
231        y = y + t;
232    }
233
234    FpResult::ok(y)
235}
236
237/// Multiply at the wider integer size, returning the high half.
238fn wmulh<I: HInt>(a: I, b: I) -> I {
239    a.widen_mul(b).hi()
240}
241
242/// Perform `count` goldschmidt iterations, returning `(r_u0, s_u?)`.
243///
244/// - `r_u0` is the reciprocal `r ~ 1 / sqrt(m)`, as U0.
245/// - `s_u2` is the square root, `s ~ sqrt(m)`, as U2.
246/// - `count` is the number of iterations to perform.
247/// - `final_set` should be true if this is the last round (same-sized integer). If so, the
248///   returned `s` will be U3, for later shifting. Otherwise, the returned `s` is U2.
249///
250/// Note that performance relies on the optimizer being able to unroll these loops (reasonably
251/// trivial, `count` is a constant when called).
252#[inline]
253fn goldschmidt<F, I>(mut r_u0: I, mut s_u2: I, count: u32, final_set: bool) -> (I, I)
254where
255    F: SqrtHelper,
256    I: HInt + From<u8>,
257{
258    let three_u2 = I::from(0b11u8) << (I::BITS - 2);
259    let mut u_u0 = r_u0;
260
261    for i in 0..count {
262        // First iteration: `s = m*r` (`u_u0 = r_u0` set above)
263        // Subsequent iterations: `s=s*u/2`
264        s_u2 = wmulh(s_u2, u_u0);
265
266        // Perform `s /= 2` if:
267        //
268        // 1. This is not the first iteration (the first iteration is `s = m*r`)...
269        // 2. ... and this is not the last set of iterations
270        // 3. ... or, if this is the last set, it is not the last iteration
271        //
272        // This step is not performed for the final iteration because the shift is combined with
273        // a later shift (moving `s` into the mantissa).
274        if i > 0 && (!final_set || i + 1 < count) {
275            s_u2 <<= 1;
276        }
277
278        // u = 3 - s*r
279        let d_u2 = wmulh(s_u2, r_u0);
280        u_u0 = three_u2.wrapping_sub(d_u2);
281
282        // r = r*u/2
283        r_u0 = wmulh(r_u0, u_u0) << 1;
284    }
285
286    (r_u0, s_u2)
287}
288
289/// Representation of whether we shift the exponent into a `u32`, or modify it in place to save
290/// the shift operations.
291enum Exp<T> {
292    /// The exponent has been shifted to a `u32` and is LSB-aligned.
293    Shifted(u32),
294    /// The exponent is in its natural position in integer repr.
295    NoShift(T),
296}
297
298/// Size-specific constants related to the square root routine.
299pub trait SqrtHelper: Float {
300    /// Integer for the first set of rounds. If unused, set to the same type as the next set.
301    type ISet1: HInt + Into<Self::ISet2> + CastFrom<Self::Int> + From<u8>;
302    /// Integer for the second set of rounds. If unused, set to the same type as the next set.
303    type ISet2: HInt + From<Self::ISet1> + From<u8>;
304
305    /// Number of rounds at `ISet1`.
306    const SET1_ROUNDS: u32 = 0;
307    /// Number of rounds at `ISet2`.
308    const SET2_ROUNDS: u32 = 0;
309    /// Number of rounds at `Self::Int`.
310    const FINAL_ROUNDS: u32;
311}
312
313#[cfg(f16_enabled)]
314impl SqrtHelper for f16 {
315    type ISet1 = u16; // unused
316    type ISet2 = u16; // unused
317
318    const FINAL_ROUNDS: u32 = 2;
319}
320
321impl SqrtHelper for f32 {
322    type ISet1 = u32; // unused
323    type ISet2 = u32; // unused
324
325    const FINAL_ROUNDS: u32 = 3;
326}
327
328impl SqrtHelper for f64 {
329    type ISet1 = u32; // unused
330    type ISet2 = u32;
331
332    const SET2_ROUNDS: u32 = 2;
333    const FINAL_ROUNDS: u32 = 2;
334}
335
336#[cfg(f128_enabled)]
337impl SqrtHelper for f128 {
338    type ISet1 = u32;
339    type ISet2 = u64;
340
341    const SET1_ROUNDS: u32 = 1;
342    const SET2_ROUNDS: u32 = 2;
343    const FINAL_ROUNDS: u32 = 2;
344}
345
346/// A U0.16 representation of `1/sqrt(x)`.
347///
348/// The index is a 7-bit number consisting of a single exponent bit and 6 bits of significand.
349#[rustfmt::skip]
350static RSQRT_TAB: [u16; 128] = [
351    0xb451, 0xb2f0, 0xb196, 0xb044, 0xaef9, 0xadb6, 0xac79, 0xab43,
352    0xaa14, 0xa8eb, 0xa7c8, 0xa6aa, 0xa592, 0xa480, 0xa373, 0xa26b,
353    0xa168, 0xa06a, 0x9f70, 0x9e7b, 0x9d8a, 0x9c9d, 0x9bb5, 0x9ad1,
354    0x99f0, 0x9913, 0x983a, 0x9765, 0x9693, 0x95c4, 0x94f8, 0x9430,
355    0x936b, 0x92a9, 0x91ea, 0x912e, 0x9075, 0x8fbe, 0x8f0a, 0x8e59,
356    0x8daa, 0x8cfe, 0x8c54, 0x8bac, 0x8b07, 0x8a64, 0x89c4, 0x8925,
357    0x8889, 0x87ee, 0x8756, 0x86c0, 0x862b, 0x8599, 0x8508, 0x8479,
358    0x83ec, 0x8361, 0x82d8, 0x8250, 0x81c9, 0x8145, 0x80c2, 0x8040,
359    0xff02, 0xfd0e, 0xfb25, 0xf947, 0xf773, 0xf5aa, 0xf3ea, 0xf234,
360    0xf087, 0xeee3, 0xed47, 0xebb3, 0xea27, 0xe8a3, 0xe727, 0xe5b2,
361    0xe443, 0xe2dc, 0xe17a, 0xe020, 0xdecb, 0xdd7d, 0xdc34, 0xdaf1,
362    0xd9b3, 0xd87b, 0xd748, 0xd61a, 0xd4f1, 0xd3cd, 0xd2ad, 0xd192,
363    0xd07b, 0xcf69, 0xce5b, 0xcd51, 0xcc4a, 0xcb48, 0xca4a, 0xc94f,
364    0xc858, 0xc764, 0xc674, 0xc587, 0xc49d, 0xc3b7, 0xc2d4, 0xc1f4,
365    0xc116, 0xc03c, 0xbf65, 0xbe90, 0xbdbe, 0xbcef, 0xbc23, 0xbb59,
366    0xba91, 0xb9cc, 0xb90a, 0xb84a, 0xb78c, 0xb6d0, 0xb617, 0xb560,
367];
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    /// Test behavior specified in IEEE 754 `squareRoot`.
374    fn spec_test<F>()
375    where
376        F: Float + SqrtHelper,
377        F::Int: HInt,
378        F::Int: From<u8>,
379        F::Int: From<F::ISet2>,
380        F::Int: CastInto<F::ISet1>,
381        F::Int: CastInto<F::ISet2>,
382        u32: CastInto<F::Int>,
383    {
384        // Values that should return a NaN and raise invalid
385        let nan = [F::NEG_INFINITY, F::NEG_ONE, F::NAN, F::MIN];
386
387        // Values that return unaltered
388        let roundtrip = [F::ZERO, F::NEG_ZERO, F::INFINITY];
389
390        for x in nan {
391            let FpResult { val, status } = sqrt_round(x, Round::Nearest);
392            assert!(val.is_nan());
393            assert!(status == Status::INVALID);
394        }
395
396        for x in roundtrip {
397            let FpResult { val, status } = sqrt_round(x, Round::Nearest);
398            assert_biteq!(val, x);
399            assert!(status == Status::OK);
400        }
401    }
402
403    #[test]
404    #[cfg(f16_enabled)]
405    fn sanity_check_f16() {
406        assert_biteq!(sqrt(100.0f16), 10.0);
407        assert_biteq!(sqrt(4.0f16), 2.0);
408    }
409
410    #[test]
411    #[cfg(f16_enabled)]
412    fn spec_tests_f16() {
413        spec_test::<f16>();
414    }
415
416    #[test]
417    #[cfg(f16_enabled)]
418    #[allow(clippy::approx_constant)]
419    fn conformance_tests_f16() {
420        let cases = [
421            (f16::PI, 0x3f17_u16),
422            // 10_000.0, using a hex literal for MSRV hack (Rust < 1.67 checks literal widths as
423            // part of the AST, so the `cfg` is irrelevant here).
424            (f16::from_bits(0x70e2), 0x5640_u16),
425            (f16::from_bits(0x0000000f), 0x13bf_u16),
426            (f16::INFINITY, f16::INFINITY.to_bits()),
427        ];
428
429        for (input, output) in cases {
430            assert_biteq!(
431                sqrt(input),
432                f16::from_bits(output),
433                "input: {input:?} ({:#018x})",
434                input.to_bits()
435            );
436        }
437    }
438
439    #[test]
440    fn sanity_check_f32() {
441        assert_biteq!(sqrt(100.0f32), 10.0);
442        assert_biteq!(sqrt(4.0f32), 2.0);
443    }
444
445    #[test]
446    fn spec_tests_f32() {
447        spec_test::<f32>();
448    }
449
450    #[test]
451    #[allow(clippy::approx_constant)]
452    fn conformance_tests_f32() {
453        let cases = [
454            (f32::PI, 0x3fe2dfc5_u32),
455            (10000.0f32, 0x42c80000_u32),
456            (f32::from_bits(0x0000000f), 0x1b2f456f_u32),
457            (f32::INFINITY, f32::INFINITY.to_bits()),
458        ];
459
460        for (input, output) in cases {
461            assert_biteq!(
462                sqrt(input),
463                f32::from_bits(output),
464                "input: {input:?} ({:#018x})",
465                input.to_bits()
466            );
467        }
468    }
469
470    #[test]
471    fn sanity_check_f64() {
472        assert_biteq!(sqrt(100.0f64), 10.0);
473        assert_biteq!(sqrt(4.0f64), 2.0);
474    }
475
476    #[test]
477    fn spec_tests_f64() {
478        spec_test::<f64>();
479    }
480
481    #[test]
482    #[allow(clippy::approx_constant)]
483    fn conformance_tests_f64() {
484        let cases = [
485            (f64::PI, 0x3ffc5bf891b4ef6a_u64),
486            (10000.0, 0x4059000000000000_u64),
487            (f64::from_bits(0x0000000f), 0x1e7efbdeb14f4eda_u64),
488            (f64::INFINITY, f64::INFINITY.to_bits()),
489        ];
490
491        for (input, output) in cases {
492            assert_biteq!(
493                sqrt(input),
494                f64::from_bits(output),
495                "input: {input:?} ({:#018x})",
496                input.to_bits()
497            );
498        }
499    }
500
501    #[test]
502    #[cfg(f128_enabled)]
503    fn sanity_check_f128() {
504        assert_biteq!(sqrt(100.0f128), 10.0);
505        assert_biteq!(sqrt(4.0f128), 2.0);
506    }
507
508    #[test]
509    #[cfg(f128_enabled)]
510    fn spec_tests_f128() {
511        spec_test::<f128>();
512    }
513
514    #[test]
515    #[cfg(f128_enabled)]
516    #[allow(clippy::approx_constant)]
517    fn conformance_tests_f128() {
518        let cases = [
519            (f128::PI, 0x3fffc5bf891b4ef6aa79c3b0520d5db9_u128),
520            // 10_000.0, see `f16` for reasoning.
521            (
522                f128::from_bits(0x400c3880000000000000000000000000),
523                0x40059000000000000000000000000000_u128,
524            ),
525            (
526                f128::from_bits(0x0000000f),
527                0x1fc9efbdeb14f4ed9b17ae807907e1e9_u128,
528            ),
529            (f128::INFINITY, f128::INFINITY.to_bits()),
530        ];
531
532        for (input, output) in cases {
533            assert_biteq!(
534                sqrt(input),
535                f128::from_bits(output),
536                "input: {input:?} ({:#018x})",
537                input.to_bits()
538            );
539        }
540    }
541}