diff --git a/fuzz/ops.rs b/fuzz/ops.rs index fb1a45d..3befe42 100644 --- a/fuzz/ops.rs +++ b/fuzz/ops.rs @@ -16,6 +16,9 @@ enum OpKind { Binary(char), Ternary(Rust<&'static str>, Cxx<&'static str>), + // A operation that's actually a method call in Rust, but only takes one argument (unary). + RustUnary(Rust<&'static str>), + // HACK(eddyb) all other ops have floating-point inputs *and* outputs, so // the easiest way to fuzz conversions from/to other types, even if it won't // cover *all possible* inputs, is to do a round-trip through the other type. @@ -41,7 +44,7 @@ impl Type { impl OpKind { fn inputs<'a, T>(&self, all_inputs: &'a [T; 3]) -> &'a [T] { match self { - Unary(_) | Roundtrip(_) => &all_inputs[..1], + Unary(_) | RustUnary(_) | Roundtrip(_) => &all_inputs[..1], Binary(_) => &all_inputs[..2], Ternary(..) => &all_inputs[..3], } @@ -59,6 +62,9 @@ const OPS: &[(&str, OpKind)] = &[ ("Rem", Binary('%')), // Ternary (`(F, F) -> F`) ops. ("MulAdd", Ternary(Rust("mul_add"), Cxx("fusedMultiplyAdd"))), + // Method-call ops. + // For now, sqrt is Rust-only, there is no C++ `APFloat` equivalent + ("Sqrt", RustUnary(Rust("sqrt"))), // Roundtrip (`F -> T -> F`) ops. ("FToI128ToF", Roundtrip(Type::SInt(128))), ("FToU128ToF", Roundtrip(Type::UInt(128))), @@ -154,6 +160,9 @@ impl FuzzOp Ternary(Rust(method), _) => { format!("{}.{method}({}, {})", inputs[0], inputs[1], inputs[2]) } + RustUnary(Rust(method)) => { + format!("{}.{method}()", inputs[0]) + } Roundtrip(ty) => format!( "<{ty} as num_traits::AsPrimitive::>::as_( >::as_({}))", @@ -189,6 +198,9 @@ impl FuzzOp Ternary(Rust(method), _) => { format!("{}.{method}({}).value", inputs[0], inputs[1..].join(", ")) } + RustUnary(Rust(method)) => { + format!("{}.{method}()", inputs[0]) + } Roundtrip(ty @ (Type::SInt(_) | Type::UInt(_))) => { let (w, i_or_u) = match ty { Type::SInt(w) => (w, "i"), @@ -266,6 +278,16 @@ struct FuzzOp { + &all_ops_map_concat(|_tag, name, kind| { let inputs = kind.inputs(&["a.to_apf()", "b.to_apf()", "c.to_apf()"]); let expr = match kind { + RustUnary(method_name) => { + if method_name.0 == "sqrt" { + // For now, sqrt is the only Rust method-call op, and it has no C++ `APFloat` equivalent + // so don't generate any C++ code for it. + return String::new(); + } else { + unreachable!() + } + } + // HACK(eddyb) `APFloat` doesn't overload `operator%`, so we have // to go through the `mod` method instead. Binary('%') => format!("((r = {}), r.mod({}), r)", inputs[0], inputs[1]), diff --git a/src/downstream.rs b/src/downstream.rs new file mode 100644 index 0000000..f63595c --- /dev/null +++ b/src/downstream.rs @@ -0,0 +1,124 @@ +use crate::{ + ieee::{IeeeDefaultExceptionHandling, IeeeFloat, Semantics}, + Category, Float, Round, Status, StatusAnd, +}; + +impl IeeeFloat { + /// This is a spec conformant implementation of the IEEE Float sqrt function + /// This is put in downstream.rs because this function hasn't been implemented in the upstream C++ version yet. + pub(crate) fn ieee_sqrt(self, round: Round) -> StatusAnd { + match self.category() { + // preserve zero sign + Category::Zero => return Status::OK.and(self), + // propagate NaN + // If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN. + // On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs. + // On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign, + // the most significant significand bit set and all other significand bits cleared. + // However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs. + // https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns + // (Thanks @programmerjake for the comment) + Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self), + // sqrt of negative number is NaN + _ if self.is_negative() => return Status::INVALID_OP.and(Self::NAN), + // sqrt(inf) = inf + Category::Infinity => return Status::OK.and(Self::INFINITY), + Category::Normal => (), + } + + // Floating point precision, excluding the integer bit. + let prec = i32::try_from(Self::PRECISION).unwrap() - 1; + + // x = 2^(exp - prec) * mant + // where mant is an integer with prec+1 bits. + // mant is a u128, which is large enough for the largest prec (112 for f128). + let mut exp = self.ilogb(); + let mut mant = self.scalbn(prec - exp).to_u128(128).value; + + if exp % 2 != 0 { + // Make exponent even, so it can be divided by 2. + exp -= 1; + mant <<= 1; + } + + // Bit-by-bit (base-2 digit-by-digit) sqrt of mant. + // mant is treated here as a fixed point number with prec fractional bits. + // mant will be shifted left by one bit to have an extra fractional bit, which + // will be used to determine the rounding direction. + + // res is the truncated sqrt of mant, where one bit is added at each iteration. + let mut res = 0u128; + // rem is the remainder with the current res + // rem_i = 2^i * ((mant<<1) - res_i^2) + // starting with res = 0, rem = mant<<1 + let mut rem = mant << 1; + // s_i = 2*res_i + let mut s = 0u128; + // d is used to iterate over bits, from high to low (d_i = 2^(-i)) + let mut d = 1u128 << (prec + 1); + + // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that + // (res_i + b_j * 2^(-j))^2 <= mant<<1 + // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b: + // res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1 + // And rearranging the terms: + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2) + // b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i + + while d != 0 { + // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1: + // t = 2*res_i + 2^(-j) + let t = s + d; + if rem >= t { + // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem + res += d; + s += d + d; + rem -= t; + } + // Adjust rem for next iteration + rem <<= 1; + // Shift iterator + d >>= 1; + } + + let mut status = Status::OK; + + // A nonzero remainder indicates that we could continue processing sqrt if we had + // more precision, potentially indefinitely. We don't because we have enough bits + // to fill our significand already, and only need the one extra bit to determine + // rounding. + if rem != 0 { + status = Status::INEXACT; + + match round { + // If the LSB is 0, we should round down and this 1 gets cut off. If the LSB + // is 1, it is either a tie (if all remaining bits would be 0) or something + // that should be rounded up. + // + // Square roots are either exact or irrational, so a `1` in the extra bit + // already implies an irrational result with more `1`s in the infinite + // precision tail that should be rounded up, which this does. We are in a + // `rem != 0` block but could technically add the `1` unconditionally, given + // that a 0 in the extra bit would imply an exact result to be rounded down + // (and the extra bit is just shifted out). + Round::NearestTiesToEven => res += 1, + // We know we have an inexact result that needs rounding up. If the round + // bit is 1, adding 1 is sufficient and adding 2 does nothing extra (the + // new LSB will get truncated). If the round bit is 0, we need to add + // two anyway to affect the significand. + Round::TowardPositive => res += 2, + // By default, shifting will round down. + Round::TowardNegative => (), + // Same as negative since the result of sqrt is positive. + Round::TowardZero => (), + Round::NearestTiesToAway => unimplemented!("unsupported rounding mode"), + }; + } + + // Remove the extra fractional bit. + res >>= 1; + + // Build resulting value with res as mantissa and exp/2 as exponent + status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec)) + } +} diff --git a/src/ieee.rs b/src/ieee.rs index 9b8c5a0..57cf32e 100644 --- a/src/ieee.rs +++ b/src/ieee.rs @@ -824,9 +824,9 @@ impl fmt::Debug for IeeeFloat { // but it's a bit too long to keep repeating in the Rust port for all ops. // FIXME(eddyb) find a better name/organization for all of this functionality // (`IeeeDefaultExceptionHandling` doesn't have a counterpart in the C++ code). -struct IeeeDefaultExceptionHandling; +pub(crate) struct IeeeDefaultExceptionHandling; impl IeeeDefaultExceptionHandling { - fn result_from_nan(mut r: IeeeFloat) -> StatusAnd> { + pub fn result_from_nan(mut r: IeeeFloat) -> StatusAnd> { assert!(r.is_nan()); let status = if r.is_signaling() { @@ -865,7 +865,7 @@ impl IeeeDefaultExceptionHandling { status.and(r) } - fn binop_result_from_either_nan(a: IeeeFloat, b: IeeeFloat) -> StatusAnd> { + pub fn binop_result_from_either_nan(a: IeeeFloat, b: IeeeFloat) -> StatusAnd> { let r = match (a.category(), b.category()) { (Category::NaN, _) => a, (_, Category::NaN) => b, @@ -1892,6 +1892,10 @@ impl Float for IeeeFloat { } self.scalbn_r(-*exp, round) } + + fn sqrt(self, round: Round) -> StatusAnd { + self.ieee_sqrt(round) + } } impl FloatConvert> for IeeeFloat { diff --git a/src/lib.rs b/src/lib.rs index 87de8a0..fa48c52 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -484,6 +484,13 @@ pub trait Float: } } + /// IEEE-754R sqrt: Returns the correctly rounded square root of the current value + /// Note: we currently don't support raising any exceptions from sqrt, so the result is always exact and the status is always OK. + #[allow(unused_variables)] + fn sqrt(self, round: Round) -> StatusAnd { + unimplemented!() + } + /// IEEE-754R isSignMinus: Returns true if and only if the current value is /// negative. /// @@ -755,5 +762,6 @@ macro_rules! float_common_impls { }; } +pub mod downstream; pub mod ieee; pub mod ppc; diff --git a/src/ppc.rs b/src/ppc.rs index b32cf75..22d0d33 100644 --- a/src/ppc.rs +++ b/src/ppc.rs @@ -365,6 +365,11 @@ where Fallback::from(self).next_up().map(Self::from) } + #[allow(unused_variables)] + fn sqrt(self, round: Round) -> StatusAnd { + unimplemented!() + } + fn from_bits(input: u128) -> Self { let (a, b) = (input, input >> F::BITS); DoubleFloat(F::from_bits(a & ((1 << F::BITS) - 1)), F::from_bits(b & ((1 << F::BITS) - 1))) diff --git a/tests/downstream.rs b/tests/downstream.rs index fc60a7c..fb67077 100644 --- a/tests/downstream.rs +++ b/tests/downstream.rs @@ -1,7 +1,7 @@ //! Tests added to `rustc_apfloat`, that were not ported from the C++ code. -use rustc_apfloat::ieee::{Double, Single, X87DoubleExtended}; -use rustc_apfloat::Float; +use rustc_apfloat::ieee::{BFloat, Double, Float8E4M3FN, Float8E5M2, Half, Quad, Single, X87DoubleExtended}; +use rustc_apfloat::{Float, Round, Status, StatusAnd}; // `f32 -> i128 -> f32` previously-crashing bit-patterns (found by fuzzing). pub const FUZZ_IEEE32_ROUNDTRIP_THROUGH_I128_CASES: &[u32] = &[ @@ -408,3 +408,405 @@ fn fuzz_x87_f80_neg_with_expected_outputs() { assert_eq!((-X87DoubleExtended::from_bits(bits)).to_bits(), expected_bits); } } + +macro_rules! for_each_ieee_float_type { + (for<$ty_var:ident: Float> $e:expr) => {{ + { + type $ty_var = Half; + $e; + } + { + type $ty_var = Single; + $e; + } + { + type $ty_var = Double; + $e; + } + { + type $ty_var = Quad; + $e; + } + { + type $ty_var = BFloat; + $e; + } + { + type $ty_var = Float8E5M2; + $e; + } + { + type $ty_var = Float8E4M3FN; + $e; + } + { + type $ty_var = X87DoubleExtended; + $e; + } + }}; +} + +#[test] +fn sqrt() { + for_each_ieee_float_type!(for test::()); + fn test() { + for round in [ + Round::NearestTiesToEven, + Round::TowardPositive, + Round::TowardNegative, + Round::TowardZero, + ] { + assert!(F::ZERO.sqrt(round).value.bitwise_eq(F::ZERO)); + assert!((-F::ZERO).sqrt(round).value.bitwise_eq(-F::ZERO)); + assert!(F::INFINITY.sqrt(round).value.bitwise_eq(F::INFINITY)); + assert!(F::NAN.sqrt(round).value.is_nan()); + assert!((-F::INFINITY).sqrt(round).value.is_nan()); + assert!((-F::from_u128(5).value).sqrt(round).value.is_nan()); + let one = F::from_u128(1).value; + assert!(one.sqrt(round).value.bitwise_eq(one)); + let f1 = F::from_u128(64).value; + let f2 = F::from_u128(8).value; + assert!(f1.sqrt(round).value.bitwise_eq(f2)); + } + } +} + +#[test] +fn fuzz_sqrt() { + // for round in [ + // Round::NearestTiesToEven, + // Round::TowardPositive, + // Round::TowardNegative, + // Round::TowardZero, + // ] { + // println!("checking rounding mode {round:?}"); + + // for xi in 0..=u16::MAX { + // if xi % 1_000 == 0 { + // println!("{xi}/{}", u16::MAX); + // } + // let x = f16::from_bits(xi); + // let a = apfloat_sqrtf16(x, round); + // let b = hardware_sqrt::sqrtf16(x, round); + + // // x86 preserves the sign on NaN inputs, others may not (e.g. aarch64 does not). + // let eq_value = if cfg!(target_arch = "x86_64") { + // a.value.bitwise_eq(b.value) + // } else { + // a.value.bitwise_eq(b.value) || (a.value.is_nan() && b.value.is_nan()) + // }; + + // if a.status != b.status || !eq_value { + // panic!( + // "\ + // incorrect result\n\ + // xi: {xi:#010x}\n\ + // x: {x:?}\n\ + // a: {a:?} {af}\n\ + // b: {b:?} {bf}\ + // ", + // af = f32::from_bits(a.value.to_bits().try_into().unwrap()), + // bf = f32::from_bits(a.value.to_bits().try_into().unwrap()), + // ); + // } + // } + // } + + for round in [ + Round::NearestTiesToEven, + Round::TowardPositive, + Round::TowardNegative, + Round::TowardZero, + ] { + println!("checking rounding mode {round:?}"); + + for xi in 0..=u32::MAX { + if xi % 100_000_000 == 0 { + println!("{xi}/{}", u32::MAX); + } + let x = f32::from_bits(xi); + let a = apfloat_sqrt(x, round); + let b = hardware_sqrt::sqrtf32(x, round); + + // x86 preserves the sign on NaN inputs, others may not (e.g. aarch64 does not). + let eq_value = if cfg!(target_arch = "x86_64") { + a.value.bitwise_eq(b.value) + } else { + a.value.bitwise_eq(b.value) || (a.value.is_nan() && b.value.is_nan()) + }; + + if a.status != b.status || !eq_value { + panic!( + "\ + incorrect result\n\ + xi: {xi:#010x}\n\ + x: {x:?}\n\ + a: {a:?} {af}\n\ + b: {b:?} {bf}\ + ", + af = f32::from_bits(a.value.to_bits().try_into().unwrap()), + bf = f32::from_bits(a.value.to_bits().try_into().unwrap()), + ); + } + } + } +} + +#[cfg(target_arch = "aarch64")] +fn apfloat_sqrtf16(x: f16, round: Round) -> StatusAnd { + Half::from_bits(x.to_bits().into()).sqrt(round) +} + +fn apfloat_sqrt(x: f32, round: Round) -> StatusAnd { + Single::from_bits(x.to_bits().into()).sqrt(round) +} + +// SQRTSS is in baseline SSE, SQRTSD needs SSE2 +#[cfg(any(target_arch = "x86_64", all(target_arch = "x86", target_feature = "sse2")))] +mod hardware_sqrt { + use super::*; + + pub fn sqrtf32(mut x: f32, round: Round) -> StatusAnd { + let mut csr_stash = 0u32; + let mut csr = make_mxcsr_cw(round); + + unsafe { + core::arch::asm!( + // stash the current control state + "stmxcsr [{csr_stash}]", + // set the control state we want, clears flags + "ldmxcsr [{csr}]", + // run sqrt + "sqrtss {x}, {x}", + // get the new control state + "stmxcsr [{csr}]", + // restore the original control state + "ldmxcsr [{csr_stash}]", + csr_stash = in(reg) &mut csr_stash, + csr = in(reg) &mut csr, + x = inout(xmm_reg) x, + options(nostack), + ); + } + + let status = check_exceptions(csr); + status.and(Single::from_bits(x.to_bits().into())) + } + + #[allow(dead_code)] + pub fn sqrtf64(mut x: f64, round: Round) -> StatusAnd { + let mut csr_stash = 0u32; + let mut csr = make_mxcsr_cw(round); + + unsafe { + core::arch::asm!( + // stash the current control state + "stmxcsr [{csr_stash}]", + // set the control state we want, clears flags + "ldmxcsr [{csr}]", + // run sqrt + "sqrtsd {x}, {x}", + // get the new control state + "stmxcsr [{csr}]", + // restore the original control state + "ldmxcsr [{csr_stash}]", + csr_stash = in(reg) &mut csr_stash, + csr = in(reg) &mut csr, + x = inout(xmm_reg) x, + options(nostack), + ); + } + + let status = check_exceptions(csr); + status.and(Double::from_bits(x.to_bits().into())) + } + + fn make_mxcsr_cw(round: Round) -> u32 { + // Default: Clear exception flags, no DAZ, no FTZ + let mut csr = 0u32; + // Set all masks so fp status doesn't turn into SIGFPE + csr |= 0b111111 << 7; + + let rc = match round { + Round::NearestTiesToEven => 0b00, + Round::TowardPositive => 0b10, + Round::TowardNegative => 0b01, + Round::TowardZero => 0b11, + Round::NearestTiesToAway => unimplemented!("unsupported rounding on x86"), + }; + + csr |= rc << 13; + csr + } + + fn check_exceptions(csr: u32) -> Status { + let mut status = Status::OK; + + if csr & (1 << 0) != 0 { + status |= Status::INVALID_OP; + } + if csr & (1 << 1) != 0 { + // denormal flag, not part of status + } + if csr & (1 << 2) != 0 { + status |= Status::DIV_BY_ZERO; + } + if csr & (1 << 3) != 0 { + status |= Status::OVERFLOW; + } + if csr & (1 << 4) != 0 { + status |= Status::UNDERFLOW; + } + if csr & (1 << 5) != 0 { + status |= Status::INEXACT; + } + + status + } +} + +#[cfg(target_arch = "aarch64")] +mod hardware_sqrt { + use super::*; + + #[cfg(target_feature = "fp16")] + pub fn sqrtf16(mut x: f16, round: Round) -> StatusAnd { + let fpcr = make_fpcr_cw(round); + let fpsr: u64; + + unsafe { + core::arch::asm!( + // stash the current control and status state + "mrs {fpcr_stash}, fpcr", + "mrs {fpsr_stash}, fpsr", + // zero the exception flags, set desired control state + "and {tmp}, {fpsr_stash}, #{CLEAR}", + "msr fpsr, {tmp}", + "msr fpcr, {fpcr}", + // run sqrt + "fsqrt {x:h}, {x:h}", + // get the status word + "mrs {fpcr}, fpsr", + // restore the original control and exception state + "msr fpcr, {fpcr_stash}", + "msr fpsr, {fpsr_stash}", + fpsr_stash = out(reg) _, + fpcr_stash = out(reg) _, + tmp = out(reg) _, + fpcr = inout(reg) fpcr => fpsr, + x = inout(vreg) x, + CLEAR = const !0x1f_u64, // + options(nomem, nostack), + ); + } + let status = check_exceptions(fpsr); + status.and(Half::from_bits(x.to_bits().into())) + } + + pub fn sqrtf32(mut x: f32, round: Round) -> StatusAnd { + let fpcr = make_fpcr_cw(round); + let fpsr: u64; + + unsafe { + core::arch::asm!( + // stash the current control and status state + "mrs {fpcr_stash}, fpcr", + "mrs {fpsr_stash}, fpsr", + // zero the exception flags, set desired control state + "and {tmp}, {fpsr_stash}, #{CLEAR}", + "msr fpsr, {tmp}", + "msr fpcr, {fpcr}", + // run sqrt + "fsqrt {x:s}, {x:s}", + // get the status word + "mrs {fpcr}, fpsr", + // restore the original control and exception state + "msr fpcr, {fpcr_stash}", + "msr fpsr, {fpsr_stash}", + fpsr_stash = out(reg) _, + fpcr_stash = out(reg) _, + tmp = out(reg) _, + fpcr = inout(reg) fpcr => fpsr, + x = inout(vreg) x, + CLEAR = const !0x1f_u64, // + options(nomem, nostack), + ); + } + let status = check_exceptions(fpsr); + status.and(Single::from_bits(x.to_bits().into())) + } + + pub fn sqrtf64(mut x: f64, round: Round) -> StatusAnd { + let fpcr = make_fpcr_cw(round); + let fpsr: u64; + + unsafe { + core::arch::asm!( + // stash the current control and status state + "mrs {fpcr_stash}, fpcr", + "mrs {fpsr_stash}, fpsr", + // zero the exception flags, set desired control state + "and {tmp}, {fpsr_stash}, #{CLEAR}", + "msr fpsr, {tmp}", + "msr fpcr, {fpcr}", + // run sqrt + "fsqrt {x:d}, {x:d}", + // get the status word + "mrs {fpcr}, fpsr", + // restore the original control and exception state + "msr fpcr, {fpcr_stash}", + "msr fpsr, {fpsr_stash}", + fpsr_stash = out(reg) _, + fpcr_stash = out(reg) _, + tmp = out(reg) _, + fpcr = inout(reg) fpcr => fpsr, + x = inout(vreg) x, + CLEAR = const !0x1f_u64, // + options(nomem, nostack), + ); + } + let status = check_exceptions(fpsr); + status.and(Double::from_bits(x.to_bits().into())) + } + + fn make_fpcr_cw(round: Round) -> u64 { + // Default: Clear exception flags, no DAZ, no FTZ + let mut csr = 0u64; + + // Disable traps on all 5 floating point exceptions + csr |= 0b11111 << 8; + + let rc = match round { + Round::NearestTiesToEven => 0b00, + Round::TowardPositive => 0b01, + Round::TowardNegative => 0b10, + Round::TowardZero => 0b11, + Round::NearestTiesToAway => unimplemented!("unsupported rounding on aarch64"), + }; + + csr |= rc << 22; + csr + } + + fn check_exceptions(fpcr: u64) -> Status { + let mut status = Status::OK; + + if fpcr & (1 << 0) != 0 { + status |= Status::INVALID_OP; + } + if fpcr & (1 << 1) != 0 { + status |= Status::DIV_BY_ZERO; + } + if fpcr & (1 << 2) != 0 { + status |= Status::OVERFLOW; + } + if fpcr & (1 << 3) != 0 { + status |= Status::UNDERFLOW; + } + if fpcr & (1 << 4) != 0 { + status |= Status::INEXACT; + } + + status + } +}