Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion fuzz/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ enum OpKind {
Binary(char),
Ternary(Rust<&'static str>, Cxx<&'static str>),

// A operation that's actually a method call in Rust, but only takes one argument (unary).
RustUnary(Rust<&'static str>),

// HACK(eddyb) all other ops have floating-point inputs *and* outputs, so
// the easiest way to fuzz conversions from/to other types, even if it won't
// cover *all possible* inputs, is to do a round-trip through the other type.
Expand All @@ -41,7 +44,7 @@ impl Type {
impl OpKind {
fn inputs<'a, T>(&self, all_inputs: &'a [T; 3]) -> &'a [T] {
match self {
Unary(_) | Roundtrip(_) => &all_inputs[..1],
Unary(_) | RustUnary(_) | Roundtrip(_) => &all_inputs[..1],
Binary(_) => &all_inputs[..2],
Ternary(..) => &all_inputs[..3],
}
Expand All @@ -59,6 +62,9 @@ const OPS: &[(&str, OpKind)] = &[
("Rem", Binary('%')),
// Ternary (`(F, F) -> F`) ops.
("MulAdd", Ternary(Rust("mul_add"), Cxx("fusedMultiplyAdd"))),
// Method-call ops.
// For now, sqrt is Rust-only, there is no C++ `APFloat` equivalent
("Sqrt", RustUnary(Rust("sqrt"))),
// Roundtrip (`F -> T -> F`) ops.
("FToI128ToF", Roundtrip(Type::SInt(128))),
("FToU128ToF", Roundtrip(Type::UInt(128))),
Expand Down Expand Up @@ -154,6 +160,9 @@ impl<HF> FuzzOp<HF>
Ternary(Rust(method), _) => {
format!("{}.{method}({}, {})", inputs[0], inputs[1], inputs[2])
}
RustUnary(Rust(method)) => {
format!("{}.{method}()", inputs[0])
}
Roundtrip(ty) => format!(
"<{ty} as num_traits::AsPrimitive::<HF>>::as_(
<HF as num_traits::AsPrimitive::<{ty}>>::as_({}))",
Expand Down Expand Up @@ -189,6 +198,9 @@ impl<F> FuzzOp<F>
Ternary(Rust(method), _) => {
format!("{}.{method}({}).value", inputs[0], inputs[1..].join(", "))
}
RustUnary(Rust(method)) => {
format!("{}.{method}()", inputs[0])
}
Roundtrip(ty @ (Type::SInt(_) | Type::UInt(_))) => {
let (w, i_or_u) = match ty {
Type::SInt(w) => (w, "i"),
Expand Down Expand Up @@ -266,6 +278,16 @@ struct FuzzOp {
+ &all_ops_map_concat(|_tag, name, kind| {
let inputs = kind.inputs(&["a.to_apf()", "b.to_apf()", "c.to_apf()"]);
let expr = match kind {
RustUnary(method_name) => {
if method_name.0 == "sqrt" {
// For now, sqrt is the only Rust method-call op, and it has no C++ `APFloat` equivalent
// so don't generate any C++ code for it.
return String::new();
} else {
unreachable!()
}
}

// HACK(eddyb) `APFloat` doesn't overload `operator%`, so we have
// to go through the `mod` method instead.
Binary('%') => format!("((r = {}), r.mod({}), r)", inputs[0], inputs[1]),
Expand Down
124 changes: 124 additions & 0 deletions src/downstream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
use crate::{
ieee::{IeeeDefaultExceptionHandling, IeeeFloat, Semantics},
Category, Float, Round, Status, StatusAnd,
};

impl<S: Semantics> IeeeFloat<S> {
/// This is a spec conformant implementation of the IEEE Float sqrt function
/// This is put in downstream.rs because this function hasn't been implemented in the upstream C++ version yet.
pub(crate) fn ieee_sqrt(self, round: Round) -> StatusAnd<Self> {
match self.category() {
// preserve zero sign
Category::Zero => return Status::OK.and(self),
// propagate NaN
// If the input is a signalling NaN, then IEEE 754 requires the result to be converted to a quiet NaN.
// On most CPUs that means the most significant bit of the significand field is 0 for signalling NaNs and 1 for quiet NaNs.
// On most CPUs they quiet a NaN by setting that bit to a 1, RISC-V instead returns the canonical NaN with positive sign,
// the most significant significand bit set and all other significand bits cleared.
// However, Rust and LLVM allow input NaNs to be returned unmodified as well as a few other options -- see Rust's rules for NaNs.
// https://doc.rust-lang.org/std/primitive.f32.html#nan-bit-patterns
// (Thanks @programmerjake for the comment)
Category::NaN => return IeeeDefaultExceptionHandling::result_from_nan(self),
// sqrt of negative number is NaN
_ if self.is_negative() => return Status::INVALID_OP.and(Self::NAN),
// sqrt(inf) = inf
Category::Infinity => return Status::OK.and(Self::INFINITY),
Category::Normal => (),
}

// Floating point precision, excluding the integer bit.
let prec = i32::try_from(Self::PRECISION).unwrap() - 1;

// x = 2^(exp - prec) * mant
// where mant is an integer with prec+1 bits.
// mant is a u128, which is large enough for the largest prec (112 for f128).
let mut exp = self.ilogb();
let mut mant = self.scalbn(prec - exp).to_u128(128).value;

if exp % 2 != 0 {
// Make exponent even, so it can be divided by 2.
exp -= 1;
mant <<= 1;
}

// Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
// mant is treated here as a fixed point number with prec fractional bits.
// mant will be shifted left by one bit to have an extra fractional bit, which
// will be used to determine the rounding direction.

// res is the truncated sqrt of mant, where one bit is added at each iteration.
let mut res = 0u128;
// rem is the remainder with the current res
// rem_i = 2^i * ((mant<<1) - res_i^2)
// starting with res = 0, rem = mant<<1
let mut rem = mant << 1;
// s_i = 2*res_i
let mut s = 0u128;
// d is used to iterate over bits, from high to low (d_i = 2^(-i))
let mut d = 1u128 << (prec + 1);

// For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
// (res_i + b_j * 2^(-j))^2 <= mant<<1
// Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
// res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
// And rearranging the terms:
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i

while d != 0 {
// Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
// t = 2*res_i + 2^(-j)
let t = s + d;
if rem >= t {
// b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
res += d;
s += d + d;
rem -= t;
}
// Adjust rem for next iteration
rem <<= 1;
// Shift iterator
d >>= 1;
}

let mut status = Status::OK;

// A nonzero remainder indicates that we could continue processing sqrt if we had
// more precision, potentially indefinitely. We don't because we have enough bits
// to fill our significand already, and only need the one extra bit to determine
// rounding.
if rem != 0 {
status = Status::INEXACT;

match round {
// If the LSB is 0, we should round down and this 1 gets cut off. If the LSB
// is 1, it is either a tie (if all remaining bits would be 0) or something
// that should be rounded up.
//
// Square roots are either exact or irrational, so a `1` in the extra bit
// already implies an irrational result with more `1`s in the infinite
// precision tail that should be rounded up, which this does. We are in a
// `rem != 0` block but could technically add the `1` unconditionally, given
// that a 0 in the extra bit would imply an exact result to be rounded down
// (and the extra bit is just shifted out).
Round::NearestTiesToEven => res += 1,
// We know we have an inexact result that needs rounding up. If the round
// bit is 1, adding 1 is sufficient and adding 2 does nothing extra (the
// new LSB will get truncated). If the round bit is 0, we need to add
// two anyway to affect the significand.
Round::TowardPositive => res += 2,
// By default, shifting will round down.
Round::TowardNegative => (),
// Same as negative since the result of sqrt is positive.
Round::TowardZero => (),
Round::NearestTiesToAway => unimplemented!("unsupported rounding mode"),
};
}

// Remove the extra fractional bit.
res >>= 1;

// Build resulting value with res as mantissa and exp/2 as exponent
status.and(Self::from_u128(res).value.scalbn(exp / 2 - prec))
}
}
10 changes: 7 additions & 3 deletions src/ieee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,9 +824,9 @@ impl<S: Semantics> fmt::Debug for IeeeFloat<S> {
// but it's a bit too long to keep repeating in the Rust port for all ops.
// FIXME(eddyb) find a better name/organization for all of this functionality
// (`IeeeDefaultExceptionHandling` doesn't have a counterpart in the C++ code).
struct IeeeDefaultExceptionHandling;
pub(crate) struct IeeeDefaultExceptionHandling;
impl IeeeDefaultExceptionHandling {
fn result_from_nan<S: Semantics>(mut r: IeeeFloat<S>) -> StatusAnd<IeeeFloat<S>> {
pub fn result_from_nan<S: Semantics>(mut r: IeeeFloat<S>) -> StatusAnd<IeeeFloat<S>> {
assert!(r.is_nan());

let status = if r.is_signaling() {
Expand Down Expand Up @@ -865,7 +865,7 @@ impl IeeeDefaultExceptionHandling {
status.and(r)
}

fn binop_result_from_either_nan<S: Semantics>(a: IeeeFloat<S>, b: IeeeFloat<S>) -> StatusAnd<IeeeFloat<S>> {
pub fn binop_result_from_either_nan<S: Semantics>(a: IeeeFloat<S>, b: IeeeFloat<S>) -> StatusAnd<IeeeFloat<S>> {
let r = match (a.category(), b.category()) {
(Category::NaN, _) => a,
(_, Category::NaN) => b,
Expand Down Expand Up @@ -1892,6 +1892,10 @@ impl<S: Semantics> Float for IeeeFloat<S> {
}
self.scalbn_r(-*exp, round)
}

fn sqrt(self, round: Round) -> StatusAnd<Self> {
self.ieee_sqrt(round)
}
}

impl<S: Semantics, T: Semantics> FloatConvert<IeeeFloat<T>> for IeeeFloat<S> {
Expand Down
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,13 @@ pub trait Float:
}
}

/// IEEE-754R sqrt: Returns the correctly rounded square root of the current value
/// Note: we currently don't support raising any exceptions from sqrt, so the result is always exact and the status is always OK.
#[allow(unused_variables)]
fn sqrt(self, round: Round) -> StatusAnd<Self> {
unimplemented!()
}

/// IEEE-754R isSignMinus: Returns true if and only if the current value is
/// negative.
///
Expand Down Expand Up @@ -755,5 +762,6 @@ macro_rules! float_common_impls {
};
}

pub mod downstream;
pub mod ieee;
pub mod ppc;
5 changes: 5 additions & 0 deletions src/ppc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ where
Fallback::from(self).next_up().map(Self::from)
}

#[allow(unused_variables)]
fn sqrt(self, round: Round) -> StatusAnd<Self> {
unimplemented!()
}

fn from_bits(input: u128) -> Self {
let (a, b) = (input, input >> F::BITS);
DoubleFloat(F::from_bits(a & ((1 << F::BITS) - 1)), F::from_bits(b & ((1 << F::BITS) - 1)))
Expand Down
Loading
Loading