From cec3240673b691b2b1472ed3108162c662f53dad Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 11:36:34 -0700 Subject: [PATCH 01/18] add sqrt from miri + tests --- src/lib.rs | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++ tests/ieee.rs | 18 +++++++++++ 2 files changed, 100 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 87de8a0..17ff340 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -484,6 +484,88 @@ pub trait Float: } } + fn sqrt(self) -> Self { + match self.category() { + // preserve zero sign + Category::Zero => self, + // propagate NaN + Category::NaN => self, + // sqrt of negative number is NaN + _ if self.is_negative() => Self::NAN, + // sqrt(∞) = ∞ + Category::Infinity => Self::INFINITY, + Category::Normal => { + // Floating point precision, excluding the integer bit + let prec = i32::try_from(Self::PRECISION).unwrap() - 1; + + // x = 2^(exp - prec) * mant + // where mant is an integer with prec+1 bits + // mant is a u128, which should be large enough for the largest prec (112 for f128) + let mut exp = self.ilogb(); + let mut mant = self.scalbn(prec - exp).to_u128(128).value; + + if exp % 2 != 0 { + // Make exponent even, so it can be divided by 2 + exp -= 1; + mant <<= 1; + } + + // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. + // mant is treated here as a fixed point number with prec fractional bits. + // mant will be shifted left by one bit to have an extra fractional bit, which + // will be used to determine the rounding direction. + + // res is the truncated sqrt of mant, where one bit is added at each iteration. + let mut res = 0u128; + // rem is the remainder with the current res + // rem_i = 2^i * ((mant<<1) - res_i^2) + // starting with res = 0, rem = mant<<1 + let mut rem = mant << 1; + // s_i = 2*res_i + let mut s = 0u128; + // d is used to iterate over bits, from high to low (d_i = 2^(-i)) + let mut d = 1u128 << (prec + 1); + + // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that + // (res_i + b_j * 2^(-j))^2 <= mant<<1 + // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: + // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 + // And rearranging the terms: + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i + + while d != 0 { + // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: + // t = 2*res_i + 2^(-j) + let t = s + d; + if rem >= t { + // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem + res += d; + s += d + d; + rem -= t; + } + // Adjust rem for next iteration + rem <<= 1; + // Shift iterator + d >>= 1; + } + + // Remove extra fractional bit from result, rounding to nearest. + // If the last bit is 0, then the nearest neighbor is definitely the lower one. + // If the last bit is 1, it sounds like this may either be a tie (if there's + // infinitely many 0s after this 1), or the nearest neighbor is the upper one. + // However, since square roots are either exact or irrational, and an exact root + // would lead to the last "extra" bit being 0, we can exclude a tie in this case. + // We therefore always round up if the last bit is 1. When the last bit is 0, + // adding 1 will not do anything since the shift will discard it. + res = (res + 1) >> 1; + + // Build resulting value with res as mantissa and exp/2 as exponent + Self::from_u128(res).value.scalbn(exp / 2 - prec) + } + } + } + /// IEEE-754R isSignMinus: Returns true if and only if the current value is /// negative. /// diff --git a/tests/ieee.rs b/tests/ieee.rs index 0c356d9..f646ff4 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -795,6 +795,24 @@ fn maximum() { assert!(nan.maximum(f1).to_f64().is_nan()); } +#[test] +fn sqrt() { + assert_eq!(64_f32.sqrt(), 8_f32); + assert_eq!(64_f64.sqrt(), 8_f64); + assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); + assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); + assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert!((-5.0_f32).sqrt().is_nan()); + assert!((-5.0_f64).sqrt().is_nan()); + assert!(f32::NEG_INFINITY.sqrt().is_nan()); + assert!(f64::NEG_INFINITY.sqrt().is_nan()); + assert!(f32::NAN.sqrt().is_nan()); + assert!(f64::NAN.sqrt().is_nan()); +} + #[test] fn denormal() { // Test single precision From 9108da8127cfc472dd9dd81ab324068514b0dd02 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 12:40:09 -0700 Subject: [PATCH 02/18] actually test double --- tests/ieee.rs | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/tests/ieee.rs b/tests/ieee.rs index f646ff4..3d84f0f 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -2,6 +2,7 @@ extern crate rustc_apfloat; use core::cmp::Ordering; +use std::ops::Neg; use rustc_apfloat::ieee::{BFloat, Double, Float8E4M3FN, Float8E5M2, Half, Quad, Single, X87DoubleExtended}; use rustc_apfloat::{Category, ExpInt, IEK_INF, IEK_NAN, IEK_ZERO}; use rustc_apfloat::{Float, FloatConvert, Round, Status}; @@ -797,20 +798,18 @@ fn maximum() { #[test] fn sqrt() { - assert_eq!(64_f32.sqrt(), 8_f32); - assert_eq!(64_f64.sqrt(), 8_f64); - assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY); - assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY); - assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert!((-5.0_f32).sqrt().is_nan()); - assert!((-5.0_f64).sqrt().is_nan()); - assert!(f32::NEG_INFINITY.sqrt().is_nan()); - assert!(f64::NEG_INFINITY.sqrt().is_nan()); - assert!(f32::NAN.sqrt().is_nan()); - assert!(f64::NAN.sqrt().is_nan()); + let f1 = Double::from_f64(64.); + let f2 = Double::from_f64(8.); + let infinity = Double::INFINITY; + let nan = Double::NAN; + let negative_infinity = Double::INFINITY.neg(); + assert_eq!(f1.sqrt().to_f64(), f2.to_f64()); + assert_eq!(infinity.sqrt().to_f64(), infinity.to_f64()); + assert_eq!(Double::ZERO.sqrt().to_f64().total_cmp(&0.0), std::cmp::Ordering::Equal); + assert_eq!((-Double::ZERO).sqrt().to_f64().total_cmp(&-0.0), std::cmp::Ordering::Equal); + assert!((-Double::from_f64(5.0)).sqrt().is_nan()); + assert!(negative_infinity.sqrt().is_nan()); + assert!(nan.sqrt().is_nan()); } #[test] From 0a204b9f8c7950bb9675884ae9f3300a4d760620 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 12:41:30 -0700 Subject: [PATCH 03/18] formatting --- tests/ieee.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/ieee.rs b/tests/ieee.rs index 3d84f0f..c1e4734 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -2,10 +2,10 @@ extern crate rustc_apfloat; use core::cmp::Ordering; -use std::ops::Neg; use rustc_apfloat::ieee::{BFloat, Double, Float8E4M3FN, Float8E5M2, Half, Quad, Single, X87DoubleExtended}; use rustc_apfloat::{Category, ExpInt, IEK_INF, IEK_NAN, IEK_ZERO}; use rustc_apfloat::{Float, FloatConvert, Round, Status}; +use std::ops::Neg; // FIXME(eddyb) maybe include this in `rustc_apfloat` itself? macro_rules! define_for_each_float_type { @@ -807,7 +807,7 @@ fn sqrt() { assert_eq!(infinity.sqrt().to_f64(), infinity.to_f64()); assert_eq!(Double::ZERO.sqrt().to_f64().total_cmp(&0.0), std::cmp::Ordering::Equal); assert_eq!((-Double::ZERO).sqrt().to_f64().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert!((-Double::from_f64(5.0)).sqrt().is_nan()); + assert!(Double::from_f64(-5.0).sqrt().is_nan()); assert!(negative_infinity.sqrt().is_nan()); assert!(nan.sqrt().is_nan()); } From b16d17fcf2a2c4a2984d57de31a93eaf993d88e4 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 12:43:13 -0700 Subject: [PATCH 04/18] remove unnecessary variables --- tests/ieee.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/ieee.rs b/tests/ieee.rs index c1e4734..62da4ee 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -800,16 +800,13 @@ fn maximum() { fn sqrt() { let f1 = Double::from_f64(64.); let f2 = Double::from_f64(8.); - let infinity = Double::INFINITY; - let nan = Double::NAN; - let negative_infinity = Double::INFINITY.neg(); assert_eq!(f1.sqrt().to_f64(), f2.to_f64()); - assert_eq!(infinity.sqrt().to_f64(), infinity.to_f64()); + assert_eq!(Double::INFINITY.sqrt().to_f64(), f64::INFINITY); assert_eq!(Double::ZERO.sqrt().to_f64().total_cmp(&0.0), std::cmp::Ordering::Equal); assert_eq!((-Double::ZERO).sqrt().to_f64().total_cmp(&-0.0), std::cmp::Ordering::Equal); assert!(Double::from_f64(-5.0).sqrt().is_nan()); - assert!(negative_infinity.sqrt().is_nan()); - assert!(nan.sqrt().is_nan()); + assert!(Double::INFINITY.neg().sqrt().is_nan()); + assert!(Double::NAN.sqrt().is_nan()); } #[test] From 816b4ea32b8405d7bdb23bfa068c5da9c0f8fd18 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 12:53:52 -0700 Subject: [PATCH 05/18] add check against host cpu sqrt --- tests/ieee.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/ieee.rs b/tests/ieee.rs index 62da4ee..02e3eae 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -801,6 +801,8 @@ fn sqrt() { let f1 = Double::from_f64(64.); let f2 = Double::from_f64(8.); assert_eq!(f1.sqrt().to_f64(), f2.to_f64()); + assert_eq!(f1.sqrt().to_f64(), 64_f64.sqrt()); + assert_eq!(f2.sqrt().to_f64(), 8_f64.sqrt()); assert_eq!(Double::INFINITY.sqrt().to_f64(), f64::INFINITY); assert_eq!(Double::ZERO.sqrt().to_f64().total_cmp(&0.0), std::cmp::Ordering::Equal); assert_eq!((-Double::ZERO).sqrt().to_f64().total_cmp(&-0.0), std::cmp::Ordering::Equal); From b2fe989bab69a86da3bca9e45851be59c26df2d1 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 13:10:59 -0700 Subject: [PATCH 06/18] check all float versions --- tests/ieee.rs | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/ieee.rs b/tests/ieee.rs index 02e3eae..cfb40e3 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -5,7 +5,6 @@ use core::cmp::Ordering; use rustc_apfloat::ieee::{BFloat, Double, Float8E4M3FN, Float8E5M2, Half, Quad, Single, X87DoubleExtended}; use rustc_apfloat::{Category, ExpInt, IEK_INF, IEK_NAN, IEK_ZERO}; use rustc_apfloat::{Float, FloatConvert, Round, Status}; -use std::ops::Neg; // FIXME(eddyb) maybe include this in `rustc_apfloat` itself? macro_rules! define_for_each_float_type { @@ -798,17 +797,27 @@ fn maximum() { #[test] fn sqrt() { - let f1 = Double::from_f64(64.); - let f2 = Double::from_f64(8.); - assert_eq!(f1.sqrt().to_f64(), f2.to_f64()); - assert_eq!(f1.sqrt().to_f64(), 64_f64.sqrt()); - assert_eq!(f2.sqrt().to_f64(), 8_f64.sqrt()); - assert_eq!(Double::INFINITY.sqrt().to_f64(), f64::INFINITY); - assert_eq!(Double::ZERO.sqrt().to_f64().total_cmp(&0.0), std::cmp::Ordering::Equal); - assert_eq!((-Double::ZERO).sqrt().to_f64().total_cmp(&-0.0), std::cmp::Ordering::Equal); - assert!(Double::from_f64(-5.0).sqrt().is_nan()); - assert!(Double::INFINITY.neg().sqrt().is_nan()); - assert!(Double::NAN.sqrt().is_nan()); + for_each_float_type!(for test::()); + fn test() { + assert!(F::ZERO.sqrt().bitwise_eq(F::ZERO)); + assert!((-F::ZERO).sqrt().bitwise_eq(-F::ZERO)); + assert!(F::INFINITY.sqrt().bitwise_eq(F::INFINITY)); + assert!(F::NAN.sqrt().is_nan()); + assert!((-F::INFINITY).sqrt().is_nan()); + assert!((-F::from_u128(5).value).sqrt().is_nan()); + let one = F::from_u128(1).value; + assert!(one.sqrt().bitwise_eq(one)); + let f1 = F::from_u128(64).value; + let f2 = F::from_u128(8).value; + assert!(f1.sqrt().bitwise_eq(f2)); + } + + // Cross-check against host sqrt + // TODO: This is only f32 and f64 for now since they are stable, update to other values later. + for x in 0..1000 { + assert_eq!(Single::from_u128(x).value.sqrt().to_f32(), f32::sqrt(x as f32)); + assert_eq!(Double::from_u128(x).value.sqrt().to_f64(), f64::sqrt(x as f64)); + } } #[test] From 3fc058aedcc64b8f10e248d4c36e0dd24ee8dd52 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Fri, 27 Mar 2026 13:12:02 -0700 Subject: [PATCH 07/18] minor formatting --- tests/ieee.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ieee.rs b/tests/ieee.rs index cfb40e3..ab1efe9 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -814,7 +814,7 @@ fn sqrt() { // Cross-check against host sqrt // TODO: This is only f32 and f64 for now since they are stable, update to other values later. - for x in 0..1000 { + for x in 0..1000 { assert_eq!(Single::from_u128(x).value.sqrt().to_f32(), f32::sqrt(x as f32)); assert_eq!(Double::from_u128(x).value.sqrt().to_f64(), f64::sqrt(x as f64)); } From f08e8bea811cbbb2a80d8902a6e7167781d252c5 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Wed, 1 Apr 2026 16:14:29 -0700 Subject: [PATCH 08/18] added comment documenting sqrt behavior --- src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index 17ff340..ae64147 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -484,6 +484,8 @@ pub trait Float: } } + /// Technically this should be StatusAnd, but since sqrt is exact for all supported formats, + /// we can just round to the nearest, ignore exceptions, and return Self. fn sqrt(self) -> Self { match self.category() { // preserve zero sign From 059fc87f3b97f16d6c13f2c7a219922315f31bd2 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Wed, 1 Apr 2026 16:34:15 -0700 Subject: [PATCH 09/18] push and hope fuzz gets triggered by CI --- fuzz/ops.rs | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/fuzz/ops.rs b/fuzz/ops.rs index fb1a45d..6522324 100644 --- a/fuzz/ops.rs +++ b/fuzz/ops.rs @@ -16,6 +16,9 @@ enum OpKind { Binary(char), Ternary(Rust<&'static str>, Cxx<&'static str>), + // method-call ops, this one has the arity of 1 (so only 1 argument) + RustMethodArity1(Rust<&'static str>), + // HACK(eddyb) all other ops have floating-point inputs *and* outputs, so // the easiest way to fuzz conversions from/to other types, even if it won't // cover *all possible* inputs, is to do a round-trip through the other type. @@ -41,7 +44,7 @@ impl Type { impl OpKind { fn inputs<'a, T>(&self, all_inputs: &'a [T; 3]) -> &'a [T] { match self { - Unary(_) | Roundtrip(_) => &all_inputs[..1], + Unary(_) | RustMethodArity1(_) | Roundtrip(_) => &all_inputs[..1], Binary(_) => &all_inputs[..2], Ternary(..) => &all_inputs[..3], } @@ -59,6 +62,9 @@ const OPS: &[(&str, OpKind)] = &[ ("Rem", Binary('%')), // Ternary (`(F, F) -> F`) ops. ("MulAdd", Ternary(Rust("mul_add"), Cxx("fusedMultiplyAdd"))), + // Method-call ops. + // For now, sqrt is Rust-only, there is no C++ `APFloat` equivalent + ("Sqrt", RustMethodArity1(Rust("sqrt"))), // Roundtrip (`F -> T -> F`) ops. ("FToI128ToF", Roundtrip(Type::SInt(128))), ("FToU128ToF", Roundtrip(Type::UInt(128))), @@ -154,6 +160,9 @@ impl FuzzOp Ternary(Rust(method), _) => { format!("{}.{method}({}, {})", inputs[0], inputs[1], inputs[2]) } + RustMethodArity1(Rust(method)) => { + format!("{}.{method}()", inputs[0]) + } Roundtrip(ty) => format!( "<{ty} as num_traits::AsPrimitive::>::as_( >::as_({}))", @@ -189,6 +198,9 @@ impl FuzzOp Ternary(Rust(method), _) => { format!("{}.{method}({}).value", inputs[0], inputs[1..].join(", ")) } + RustMethodArity1(Rust(method)) => { + format!("{}.{method}()", inputs[0]) + } Roundtrip(ty @ (Type::SInt(_) | Type::UInt(_))) => { let (w, i_or_u) = match ty { Type::SInt(w) => (w, "i"), @@ -266,6 +278,16 @@ struct FuzzOp { + &all_ops_map_concat(|_tag, name, kind| { let inputs = kind.inputs(&["a.to_apf()", "b.to_apf()", "c.to_apf()"]); let expr = match kind { + RustMethodArity1(method_name) => { + if method_name.0 == "sqrt" { + // For now, sqrt is the only Rust method-call op, and it has no C++ `APFloat` equivalent + // so don't generate any C++ code for it. + return String::new(); + } else { + unreachable!() + } + } + // HACK(eddyb) `APFloat` doesn't overload `operator%`, so we have // to go through the `mod` method instead. Binary('%') => format!("((r = {}), r.mod({}), r)", inputs[0], inputs[1]), From 1deb61a9c858013cee8dafa63988ce2315141451 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Wed, 1 Apr 2026 16:41:00 -0700 Subject: [PATCH 10/18] formatting --- fuzz/ops.rs | 2 +- src/lib.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fuzz/ops.rs b/fuzz/ops.rs index 6522324..1d8f0b5 100644 --- a/fuzz/ops.rs +++ b/fuzz/ops.rs @@ -286,7 +286,7 @@ struct FuzzOp { } else { unreachable!() } - } + } // HACK(eddyb) `APFloat` doesn't overload `operator%`, so we have // to go through the `mod` method instead. diff --git a/src/lib.rs b/src/lib.rs index ae64147..d5afa3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -484,7 +484,7 @@ pub trait Float: } } - /// Technically this should be StatusAnd, but since sqrt is exact for all supported formats, + /// Technically this should be StatusAnd, but since sqrt is exact for all supported formats, /// we can just round to the nearest, ignore exceptions, and return Self. fn sqrt(self) -> Self { match self.category() { From b6404fe4d046c5fe7f2e73057f26b09a596d76eb Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Wed, 1 Apr 2026 17:28:17 -0700 Subject: [PATCH 11/18] add comment --- src/lib.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index d5afa3f..31000c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -491,6 +491,13 @@ pub trait Float: // preserve zero sign Category::Zero => self, // propagate NaN + // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. + // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. + // On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign, + // the most significant significand bit set and all other significand bits cleared. + // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. + // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns + // (Thanks @programmerjake for the comment) Category::NaN => self, // sqrt of negative number is NaN _ if self.is_negative() => Self::NAN, From 64387f74b2c9173323ae17be5a7824f87502aa8b Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Wed, 1 Apr 2026 21:01:42 -0700 Subject: [PATCH 12/18] code cleanup, make sqrt function ieee only --- fuzz/ops.rs | 14 ++++---- src/ieee.rs | 89 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 91 ++------------------------------------------------- src/ppc.rs | 4 +++ tests/ieee.rs | 39 +++++++++++++++++++++- 5 files changed, 141 insertions(+), 96 deletions(-) diff --git a/fuzz/ops.rs b/fuzz/ops.rs index 1d8f0b5..3befe42 100644 --- a/fuzz/ops.rs +++ b/fuzz/ops.rs @@ -16,8 +16,8 @@ enum OpKind { Binary(char), Ternary(Rust<&'static str>, Cxx<&'static str>), - // method-call ops, this one has the arity of 1 (so only 1 argument) - RustMethodArity1(Rust<&'static str>), + // A operation that's actually a method call in Rust, but only takes one argument (unary). + RustUnary(Rust<&'static str>), // HACK(eddyb) all other ops have floating-point inputs *and* outputs, so // the easiest way to fuzz conversions from/to other types, even if it won't @@ -44,7 +44,7 @@ impl Type { impl OpKind { fn inputs<'a, T>(&self, all_inputs: &'a [T; 3]) -> &'a [T] { match self { - Unary(_) | RustMethodArity1(_) | Roundtrip(_) => &all_inputs[..1], + Unary(_) | RustUnary(_) | Roundtrip(_) => &all_inputs[..1], Binary(_) => &all_inputs[..2], Ternary(..) => &all_inputs[..3], } @@ -64,7 +64,7 @@ const OPS: &[(&str, OpKind)] = &[ ("MulAdd", Ternary(Rust("mul_add"), Cxx("fusedMultiplyAdd"))), // Method-call ops. // For now, sqrt is Rust-only, there is no C++ `APFloat` equivalent - ("Sqrt", RustMethodArity1(Rust("sqrt"))), + ("Sqrt", RustUnary(Rust("sqrt"))), // Roundtrip (`F -> T -> F`) ops. ("FToI128ToF", Roundtrip(Type::SInt(128))), ("FToU128ToF", Roundtrip(Type::UInt(128))), @@ -160,7 +160,7 @@ impl FuzzOp Ternary(Rust(method), _) => { format!("{}.{method}({}, {})", inputs[0], inputs[1], inputs[2]) } - RustMethodArity1(Rust(method)) => { + RustUnary(Rust(method)) => { format!("{}.{method}()", inputs[0]) } Roundtrip(ty) => format!( @@ -198,7 +198,7 @@ impl FuzzOp Ternary(Rust(method), _) => { format!("{}.{method}({}).value", inputs[0], inputs[1..].join(", ")) } - RustMethodArity1(Rust(method)) => { + RustUnary(Rust(method)) => { format!("{}.{method}()", inputs[0]) } Roundtrip(ty @ (Type::SInt(_) | Type::UInt(_))) => { @@ -278,7 +278,7 @@ struct FuzzOp { + &all_ops_map_concat(|_tag, name, kind| { let inputs = kind.inputs(&["a.to_apf()", "b.to_apf()", "c.to_apf()"]); let expr = match kind { - RustMethodArity1(method_name) => { + RustUnary(method_name) => { if method_name.0 == "sqrt" { // For now, sqrt is the only Rust method-call op, and it has no C++ `APFloat` equivalent // so don't generate any C++ code for it. diff --git a/src/ieee.rs b/src/ieee.rs index 9b8c5a0..eeef2ec 100644 --- a/src/ieee.rs +++ b/src/ieee.rs @@ -1563,6 +1563,95 @@ impl Float for IeeeFloat { } } + fn sqrt(self) -> Self { + match self.category() { + // preserve zero sign + Category::Zero => return self, + // propagate NaN + // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. + // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. + // On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign, + // the most significant significand bit set and all other significand bits cleared. + // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. + // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns + // (Thanks @programmerjake for the comment) + Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self).value, + // sqrt of negative number is NaN + _ if self.is_negative() => return Self::NAN, + // sqrt(inf) = inf + Category::Infinity => return Self::INFINITY, + Category::Normal => (), + } + + // Floating point precision, excluding the integer bit. + let prec = i32::try_from(Self::PRECISION).unwrap() - 1; + + // x = 2^(exp - prec) * mant + // where mant is an integer with prec+1 bits. + // mant is a u128, which is large enough for the largest prec (112 for f128). + let mut exp = self.ilogb(); + let mut mant = self.scalbn(prec - exp).to_u128(128).value; + + if exp % 2 != 0 { + // Make exponent even, so it can be divided by 2. + exp -= 1; + mant <<= 1; + } + + // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. + // mant is treated here as a fixed point number with prec fractional bits. + // mant will be shifted left by one bit to have an extra fractional bit, which + // will be used to determine the rounding direction. + + // res is the truncated sqrt of mant, where one bit is added at each iteration. + let mut res = 0u128; + // rem is the remainder with the current res + // rem_i = 2^i * ((mant<<1) - res_i^2) + // starting with res = 0, rem = mant<<1 + let mut rem = mant << 1; + // s_i = 2*res_i + let mut s = 0u128; + // d is used to iterate over bits, from high to low (d_i = 2^(-i)) + let mut d = 1u128 << (prec + 1); + + // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that + // (res_i + b_j * 2^(-j))^2 <= mant<<1 + // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: + // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 + // And rearranging the terms: + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i + + while d != 0 { + // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: + // t = 2*res_i + 2^(-j) + let t = s + d; + if rem >= t { + // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem + res += d; + s += d + d; + rem -= t; + } + // Adjust rem for next iteration + rem <<= 1; + // Shift iterator + d >>= 1; + } + + // Remove extra fractional bit from result, rounding to nearest. + // If the last bit is 0, then the nearest neighbor is definitely the lower one. + // If the last bit is 1, it sounds like this may either be a tie (if there's + // infinitely many 0s after this 1), or the nearest neighbor is the upper one. + // However, since square roots are either exact or irrational, and an exact root + // would lead to the last "extra" bit being 0, we can exclude a tie in this case. + // We therefore always round up if the last bit is 1. When the last bit is 0, + // adding 1 will not do anything since the shift will discard it. + res = (res + 1) >> 1; + + // Build resulting value with res as mantissa and exp/2 as exponent. + Self::from_u128(res).value.scalbn(exp / 2 - prec) + } + fn from_bits(input: u128) -> Self { // Dispatch to semantics. S::from_bits(input) diff --git a/src/lib.rs b/src/lib.rs index 31000c0..9a4c8c7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -484,95 +484,10 @@ pub trait Float: } } - /// Technically this should be StatusAnd, but since sqrt is exact for all supported formats, - /// we can just round to the nearest, ignore exceptions, and return Self. + /// IEEE-754R sqrt: Returns the correctly rounded square root of the current value + /// Note: we currently don't support raising any exceptions from sqrt, so the result is always exact and the status is always OK. fn sqrt(self) -> Self { - match self.category() { - // preserve zero sign - Category::Zero => self, - // propagate NaN - // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. - // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. - // On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign, - // the most significant significand bit set and all other significand bits cleared. - // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. - // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns - // (Thanks @programmerjake for the comment) - Category::NaN => self, - // sqrt of negative number is NaN - _ if self.is_negative() => Self::NAN, - // sqrt(∞) = ∞ - Category::Infinity => Self::INFINITY, - Category::Normal => { - // Floating point precision, excluding the integer bit - let prec = i32::try_from(Self::PRECISION).unwrap() - 1; - - // x = 2^(exp - prec) * mant - // where mant is an integer with prec+1 bits - // mant is a u128, which should be large enough for the largest prec (112 for f128) - let mut exp = self.ilogb(); - let mut mant = self.scalbn(prec - exp).to_u128(128).value; - - if exp % 2 != 0 { - // Make exponent even, so it can be divided by 2 - exp -= 1; - mant <<= 1; - } - - // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. - // mant is treated here as a fixed point number with prec fractional bits. - // mant will be shifted left by one bit to have an extra fractional bit, which - // will be used to determine the rounding direction. - - // res is the truncated sqrt of mant, where one bit is added at each iteration. - let mut res = 0u128; - // rem is the remainder with the current res - // rem_i = 2^i * ((mant<<1) - res_i^2) - // starting with res = 0, rem = mant<<1 - let mut rem = mant << 1; - // s_i = 2*res_i - let mut s = 0u128; - // d is used to iterate over bits, from high to low (d_i = 2^(-i)) - let mut d = 1u128 << (prec + 1); - - // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that - // (res_i + b_j * 2^(-j))^2 <= mant<<1 - // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: - // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 - // And rearranging the terms: - // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) - // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i - - while d != 0 { - // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: - // t = 2*res_i + 2^(-j) - let t = s + d; - if rem >= t { - // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem - res += d; - s += d + d; - rem -= t; - } - // Adjust rem for next iteration - rem <<= 1; - // Shift iterator - d >>= 1; - } - - // Remove extra fractional bit from result, rounding to nearest. - // If the last bit is 0, then the nearest neighbor is definitely the lower one. - // If the last bit is 1, it sounds like this may either be a tie (if there's - // infinitely many 0s after this 1), or the nearest neighbor is the upper one. - // However, since square roots are either exact or irrational, and an exact root - // would lead to the last "extra" bit being 0, we can exclude a tie in this case. - // We therefore always round up if the last bit is 1. When the last bit is 0, - // adding 1 will not do anything since the shift will discard it. - res = (res + 1) >> 1; - - // Build resulting value with res as mantissa and exp/2 as exponent - Self::from_u128(res).value.scalbn(exp / 2 - prec) - } - } + unimplemented!() } /// IEEE-754R isSignMinus: Returns true if and only if the current value is diff --git a/src/ppc.rs b/src/ppc.rs index b32cf75..483f225 100644 --- a/src/ppc.rs +++ b/src/ppc.rs @@ -365,6 +365,10 @@ where Fallback::from(self).next_up().map(Self::from) } + fn sqrt(self) -> Self { + unimplemented!() + } + fn from_bits(input: u128) -> Self { let (a, b) = (input, input >> F::BITS); DoubleFloat(F::from_bits(a & ((1 << F::BITS) - 1)), F::from_bits(b & ((1 << F::BITS) - 1))) diff --git a/tests/ieee.rs b/tests/ieee.rs index ab1efe9..2af6e88 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -35,6 +35,43 @@ define_for_each_float_type! { rustc_apfloat::ppc::DoubleDouble, } +macro_rules! for_each_ieee_float_type { + (for<$ty_var:ident: Float> $e:expr) => {{ + { + type $ty_var = Half; + $e; + } + { + type $ty_var = Single; + $e; + } + { + type $ty_var = Double; + $e; + } + { + type $ty_var = Quad; + $e; + } + { + type $ty_var = BFloat; + $e; + } + { + type $ty_var = Float8E5M2; + $e; + } + { + type $ty_var = Float8E4M3FN; + $e; + } + { + type $ty_var = X87DoubleExtended; + $e; + } + }}; +} + trait SingleExt { fn from_f32(input: f32) -> Self; fn to_f32(self) -> f32; @@ -797,7 +834,7 @@ fn maximum() { #[test] fn sqrt() { - for_each_float_type!(for test::()); + for_each_ieee_float_type!(for test::()); fn test() { assert!(F::ZERO.sqrt().bitwise_eq(F::ZERO)); assert!((-F::ZERO).sqrt().bitwise_eq(-F::ZERO)); From 562f7efb3cf941588dca86978245465d23b333b1 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Thu, 2 Apr 2026 19:04:17 -0700 Subject: [PATCH 13/18] add rounding mode as parameter --- src/ieee.rs | 51 +++++++++++++++++++++++++++++++++++++++------------ src/lib.rs | 3 ++- src/ppc.rs | 3 ++- tests/ieee.rs | 36 ++++++++++++++++++------------------ 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/src/ieee.rs b/src/ieee.rs index eeef2ec..90b399e 100644 --- a/src/ieee.rs +++ b/src/ieee.rs @@ -1563,7 +1563,7 @@ impl Float for IeeeFloat { } } - fn sqrt(self) -> Self { + fn sqrt(self, round: Round) -> Self { match self.category() { // preserve zero sign Category::Zero => return self, @@ -1638,18 +1638,45 @@ impl Float for IeeeFloat { d >>= 1; } - // Remove extra fractional bit from result, rounding to nearest. - // If the last bit is 0, then the nearest neighbor is definitely the lower one. - // If the last bit is 1, it sounds like this may either be a tie (if there's - // infinitely many 0s after this 1), or the nearest neighbor is the upper one. - // However, since square roots are either exact or irrational, and an exact root - // would lead to the last "extra" bit being 0, we can exclude a tie in this case. - // We therefore always round up if the last bit is 1. When the last bit is 0, - // adding 1 will not do anything since the shift will discard it. - res = (res + 1) >> 1; + let mut status = Status::OK; + + // A nonzero remainder indicates that we could continue processing sqrt if we had + // more precision, potentially indefinitely. We don't because we have enough bits + // to fill our significand already, and only need the one extra bit to determine + // rounding. + if rem != 0 { + status = Status::INEXACT; + + match round { + // If the LSB is 0, we should round down and this 1 gets cut off. If the LSB + // is 1, it is either a tie (if all remaining bits would be 0) or something + // that should be rounded up. + // + // Square roots are either exact or irrational, so a `1` in the extra bit + // already implies an irrational result with more `1`s in the infinite + // precision tail that should be rounded up, which this does. We are in a + // `rem != 0` block but could technically add the `1` unconditionally, given + // that a 0 in the extra bit would imply an exact result to be rounded down + // (and the extra bit is just shifted out). + Round::NearestTiesToEven => res += 1, + // We know we have an inexact result that needs rounding up. If the round + // bit is 1, adding 1 is sufficient and adding 2 does nothing extra (the + // new LSB will get truncated). If the round bit is 0, we need to add + // two anyway to affect the significand. + Round::TowardPositive => res += 2, + // By default, shifting will round down. + Round::TowardNegative => (), + // Same as negative since the result of sqrt is positive. + Round::TowardZero => (), + Round::NearestTiesToAway => unimplemented!("unsupported rounding mode"), + }; + } + + // Remove the extra fractional bit. + res >>= 1; - // Build resulting value with res as mantissa and exp/2 as exponent. - Self::from_u128(res).value.scalbn(exp / 2 - prec) + // Build resulting value with res as mantissa and exp/2 as exponent + status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec)).value } fn from_bits(input: u128) -> Self { diff --git a/src/lib.rs b/src/lib.rs index 9a4c8c7..0467bf0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -486,7 +486,8 @@ pub trait Float: /// IEEE-754R sqrt: Returns the correctly rounded square root of the current value /// Note: we currently don't support raising any exceptions from sqrt, so the result is always exact and the status is always OK. - fn sqrt(self) -> Self { + #[allow(unused_variables)] + fn sqrt(self, round: Round) -> Self { unimplemented!() } diff --git a/src/ppc.rs b/src/ppc.rs index 483f225..eb6872b 100644 --- a/src/ppc.rs +++ b/src/ppc.rs @@ -365,7 +365,8 @@ where Fallback::from(self).next_up().map(Self::from) } - fn sqrt(self) -> Self { + #[allow(unused_variables)] + fn sqrt(self, round: Round) -> Self { unimplemented!() } diff --git a/tests/ieee.rs b/tests/ieee.rs index 2af6e88..fc8ae3b 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -836,24 +836,24 @@ fn maximum() { fn sqrt() { for_each_ieee_float_type!(for test::()); fn test() { - assert!(F::ZERO.sqrt().bitwise_eq(F::ZERO)); - assert!((-F::ZERO).sqrt().bitwise_eq(-F::ZERO)); - assert!(F::INFINITY.sqrt().bitwise_eq(F::INFINITY)); - assert!(F::NAN.sqrt().is_nan()); - assert!((-F::INFINITY).sqrt().is_nan()); - assert!((-F::from_u128(5).value).sqrt().is_nan()); - let one = F::from_u128(1).value; - assert!(one.sqrt().bitwise_eq(one)); - let f1 = F::from_u128(64).value; - let f2 = F::from_u128(8).value; - assert!(f1.sqrt().bitwise_eq(f2)); - } - - // Cross-check against host sqrt - // TODO: This is only f32 and f64 for now since they are stable, update to other values later. - for x in 0..1000 { - assert_eq!(Single::from_u128(x).value.sqrt().to_f32(), f32::sqrt(x as f32)); - assert_eq!(Double::from_u128(x).value.sqrt().to_f64(), f64::sqrt(x as f64)); + for round in [ + Round::NearestTiesToEven, + Round::TowardPositive, + Round::TowardNegative, + Round::TowardZero, + ] { + assert!(F::ZERO.sqrt(round).bitwise_eq(F::ZERO)); + assert!((-F::ZERO).sqrt(round).bitwise_eq(-F::ZERO)); + assert!(F::INFINITY.sqrt(round).bitwise_eq(F::INFINITY)); + assert!(F::NAN.sqrt(round).is_nan()); + assert!((-F::INFINITY).sqrt(round).is_nan()); + assert!((-F::from_u128(5).value).sqrt(round).is_nan()); + let one = F::from_u128(1).value; + assert!(one.sqrt(round).bitwise_eq(one)); + let f1 = F::from_u128(64).value; + let f2 = F::from_u128(8).value; + assert!(f1.sqrt(round).bitwise_eq(f2)); + } } } From c13103d24b6eae5d2a64498259daf0113d89f2f8 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Thu, 2 Apr 2026 19:18:44 -0700 Subject: [PATCH 14/18] move code around to downstream.rs --- src/downstream.rs | 124 +++++++++++++++++++++++++++++++++++++++++++ src/ieee.rs | 126 +++----------------------------------------- src/lib.rs | 1 + tests/downstream.rs | 66 ++++++++++++++++++++++- tests/ieee.rs | 62 ---------------------- 5 files changed, 196 insertions(+), 183 deletions(-) create mode 100644 src/downstream.rs diff --git a/src/downstream.rs b/src/downstream.rs new file mode 100644 index 0000000..9728b3f --- /dev/null +++ b/src/downstream.rs @@ -0,0 +1,124 @@ +use crate::{ + ieee::{IeeeDefaultExceptionHandling, IeeeFloat, Semantics}, + Category, Float, Round, Status, +}; + +impl IeeeFloat { + /// This is a spec conformant implementation of the IEEE Float sqrt function + /// This is put in downstream.rs because this function hasn't been implemented in the upstream C++ version yet. + pub(crate) fn ieee_sqrt(self, round: Round) -> Self { + match self.category() { + // preserve zero sign + Category::Zero => return self, + // propagate NaN + // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. + // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. + // On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign, + // the most significant significand bit set and all other significand bits cleared. + // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. + // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns + // (Thanks @programmerjake for the comment) + Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self).value, + // sqrt of negative number is NaN + _ if self.is_negative() => return Self::NAN, + // sqrt(inf) = inf + Category::Infinity => return Self::INFINITY, + Category::Normal => (), + } + + // Floating point precision, excluding the integer bit. + let prec = i32::try_from(Self::PRECISION).unwrap() - 1; + + // x = 2^(exp - prec) * mant + // where mant is an integer with prec+1 bits. + // mant is a u128, which is large enough for the largest prec (112 for f128). + let mut exp = self.ilogb(); + let mut mant = self.scalbn(prec - exp).to_u128(128).value; + + if exp % 2 != 0 { + // Make exponent even, so it can be divided by 2. + exp -= 1; + mant <<= 1; + } + + // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. + // mant is treated here as a fixed point number with prec fractional bits. + // mant will be shifted left by one bit to have an extra fractional bit, which + // will be used to determine the rounding direction. + + // res is the truncated sqrt of mant, where one bit is added at each iteration. + let mut res = 0u128; + // rem is the remainder with the current res + // rem_i = 2^i * ((mant<<1) - res_i^2) + // starting with res = 0, rem = mant<<1 + let mut rem = mant << 1; + // s_i = 2*res_i + let mut s = 0u128; + // d is used to iterate over bits, from high to low (d_i = 2^(-i)) + let mut d = 1u128 << (prec + 1); + + // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that + // (res_i + b_j * 2^(-j))^2 <= mant<<1 + // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: + // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 + // And rearranging the terms: + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i + + while d != 0 { + // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: + // t = 2*res_i + 2^(-j) + let t = s + d; + if rem >= t { + // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem + res += d; + s += d + d; + rem -= t; + } + // Adjust rem for next iteration + rem <<= 1; + // Shift iterator + d >>= 1; + } + + let mut status = Status::OK; + + // A nonzero remainder indicates that we could continue processing sqrt if we had + // more precision, potentially indefinitely. We don't because we have enough bits + // to fill our significand already, and only need the one extra bit to determine + // rounding. + if rem != 0 { + status = Status::INEXACT; + + match round { + // If the LSB is 0, we should round down and this 1 gets cut off. If the LSB + // is 1, it is either a tie (if all remaining bits would be 0) or something + // that should be rounded up. + // + // Square roots are either exact or irrational, so a `1` in the extra bit + // already implies an irrational result with more `1`s in the infinite + // precision tail that should be rounded up, which this does. We are in a + // `rem != 0` block but could technically add the `1` unconditionally, given + // that a 0 in the extra bit would imply an exact result to be rounded down + // (and the extra bit is just shifted out). + Round::NearestTiesToEven => res += 1, + // We know we have an inexact result that needs rounding up. If the round + // bit is 1, adding 1 is sufficient and adding 2 does nothing extra (the + // new LSB will get truncated). If the round bit is 0, we need to add + // two anyway to affect the significand. + Round::TowardPositive => res += 2, + // By default, shifting will round down. + Round::TowardNegative => (), + // Same as negative since the result of sqrt is positive. + Round::TowardZero => (), + Round::NearestTiesToAway => unimplemented!("unsupported rounding mode"), + }; + } + + // Remove the extra fractional bit. + res >>= 1; + + // Build resulting value with res as mantissa and exp/2 as exponent + status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec)).value + } +} diff --git a/src/ieee.rs b/src/ieee.rs index 90b399e..f0203a8 100644 --- a/src/ieee.rs +++ b/src/ieee.rs @@ -824,9 +824,9 @@ impl fmt::Debug for IeeeFloat { // but it's a bit too long to keep repeating in the Rust port for all ops. // FIXME(eddyb) find a better name/organization for all of this functionality // (`IeeeDefaultExceptionHandling` doesn't have a counterpart in the C++ code). -struct IeeeDefaultExceptionHandling; +pub(crate) struct IeeeDefaultExceptionHandling; impl IeeeDefaultExceptionHandling { - fn result_from_nan(mut r: IeeeFloat) -> StatusAnd> { + pub fn result_from_nan(mut r: IeeeFloat) -> StatusAnd> { assert!(r.is_nan()); let status = if r.is_signaling() { @@ -865,7 +865,7 @@ impl IeeeDefaultExceptionHandling { status.and(r) } - fn binop_result_from_either_nan(a: IeeeFloat, b: IeeeFloat) -> StatusAnd> { + pub fn binop_result_from_either_nan(a: IeeeFloat, b: IeeeFloat) -> StatusAnd> { let r = match (a.category(), b.category()) { (Category::NaN, _) => a, (_, Category::NaN) => b, @@ -1563,122 +1563,6 @@ impl Float for IeeeFloat { } } - fn sqrt(self, round: Round) -> Self { - match self.category() { - // preserve zero sign - Category::Zero => return self, - // propagate NaN - // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. - // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. - // On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign, - // the most significant significand bit set and all other significand bits cleared. - // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. - // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns - // (Thanks @programmerjake for the comment) - Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self).value, - // sqrt of negative number is NaN - _ if self.is_negative() => return Self::NAN, - // sqrt(inf) = inf - Category::Infinity => return Self::INFINITY, - Category::Normal => (), - } - - // Floating point precision, excluding the integer bit. - let prec = i32::try_from(Self::PRECISION).unwrap() - 1; - - // x = 2^(exp - prec) * mant - // where mant is an integer with prec+1 bits. - // mant is a u128, which is large enough for the largest prec (112 for f128). - let mut exp = self.ilogb(); - let mut mant = self.scalbn(prec - exp).to_u128(128).value; - - if exp % 2 != 0 { - // Make exponent even, so it can be divided by 2. - exp -= 1; - mant <<= 1; - } - - // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. - // mant is treated here as a fixed point number with prec fractional bits. - // mant will be shifted left by one bit to have an extra fractional bit, which - // will be used to determine the rounding direction. - - // res is the truncated sqrt of mant, where one bit is added at each iteration. - let mut res = 0u128; - // rem is the remainder with the current res - // rem_i = 2^i * ((mant<<1) - res_i^2) - // starting with res = 0, rem = mant<<1 - let mut rem = mant << 1; - // s_i = 2*res_i - let mut s = 0u128; - // d is used to iterate over bits, from high to low (d_i = 2^(-i)) - let mut d = 1u128 << (prec + 1); - - // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that - // (res_i + b_j * 2^(-j))^2 <= mant<<1 - // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: - // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 - // And rearranging the terms: - // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) - // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i - - while d != 0 { - // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: - // t = 2*res_i + 2^(-j) - let t = s + d; - if rem >= t { - // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem - res += d; - s += d + d; - rem -= t; - } - // Adjust rem for next iteration - rem <<= 1; - // Shift iterator - d >>= 1; - } - - let mut status = Status::OK; - - // A nonzero remainder indicates that we could continue processing sqrt if we had - // more precision, potentially indefinitely. We don't because we have enough bits - // to fill our significand already, and only need the one extra bit to determine - // rounding. - if rem != 0 { - status = Status::INEXACT; - - match round { - // If the LSB is 0, we should round down and this 1 gets cut off. If the LSB - // is 1, it is either a tie (if all remaining bits would be 0) or something - // that should be rounded up. - // - // Square roots are either exact or irrational, so a `1` in the extra bit - // already implies an irrational result with more `1`s in the infinite - // precision tail that should be rounded up, which this does. We are in a - // `rem != 0` block but could technically add the `1` unconditionally, given - // that a 0 in the extra bit would imply an exact result to be rounded down - // (and the extra bit is just shifted out). - Round::NearestTiesToEven => res += 1, - // We know we have an inexact result that needs rounding up. If the round - // bit is 1, adding 1 is sufficient and adding 2 does nothing extra (the - // new LSB will get truncated). If the round bit is 0, we need to add - // two anyway to affect the significand. - Round::TowardPositive => res += 2, - // By default, shifting will round down. - Round::TowardNegative => (), - // Same as negative since the result of sqrt is positive. - Round::TowardZero => (), - Round::NearestTiesToAway => unimplemented!("unsupported rounding mode"), - }; - } - - // Remove the extra fractional bit. - res >>= 1; - - // Build resulting value with res as mantissa and exp/2 as exponent - status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec)).value - } - fn from_bits(input: u128) -> Self { // Dispatch to semantics. S::from_bits(input) @@ -2008,6 +1892,10 @@ impl Float for IeeeFloat { } self.scalbn_r(-*exp, round) } + + fn sqrt(self, round: Round) -> Self { + self.ieee_sqrt(round) + } } impl FloatConvert> for IeeeFloat { diff --git a/src/lib.rs b/src/lib.rs index 0467bf0..74b3724 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -762,5 +762,6 @@ macro_rules! float_common_impls { }; } +pub mod downstream; pub mod ieee; pub mod ppc; diff --git a/tests/downstream.rs b/tests/downstream.rs index fc60a7c..fcb5a8c 100644 --- a/tests/downstream.rs +++ b/tests/downstream.rs @@ -1,7 +1,7 @@ //! Tests added to `rustc_apfloat`, that were not ported from the C++ code. -use rustc_apfloat::ieee::{Double, Single, X87DoubleExtended}; -use rustc_apfloat::Float; +use rustc_apfloat::ieee::{BFloat, Double, Float8E4M3FN, Float8E5M2, Half, Quad, Single, X87DoubleExtended}; +use rustc_apfloat::{Float, Round}; // `f32 -> i128 -> f32` previously-crashing bit-patterns (found by fuzzing). pub const FUZZ_IEEE32_ROUNDTRIP_THROUGH_I128_CASES: &[u32] = &[ @@ -408,3 +408,65 @@ fn fuzz_x87_f80_neg_with_expected_outputs() { assert_eq!((-X87DoubleExtended::from_bits(bits)).to_bits(), expected_bits); } } + +macro_rules! for_each_ieee_float_type { + (for<$ty_var:ident: Float> $e:expr) => {{ + { + type $ty_var = Half; + $e; + } + { + type $ty_var = Single; + $e; + } + { + type $ty_var = Double; + $e; + } + { + type $ty_var = Quad; + $e; + } + { + type $ty_var = BFloat; + $e; + } + { + type $ty_var = Float8E5M2; + $e; + } + { + type $ty_var = Float8E4M3FN; + $e; + } + { + type $ty_var = X87DoubleExtended; + $e; + } + }}; +} + +#[test] +fn sqrt() { + for_each_ieee_float_type!(for test::()); + fn test() { + for round in [ + Round::NearestTiesToEven, + Round::TowardPositive, + Round::TowardNegative, + Round::TowardZero, + ] { + assert!(F::ZERO.sqrt(round).bitwise_eq(F::ZERO)); + assert!((-F::ZERO).sqrt(round).bitwise_eq(-F::ZERO)); + assert!(F::INFINITY.sqrt(round).bitwise_eq(F::INFINITY)); + assert!(F::NAN.sqrt(round).is_nan()); + assert!((-F::INFINITY).sqrt(round).is_nan()); + assert!((-F::from_u128(5).value).sqrt(round).is_nan()); + let one = F::from_u128(1).value; + assert!(one.sqrt(round).bitwise_eq(one)); + let f1 = F::from_u128(64).value; + let f2 = F::from_u128(8).value; + assert!(f1.sqrt(round).bitwise_eq(f2)); + } + } +} diff --git a/tests/ieee.rs b/tests/ieee.rs index fc8ae3b..0c356d9 100644 --- a/tests/ieee.rs +++ b/tests/ieee.rs @@ -35,43 +35,6 @@ define_for_each_float_type! { rustc_apfloat::ppc::DoubleDouble, } -macro_rules! for_each_ieee_float_type { - (for<$ty_var:ident: Float> $e:expr) => {{ - { - type $ty_var = Half; - $e; - } - { - type $ty_var = Single; - $e; - } - { - type $ty_var = Double; - $e; - } - { - type $ty_var = Quad; - $e; - } - { - type $ty_var = BFloat; - $e; - } - { - type $ty_var = Float8E5M2; - $e; - } - { - type $ty_var = Float8E4M3FN; - $e; - } - { - type $ty_var = X87DoubleExtended; - $e; - } - }}; -} - trait SingleExt { fn from_f32(input: f32) -> Self; fn to_f32(self) -> f32; @@ -832,31 +795,6 @@ fn maximum() { assert!(nan.maximum(f1).to_f64().is_nan()); } -#[test] -fn sqrt() { - for_each_ieee_float_type!(for test::()); - fn test() { - for round in [ - Round::NearestTiesToEven, - Round::TowardPositive, - Round::TowardNegative, - Round::TowardZero, - ] { - assert!(F::ZERO.sqrt(round).bitwise_eq(F::ZERO)); - assert!((-F::ZERO).sqrt(round).bitwise_eq(-F::ZERO)); - assert!(F::INFINITY.sqrt(round).bitwise_eq(F::INFINITY)); - assert!(F::NAN.sqrt(round).is_nan()); - assert!((-F::INFINITY).sqrt(round).is_nan()); - assert!((-F::from_u128(5).value).sqrt(round).is_nan()); - let one = F::from_u128(1).value; - assert!(one.sqrt(round).bitwise_eq(one)); - let f1 = F::from_u128(64).value; - let f2 = F::from_u128(8).value; - assert!(f1.sqrt(round).bitwise_eq(f2)); - } - } -} - #[test] fn denormal() { // Test single precision From d0a2760be78e1921f0ed889b54339bd82b251278 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Thu, 2 Apr 2026 20:05:08 -0700 Subject: [PATCH 15/18] nan improvement --- src/downstream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/downstream.rs b/src/downstream.rs index 9728b3f..f428b5f 100644 --- a/src/downstream.rs +++ b/src/downstream.rs @@ -18,7 +18,7 @@ impl IeeeFloat { // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns // (Thanks @programmerjake for the comment) - Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self).value, + Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self).value.copy_sign(Self::NAN), // sqrt of negative number is NaN _ if self.is_negative() => return Self::NAN, // sqrt(inf) = inf From d645d5f9f41898b501477de6565cdfb2a47a0f1c Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Thu, 2 Apr 2026 20:21:11 -0700 Subject: [PATCH 16/18] gonna push this and hope it fuzzes correctly --- src/downstream.rs | 14 +- src/ieee.rs | 2 +- src/lib.rs | 2 +- src/ppc.rs | 2 +- tests/downstream.rs | 357 ++++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 358 insertions(+), 19 deletions(-) diff --git a/src/downstream.rs b/src/downstream.rs index f428b5f..adb3e88 100644 --- a/src/downstream.rs +++ b/src/downstream.rs @@ -1,15 +1,15 @@ use crate::{ ieee::{IeeeDefaultExceptionHandling, IeeeFloat, Semantics}, - Category, Float, Round, Status, + Category, Float, Round, Status, StatusAnd, }; impl IeeeFloat { /// This is a spec conformant implementation of the IEEE Float sqrt function /// This is put in downstream.rs because this function hasn't been implemented in the upstream C++ version yet. - pub(crate) fn ieee_sqrt(self, round: Round) -> Self { + pub(crate) fn ieee_sqrt(self, round: Round) -> StatusAnd { match self.category() { // preserve zero sign - Category::Zero => return self, + Category::Zero => return Status::OK.and(self), // propagate NaN // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. @@ -18,11 +18,11 @@ impl IeeeFloat { // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns // (Thanks @programmerjake for the comment) - Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self).value.copy_sign(Self::NAN), + Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self.copy_sign(Self::NAN)), // sqrt of negative number is NaN - _ if self.is_negative() => return Self::NAN, + _ if self.is_negative() => return Status::INVALID_OP.and(Self::NAN), // sqrt(inf) = inf - Category::Infinity => return Self::INFINITY, + Category::Infinity => return Status::OK.and(Self::INFINITY), Category::Normal => (), } @@ -119,6 +119,6 @@ impl IeeeFloat { res >>= 1; // Build resulting value with res as mantissa and exp/2 as exponent - status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec)).value + status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec)) } } diff --git a/src/ieee.rs b/src/ieee.rs index f0203a8..57cf32e 100644 --- a/src/ieee.rs +++ b/src/ieee.rs @@ -1893,7 +1893,7 @@ impl Float for IeeeFloat { self.scalbn_r(-*exp, round) } - fn sqrt(self, round: Round) -> Self { + fn sqrt(self, round: Round) -> StatusAnd { self.ieee_sqrt(round) } } diff --git a/src/lib.rs b/src/lib.rs index 74b3724..fa48c52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -487,7 +487,7 @@ pub trait Float: /// IEEE-754R sqrt: Returns the correctly rounded square root of the current value /// Note: we currently don't support raising any exceptions from sqrt, so the result is always exact and the status is always OK. #[allow(unused_variables)] - fn sqrt(self, round: Round) -> Self { + fn sqrt(self, round: Round) -> StatusAnd { unimplemented!() } diff --git a/src/ppc.rs b/src/ppc.rs index eb6872b..22d0d33 100644 --- a/src/ppc.rs +++ b/src/ppc.rs @@ -366,7 +366,7 @@ where } #[allow(unused_variables)] - fn sqrt(self, round: Round) -> Self { + fn sqrt(self, round: Round) -> StatusAnd { unimplemented!() } diff --git a/tests/downstream.rs b/tests/downstream.rs index fcb5a8c..7ab88d4 100644 --- a/tests/downstream.rs +++ b/tests/downstream.rs @@ -1,7 +1,7 @@ //! Tests added to `rustc_apfloat`, that were not ported from the C++ code. use rustc_apfloat::ieee::{BFloat, Double, Float8E4M3FN, Float8E5M2, Half, Quad, Single, X87DoubleExtended}; -use rustc_apfloat::{Float, Round}; +use rustc_apfloat::{Float, Round, Status, StatusAnd}; // `f32 -> i128 -> f32` previously-crashing bit-patterns (found by fuzzing). pub const FUZZ_IEEE32_ROUNDTRIP_THROUGH_I128_CASES: &[u32] = &[ @@ -456,17 +456,356 @@ fn sqrt() { Round::TowardNegative, Round::TowardZero, ] { - assert!(F::ZERO.sqrt(round).bitwise_eq(F::ZERO)); - assert!((-F::ZERO).sqrt(round).bitwise_eq(-F::ZERO)); - assert!(F::INFINITY.sqrt(round).bitwise_eq(F::INFINITY)); - assert!(F::NAN.sqrt(round).is_nan()); - assert!((-F::INFINITY).sqrt(round).is_nan()); - assert!((-F::from_u128(5).value).sqrt(round).is_nan()); + assert!(F::ZERO.sqrt(round).value.bitwise_eq(F::ZERO)); + assert!((-F::ZERO).sqrt(round).value.bitwise_eq(-F::ZERO)); + assert!(F::INFINITY.sqrt(round).value.bitwise_eq(F::INFINITY)); + assert!(F::NAN.sqrt(round).value.is_nan()); + assert!((-F::INFINITY).sqrt(round).value.is_nan()); + assert!((-F::from_u128(5).value).sqrt(round).value.is_nan()); let one = F::from_u128(1).value; - assert!(one.sqrt(round).bitwise_eq(one)); + assert!(one.sqrt(round).value.bitwise_eq(one)); let f1 = F::from_u128(64).value; let f2 = F::from_u128(8).value; - assert!(f1.sqrt(round).bitwise_eq(f2)); + assert!(f1.sqrt(round).value.bitwise_eq(f2)); } } } + +#[test] +fn fuzz_sqrt() { + // for round in [ + // Round::NearestTiesToEven, + // Round::TowardPositive, + // Round::TowardNegative, + // Round::TowardZero, + // ] { + // println!("checking rounding mode {round:?}"); + + // for xi in 0..=u16::MAX { + // if xi % 1_000 == 0 { + // println!("{xi}/{}", u16::MAX); + // } + // let x = f16::from_bits(xi); + // let a = apfloat_sqrtf16(x, round); + // let b = hardware_sqrt::sqrtf16(x, round); + + // // x86 preserves the sign on NaN inputs, others may not (e.g. aarch64 does not). + // let eq_value = if cfg!(target_arch = "x86_64") { + // a.value.bitwise_eq(b.value) + // } else { + // a.value.bitwise_eq(b.value) || (a.value.is_nan() && b.value.is_nan()) + // }; + + // if a.status != b.status || !eq_value { + // panic!( + // "\ + // incorrect result\n\ + // xi: {xi:#010x}\n\ + // x: {x:?}\n\ + // a: {a:?} {af}\n\ + // b: {b:?} {bf}\ + // ", + // af = f32::from_bits(a.value.to_bits().try_into().unwrap()), + // bf = f32::from_bits(a.value.to_bits().try_into().unwrap()), + // ); + // } + // } + // } + + for round in [ + Round::NearestTiesToEven, + Round::TowardPositive, + Round::TowardNegative, + Round::TowardZero, + ] { + println!("checking rounding mode {round:?}"); + + for xi in 0..=u32::MAX { + if xi % 100_000_000 == 0 { + println!("{xi}/{}", u32::MAX); + } + let x = f32::from_bits(xi); + let a = apfloat_sqrt(x, round); + let b = hardware_sqrt::sqrtf32(x, round); + + // x86 preserves the sign on NaN inputs, others may not (e.g. aarch64 does not). + let eq_value = if cfg!(target_arch = "x86_64") { + a.value.bitwise_eq(b.value) + } else { + a.value.bitwise_eq(b.value) || (a.value.is_nan() && b.value.is_nan()) + }; + + if a.status != b.status || !eq_value { + panic!( + "\ + incorrect result\n\ + xi: {xi:#010x}\n\ + x: {x:?}\n\ + a: {a:?} {af}\n\ + b: {b:?} {bf}\ + ", + af = f32::from_bits(a.value.to_bits().try_into().unwrap()), + bf = f32::from_bits(a.value.to_bits().try_into().unwrap()), + ); + } + } + } +} + +#[cfg(target_arch = "aarch64")] +fn apfloat_sqrtf16(x: f16, round: Round) -> StatusAnd { + Half::from_bits(x.to_bits().into()).sqrt(round) +} + +fn apfloat_sqrt(x: f32, round: Round) -> StatusAnd { + Single::from_bits(x.to_bits().into()).sqrt(round) +} + +// SQRTSS is in baseline SSE, SQRTSD needs SSE2 +#[cfg(any(target_arch = "x86_64", all(target_arch = "x86", target_feature = "sse2")))] +mod hardware_sqrt { + use super::*; + + pub fn sqrtf32(mut x: f32, round: Round) -> StatusAnd { + let mut csr_stash = 0u32; + let mut csr = make_mxcsr_cw(round); + + unsafe { + core::arch::asm!( + // stash the current control state + "stmxcsr [{csr_stash}]", + // set the control state we want, clears flags + "ldmxcsr [{csr}]", + // run sqrt + "sqrtss {x}, {x}", + // get the new control state + "stmxcsr [{csr}]", + // restore the original control state + "ldmxcsr [{csr_stash}]", + csr_stash = in(reg) &mut csr_stash, + csr = in(reg) &mut csr, + x = inout(xmm_reg) x, + options(nostack), + ); + } + + let status = check_exceptions(csr); + status.and(Single::from_bits(x.to_bits().into())) + } + + pub fn sqrtf64(mut x: f64, round: Round) -> StatusAnd { + let mut csr_stash = 0u32; + let mut csr = make_mxcsr_cw(round); + + unsafe { + core::arch::asm!( + // stash the current control state + "stmxcsr [{csr_stash}]", + // set the control state we want, clears flags + "ldmxcsr [{csr}]", + // run sqrt + "sqrtsd {x}, {x}", + // get the new control state + "stmxcsr [{csr}]", + // restore the original control state + "ldmxcsr [{csr_stash}]", + csr_stash = in(reg) &mut csr_stash, + csr = in(reg) &mut csr, + x = inout(xmm_reg) x, + options(nostack), + ); + } + + let status = check_exceptions(csr); + status.and(Double::from_bits(x.to_bits().into())) + } + + fn make_mxcsr_cw(round: Round) -> u32 { + // Default: Clear exception flags, no DAZ, no FTZ + let mut csr = 0u32; + // Set all masks so fp status doesn't turn into SIGFPE + csr |= 0b111111 << 7; + + let rc = match round { + Round::NearestTiesToEven => 0b00, + Round::TowardPositive => 0b10, + Round::TowardNegative => 0b01, + Round::TowardZero => 0b11, + Round::NearestTiesToAway => unimplemented!("unsupported rounding on x86"), + }; + + csr |= rc << 13; + csr + } + + fn check_exceptions(csr: u32) -> Status { + let mut status = Status::OK; + + if csr & (1 << 0) != 0 { + status |= Status::INVALID_OP; + } + if csr & (1 << 1) != 0 { + // denormal flag, not part of status + } + if csr & (1 << 2) != 0 { + status |= Status::DIV_BY_ZERO; + } + if csr & (1 << 3) != 0 { + status |= Status::OVERFLOW; + } + if csr & (1 << 4) != 0 { + status |= Status::UNDERFLOW; + } + if csr & (1 << 5) != 0 { + status |= Status::INEXACT; + } + + status + } +} + +#[cfg(target_arch = "aarch64")] +mod hardware_sqrt { + use super::*; + + #[cfg(target_feature = "fp16")] + pub fn sqrtf16(mut x: f16, round: Round) -> StatusAnd { + let fpcr = make_fpcr_cw(round); + let fpsr: u64; + + unsafe { + core::arch::asm!( + // stash the current control and status state + "mrs {fpcr_stash}, fpcr", + "mrs {fpsr_stash}, fpsr", + // zero the exception flags, set desired control state + "and {tmp}, {fpsr_stash}, #{CLEAR}", + "msr fpsr, {tmp}", + "msr fpcr, {fpcr}", + // run sqrt + "fsqrt {x:h}, {x:h}", + // get the status word + "mrs {fpcr}, fpsr", + // restore the original control and exception state + "msr fpcr, {fpcr_stash}", + "msr fpsr, {fpsr_stash}", + fpsr_stash = out(reg) _, + fpcr_stash = out(reg) _, + tmp = out(reg) _, + fpcr = inout(reg) fpcr => fpsr, + x = inout(vreg) x, + CLEAR = const !0x1f_u64, // + options(nomem, nostack), + ); + } + let status = check_exceptions(fpsr); + status.and(Half::from_bits(x.to_bits().into())) + } + + pub fn sqrtf32(mut x: f32, round: Round) -> StatusAnd { + let fpcr = make_fpcr_cw(round); + let fpsr: u64; + + unsafe { + core::arch::asm!( + // stash the current control and status state + "mrs {fpcr_stash}, fpcr", + "mrs {fpsr_stash}, fpsr", + // zero the exception flags, set desired control state + "and {tmp}, {fpsr_stash}, #{CLEAR}", + "msr fpsr, {tmp}", + "msr fpcr, {fpcr}", + // run sqrt + "fsqrt {x:s}, {x:s}", + // get the status word + "mrs {fpcr}, fpsr", + // restore the original control and exception state + "msr fpcr, {fpcr_stash}", + "msr fpsr, {fpsr_stash}", + fpsr_stash = out(reg) _, + fpcr_stash = out(reg) _, + tmp = out(reg) _, + fpcr = inout(reg) fpcr => fpsr, + x = inout(vreg) x, + CLEAR = const !0x1f_u64, // + options(nomem, nostack), + ); + } + let status = check_exceptions(fpsr); + status.and(Single::from_bits(x.to_bits().into())) + } + + pub fn sqrtf64(mut x: f64, round: Round) -> StatusAnd { + let fpcr = make_fpcr_cw(round); + let fpsr: u64; + + unsafe { + core::arch::asm!( + // stash the current control and status state + "mrs {fpcr_stash}, fpcr", + "mrs {fpsr_stash}, fpsr", + // zero the exception flags, set desired control state + "and {tmp}, {fpsr_stash}, #{CLEAR}", + "msr fpsr, {tmp}", + "msr fpcr, {fpcr}", + // run sqrt + "fsqrt {x:d}, {x:d}", + // get the status word + "mrs {fpcr}, fpsr", + // restore the original control and exception state + "msr fpcr, {fpcr_stash}", + "msr fpsr, {fpsr_stash}", + fpsr_stash = out(reg) _, + fpcr_stash = out(reg) _, + tmp = out(reg) _, + fpcr = inout(reg) fpcr => fpsr, + x = inout(vreg) x, + CLEAR = const !0x1f_u64, // + options(nomem, nostack), + ); + } + let status = check_exceptions(fpsr); + status.and(Double::from_bits(x.to_bits().into())) + } + + fn make_fpcr_cw(round: Round) -> u64 { + // Default: Clear exception flags, no DAZ, no FTZ + let mut csr = 0u64; + + // Disable traps on all 5 floating point exceptions + csr |= 0b11111 << 8; + + let rc = match round { + Round::NearestTiesToEven => 0b00, + Round::TowardPositive => 0b01, + Round::TowardNegative => 0b10, + Round::TowardZero => 0b11, + Round::NearestTiesToAway => unimplemented!("unsupported rounding on aarch64"), + }; + + csr |= rc << 22; + csr + } + + fn check_exceptions(fpcr: u64) -> Status { + let mut status = Status::OK; + + if fpcr & (1 << 0) != 0 { + status |= Status::INVALID_OP; + } + if fpcr & (1 << 1) != 0 { + status |= Status::DIV_BY_ZERO; + } + if fpcr & (1 << 2) != 0 { + status |= Status::OVERFLOW; + } + if fpcr & (1 << 3) != 0 { + status |= Status::UNDERFLOW; + } + if fpcr & (1 << 4) != 0 { + status |= Status::INEXACT; + } + + status + } +} From fa1f2607f029299a13c13e74cbe244adb87b69e9 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Thu, 2 Apr 2026 20:59:14 -0700 Subject: [PATCH 17/18] add dead code check --- tests/downstream.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/downstream.rs b/tests/downstream.rs index 7ab88d4..fb67077 100644 --- a/tests/downstream.rs +++ b/tests/downstream.rs @@ -593,6 +593,7 @@ mod hardware_sqrt { status.and(Single::from_bits(x.to_bits().into())) } + #[allow(dead_code)] pub fn sqrtf64(mut x: f64, round: Round) -> StatusAnd { let mut csr_stash = 0u32; let mut csr = make_mxcsr_cw(round); From 23db47a1929f27eb41d16820f28433aaa35fe8f7 Mon Sep 17 00:00:00 2001 From: Srayan Jana Date: Thu, 2 Apr 2026 22:49:45 -0700 Subject: [PATCH 18/18] remove copy sign, hope this fuzzes correctly --- src/downstream.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/downstream.rs b/src/downstream.rs index adb3e88..f63595c 100644 --- a/src/downstream.rs +++ b/src/downstream.rs @@ -18,7 +18,7 @@ impl IeeeFloat { // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns // (Thanks @programmerjake for the comment) - Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self.copy_sign(Self::NAN)), + Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self), // sqrt of negative number is NaN _ if self.is_negative() => return Status::INVALID_OP.and(Self::NAN), // sqrt(inf) = inf