From 392d3378033c2474b897baa29eed4e9acdadaae6 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 16:23:24 +0200 Subject: [PATCH 1/8] make `Microfloat` abstract, add `@microfloat` macro, rename types, avoid private Base names, fix F8E8M0, remove BitPackingExt --- Project.toml | 14 +- README.md | 31 ++-- docs/src/microfloats.md | 2 - ext/BitPackingExt.jl | 8 - src/Microfloat.jl | 166 ++++++++++++++---- src/Microfloats.jl | 107 +++++------ src/conversion.jl | 78 +++++--- src/macros.jl | 63 +++++++ src/ops.jl | 4 +- src/random.jl | 5 +- src/show.jl | 8 - src/variants.jl | 13 ++ src/variants/Finite.jl | 6 - src/variants/IEEE_754_like.jl | 24 --- src/variants/MX.jl | 31 ---- test/Float8s/runtests.jl | 39 ---- test/MX/runtests.jl | 4 - test/Microfloat.jl | 151 ---------------- test/Project.toml | 4 +- test/basic.jl | 124 +++++++++++++ test/dlfp8_parity.jl | 28 +++ .../{MX/MX_compliance.jl => mx_compliance.jl} | 3 +- .../{MX/MX_properties.jl => mx_properties.jl} | 48 ++--- test/overflow.jl | 112 ++++++------ test/runtests.jl | 56 ++++-- 25 files changed, 612 insertions(+), 517 deletions(-) delete mode 100644 ext/BitPackingExt.jl create mode 100644 src/macros.jl delete mode 100644 src/show.jl create mode 100644 src/variants.jl delete mode 100644 src/variants/Finite.jl delete mode 100644 src/variants/IEEE_754_like.jl delete mode 100644 src/variants/MX.jl delete mode 100644 test/Float8s/runtests.jl delete mode 100644 test/MX/runtests.jl delete mode 100644 test/Microfloat.jl create mode 100644 test/basic.jl create mode 100644 test/dlfp8_parity.jl rename test/{MX/MX_compliance.jl => mx_compliance.jl} (98%) rename test/{MX/MX_properties.jl => mx_properties.jl} (76%) diff --git a/Project.toml b/Project.toml index fc0c02d..b1a6e87 100644 --- a/Project.toml +++ b/Project.toml @@ -1,20 +1,18 @@ name = "Microfloats" uuid = "31c70f10-a750-4521-b13c-797315ae2933" +version = "0.2.0" authors = ["Anton Oresten and contributors"] -version = "0.1.1" + +[workspace] +projects = ["test"] [deps] BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[weakdeps] -BitPacking = "b58c8408-13c4-4787-8733-7038ae624acf" - -[extensions] -BitPackingExt = "BitPacking" +Republic = "27243419-9dde-4721-b67c-fd63626fea7f" [compat] BFloat16s = "0.5, 0.6" -BitPacking = "0.1" Random = "1" +Republic = "2.1" julia = "1.10" diff --git a/README.md b/README.md index 7f8558f..c489854 100644 --- a/README.md +++ b/README.md @@ -11,48 +11,37 @@ Instances of a sub-8 bit floating point type are still 8 bits wide in memory; th ## Usage -Along with the types already exported by Microfloats, we can also create our own types by passing the number of sign, exponent, and mantissa bits to the `Microfloat` type constructor. For example, one can recreate the `Float8` and `Float8_4` types exported by Float8s.jl: +Define your own primitive type with the macro: ```julia using Microfloats -# IEEE_754_like variant for {Float64,Float32,Float16}-like overflowing -const MicrofloatIEEE{S,E,M} = Microfloat{S,E,M,IEEE_754_like} - -const Float8 = MicrofloatIEEE{1,3,4} -const Float8_4 = MicrofloatIEEE{1,4,3} +@microfloat MyE5M2 sign=1 exponent=5 significand=2 nonfinite=IEEE +``` -# creating a sawed-off Float16 (BFloat8?) becomes trivial: -const Float8_5 = MicrofloatIEEE{1,5,2} +Or the hand-written equivalent: -# unsigned variants: -const UFloat7 = MicrofloatIEEE{0,3,4} -const UFloat7_4 = MicrofloatIEEE{0,4,3} -const UFloat7_5 = MicrofloatIEEE{0,5,2} +```julia +primitive type MyE5M2 <: Microfloat{1,5,2} 8 end +Microfloats.non_finite_behavior(::Type{MyE5M2}) = IEEE ``` -### Microscaling (MX) +## Overflow policy -Microfloats implements the E4M3, E5M2, E2M3, E3M2, E2M1, and E8M0 types from the [Open Compute Project Microscaling Formats (MX) Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). These are exported as `MX_E4M3`, `MX_E5M2`, `MX_E2M3`, `MX_E3M2`, `MX_E2M1`, and `MX_E8M0`, respectively, with most of these using saturated arithmetic (no Inf or NaN), and a different encoding for the types that do have NaNs. +`SAT` saturates out-of-range values to `±floatmax(T)`. `OVF` uses the type's +sentinel (`±Inf` for IEEE, `NaN` for NanOnlyAllOnes; throws for FiniteOnly). For INT8, see `FixedPointNumbers.Q1f6`. -> [!NOTE] -> MX types may not be fully MX compliant, but efforts have been and continue to be made to adhere to the specification. See issues with the [![MX-compliance](https://img.shields.io/github/labels/MurrellGroup/Microfloats.jl/mx-compliance)](https://github.com/MurrellGroup/Microfloats.jl/labels/mx-compliance) label. - -Since Microfloats.jl only implements the primitive types, microscaling itself may be done with [Microscaling.jl](https://github.com/MurrellGroup/Microscaling.jl), which includes quantization and bitpacking. - ## Installation ```julia using Pkg -Pkg.Registry.add(url="https://github.com/MurrellGroup/MurrellGroupRegistry") Pkg.add("Microfloats") ``` ## See also -- [Microscaling.jl](https://github.com/MurrellGroup/Microscaling.jl) - [FixedPointNumbers.jl](https://github.com/JuliaMath/FixedPointNumbers.jl) - [MicroFloatingPoints.jl](https://github.com/goualard-f/MicroFloatingPoints.jl) - [DLFP8Types.jl](https://github.com/chengchingwen/DLFP8Types.jl) diff --git a/docs/src/microfloats.md b/docs/src/microfloats.md index 607dc78..3722b3d 100644 --- a/docs/src/microfloats.md +++ b/docs/src/microfloats.md @@ -15,8 +15,6 @@ Finite These types have IEEE 754-like Inf/NaN encodings, with Inf being represented as all 1s in the exponent and a significand of zero, and NaN being represented as all 1s in the exponent and a non-zero significand. ```@docs -IEEE_754_like -Float8_E3M4 Float8_E4M3 Float8_E5M2 Float6_E2M3 diff --git a/ext/BitPackingExt.jl b/ext/BitPackingExt.jl deleted file mode 100644 index 9d2a035..0000000 --- a/ext/BitPackingExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module BitPackingExt - -using Microfloats -using BitPacking - -BitPacking.bitwidth(::Type{T}) where T<:Microfloat = Microfloats.total_bits(T) - -end diff --git a/src/Microfloat.jl b/src/Microfloat.jl index 6ede894..5ed2d97 100644 --- a/src/Microfloat.jl +++ b/src/Microfloat.jl @@ -1,53 +1,154 @@ -import Base: - reinterpret, - exponent_bits, significand_bits, - sign_mask, exponent_mask, significand_mask, - signbit, exponent, exponent_bias +@republic import Base: signbit, exponent -sign_bits(::Type{<:AbstractFloat}) = 1 -total_bits(::Type{T}) where T<:AbstractFloat = sign_bits(T) + exponent_bits(T) + significand_bits(T) +""" + Microfloat{S,E,M} <: AbstractFloat -primitive type Microfloat{S,E,M,V} <: AbstractFloat 8 end +Abstract type for within-byte floating-point numbers with `S` sign bits +(0 or 1), `E` exponent bits (≥ 1), and `M` significand bits (≥ 0), +with `S + E + M ≤ 8`. -reinterpret(::Type{Unsigned}, x::Microfloat) = reinterpret(UInt8, x) +Concrete subtypes are 8-bit `primitive type`s that must also register +a [`non_finite_behavior`](@ref). See [`@microfloat`](@ref) for the +macro-based convenience declaration. +""" +abstract type Microfloat{S,E,M} <: AbstractFloat end -sign_bits(::Type{<:Microfloat{S}}) where S = S -exponent_bits(::Type{<:Microfloat{<:Any,E}}) where E = E -significand_bits(::Type{<:Microfloat{<:Any,<:Any,M}}) where M = M -variant(::Type{<:Microfloat{<:Any,<:Any,<:Any,V}}) where V = V +float_bits(::Type{<:Microfloat{S,E,M}}) where {S,E,M} = (S,E,M) -sign_mask(::Type{T}) where T<:Microfloat = (0x01 << sign_bits(T) - 0x01) << (exponent_bits(T) + significand_bits(T)) -exponent_mask(::Type{T}) where T<:Microfloat = (0x01 << exponent_bits(T) - 0x01) << significand_bits(T) -significand_mask(::Type{T}) where T<:Microfloat = 0x01 << significand_bits(T) - 0x01 +sign_mask(::Type{T}) where T<:Microfloat = UInt8((0x01 << sign_bits(T) - 0x01) << (exponent_bits(T) + significand_bits(T))) +exponent_mask(::Type{T}) where T<:Microfloat = UInt8((0x01 << exponent_bits(T) - 0x01) << significand_bits(T)) +significand_mask(::Type{T}) where T<:Microfloat = UInt8(0x01 << significand_bits(T) - 0x01) -signbit(x::Microfloat) = sign_bits(typeof(x)) > 0 && reinterpret(Unsigned, x) & sign_mask(typeof(x)) == sign_mask(typeof(x)) +Base.reinterpret(::Type{Unsigned}, x::Microfloat) = reinterpret(UInt8, x) + +signbit(x::Microfloat) = sign_bits(typeof(x)) > 0 && !iszero(reinterpret(Unsigned, x) & sign_mask(typeof(x))) exponent(x::Microfloat) = Int((reinterpret(Unsigned, x) & exponent_mask(typeof(x))) >> significand_bits(typeof(x))) -exponent_bias(::Type{T}) where T<:Microfloat = 2^(exponent_bits(T) - 1) - 1 -hasinf(::Type{<:Microfloat}) = false -hasnan(::Type{<:Microfloat}) = false +function Base.show(io::IO, x::T) where T<:Microfloat + show_typeinfo = get(IOContext(io), :typeinfo, nothing) != T + show_typeinfo && print(io, repr(T), "(") + print(io, Float64(x)) + show_typeinfo && print(io, ")") + return nothing +end + +""" + NonFiniteBehavior + +Trait hierarchy describing how a [`Microfloat`](@ref) type encodes non-finite +values. Each concrete `Microfloat` subtype registers its behavior by defining +a [`non_finite_behavior`](@ref) method. + +Three behaviors: + +- [`IEEE`](@ref): exponent all-ones with zero significand ⇒ Inf; + all-ones exponent with nonzero significand ⇒ NaN. +- [`NanOnlyAllOnes`](@ref): no Inf. The single NaN encoding has all + exponent and significand bits set. +- [`FiniteOnly`](@ref): no Inf and no NaN — every bit pattern is finite. + Matches MX sub-byte types and `F4E2M1FN`. +""" +abstract type NonFiniteBehavior end + +"""IEEE-754-style encoding of Inf and NaN. Requires `M ≥ 1`.""" +abstract type IEEE <: NonFiniteBehavior end + +"""NaN encoded as all-ones in exponent+significand; no Inf.""" +abstract type NanOnlyAllOnes <: NonFiniteBehavior end + +"""No Inf or NaN — every bit pattern is a finite value.""" +abstract type FiniteOnly <: NonFiniteBehavior end + +hasinf(::Type{IEEE}) = true +hasinf(::Type{NanOnlyAllOnes}) = false +hasinf(::Type{FiniteOnly}) = false + +hasnan(::Type{IEEE}) = true +hasnan(::Type{NanOnlyAllOnes}) = true +hasnan(::Type{FiniteOnly}) = false -inf(::Type{T}) where T<:Microfloat = throw(DomainError(T, lazy"$T has no Inf")) -nan(::Type{T}) where T<:Microfloat = throw(DomainError(T, lazy"$T has no NaN")) +""" + non_finite_behavior(T) -> Type{<:NonFiniteBehavior} -Base.isinf(::Microfloat) = false -Base.isnan(::Microfloat) = false +Required trait method on every concrete [`Microfloat`](@ref) subtype. +Returns one of `IEEE`, `NanOnlyAllOnes`, or `FiniteOnly`. +""" +non_finite_behavior(::Type{T}) where T<:Microfloat = + error("$T must define `Microfloats.non_finite_behavior(::Type{$T})`") + +hasinf(::Type{T}) where T<:Microfloat = hasinf(non_finite_behavior(T)) +hasnan(::Type{T}) where T<:Microfloat = hasnan(non_finite_behavior(T)) + +# ───────────────────────── Inf / NaN / floatmax / inf / nan ────────────────────────── + +Base.isinf(x::T) where T<:Microfloat = _isinf(non_finite_behavior(T), x) +Base.isnan(x::T) where T<:Microfloat = _isnan(non_finite_behavior(T), x) + +"""Bit pattern for +Inf. Throws if the type has no Inf.""" +inf(::Type{T}) where T<:Microfloat = _inf(non_finite_behavior(T), T) + +"""Bit pattern for NaN. Throws if the type has no NaN.""" +nan(::Type{T}) where T<:Microfloat = _nan(non_finite_behavior(T), T) + +Base.floatmax(::Type{T}) where T<:Microfloat = _floatmax(non_finite_behavior(T), T) + +# IEEE +function _isinf(::Type{IEEE}, x::T) where T<:Microfloat + raw = reinterpret(Unsigned, x) + (raw & exponent_mask(T)) == exponent_mask(T) && iszero(raw & significand_mask(T)) +end +function _isnan(::Type{IEEE}, x::T) where T<:Microfloat + raw = reinterpret(Unsigned, x) + (raw & exponent_mask(T)) == exponent_mask(T) && !iszero(raw & significand_mask(T)) +end +_inf(::Type{IEEE}, ::Type{T}) where T<:Microfloat = reinterpret(T, exponent_mask(T)) +_nan(::Type{IEEE}, ::Type{T}) where T<:Microfloat = + reinterpret(T, exponent_mask(T) | (UInt8(0x01) << (significand_bits(T) - 1))) +_floatmax(::Type{IEEE}, ::Type{T}) where T<:Microfloat = + reinterpret(T, (exponent_mask(T) - (UInt8(0x01) << significand_bits(T))) | significand_mask(T)) + +# NanOnlyAllOnes +_isinf(::Type{NanOnlyAllOnes}, ::Microfloat) = false +function _isnan(::Type{NanOnlyAllOnes}, x::T) where T<:Microfloat + raw = reinterpret(Unsigned, x) + (raw & ~sign_mask(T)) == (exponent_mask(T) | significand_mask(T)) +end +_inf(::Type{NanOnlyAllOnes}, ::Type{T}) where T<:Microfloat = + throw(DomainError(T, "$T has no Inf")) +_nan(::Type{NanOnlyAllOnes}, ::Type{T}) where T<:Microfloat = + reinterpret(T, exponent_mask(T) | significand_mask(T)) +_floatmax(::Type{NanOnlyAllOnes}, ::Type{T}) where T<:Microfloat = + reinterpret(T, (exponent_mask(T) | significand_mask(T)) - UInt8(0x01)) + +# FiniteOnly +_isinf(::Type{FiniteOnly}, ::Microfloat) = false +_isnan(::Type{FiniteOnly}, ::Microfloat) = false +_inf(::Type{FiniteOnly}, ::Type{T}) where T<:Microfloat = + throw(DomainError(T, "$T has no Inf")) +_nan(::Type{FiniteOnly}, ::Type{T}) where T<:Microfloat = + throw(DomainError(T, "$T has no NaN")) +_floatmax(::Type{FiniteOnly}, ::Type{T}) where T<:Microfloat = + reinterpret(T, exponent_mask(T) | significand_mask(T)) + +# ───────────────────────── generic Base methods ────────────────────────── Base.typemin(::Type{T}) where T<:Microfloat{0} = zero(T) Base.typemin(::Type{T}) where T<:Microfloat = hasinf(T) ? -inf(T) : -floatmax(T) Base.typemax(::Type{T}) where T<:Microfloat = hasinf(T) ? inf(T) : floatmax(T) -Base.floatmin(::Type{T}) where T<:Microfloat = exponent_bits(T) > 1 ? reinterpret(T, one(UInt8) << significand_bits(T)) : throw(DomainError(T, "$T has no normal numbers")) -Base.floatmax(::Type{T}) where T<:Microfloat = reinterpret(T, exponent_mask(T) | significand_mask(T)) +Base.floatmin(::Type{T}) where T<:Microfloat = + significand_bits(T) == 0 ? reinterpret(T, 0x00) : + reinterpret(T, UInt8(0x01) << significand_bits(T)) Base.zero(::Type{T}) where T<:Microfloat = reinterpret(T, 0x00) Base.one(::Type{T}) where T<:Microfloat = T(true) -Base.eps(x::Microfloat) = max(x-prevfloat(x), nextfloat(x)-x) +Base.eps(x::Microfloat) = max(x - prevfloat(x), nextfloat(x) - x) Base.eps(T::Type{<:Microfloat}) = eps(one(T)) Base.abs(x::T) where T<:Microfloat = reinterpret(T, reinterpret(Unsigned, x) & ~sign_mask(T)) -Base.iszero(x::T) where T<:Microfloat = abs(x) === zero(T) +Base.iszero(x::T) where T<:Microfloat = significand_bits(T) == 0 ? false : abs(x) === zero(T) +Base.:(-)(x::T) where T<:Microfloat{0} = throw(DomainError(x, "cannot negate unsigned $T")) Base.:(-)(x::T) where T<:Microfloat = reinterpret(T, sign_mask(T) ⊻ reinterpret(Unsigned, x)) Base.Bool(x::T) where T<:Microfloat = iszero(x) ? false : isone(x) ? true : throw(InexactError(:Bool, Bool, x)) @@ -57,7 +158,8 @@ Base.sign(x::Microfloat) = ifelse(isnan(x) | iszero(x), x, ifelse(signbit(x), -o Base.round(x::T, r::RoundingMode; kws...) where T<:Microfloat = T(round(Float32(x), r; kws...)) -Base.issubnormal(x::T) where T<:Microfloat = 0x00 < (reinterpret(Unsigned, x) & ~sign_mask(T)) <= (0x01 << significand_bits(T)) - 0x01 +Base.issubnormal(x::T) where T<:Microfloat = + 0x00 < (reinterpret(Unsigned, x) & ~sign_mask(T)) <= (UInt8(0x01) << significand_bits(T)) - 0x01 ispositive(x::T) where T<:Microfloat = iszero(reinterpret(Unsigned, x) & sign_mask(T)) @@ -81,9 +183,12 @@ function Base.prevfloat(x::T) where T<:Microfloat elseif isinf(x) return ispositive(x) ? floatmax(T) : -inf(T) elseif iszero(x) + sign_bits(T) == 0 && return x return reinterpret(T, sign_mask(T) | 0x01) elseif ispositive(x) - return reinterpret(T, reinterpret(Unsigned, x) - 0x01) + raw = reinterpret(Unsigned, x) + raw == 0x00 && sign_bits(T) == 0 && return x + return reinterpret(T, raw - 0x01) else return reinterpret(T, reinterpret(Unsigned, x) + 0x01) end @@ -96,4 +201,3 @@ Base.widen(::Type{T}) where T<:Microfloat = BFloat16 Base.promote_rule(::Type{M}, ::Type{T}) where {M<:Microfloat,T<:Union{BFloat16,Float16,Float32,Float64}} = T Base.promote_rule(::Type{M}, ::Type{T}) where {M<:Microfloat,T<:Integer} = M Base.promote_rule(::Type{M}, ::Type{M}) where {M<:Microfloat} = M -Base.promote_rule(::Type{<:Microfloat}, ::Type{<:Microfloat}) = BFloat16 diff --git a/src/Microfloats.jl b/src/Microfloats.jl index 9069656..f3574fa 100644 --- a/src/Microfloats.jl +++ b/src/Microfloats.jl @@ -1,80 +1,63 @@ module Microfloats -import BFloat16s: BFloat16 -export BFloat16 +using Republic -include("Microfloat.jl") -export Microfloat +@reexport import BFloat16s: BFloat16 -include("variants/Finite.jl") -export Finite +float_bits(::Type{Float64}) = (1, 11, 52) +float_bits(::Type{Float32}) = (1, 8, 23) +float_bits(::Type{Float16}) = (1, 5, 10) +float_bits(::Type{BFloat16}) = (1, 8, 7) -include("variants/IEEE_754_like.jl") -export IEEE_754_like -export Float8_E5M2 -export Float8_E4M3 -export Float8_E3M4 -export Float6_E3M2 -export Float6_E2M3 -export Float4_E2M1 +bitwidth(::Type{T}) where T<:AbstractFloat = sum(float_bits(T)) +sign_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[1] +exponent_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[2] +significand_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[3] -include("variants/MX.jl") -export MX -export MX_E5M2 -export MX_E4M3 -export MX_E3M2 -export MX_E2M3 -export MX_E2M1 -export MX_E8M0 +exponent_bias(::Type{T}) where T<:AbstractFloat = 2^(exponent_bits(T) - 1) - 1 -# something weird happens when the @generated BFloat16 -# method is put before the other includes. -# presumably precompilation is being excessively greedy -include("conversion.jl") -export OVF, SAT +@public bitwidth, sign_bits, exponent_bits, significand_bits -include("ops.jl") -include("show.jl") -include("random.jl") +include("Microfloat.jl") +export Microfloat +@public sign_mask, exponent_mask, significand_mask +@public NonFiniteBehavior, non_finite_behavior, hasinf, hasnan, inf, nan +export IEEE, NanOnlyAllOnes, FiniteOnly -""" - Microfloat{S,E,M,V} +include("conversion.jl") +export OverflowPolicy, SAT, OVF +@public default_overflow_policy -A `Microfloat` type has `S` sign bits (between 0 and 1), -`E` exponent bits (between 1 and 8), and `M` mantissa bits (between 0 and 7). -""" -Microfloat +include("macros.jl") +export @microfloat -""" - SAT -""" -SAT +# Each `@microfloat` call builds a per-type BFloat16 lookup table, +# so conversion.jl must be loaded before this point. +include("variants.jl") +export Float8_E5M2, Float8_E4M3, Float8_E3M4 +export Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU +export Float6_E2M3FN, Float6_E3M2FN +export Float4_E2M1FN -""" - OVF -""" -OVF +include("ops.jl") +include("random.jl") for T in ( - :Float8_E5M2, :Float8_E4M3, :Float8_E3M4, :Float6_E3M2, :Float6_E2M3, :Float4_E2M1, - :MX_E5M2, :MX_E4M3, :MX_E3M2, :MX_E2M3, :MX_E2M1, :MX_E8M0, + :Float8_E4M3FN, :Float8_E5M2, :Float8_E8M0FNU, + :Float6_E2M3FN, :Float6_E3M2FN, + :Float4_E2M1FN, ) - @eval begin - @doc """ - $($T) - - ## Properties - - Bits: `$(sign_bits($T))` sign + `$(exponent_bits($T))` exponent + `$(significand_bits($T))` significand (`$(total_bits($T))` total) - - Variant: `$(variant($T))` - - Has Inf: `$(hasinf($T))` - - Has NaN: `$(hasnan($T))` - - Max normal: `$(Float64(floatmax($T)))` - - Min normal: `$(Float64(floatmin($T)))` - - Max subnormal: `$(significand_bits($T) > 0 ? Float64(prevfloat(floatmin($T))) : "N/A")` - - Min subnormal: `$(significand_bits($T) > 0 ? Float64(nextfloat(zero($T))) : "N/A")` - """ - $T - end + @eval @doc """ + $($T) + + ## Properties + - Bits: `$(sign_bits($T))` sign + `$(exponent_bits($T))` exponent + `$(significand_bits($T))` significand (`$(bitwidth($T))` total) + - Non-finite behavior: `$(non_finite_behavior($T))` + - Has Inf: `$(hasinf($T))` + - Has NaN: `$(hasnan($T))` + - Max normal: `$(Float64(floatmax($T)))` + - Min positive: `$(Float64(floatmin($T)))` + """ $T end end diff --git a/src/conversion.jl b/src/conversion.jl index 1bd0fb5..b2e1949 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,5 +1,21 @@ +""" + OverflowPolicy + +Policy controlling how out-of-range finite input is mapped to a [`Microfloat`](@ref). + +- [`SAT`](@ref): saturate to `±floatmax(T)`. +- [`OVF`](@ref): overflow to the type's sentinel (`±Inf` for `IEEE`, + `NaN` for `NanOnlyAllOnes`; throws for `FiniteOnly` — no sentinel exists). + +Defaults are resolved by [`default_overflow_policy`](@ref) from the type's +[`non_finite_behavior`](@ref): `IEEE` → `OVF`, otherwise `SAT`. +""" abstract type OverflowPolicy end + +"""Saturating conversion: out-of-range finite inputs clamp to `±floatmax(T)`.""" abstract type SAT <: OverflowPolicy end + +"""Sentinel overflow: out-of-range finite inputs go to `±Inf` (IEEE) or `NaN` (NanOnlyAllOnes).""" abstract type OVF <: OverflowPolicy end function rshift_round_to_even(x::UInt16, n::Int) @@ -11,14 +27,15 @@ function rshift_round_to_even(x::UInt16, n::Int) UInt16((x_32 >> n) + (up ? 1 : 0)) end -is_outside_floatmax(xb::BFloat16, ::Type{T}) where T<:Microfloat = reinterpret(Unsigned, abs(xb)) > reinterpret(Unsigned, BFloat16(floatmax(T))) +is_outside_floatmax(xb::BFloat16, ::Type{T}) where T<:Microfloat = + reinterpret(Unsigned, abs(xb)) > reinterpret(Unsigned, BFloat16(floatmax(T))) clamp_floatmax(x::T) where T<:Microfloat = signbit(x) ? -floatmax(T) : floatmax(T) clamp_inf(x::T) where T<:Microfloat = signbit(x) ? -inf(T) : inf(T) function epilogue(x::T, xb::BFloat16, ::Type{P}) where {T<:Microfloat,P<:OverflowPolicy} if P <: SAT if isnan(xb) - return nan(T) + return hasnan(T) ? nan(T) : throw(DomainError(xb, "$T has no NaN")) elseif isinf(xb) || is_outside_floatmax(xb, T) return clamp_floatmax(x) else @@ -26,18 +43,33 @@ function epilogue(x::T, xb::BFloat16, ::Type{P}) where {T<:Microfloat,P<:Overflo end elseif P <: OVF if isnan(xb) - return nan(T) + return hasnan(T) ? nan(T) : throw(DomainError(xb, "$T has no NaN")) elseif isinf(xb) || is_outside_floatmax(xb, T) - return hasinf(T) ? clamp_inf(x) : nan(T) + return hasinf(T) ? clamp_inf(x) : + hasnan(T) ? nan(T) : + throw(DomainError(xb, "$T has no overflow sentinel; use SAT")) else return x end + else + throw(ArgumentError("Unknown overflow policy $P")) end end -default_overflow_policy(::Type{T}) where T = hasnan(T) ? OVF : SAT +""" + default_overflow_policy(T) -> Type{<:OverflowPolicy} + +Default overflow policy for `Microfloat` type `T`. Keys on `hasinf(T)`: +IEEE types default to `OVF` (finite overflow → Inf), all others default +to `SAT` (clamp to `floatmax`). Matches PyTorch/Triton/Quartet-II practice +for FP8/FP4. +""" +default_overflow_policy(::Type{T}) where T<:Microfloat = hasinf(T) ? OVF : SAT function (::Type{T})(x::BFloat16, ::Type{P}=default_overflow_policy(T)) where {T<:Microfloat,P<:OverflowPolicy} + if sign_bits(T) == 0 && signbit(x) + throw(DomainError(x, "negative input to unsigned $T")) + end iszero(x) && return zero(T) bf16_exp = Int((reinterpret(Unsigned, x) >> 7) & 0x00ff) @@ -65,18 +97,22 @@ function (::Type{T})(x::BFloat16, ::Type{P}=default_overflow_policy(T)) where {T # Normal path in target format shift = 7 - significand_bits(T) total = shift >= 0 ? rshift_round_to_even(sig8, shift) : (sig8 << (-shift)) - t_exp_rounded = t_exp + (total >> (significand_bits(T) + 1)) - max_exp = (1 << exponent_bits(T)) - 1 - if t_exp_rounded > max_exp - if !hasinf(T) - t_exp_rounded = max_exp - total = (UInt16(1) << significand_bits(T)) | UInt16((1 << significand_bits(T)) - 1) - else - t_exp_rounded = max_exp + if total == 0 + t_raw = 0x00 + else + t_exp_rounded = t_exp + (total >> (significand_bits(T) + 1)) + max_exp = (1 << exponent_bits(T)) - 1 + if t_exp_rounded > max_exp + if !hasinf(T) + t_exp_rounded = max_exp + total = (UInt16(1) << significand_bits(T)) | UInt16((1 << significand_bits(T)) - 1) + else + t_exp_rounded = max_exp + end end + frac_field = UInt8(total) & UInt8((1 << significand_bits(T)) - 1) + t_raw = (UInt8(t_exp_rounded) << significand_bits(T)) | frac_field end - frac_field = UInt8(total) & UInt8((1 << significand_bits(T)) - 1) - t_raw = (UInt8(t_exp_rounded) << significand_bits(T)) | frac_field end t_raw |= (reinterpret(Unsigned, x) >> 15 % UInt8) << (exponent_bits(T) + significand_bits(T)) & sign_mask(T) @@ -87,7 +123,7 @@ end (::Type{T})(x::Number, args...) where {T<:Microfloat} = T(BFloat16(x), args...) (::Type{T})(::Type{P}) where {T<:Microfloat,P<:OverflowPolicy} = x -> T(x, P) -function to_bfloat16(x::T) where {T<:Microfloat} +function _to_bfloat16(x::T) where {T<:Microfloat} t_raw = reinterpret(UInt8, x) t_sign = (sign_bits(T) == 1) && (t_raw & (UInt8(1) << (exponent_bits(T) + significand_bits(T))) != 0) @@ -96,7 +132,6 @@ function to_bfloat16(x::T) where {T<:Microfloat} bf16_sign_bit = UInt16(t_sign ? 1 : 0) << 15 - # Check for special values first, using trait-aware detection if isinf(x) return reinterpret(BFloat16, bf16_sign_bit | 0x7f80) elseif isnan(x) @@ -108,7 +143,7 @@ function to_bfloat16(x::T) where {T<:Microfloat} M = significand_bits(T) bias = exponent_bias(T) - if t_exponent_field == 0 # Subnormal + if t_exponent_field == 0 && M > 0 # Subnormal nlz = M - 1 - floor(Int, log2(t_fraction_field)) t_significand_total = UInt16(t_fraction_field) << (nlz + 1) t_true_exponent = -nlz - bias @@ -117,7 +152,6 @@ function to_bfloat16(x::T) where {T<:Microfloat} t_true_exponent = t_exponent_field - bias end - # Common path for conversion to BFloat16 bf16_exponent_field = t_true_exponent + 127 bf16_significand_total = if M >= 7 rshift_round_to_even(t_significand_total, M - 7) @@ -146,9 +180,7 @@ function to_bfloat16(x::T) where {T<:Microfloat} end end -@generated function BFloat16(x::T) where T<:Microfloat - lookup = Tuple(to_bfloat16(reinterpret(T, i % UInt8)) for i in 0:2^total_bits(T)-1) - :($lookup[reinterpret(UInt8, x) + 0x0001]) -end +to_bfloat16(x::T) where T<:Microfloat = _to_bfloat16(x) +BFloat16(x::T) where T<:Microfloat = to_bfloat16(x) (::Type{T})(x::Microfloat) where {T<:AbstractFloat} = T(BFloat16(x)) diff --git a/src/macros.jl b/src/macros.jl new file mode 100644 index 0000000..b44860e --- /dev/null +++ b/src/macros.jl @@ -0,0 +1,63 @@ +""" + @microfloat Name sign=1 exponent=E significand=M nonfinite=Trait + +Declare an 8-bit `primitive type Name <: Microfloat{sign,E,M} 8 end` and +register its [`non_finite_behavior`](@ref) as `Trait`. + +All keyword arguments are passed positionally as `name=value` pairs: + +- `sign` — `0` or `1`. Default `1`. +- `exponent` — required, `≥ 1`. +- `significand` — required, `≥ 0`. +- `nonfinite` — required. One of `IEEE`, `NanOnlyAllOnes`, `FiniteOnly`. + +Hand-written equivalent (also supported): + +```julia +primitive type Name <: Microfloat{sign,exponent,significand} 8 end +non_finite_behavior(::Type{Name}) = Trait +``` +""" +macro microfloat(name, kwargs...) + mod = @__MODULE__ + S = 1 + E = nothing + M = nothing + behavior = IEEE + + for kw in kwargs + (kw isa Expr && kw.head == :(=)) || + error("@microfloat: expected keyword arguments (e.g. exponent=5), got $kw") + k, v = kw.args + if k == :sign + S = v + elseif k == :exponent + E = v + elseif k == :significand + M = v + elseif k == :nonfinite + behavior = v + else + error("@microfloat: unknown keyword `$k`") + end + end + + E === nothing && error("@microfloat: `exponent` is required") + M === nothing && error("@microfloat: `significand` is required") + behavior isa NonFiniteBehavior && error("@microfloat: `nonfinite` is required") + S in (0, 1) || error("@microfloat: `sign` must be 0 or 1, got $S") + E >= 1 || error("@microfloat: `exponent` must be >= 1, got $E") + M >= 0 || error("@microfloat: `significand` must be >= 0, got $M") + S + E + M <= 8 || error("@microfloat: `sign + exponent + significand` must be <= 8, got $(S + E + M)") + + T = esc(name) + trait = esc(behavior) + N = S + E + M + quote + primitive type $T <: $mod.Microfloat{$S,$E,$M} 8 end + $mod.non_finite_behavior(::Type{$T}) = $trait + let lookup = Tuple($mod._to_bfloat16(reinterpret($T, i % UInt8)) for i in 0:$(2^N - 1)) + $mod.to_bfloat16(x::$T) = lookup[reinterpret(UInt8, x) + 0x0001] + end + end +end diff --git a/src/ops.jl b/src/ops.jl index 077c154..ff7689d 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -7,8 +7,10 @@ end import Base: (+), (-), (*), (/), (\), (^) +# Same-type binary ops. Cross-microfloat ops are intentionally unsupported — +# callers must explicitly cast to a wider type first for op in (:+, :-, :*, :/, :\, :^) - @eval ($op)(a::Microfloat, b::Microfloat) = promote_type(typeof(a), typeof(b))(($op)(Float32(a), Float32(b))) + @eval ($op)(a::T, b::T) where T<:Microfloat = T(($op)(Float32(a), Float32(b))) end (^)(a::T, b::Integer) where T<:Microfloat = T(Float32(a)^b) diff --git a/src/random.jl b/src/random.jl index c57d901..7f0e761 100644 --- a/src/random.jl +++ b/src/random.jl @@ -6,8 +6,9 @@ function Base.rand(rng::Random.AbstractRNG, ::Random.SamplerTrivial{Random.Close end # Standard normal sampling for signed Microfloats -function Base.randn(rng::Random.AbstractRNG, ::Type{T}) where {S,E,M,V,T<:Microfloat{S,E,M,V}} - S == 0 && throw(ArgumentError("randn is undefined for unsigned microfloats (no sign bit)")) +Base.randn(::Random.AbstractRNG, ::Type{T}) where T<:Microfloat = + throw(ArgumentError("randn is undefined for unsigned microfloats (must have 1 sign bit)")) +function Base.randn(rng::Random.AbstractRNG, ::Type{T}) where T<:Microfloat{1} z = randn(rng, Float32) b = Float32(floatmax(T)) return T(clamp(z, -b, b)) diff --git a/src/show.jl b/src/show.jl deleted file mode 100644 index d25922f..0000000 --- a/src/show.jl +++ /dev/null @@ -1,8 +0,0 @@ -function Base.show(io::IO, x::T) where T <: Microfloat - show_typeinfo = get(IOContext(io), :typeinfo, nothing) != T - type = repr(T) - show_typeinfo && print(io, type, "(") - print(io, Float64(x)) - show_typeinfo && print(io, ")") - return nothing -end diff --git a/src/variants.jl b/src/variants.jl new file mode 100644 index 0000000..c9b2ae9 --- /dev/null +++ b/src/variants.jl @@ -0,0 +1,13 @@ +# IEEE-like +@microfloat Float8_E5M2 exponent=5 significand=2 +@microfloat Float8_E4M3 exponent=4 significand=3 +@microfloat Float8_E3M4 exponent=3 significand=4 + +# NanOnlyAllOnes (FN-suffixed) +@microfloat Float8_E4M3FN exponent=4 significand=3 nonfinite=NanOnlyAllOnes +@microfloat Float8_E8M0FNU sign=0 exponent=8 significand=0 nonfinite=NanOnlyAllOnes + +# FiniteOnly +@microfloat Float6_E2M3FN exponent=2 significand=3 nonfinite=FiniteOnly +@microfloat Float6_E3M2FN exponent=3 significand=2 nonfinite=FiniteOnly +@microfloat Float4_E2M1FN exponent=2 significand=1 nonfinite=FiniteOnly diff --git a/src/variants/Finite.jl b/src/variants/Finite.jl deleted file mode 100644 index 2b1f44a..0000000 --- a/src/variants/Finite.jl +++ /dev/null @@ -1,6 +0,0 @@ -""" - Finite - -A variant of the `Microfloat` type that supports finite values. -""" -abstract type Finite end diff --git a/src/variants/IEEE_754_like.jl b/src/variants/IEEE_754_like.jl deleted file mode 100644 index 93d7b94..0000000 --- a/src/variants/IEEE_754_like.jl +++ /dev/null @@ -1,24 +0,0 @@ -""" - IEEE_754_like -""" -abstract type IEEE_754_like end - -const IEEEFloat{S,E,M} = Microfloat{S,E,M,IEEE_754_like} - -hasinf(::Type{<:IEEEFloat}) = true -hasnan(::Type{<:IEEEFloat}) = true - -inf(::Type{T}) where T<:IEEEFloat = reinterpret(T, exponent_mask(T)) -nan(::Type{T}) where T<:IEEEFloat = reinterpret(T, exponent_mask(T) | 0x01 << (significand_bits(T) - 1)) - -Base.isinf(x::T) where T<:IEEEFloat = reinterpret(Unsigned, x) & exponent_mask(T) == exponent_mask(T) && iszero(reinterpret(Unsigned, x) & significand_mask(T)) -Base.isnan(x::T) where T<:IEEEFloat = reinterpret(Unsigned, x) & exponent_mask(T) == exponent_mask(T) && !iszero(reinterpret(Unsigned, x) & significand_mask(T)) - -Base.floatmax(::Type{T}) where T<:IEEEFloat = reinterpret(T, exponent_mask(T) - 0x01 << significand_bits(T) | significand_mask(T)) - -const Float8_E3M4 = IEEEFloat{1,3,4} -const Float8_E4M3 = IEEEFloat{1,4,3} -const Float8_E5M2 = IEEEFloat{1,5,2} -const Float6_E2M3 = IEEEFloat{1,2,3} -const Float6_E3M2 = IEEEFloat{1,3,2} -const Float4_E2M1 = IEEEFloat{1,2,1} diff --git a/src/variants/MX.jl b/src/variants/MX.jl deleted file mode 100644 index 9c853ed..0000000 --- a/src/variants/MX.jl +++ /dev/null @@ -1,31 +0,0 @@ -""" - MX -""" -abstract type MX end - -const MXFloat{S,E,M} = Microfloat{S,E,M,MX} - -const MX_E5M2 = MXFloat{1,5,2} -const MX_E4M3 = MXFloat{1,4,3} -const MX_E3M2 = MXFloat{1,3,2} -const MX_E2M3 = MXFloat{1,2,3} -const MX_E2M1 = MXFloat{1,2,1} -const MX_E8M0 = MXFloat{0,8,0} - -hasinf(::Type{MX_E5M2}) = true -hasnan(::Type{MX_E5M2}) = true -inf(::Type{T}) where T<:MX_E5M2 = reinterpret(T, exponent_mask(T)) -nan(::Type{T}) where T<:MX_E5M2 = reinterpret(T, exponent_mask(T) | 0x01 << (significand_bits(T) - 1)) -Base.isinf(x::T) where T<:MX_E5M2 = hasinf(T) && reinterpret(Unsigned, x) & exponent_mask(T) == exponent_mask(T) && iszero(reinterpret(Unsigned, x) & significand_mask(T)) -Base.isnan(x::T) where T<:MX_E5M2 = reinterpret(Unsigned, x) & exponent_mask(T) == exponent_mask(T) && !iszero(reinterpret(Unsigned, x) & significand_mask(T)) -Base.floatmax(::Type{T}) where T<:MX_E5M2 = reinterpret(T, exponent_mask(T) - 0x01 << significand_bits(T) | significand_mask(T)) - -hasnan(::Type{MX_E4M3}) = true -nan(::Type{T}) where T<:MX_E4M3 = reinterpret(T, exponent_mask(T) | significand_mask(T)) -Base.isnan(x::T) where T<:MX_E4M3 = reinterpret(Unsigned, x) & ~sign_mask(T) == (exponent_mask(T) | significand_mask(T)) -Base.floatmax(::Type{T}) where T<:MX_E4M3 = reinterpret(T, exponent_mask(T) | (significand_mask(T) - 0x01)) - -hasnan(::Type{MX_E8M0}) = true -nan(::Type{T}) where T<:MX_E8M0 = reinterpret(T, 0xff) -Base.isnan(x::T) where T<:MX_E8M0 = reinterpret(Unsigned, x) == 0xff -Base.floatmax(::Type{T}) where T<:MX_E8M0 = reinterpret(T, 0xfe) diff --git a/test/Float8s/runtests.jl b/test/Float8s/runtests.jl deleted file mode 100644 index 10c59b5..0000000 --- a/test/Float8s/runtests.jl +++ /dev/null @@ -1,39 +0,0 @@ -# check parity with the Float8s.jl package - -using Float8s: Float8, Float8_4 - -@testset "Float8s.jl parity" begin - - @testset "E3M4" begin - - @testset for i in 0x00:0xfe - @test Float8_E3M4(Float32(reinterpret(Float8, i))) ≡ reinterpret(Float8_E3M4, i) - - @test Float8(Float32(reinterpret(Float8_E3M4, i))) ≡ reinterpret(Float8, i) - - @test Float32(reinterpret(Float8, i)) ≡ - Float32(reinterpret(Float8_E3M4, i)) - - @test Float32(Float8(Float32(reinterpret(Float8, i)))) ≡ - Float32(Float8_E3M4(Float32(reinterpret(Float8_E3M4, i)))) - end - - end - - @testset "E4M3" begin - - @testset for i in 0x00:0xfe - @test Float8_E4M3(Float32(reinterpret(Float8_4, i))) ≡ reinterpret(Float8_E4M3, i) - - @test Float8_4(Float32(reinterpret(Float8_E4M3, i))) ≡ reinterpret(Float8_4, i) - - @test Float32(reinterpret(Float8_4, i)) ≡ - Float32(reinterpret(Float8_E4M3, i)) - - @test Float32(Float8_4(Float32(reinterpret(Float8_4, i)))) ≡ - Float32(Float8_E4M3(Float32(reinterpret(Float8_E4M3, i)))) - end - - end - -end diff --git a/test/MX/runtests.jl b/test/MX/runtests.jl deleted file mode 100644 index 9a9aca6..0000000 --- a/test/MX/runtests.jl +++ /dev/null @@ -1,4 +0,0 @@ -@testset "MX" begin - include("MX_compliance.jl") - include("MX_properties.jl") -end \ No newline at end of file diff --git a/test/Microfloat.jl b/test/Microfloat.jl deleted file mode 100644 index 79d0c61..0000000 --- a/test/Microfloat.jl +++ /dev/null @@ -1,151 +0,0 @@ -using Test -using Microfloats -using Random - -const TYPES = [ - Microfloat{0, 3, 4, IEEE_754_like}, - Microfloat{0, 4, 3, IEEE_754_like}, - Microfloat{0, 3, 3, IEEE_754_like}, - Microfloat{0, 4, 2, IEEE_754_like}, - Microfloat{0, 5, 1, IEEE_754_like}, - Microfloat{0, 3, 2, IEEE_754_like}, - Microfloat{0, 2, 3, IEEE_754_like}, - Microfloat{0, 2, 2, IEEE_754_like}, - Microfloat{0, 3, 1, IEEE_754_like}, - Microfloat{0, 1, 3, IEEE_754_like}, - Microfloat{0, 2, 1, IEEE_754_like}, - Microfloat{1, 3, 4, IEEE_754_like}, - Microfloat{1, 4, 3, IEEE_754_like}, - Microfloat{1, 3, 3, IEEE_754_like}, - Microfloat{1, 4, 2, IEEE_754_like}, - Microfloat{1, 5, 1, IEEE_754_like}, - Microfloat{1, 3, 2, IEEE_754_like}, - Microfloat{1, 2, 3, IEEE_754_like}, - Microfloat{1, 2, 2, IEEE_754_like}, - Microfloat{1, 3, 1, IEEE_754_like}, - Microfloat{1, 1, 3, IEEE_754_like}, - Microfloat{1, 2, 1, IEEE_754_like}, -] - -@testset "Microfloat" begin - - @testset for T in TYPES - @test hash(one(T)) == hash(1) - - @test prevfloat(eps(T)) < eps(T) - @test nextfloat(eps(T)) > eps(T) - @test nextfloat(zero(T)) > zero(T) - @test isfinite(prevfloat(T(Inf))) - - if Base.exponent_bits(T) > 1 - @test floatmin(T) == reinterpret(T, 0x01 << Base.significand_bits(T)) - else - @test_throws DomainError floatmin(T) - end - @test floatmax(T) == prevfloat(T(Inf)) - - @test typemax(T) == T(Inf) - - @test sign(T(Inf)) == 1.0 - @test sign(T(1.0)) == 1.0 - @test sign(T(0.0)) == 0.0 - @test isnan(sign(T(NaN))) - - if T <: Microfloat{1} - @test typemin(T) == T(-Inf) - - @test sign(T(-0.0)) == -0.0 - @test sign(T(-1.0)) == -1.0 - @test sign(T(-Inf)) == -1.0 - else - @test typemin(T) == zero(T) - end - - @test precision(T) == Base.significand_bits(T) + 1 - end -end - -@testset "IEEE microfloats: subnormals and rounding" begin - @testset for T in TYPES - @test !issubnormal(zero(T)) - @test issubnormal(nextfloat(zero(T))) - - if Base.exponent_bits(T) > 1 - @test issubnormal(prevfloat(floatmin(T))) - @test !issubnormal(floatmin(T)) - end - - min_sub_u = 0x01 - min_sub = reinterpret(T, min_sub_u) - - # Real values - min_sub_val = BFloat16(2.0)^(1 - Microfloats.exponent_bias(T) - Base.significand_bits(T)) - half = min_sub_val/2 - just_below_half = prevfloat(half) - just_above_half = nextfloat(half) - just_below = prevfloat(min_sub_val) - just_above = nextfloat(min_sub_val) - - # Exact min subnormal - @test BFloat16(min_sub) == min_sub_val - - # Values well below half of min subnormal should round to +0 - @test T(half/4) == zero(T) - - # Exactly half rounds to even -> zero; below half also zero - @test T(half) == zero(T) - @test T(just_below_half) == zero(T) - - # Values just above half of min subnormal should round to min subnormal - @test T(just_above_half) == min_sub - - # Values just below min subnormal remain min subnormal after rounding up from BFloat16 - @test T(just_below) == min_sub - - # Values just above min subnormal quantize to min subnormal or the next representable - # depending on spacing; at least should be >= min_sub - @test BFloat16(T(just_above)) >= min_sub_val - end -end - -@testset "IEEE microfloats: monotonic Float32 mapping (canonical encodings)" begin - @testset for T in TYPES - vals = Tuple{UInt8,Float32,Any}[] - for u in UInt8(0):UInt8(0xff) - x = reinterpret(T, u) - isnan(x) && continue - # Only include canonical encodings: padding bits outside fields are zero - used_mask = Microfloats.sign_mask(T) | Base.exponent_mask(T) | Base.significand_mask(T) - ((u & ~used_mask) != 0x00) && continue - push!(vals, (u, Float32(x), x)) - end - sort!(vals, by = t -> t[2]) - for i in 1:length(vals)-1 - a = vals[i]; b = vals[i+1] - if a[2] == b[2] - # duplicate comes only from signed zeros - @test iszero(a[3]) && iszero(b[3]) - else - @test a[2] < b[2] - end - end - end -end - -@testset "IEEE microfloats: rand and randn" begin - rng = MersenneTwister(123) - @testset for T in TYPES - @testset "$T rand()" begin - xs = rand(rng, T, 1000) - @test all(x -> isfinite(x), xs) - @test any(x -> x != zero(T), xs) # likely non-degenerate - end - if Microfloats.sign_bits(T) == 1 - @testset "$T randn()" begin - xs = randn(rng, T, 1000) - @test all(isfinite, xs) - @test any(x -> x != zero(T), xs) - end - end - end -end diff --git a/test/Project.toml b/test/Project.toml index e56b2ae..ca75f0d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,5 @@ [deps] -BitPacking = "b58c8408-13c4-4787-8733-7038ae624acf" -Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" +DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" +Microfloats = "31c70f10-a750-4521-b13c-797315ae2933" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/basic.jl b/test/basic.jl new file mode 100644 index 0000000..77866b2 --- /dev/null +++ b/test/basic.jl @@ -0,0 +1,124 @@ +@testset "Microfloat generic properties" begin + @testset "$T" for T in TYPES + @test hash(one(T)) == hash(1) + @test precision(T) == Microfloats.significand_bits(T) + 1 + @test floatmax(T) > zero(T) + @test isfinite(floatmax(T)) + + @test signbit(zero(T)) == false + + if hasinf(T) + @test typemax(T) == inf(T) + @test floatmax(T) == prevfloat(inf(T)) + else + @test typemax(T) == floatmax(T) + end + + if sign_bits(T) == 1 + @test typemin(T) == (hasinf(T) ? -inf(T) : -floatmax(T)) + else + @test typemin(T) == zero(T) + end + + # nextfloat/prevfloat consistency + @test nextfloat(zero(T)) > zero(T) + sign_bits(T) == 1 && @test prevfloat(zero(T)) < zero(T) + + # round-trip: every canonical bit pattern goes out to BFloat16 and back + used_mask = Microfloats.sign_mask(T) | Microfloats.exponent_mask(T) | Microfloats.significand_mask(T) + @testset "round-trip" for u in 0x00:0xff + (u & ~used_mask) != 0x00 && continue + x = reinterpret(T, u) + isnan(x) && continue + @test T(BFloat16(x)) ≡ x + @test T(Float32(x)) ≡ x + end + end +end + +@testset "NonFiniteBehavior trait" begin + @test hasinf(IEEE) && hasnan(IEEE) + @test !hasinf(NanOnlyAllOnes) && hasnan(NanOnlyAllOnes) + @test !hasinf(FiniteOnly) && !hasnan(FiniteOnly) + + @test non_finite_behavior(Float8_E5M2) === IEEE + @test non_finite_behavior(Float8_E4M3FN) === NanOnlyAllOnes + @test non_finite_behavior(Float8_E8M0FNU) === NanOnlyAllOnes + @test non_finite_behavior(Float4_E2M1FN) === FiniteOnly + + # Forgetting `non_finite_behavior` on a custom type errors loudly. + primitive type _BadFloat <: Microfloats.Microfloat{1,2,1} 8 end + @test_throws ErrorException Microfloats.non_finite_behavior(_BadFloat) +end + +@testset "IEEE types: Inf and NaN encodings" begin + @testset "$T" for T in (Float8_E3M4, Float8_E5M2) + @test isinf(inf(T)) + @test !isnan(inf(T)) + @test isnan(nan(T)) + @test !isinf(nan(T)) + @test isinf(T(Inf)) + @test isinf(T(-Inf)) + @test isnan(T(NaN)) + end +end + +@testset "NanOnlyAllOnes types: no Inf" begin + @testset "$T" for T in (Float8_E4M3FN, Float8_E8M0FNU) + for u in 0x00:0xff + @test !isinf(reinterpret(T, u)) + end + @test_throws DomainError inf(T) + @test isnan(nan(T)) + end +end + +@testset "FiniteOnly types: no Inf or NaN" begin + @testset "$T" for T in (Float4_E2M1FN, Float6_E2M3FN, Float6_E3M2FN) + for u in 0x00:0xff + x = reinterpret(T, u) + @test !isinf(x) + @test !isnan(x) + end + @test_throws DomainError inf(T) + @test_throws DomainError nan(T) + end +end + +@testset "Unsigned microfloats" begin + @test_throws DomainError Float8_E8M0FNU(-1.0) + @test_throws DomainError Float8_E8M0FNU(-0.0) + @test_throws DomainError -one(Float8_E8M0FNU) + @test_throws ArgumentError randn(Float8_E8M0FNU) + + # Positive round-trip through BF16 is lossless for powers of 2 (the only E8M0 values) + @test Float32(Float8_E8M0FNU(1.0)) == 1.0 + @test Float32(Float8_E8M0FNU(2.0)) == 2.0 + @test Float32(Float8_E8M0FNU(0.5)) == 0.5 + @test Float32(reinterpret(Float8_E8M0FNU, 0x00)) == 2f0^-127 + @test !iszero(reinterpret(Float8_E8M0FNU, 0x00)) +end + +@testset "rand / randn" begin + rng = MersenneTwister(123) + @testset "$T rand" for T in TYPES + xs = rand(rng, T, 1000) + @test all(isfinite, xs) + @test any(x -> x != zero(T), xs) + end + @testset "$T randn" for T in SIGNED_TYPES + xs = randn(rng, T, 1000) + @test all(isfinite, xs) + @test any(x -> x != zero(T), xs) + end +end + +@testset "Cross-microfloat arithmetic is unsupported" begin + a = Float8_E4M3FN(1.0) + b = Float8_E5M2(1.0) + # No cross-microfloat promote_rule → Julia's promotion machinery errors. + @test_throws ErrorException a + b + @test_throws ErrorException a * b + # Same-type still works. + @test a + a == Float8_E4M3FN(2.0) +end diff --git a/test/dlfp8_parity.jl b/test/dlfp8_parity.jl new file mode 100644 index 0000000..8818e29 --- /dev/null +++ b/test/dlfp8_parity.jl @@ -0,0 +1,28 @@ +import DLFP8Types + +@testset "DLFP8Types.jl parity" begin + # DLFP8Types.Float8_E4M3FN and .Float8_E5M2 use the same bit layout and + # non-finite semantics as ours. Every bit pattern must agree on the + # Float32 value (NaN encodings and signed zeros are allowed to differ + # between the two packages — the semantic float value is what matters). + @testset "$(M_T)" for (M_T, D_T) in ( + (Microfloats.Float8_E4M3FN, DLFP8Types.Float8_E4M3FN), + (Microfloats.Float8_E5M2, DLFP8Types.Float8_E5M2), + ) + for i in 0x00:0xff + mx = reinterpret(M_T, i) + dx = reinterpret(D_T, i) + @test Float32(mx) ≡ Float32(dx) + @test isinf(mx) == isinf(dx) + @test isnan(mx) == isnan(dx) + @test iszero(mx) == iszero(dx) + end + # Float32 → narrow goes to the same semantic value. + # DLFP8Types uses strict spec semantics (overflow → NaN for E4M3FN, + # overflow → Inf for E5M2), which is our `OVF` policy. + for f in Float32[0.0, -0.0, 1.0, -1.0, 3.5, -3.5, 448.0, 1000.0, + 1.5f-5, 2.0f-7, NaN32, Inf32, -Inf32] + @test Float32(M_T(f, OVF)) ≡ Float32(D_T(f)) + end + end +end diff --git a/test/MX/MX_compliance.jl b/test/mx_compliance.jl similarity index 98% rename from test/MX/MX_compliance.jl rename to test/mx_compliance.jl index db97c03..bbfb2e3 100644 --- a/test/MX/MX_compliance.jl +++ b/test/mx_compliance.jl @@ -166,7 +166,8 @@ @test isnan(reinterpret(MX_E8M0, 0b11111111)) - @test iszero(reinterpret(MX_E8M0, 0b00000000)) + @test !iszero(reinterpret(MX_E8M0, 0b00000000)) + @test Float32(reinterpret(MX_E8M0, 0b00000000)) == 2f0^-127 end end diff --git a/test/MX/MX_properties.jl b/test/mx_properties.jl similarity index 76% rename from test/MX/MX_properties.jl rename to test/mx_properties.jl index 8c513e7..f00c939 100644 --- a/test/MX/MX_properties.jl +++ b/test/mx_properties.jl @@ -1,7 +1,7 @@ @testset "MX: no Infs" begin for T in (MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0) @testset "$T no isinf()" begin - for u in UInt8(0):UInt8(0xff) + for u in 0x00:0xff @test !isinf(reinterpret(T, u)) end end @@ -13,12 +13,12 @@ end @testset "E4M3" begin T = MX_E4M3 em = UInt8(Microfloats.exponent_mask(T)) - mm = UInt8(Base.significand_mask(T)) + mm = UInt8(Microfloats.significand_mask(T)) sm = UInt8(Microfloats.sign_mask(T)) - nm = Base.significand_bits(T) + nm = Microfloats.significand_bits(T) maxm = UInt8((UInt16(1) << nm) - 1) - for s in (UInt8(0), sm) - for mv in UInt8(0):maxm + for s in (0x00, sm) + for mv in 0x00:maxm m = mv & mm x = reinterpret(T, (s & sm) | em | m) if m == mm @@ -33,13 +33,13 @@ end # E3M2/E2M3/E2M1: exp=all-ones are finite; no NaN sentinel for T in (MX_E3M2, MX_E2M3, MX_E2M1) @testset "$T exp=all-ones finite" begin - em = UInt8(Base.exponent_mask(T)) + em = UInt8(Microfloats.exponent_mask(T)) sm = UInt8(Microfloats.sign_mask(T)) - nm = Base.significand_bits(T) - mm = UInt8(Base.significand_mask(T)) + nm = Microfloats.significand_bits(T) + mm = UInt8(Microfloats.significand_mask(T)) maxm = UInt8((UInt16(1) << nm) - 1) - for s in (UInt8(0), sm) - for mv in UInt8(0):maxm + for s in (0x00, sm) + for mv in 0x00:maxm m = mv & mm x = reinterpret(T, (s & sm) | em | m) @test isfinite(x) @@ -61,7 +61,7 @@ end @testset "MX: round-trip via Float32 preserves bits (canonical encodings)" begin for T in (MX_E4M3, MX_E5M2, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0) @testset "$T" begin - used_mask = Microfloats.sign_mask(T) | Base.exponent_mask(T) | Base.significand_mask(T) + used_mask = Microfloats.sign_mask(T) | Microfloats.exponent_mask(T) | Microfloats.significand_mask(T) for u in 0x00:0xff (u & ~used_mask) != 0x00 && continue x = reinterpret(T, u) @@ -76,24 +76,25 @@ end for T in (MX_E5M2, MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0) @testset "$T" begin fmax = floatmax(T) - # +Inf/-Inf map to ±floatmax (unsigned maps both to +floatmax) + # +Inf maps to +floatmax under SAT. @test T(Inf32, SAT) == fmax + # Negative input: signed types saturate to -fmax; unsigned throws. if Microfloats.sign_bits(T) == 0 - @test T(-Inf32, SAT) == fmax + @test_throws DomainError T(-Inf32, SAT) else @test T(-Inf32, SAT) == -fmax end - # NaN maps to sentinel for E4M3/E8M0, else saturates to floatmax - if T <: Union{MX_E4M3, MX_E5M2, MX_E8M0} + # NaN input: types with NaN encoding return NaN; FiniteOnly throws. + if hasnan(T) @test isnan(T(NaN32)) else @test_throws DomainError T(NaN32) end - # Values just beyond floatmax saturate + # Values just beyond floatmax saturate (positive side). big = nextfloat(Float32(fmax)) @test T(big, SAT) == fmax if Microfloats.sign_bits(T) == 0 - @test T(-big, SAT) == fmax + @test_throws DomainError T(-big, SAT) else @test T(-big, SAT) == -fmax end @@ -106,7 +107,7 @@ end @testset "$T subnormal min value" begin if Microfloats.significand_bits(T) > 0 x = reinterpret(T, 0x01) - expected = Float32(2.0)^(1 - Base.exponent_bias(T) - Base.significand_bits(T)) + expected = Float32(2.0)^(1 - Microfloats.exponent_bias(T) - Microfloats.significand_bits(T)) @test Float32(x) == expected end end @@ -123,8 +124,9 @@ end end end end - @testset "E8M0 zero" begin - @test iszero(reinterpret(MX_E8M0, 0x00)) + @testset "E8M0 minimum (no zero)" begin + @test !iszero(reinterpret(MX_E8M0, 0x00)) + @test Float32(reinterpret(MX_E8M0, 0x00)) == 2f0^-127 end end @@ -140,7 +142,7 @@ end end @testset "E4M3 NaN equality" begin T = MX_E4M3 - x = reinterpret(T, Base.exponent_mask(T) | Base.significand_mask(T)) + x = reinterpret(T, Microfloats.exponent_mask(T) | Microfloats.significand_mask(T)) @test isnan(x) @test !(x == x) end @@ -156,11 +158,11 @@ end for T in (MX_E4M3, MX_E5M2, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0) @testset "$T" begin vals = Tuple{UInt8,Float32,Any}[] - for u in UInt8(0):UInt8(0xff) + for u in 0x00:0xff x = reinterpret(T, u) isnan(x) && continue # Only include canonical encodings: padding bits outside fields are zero - used_mask = UInt8(Microfloats.sign_mask(T) | Base.exponent_mask(T) | Base.significand_mask(T)) + used_mask = UInt8(Microfloats.sign_mask(T) | Microfloats.exponent_mask(T) | Microfloats.significand_mask(T)) (u & ~used_mask) != 0x00 && continue push!(vals, (u, Float32(x), x)) end diff --git a/test/overflow.jl b/test/overflow.jl index dd71cc7..513e265 100644 --- a/test/overflow.jl +++ b/test/overflow.jl @@ -1,70 +1,70 @@ -@testset "Overflow" begin +@testset "Overflow: IEEE types" begin + # Default: OVF → Inf for finite overflow, NaN input → NaN + @testset "$T" for T in (Float8_E3M4, Float8_E5M2) + @test default_overflow_policy(T) === OVF - @testset "Has Inf+NaN" begin - @testset for T in ( - Float8_E5M2, Float8_E4M3, Float8_E3M4, Float6_E3M2, Float6_E2M3, Float4_E2M1, - MX_E5M2, - ) - @test T(NaN, SAT) |> isnan - @test T(NaN, OVF) |> isnan + @test isnan(T(NaN)) + @test T(+Inf) == +inf(T) + @test T(-Inf) == -inf(T) - @test T(+Inf, SAT) == +floatmax(T) - @test T(-Inf, SAT) == -floatmax(T) - @test T(+Inf, OVF) == +Inf - @test T(-Inf, OVF) == -Inf - - greater_than_floatmax = nextfloat(BFloat16(floatmax(T))) - @test T(+greater_than_floatmax, SAT) == +floatmax(T) - @test T(-greater_than_floatmax, SAT) == -floatmax(T) - @test T(+greater_than_floatmax, OVF) == +Inf - @test T(-greater_than_floatmax, OVF) == -Inf - end + big = nextfloat(BFloat16(floatmax(T))) + @test T(+big) == +inf(T) # default OVF + @test T(-big) == -inf(T) + @test T(+big, SAT) == +floatmax(T) + @test T(-big, SAT) == -floatmax(T) end +end - @testset "Has NaN" begin - @testset for T in ( - MX_E4M3, MX_E8M0, - ) - @test T(NaN, SAT) |> isnan - @test T(NaN, OVF) |> isnan +@testset "Overflow: NanOnlyAllOnes types" begin + # Default: SAT → floatmax for finite overflow (matches PyTorch/Triton/Quartet) + @testset "$T" for T in (Float8_E4M3FN,) + @test default_overflow_policy(T) === SAT - @test T(+Inf, SAT) == +floatmax(T) - @test T(-Inf, SAT) == -floatmax(T) - @test T(+Inf, OVF) |> isnan - @test T(-Inf, OVF) |> isnan + @test isnan(T(NaN)) + @test T(+Inf) == +floatmax(T) + @test T(-Inf) == -floatmax(T) - greater_than_floatmax = nextfloat(BFloat16(floatmax(T))) - @test T(+greater_than_floatmax, SAT) == +floatmax(T) - @test T(-greater_than_floatmax, SAT) == -floatmax(T) - @test T(+greater_than_floatmax, OVF) |> isnan - @test T(-greater_than_floatmax, OVF) |> isnan - end + big = nextfloat(BFloat16(floatmax(T))) + @test T(+big) == +floatmax(T) # default SAT + @test T(-big) == -floatmax(T) + @test isnan(T(+big, OVF)) + @test isnan(T(-big, OVF)) end - @testset "Finite" begin - @testset for T in ( - MX_E3M2, MX_E2M3, MX_E2M1, - ) + # Unsigned NanOnlyAllOnes (E8M0FNU): negative input throws, large positive saturates + T = Float8_E8M0FNU + @test default_overflow_policy(T) === SAT + @test T(+Inf) == floatmax(T) + @test_throws DomainError T(-Inf) + big = nextfloat(BFloat16(floatmax(T))) + @test T(big) == floatmax(T) + @test isnan(T(big, OVF)) + @test isnan(T(NaN)) +end - @test_throws DomainError T(NaN, SAT) - @test_throws DomainError T(NaN, OVF) +@testset "Overflow: FiniteOnly types" begin + # Default: SAT. NaN input always throws (no sentinel). OVF also throws on overflow. + @testset "$T" for T in (Float4_E2M1FN, Float6_E2M3FN, Float6_E3M2FN) + @test default_overflow_policy(T) === SAT - @test T(+Inf, SAT) == +floatmax(T) - @test T(-Inf, SAT) == -floatmax(T) - @test_throws DomainError T(+Inf, OVF) - @test_throws DomainError T(-Inf, OVF) + @test_throws DomainError T(NaN) + @test_throws DomainError T(NaN, OVF) + @test_throws DomainError T(NaN, SAT) - greater_than_floatmax = nextfloat(BFloat16(floatmax(T))) - @test T(+greater_than_floatmax, SAT) == +floatmax(T) - @test T(-greater_than_floatmax, SAT) == -floatmax(T) - @test_throws DomainError T(+greater_than_floatmax, OVF) - @test_throws DomainError T(-greater_than_floatmax, OVF) - end + @test T(+Inf) == +floatmax(T) + @test T(-Inf) == -floatmax(T) + @test_throws DomainError T(+Inf, OVF) + @test_throws DomainError T(-Inf, OVF) - @test MX_E2M1(6, SAT) == 6 - @test MX_E2M1(6, OVF) == 6 - @test MX_E2M1(7, SAT) == 6 - @test_throws DomainError MX_E2M1(7, OVF) + big = nextfloat(BFloat16(floatmax(T))) + @test T(+big) == +floatmax(T) + @test T(-big) == -floatmax(T) + @test_throws DomainError T(+big, OVF) + @test_throws DomainError T(-big, OVF) end -end \ No newline at end of file + # Specific: Float4_E2M1FN (NVFP4 value type) saturates at ±6 + @test Float32(Float4_E2M1FN(6.0)) == 6.0 + @test Float32(Float4_E2M1FN(7.0)) == 6.0 # SAT + @test_throws DomainError Float4_E2M1FN(7.0, OVF) +end diff --git a/test/runtests.jl b/test/runtests.jl index 401d1a1..ca940cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,24 +1,52 @@ using Microfloats +using Microfloats: non_finite_behavior, hasinf, hasnan, inf, nan, + sign_bits, bitwidth, default_overflow_policy using Test +using Random -using BitPacking +# NaN-aware bit-identity for tests. Overrides `===` for Microfloat pairs so +# round-trip tests can treat distinct NaN encodings as equivalent without +# losing the bit-identity semantics for signed zeros. +≡(a, b) = isnan(a) || isnan(b) ? isnan(a) && isnan(b) : a == b -a ≡ b = isnan(a) || isnan(b) ? true : a == b +# variants used only for tests +@microfloat UFloat7_E3M4 sign=0 exponent=3 significand=4 +@microfloat UFloat7_E4M3 sign=0 exponent=4 significand=3 +@microfloat UFloat7_E5M2 sign=0 exponent=5 significand=2 +@microfloat UFloat7_E4M3FN sign=0 exponent=4 significand=3 nonfinite=NanOnlyAllOnes +@microfloat UFloat5_E2M3 sign=0 exponent=2 significand=3 nonfinite=FiniteOnly +@microfloat UFloat5_E3M2 sign=0 exponent=3 significand=2 nonfinite=FiniteOnly +@microfloat UFloat3_E2M1 sign=0 exponent=2 significand=1 nonfinite=FiniteOnly -@testset "Microfloats" begin +const SIGNED_TYPES = ( + Float8_E3M4, Float8_E4M3, Float8_E5M2, + Float8_E4M3FN, + Float6_E2M3FN, Float6_E3M2FN, + Float4_E2M1FN, +) - include("Microfloat.jl") - include("overflow.jl") +const UNSIGNED_TYPES = ( + Float8_E8M0FNU, + UFloat7_E3M4, UFloat7_E4M3, UFloat7_E5M2, + UFloat7_E4M3FN, + UFloat5_E2M3, UFloat5_E3M2, + UFloat3_E2M1, +) - include("Float8s/runtests.jl") - include("MX/runtests.jl") +const TYPES = (SIGNED_TYPES..., UNSIGNED_TYPES...) - @testset "BitPackingExt" begin - x = randn(Float4_E2M1, 16) - y = bitpacked(x) - @test y isa BitPackedArray - @test x == y - @test x == bitunpacked(y) - end +# OCP Microscaling Formats v1.0 aliases +const MX_E5M2 = Float8_E5M2 +const MX_E4M3 = Float8_E4M3FN +const MX_E3M2 = Float6_E3M2FN +const MX_E2M3 = Float6_E2M3FN +const MX_E2M1 = Float4_E2M1FN +const MX_E8M0 = Float8_E8M0FNU +@testset "Microfloats" begin + include("basic.jl") + include("overflow.jl") + include("mx_compliance.jl") + include("mx_properties.jl") + include("dlfp8_parity.jl") end From 5f110fe8b1ba6b0b2e90a2653a09650c80dfd44b Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 16:39:17 +0200 Subject: [PATCH 2/8] fix sign of zero getting lost in conversion --- src/conversion.jl | 2 +- test/basic.jl | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/conversion.jl b/src/conversion.jl index b2e1949..03ed7e7 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -70,7 +70,7 @@ function (::Type{T})(x::BFloat16, ::Type{P}=default_overflow_policy(T)) where {T if sign_bits(T) == 0 && signbit(x) throw(DomainError(x, "negative input to unsigned $T")) end - iszero(x) && return zero(T) + iszero(x) && return signbit(x) ? -zero(T) : zero(T) bf16_exp = Int((reinterpret(Unsigned, x) >> 7) & 0x00ff) bf16_frac = UInt16(reinterpret(Unsigned, x) & 0x007f) diff --git a/test/basic.jl b/test/basic.jl index 77866b2..58396d7 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -85,6 +85,19 @@ end end end +@testset "Signed zero preservation" begin + for T in SIGNED_TYPES + @testset "$T" begin + nz = T(-0.0) + @test iszero(nz) + @test signbit(nz) + @test Float32(nz) === -0.0f0 + @test nz == zero(T) + @test signbit(nz) != signbit(zero(T)) + end + end +end + @testset "Unsigned microfloats" begin @test_throws DomainError Float8_E8M0FNU(-1.0) @test_throws DomainError Float8_E8M0FNU(-0.0) From 3329cba5132eb1e32676127a6e2c05a436b05308 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 21:09:39 +0200 Subject: [PATCH 3/8] rework documentation --- README.md | 28 +++++++++++-- docs/make.jl | 4 +- docs/src/assets/icon.svg | 37 ----------------- docs/src/conversion.md | 82 ------------------------------------ docs/src/index.md | 2 +- docs/src/microfloat.md | 41 ++++++++++++++++++ docs/src/microfloats.md | 39 ----------------- docs/src/predefined.md | 25 +++++++++++ src/Microfloat.jl | 33 +++------------ src/Microfloats.jl | 52 ++++++----------------- src/conversion.jl | 90 ++++++++++++++++++++-------------------- src/macro.jl | 85 +++++++++++++++++++++++++++++++++++++ src/macros.jl | 63 ---------------------------- src/utils.jl | 39 +++++++++++++++++ src/variants.jl | 25 +++++++++++ test/dlfp8_parity.jl | 7 ++-- test/mx_properties.jl | 67 ++++++++++++++++++------------ test/overflow.jl | 80 +++++++++++++++++++---------------- test/runtests.jl | 10 ++++- 19 files changed, 401 insertions(+), 408 deletions(-) delete mode 100644 docs/src/assets/icon.svg delete mode 100644 docs/src/conversion.md create mode 100644 docs/src/microfloat.md delete mode 100644 docs/src/microfloats.md create mode 100644 docs/src/predefined.md create mode 100644 src/macro.jl delete mode 100644 src/macros.jl create mode 100644 src/utils.jl diff --git a/README.md b/README.md index c489854..1cceddc 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Microfloats +# Microfloats [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Microfloats.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Microfloats.jl/dev/) @@ -28,10 +28,30 @@ Microfloats.non_finite_behavior(::Type{MyE5M2}) = IEEE ## Overflow policy -`SAT` saturates out-of-range values to `±floatmax(T)`. `OVF` uses the type's -sentinel (`±Inf` for IEEE, `NaN` for NanOnlyAllOnes; throws for FiniteOnly). +Each `Microfloat` type carries its overflow policy as a compile-time trait, +baked in at declaration via `@microfloat`. There is no runtime override — +to convert with different semantics, declare a second type with the same +bit layout and the alternate policy, and `reinterpret` between them. -For INT8, see `FixedPointNumbers.Q1f6`. +- `SAT` saturates out-of-range finite values to `±floatmax(T)`. +- `OVF` maps them to the type's sentinel: `±Inf` for `IEEE`, `NaN` for + `NanOnlyAllOnes` (unavailable for `FiniteOnly`). + +Default rule: `OVF` if the type has any non-finite sentinel, else `SAT` +(forced for `FiniteOnly`). This matches cutile-python / OCP strict +semantics. The shipped policies: + +| Type | NonFiniteBehavior | Overflow | +| ------------------------------------------------- | ----------------- | -------- | +| `Float8_E5M2`, `Float8_E4M3`, `Float8_E3M4` | `IEEE` | `OVF` | +| `Float8_E4M3FN`, `Float8_E8M0FNU` | `NanOnlyAllOnes` | `OVF` | +| `Float6_E2M3FN`, `Float6_E3M2FN`, `Float4_E2M1FN` | `FiniteOnly` | `SAT` | + +For PyTorch/Triton-style saturating E4M3FN, declare a twin: + +```julia +@microfloat Float8_E4M3FN_SAT exponent=4 significand=3 nonfinite=NanOnlyAllOnes overflow=SAT +``` ## Installation diff --git a/docs/make.jl b/docs/make.jl index e2ba713..8dd00e4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -14,8 +14,8 @@ makedocs(; ), pages=[ "Home" => "index.md", - "Microfloats" => "microfloats.md", - "Conversion" => "conversion.md", + "The Microfloat type" => "microfloat.md", + "Predefined types" => "predefined.md", ], ) diff --git a/docs/src/assets/icon.svg b/docs/src/assets/icon.svg deleted file mode 100644 index 903f936..0000000 --- a/docs/src/assets/icon.svg +++ /dev/null @@ -1,37 +0,0 @@ - - - - - - Microfloats.jl mark - - Two base circles (red left, purple right), a symmetric blue sine wave above them, - and a small (micro) green circle floating above the wave’s center crest. - - - - - - - - - - - - - \ No newline at end of file diff --git a/docs/src/conversion.md b/docs/src/conversion.md deleted file mode 100644 index 02d6bcb..0000000 --- a/docs/src/conversion.md +++ /dev/null @@ -1,82 +0,0 @@ - -# Conversion - -## BFloat16 - -Conversion to and from `Microfloat` uses `BFloat16` as an intermediate type, -since BFloat16 has 1 sign bit, 8 exponent bits, and 7 significand (mantissa) bits, -and is therefore able to represent all `Microfloat` types. - -## Rounding - -Converting from larger types will round to the nearest even value, i.e. -the value whose bit representation ends in 0. - -## Overflow policies - -When converting from a wider type to a `Microfloat`, one may want certain behaviors -in regard to Inf and NaN handling. - -```@raw html -
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Source ValueDestination Value
Has Inf+NaNHas NaNFinite
SATOVFSATOVFSATOVF
NaNNaNNaNNaNNaNErrorError
±Inf±floatmax±Inf±floatmaxNaN±floatmaxError
>|floatmax|±floatmax±Inf±floatmaxNaN±floatmaxError
-
-
-``` - -```@docs -OVF -SAT -``` \ No newline at end of file diff --git a/docs/src/index.md b/docs/src/index.md index fcd54c1..9d4f465 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -7,7 +7,7 @@ CurrentModule = Microfloats ## Contents ```@contents -Pages = ["index.md", "microfloats.md", "conversion.md"] +Pages = ["index.md", "microfloat.md", "predefined.md"] ``` ## Index diff --git a/docs/src/microfloat.md b/docs/src/microfloat.md new file mode 100644 index 0000000..daf779c --- /dev/null +++ b/docs/src/microfloat.md @@ -0,0 +1,41 @@ +# The Microfloat type + +```@docs +Microfloat +``` + +## Defining a new Microfloat + +```@docs +@microfloat +``` + +## API + +### Non-Finite Behavior + +```@docs +Microfloats.hasinf +Microfloats.hasnan +Microfloats.non_finite_behavior +Microfloats.IEEE +Microfloats.NanOnlyAllOnes +Microfloats.FiniteOnly +``` + +### Overflow policies + +```@docs +Microfloats.overflow_policy +Microfloats.OVF +Microfloats.SAT +``` + +### Reflection + +```@docs +Microfloats.bitwidth +Microfloats.sign_bits +Microfloats.exponent_bits +Microfloats.significand_bits +``` diff --git a/docs/src/microfloats.md b/docs/src/microfloats.md deleted file mode 100644 index 3722b3d..0000000 --- a/docs/src/microfloats.md +++ /dev/null @@ -1,39 +0,0 @@ -# Microfloat - -```@docs -Microfloat -``` - -## Finite - -```@docs -Finite -``` - -## IEEE 754-like - -These types have IEEE 754-like Inf/NaN encodings, with Inf being represented as all 1s in the exponent and a significand of zero, and NaN being represented as all 1s in the exponent and a non-zero significand. - -```@docs -Float8_E4M3 -Float8_E5M2 -Float6_E2M3 -Float6_E3M2 -Float4_E2M1 -``` - -## Microscaling (MX) - -Types from [Open Compute Project Microscaling Formats (MX) Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), with `MX_E5M2` adhering to the IEEE 754-like encoding of Inf/NaN, -whereas MX_E4M3 and MX_E8M0 have no Inf, and only one representation of NaN (excluding the sign bit), -and the finite types MX_E3M2, MX_E2M3, and MX_E2M1 which have no Inf or NaNs. - -```@docs -MX -MX_E5M2 -MX_E4M3 -MX_E3M2 -MX_E2M3 -MX_E2M1 -MX_E8M0 -``` diff --git a/docs/src/predefined.md b/docs/src/predefined.md new file mode 100644 index 0000000..2b982de --- /dev/null +++ b/docs/src/predefined.md @@ -0,0 +1,25 @@ +## Predefined types + +Microfloats defines and exports a set of common types. + +### IEEE-like + +These types have IEEE 754-like Inf/NaN encodings, with Inf being represented as all 1s in the exponent and a significand of zero, and NaN being represented as all 1s in the exponent and a non-zero significand. + +```@docs +Float8_E5M2 +Float8_E4M3 +Float8_E3M4 +``` + +### Finite + +These types have no Inf encoding, with alternate or no NaN encodings at all. + +```@docs +Float8_E4M3FN +Float8_E8M0FNU +Float6_E3M2FN +Float6_E2M3FN +Float4_E2M1FN +``` diff --git a/src/Microfloat.jl b/src/Microfloat.jl index 5ed2d97..594a985 100644 --- a/src/Microfloat.jl +++ b/src/Microfloat.jl @@ -3,13 +3,9 @@ """ Microfloat{S,E,M} <: AbstractFloat -Abstract type for within-byte floating-point numbers with `S` sign bits -(0 or 1), `E` exponent bits (≥ 1), and `M` significand bits (≥ 0), -with `S + E + M ≤ 8`. +Abstract type for floating-point numbers that fit within a single byte. -Concrete subtypes are 8-bit `primitive type`s that must also register -a [`non_finite_behavior`](@ref). See [`@microfloat`](@ref) for the -macro-based convenience declaration. +See [`@microfloat`](@ref) for type declaration. """ abstract type Microfloat{S,E,M} <: AbstractFloat end @@ -32,25 +28,9 @@ function Base.show(io::IO, x::T) where T<:Microfloat return nothing end -""" - NonFiniteBehavior - -Trait hierarchy describing how a [`Microfloat`](@ref) type encodes non-finite -values. Each concrete `Microfloat` subtype registers its behavior by defining -a [`non_finite_behavior`](@ref) method. - -Three behaviors: - -- [`IEEE`](@ref): exponent all-ones with zero significand ⇒ Inf; - all-ones exponent with nonzero significand ⇒ NaN. -- [`NanOnlyAllOnes`](@ref): no Inf. The single NaN encoding has all - exponent and significand bits set. -- [`FiniteOnly`](@ref): no Inf and no NaN — every bit pattern is finite. - Matches MX sub-byte types and `F4E2M1FN`. -""" abstract type NonFiniteBehavior end -"""IEEE-754-style encoding of Inf and NaN. Requires `M ≥ 1`.""" +"""IEEE-754-style encoding of Inf and NaN.""" abstract type IEEE <: NonFiniteBehavior end """NaN encoded as all-ones in exponent+significand; no Inf.""" @@ -68,10 +48,9 @@ hasnan(::Type{NanOnlyAllOnes}) = true hasnan(::Type{FiniteOnly}) = false """ - non_finite_behavior(T) -> Type{<:NonFiniteBehavior} + non_finite_behavior(::Type{<:Microfloat}) -> Type{<:NonFiniteBehavior} -Required trait method on every concrete [`Microfloat`](@ref) subtype. -Returns one of `IEEE`, `NanOnlyAllOnes`, or `FiniteOnly`. +Return `IEEE`, `NanOnlyAllOnes`, or `FiniteOnly` based on the assigned trait. """ non_finite_behavior(::Type{T}) where T<:Microfloat = error("$T must define `Microfloats.non_finite_behavior(::Type{$T})`") @@ -84,10 +63,8 @@ hasnan(::Type{T}) where T<:Microfloat = hasnan(non_finite_behavior(T)) Base.isinf(x::T) where T<:Microfloat = _isinf(non_finite_behavior(T), x) Base.isnan(x::T) where T<:Microfloat = _isnan(non_finite_behavior(T), x) -"""Bit pattern for +Inf. Throws if the type has no Inf.""" inf(::Type{T}) where T<:Microfloat = _inf(non_finite_behavior(T), T) -"""Bit pattern for NaN. Throws if the type has no NaN.""" nan(::Type{T}) where T<:Microfloat = _nan(non_finite_behavior(T), T) Base.floatmax(::Type{T}) where T<:Microfloat = _floatmax(non_finite_behavior(T), T) diff --git a/src/Microfloats.jl b/src/Microfloats.jl index f3574fa..82b09e9 100644 --- a/src/Microfloats.jl +++ b/src/Microfloats.jl @@ -2,62 +2,34 @@ module Microfloats using Republic -@reexport import BFloat16s: BFloat16 +import BFloat16s: BFloat16 -float_bits(::Type{Float64}) = (1, 11, 52) -float_bits(::Type{Float32}) = (1, 8, 23) -float_bits(::Type{Float16}) = (1, 5, 10) -float_bits(::Type{BFloat16}) = (1, 8, 7) - -bitwidth(::Type{T}) where T<:AbstractFloat = sum(float_bits(T)) -sign_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[1] -exponent_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[2] -significand_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[3] - -exponent_bias(::Type{T}) where T<:AbstractFloat = 2^(exponent_bits(T) - 1) - 1 - -@public bitwidth, sign_bits, exponent_bits, significand_bits +include("utils.jl") +@public bitwidth +@public sign_bits, exponent_bits, significand_bits include("Microfloat.jl") export Microfloat -@public sign_mask, exponent_mask, significand_mask -@public NonFiniteBehavior, non_finite_behavior, hasinf, hasnan, inf, nan -export IEEE, NanOnlyAllOnes, FiniteOnly +@public hasinf, hasnan +@public non_finite_behavior +@public IEEE, NanOnlyAllOnes, FiniteOnly include("conversion.jl") -export OverflowPolicy, SAT, OVF -@public default_overflow_policy +@public overflow_policy +@public SAT, OVF -include("macros.jl") -export @microfloat +include("macro.jl") +@public @microfloat # Each `@microfloat` call builds a per-type BFloat16 lookup table, # so conversion.jl must be loaded before this point. include("variants.jl") export Float8_E5M2, Float8_E4M3, Float8_E3M4 -export Float8_E4M3FN, Float8_E5M2, Float8_E8M0FNU +export Float8_E4M3FN, Float8_E8M0FNU export Float6_E2M3FN, Float6_E3M2FN export Float4_E2M1FN include("ops.jl") include("random.jl") -for T in ( - :Float8_E4M3FN, :Float8_E5M2, :Float8_E8M0FNU, - :Float6_E2M3FN, :Float6_E3M2FN, - :Float4_E2M1FN, -) - @eval @doc """ - $($T) - - ## Properties - - Bits: `$(sign_bits($T))` sign + `$(exponent_bits($T))` exponent + `$(significand_bits($T))` significand (`$(bitwidth($T))` total) - - Non-finite behavior: `$(non_finite_behavior($T))` - - Has Inf: `$(hasinf($T))` - - Has NaN: `$(hasnan($T))` - - Max normal: `$(Float64(floatmax($T)))` - - Min positive: `$(Float64(floatmin($T)))` - """ $T -end - end diff --git a/src/conversion.jl b/src/conversion.jl index 03ed7e7..8ba0d87 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -1,22 +1,37 @@ +abstract type OverflowPolicy end + """ - OverflowPolicy + OVF -Policy controlling how out-of-range finite input is mapped to a [`Microfloat`](@ref). +Sentinel overflow: out-of-range finite inputs go to `±Inf` (IEEE) or `NaN` (NanOnlyAllOnes). -- [`SAT`](@ref): saturate to `±floatmax(T)`. -- [`OVF`](@ref): overflow to the type's sentinel (`±Inf` for `IEEE`, - `NaN` for `NanOnlyAllOnes`; throws for `FiniteOnly` — no sentinel exists). +| Input Condition | `T` has Inf+NaN | `T` has NaN | `T` is finite | +| ---------------------- | --------------- | ------------- | ------------- | +| `isnan(x)` | NaN | NaN | Error | +| `abs(x) > floatmax(T)` | ±Inf | NaN | Error | +""" +abstract type OVF <: OverflowPolicy end -Defaults are resolved by [`default_overflow_policy`](@ref) from the type's -[`non_finite_behavior`](@ref): `IEEE` → `OVF`, otherwise `SAT`. """ -abstract type OverflowPolicy end + SAT + +Saturating conversion: out-of-range finite inputs clamp to `±floatmax(T)`. -"""Saturating conversion: out-of-range finite inputs clamp to `±floatmax(T)`.""" +| Input Condition | `T` has Inf+NaN | `T` has NaN | `T` is finite | +| ---------------------- | --------------- | ------------- | ------------- | +| `isnan(x)` | NaN | NaN | Error | +| `abs(x) > floatmax(T)` | ±floatmax | ±floatmax | ±floatmax | +""" abstract type SAT <: OverflowPolicy end -"""Sentinel overflow: out-of-range finite inputs go to `±Inf` (IEEE) or `NaN` (NanOnlyAllOnes).""" -abstract type OVF <: OverflowPolicy end +""" + overflow_policy(T) -> Type{<:OverflowPolicy} + +Required trait method on every concrete [`Microfloat`](@ref) subtype. +Returns [`OVF`](@ref) or [`SAT`](@ref). Registered by [`@microfloat`](@ref). +""" +overflow_policy(::Type{T}) where T<:Microfloat = + error("$T must define `Microfloats.overflow_policy(::Type{$T})`") function rshift_round_to_even(x::UInt16, n::Int) n <= 0 && return x >> n @@ -32,41 +47,29 @@ is_outside_floatmax(xb::BFloat16, ::Type{T}) where T<:Microfloat = clamp_floatmax(x::T) where T<:Microfloat = signbit(x) ? -floatmax(T) : floatmax(T) clamp_inf(x::T) where T<:Microfloat = signbit(x) ? -inf(T) : inf(T) -function epilogue(x::T, xb::BFloat16, ::Type{P}) where {T<:Microfloat,P<:OverflowPolicy} - if P <: SAT - if isnan(xb) - return hasnan(T) ? nan(T) : throw(DomainError(xb, "$T has no NaN")) - elseif isinf(xb) || is_outside_floatmax(xb, T) - return clamp_floatmax(x) - else - return x - end - elseif P <: OVF - if isnan(xb) - return hasnan(T) ? nan(T) : throw(DomainError(xb, "$T has no NaN")) - elseif isinf(xb) || is_outside_floatmax(xb, T) - return hasinf(T) ? clamp_inf(x) : - hasnan(T) ? nan(T) : - throw(DomainError(xb, "$T has no overflow sentinel; use SAT")) - else - return x - end +function _finalize(x::T, xb::BFloat16, ::Type{SAT}) where T<:Microfloat + if isnan(xb) + return hasnan(T) ? nan(T) : throw(DomainError(xb, "$T has no NaN")) + elseif isinf(xb) || is_outside_floatmax(xb, T) + return clamp_floatmax(x) else - throw(ArgumentError("Unknown overflow policy $P")) + return x end end -""" - default_overflow_policy(T) -> Type{<:OverflowPolicy} - -Default overflow policy for `Microfloat` type `T`. Keys on `hasinf(T)`: -IEEE types default to `OVF` (finite overflow → Inf), all others default -to `SAT` (clamp to `floatmax`). Matches PyTorch/Triton/Quartet-II practice -for FP8/FP4. -""" -default_overflow_policy(::Type{T}) where T<:Microfloat = hasinf(T) ? OVF : SAT +function _finalize(x::T, xb::BFloat16, ::Type{OVF}) where T<:Microfloat + if isnan(xb) + return hasnan(T) ? nan(T) : throw(DomainError(xb, "$T has no NaN")) + elseif isinf(xb) || is_outside_floatmax(xb, T) + return hasinf(T) ? clamp_inf(x) : + hasnan(T) ? nan(T) : + throw(DomainError(xb, "$T has no overflow sentinel; declare the type with overflow=SAT")) + else + return x + end +end -function (::Type{T})(x::BFloat16, ::Type{P}=default_overflow_policy(T)) where {T<:Microfloat,P<:OverflowPolicy} +function (::Type{T})(x::BFloat16) where T<:Microfloat if sign_bits(T) == 0 && signbit(x) throw(DomainError(x, "negative input to unsigned $T")) end @@ -117,11 +120,10 @@ function (::Type{T})(x::BFloat16, ::Type{P}=default_overflow_policy(T)) where {T t_raw |= (reinterpret(Unsigned, x) >> 15 % UInt8) << (exponent_bits(T) + significand_bits(T)) & sign_mask(T) - return epilogue(reinterpret(T, t_raw), x, P) + return _finalize(reinterpret(T, t_raw), x, overflow_policy(T)) end -(::Type{T})(x::Number, args...) where {T<:Microfloat} = T(BFloat16(x), args...) -(::Type{T})(::Type{P}) where {T<:Microfloat,P<:OverflowPolicy} = x -> T(x, P) +(::Type{T})(x::Number) where {T<:Microfloat} = T(BFloat16(x)) function _to_bfloat16(x::T) where {T<:Microfloat} t_raw = reinterpret(UInt8, x) diff --git a/src/macro.jl b/src/macro.jl new file mode 100644 index 0000000..6255e42 --- /dev/null +++ b/src/macro.jl @@ -0,0 +1,85 @@ +""" + @microfloat name [kwargs...] + +Define a new type + +Default policy rule: `OVF` when the type has any non-finite sentinel +(`IEEE` or `NanOnlyAllOnes`), else `SAT` (forced for `FiniteOnly` since no +sentinel encoding exists). + +## Keyword arguments + +- `sign`: Number of sign bits: `1` (default) or `0` +- `exponent`: Number of exponent bits: ≥ `1` +- `significand`: Number of significand / mantissa bits: ≥ `0` +- `nonfinite`: [`IEEE`](@ref) (default), [`NanOnlyAllOnes`](@ref), or [`FiniteOnly`](@ref). +- `overflow`: Overflow handling during conversion from other types: [`SAT`](@ref) or [`OVF`](@ref). Default: `OVF` if the type has any + non-finite values (`IEEE`, `NanOnlyAllOnes`), otherwise `SAT` (`FiniteOnly`). + +## Examples + +Converting from larger types rounds to the nearest even value, i.e. +the value whose bit representation ends in `0`. +""" +macro microfloat(name, kwargs...) + mod = @__MODULE__ + S = 1 + E = nothing + M = nothing + nonfinite = nothing + overflow = nothing + + for kw in kwargs + (kw isa Expr && kw.head == :(=)) || + error("@microfloat: expected keyword arguments (e.g. exponent=5), got $kw") + k, v = kw.args + if k == :sign + S = Int(v) + elseif k == :exponent + E = Int(v) + elseif k == :significand + M = Int(v) + elseif k == :nonfinite + nonfinite = v + elseif k == :overflow + overflow = v + else + error("@microfloat: unknown keyword `$k`") + end + end + + E === nothing && error("@microfloat: `exponent` is required") + M === nothing && error("@microfloat: `significand` is required") + S in (0, 1) || error("@microfloat: `sign` must be 0 or 1, got $S") + E >= 1 || error("@microfloat: `exponent` must be >= 1, got $E") + M >= 0 || error("@microfloat: `significand` must be >= 0, got $M") + S + E + M <= 8 || error("@microfloat: `sign + exponent + significand` must be <= 8, got $(S + E + M)") + + nonfinite_expr = nonfinite === nothing ? :($IEEE) : esc(nonfinite) + overflow_expr = overflow === nothing ? + :($hasinf($nonfinite_expr) || $hasnan($nonfinite_expr) ? $OVF : $SAT) : + esc(overflow) + + T = esc(name) + N = S + E + M + + quote + Base.@__doc__ primitive type $T <: $Microfloat{$S,$E,$M} 8 end + $_validate_microfloat($T, $nonfinite_expr, $overflow_expr) + $mod.non_finite_behavior(::Type{$T}) = $nonfinite_expr + $mod.overflow_policy(::Type{$T}) = $overflow_expr + let lookup = Tuple($_to_bfloat16(reinterpret($T, i % UInt8)) for i in 0:$(2^N - 1)) + $mod.to_bfloat16(x::$T) = lookup[reinterpret(UInt8, x) + 0x0001] + end + end +end + +function _validate_microfloat(T, nonfinite, overflow) + (nonfinite isa Type && nonfinite <: NonFiniteBehavior) || + throw(ArgumentError("@microfloat($T): `nonfinite` must be IEEE, NanOnlyAllOnes, or FiniteOnly, got $nonfinite")) + (overflow isa Type && overflow <: OverflowPolicy) || + throw(ArgumentError("@microfloat($T): `overflow` must be SAT or OVF, got $overflow")) + nonfinite === FiniteOnly && overflow === OVF && + throw(ArgumentError("@microfloat($T): `overflow=OVF` invalid for `nonfinite=FiniteOnly` (no sentinel encoding)")) + return nothing +end diff --git a/src/macros.jl b/src/macros.jl deleted file mode 100644 index b44860e..0000000 --- a/src/macros.jl +++ /dev/null @@ -1,63 +0,0 @@ -""" - @microfloat Name sign=1 exponent=E significand=M nonfinite=Trait - -Declare an 8-bit `primitive type Name <: Microfloat{sign,E,M} 8 end` and -register its [`non_finite_behavior`](@ref) as `Trait`. - -All keyword arguments are passed positionally as `name=value` pairs: - -- `sign` — `0` or `1`. Default `1`. -- `exponent` — required, `≥ 1`. -- `significand` — required, `≥ 0`. -- `nonfinite` — required. One of `IEEE`, `NanOnlyAllOnes`, `FiniteOnly`. - -Hand-written equivalent (also supported): - -```julia -primitive type Name <: Microfloat{sign,exponent,significand} 8 end -non_finite_behavior(::Type{Name}) = Trait -``` -""" -macro microfloat(name, kwargs...) - mod = @__MODULE__ - S = 1 - E = nothing - M = nothing - behavior = IEEE - - for kw in kwargs - (kw isa Expr && kw.head == :(=)) || - error("@microfloat: expected keyword arguments (e.g. exponent=5), got $kw") - k, v = kw.args - if k == :sign - S = v - elseif k == :exponent - E = v - elseif k == :significand - M = v - elseif k == :nonfinite - behavior = v - else - error("@microfloat: unknown keyword `$k`") - end - end - - E === nothing && error("@microfloat: `exponent` is required") - M === nothing && error("@microfloat: `significand` is required") - behavior isa NonFiniteBehavior && error("@microfloat: `nonfinite` is required") - S in (0, 1) || error("@microfloat: `sign` must be 0 or 1, got $S") - E >= 1 || error("@microfloat: `exponent` must be >= 1, got $E") - M >= 0 || error("@microfloat: `significand` must be >= 0, got $M") - S + E + M <= 8 || error("@microfloat: `sign + exponent + significand` must be <= 8, got $(S + E + M)") - - T = esc(name) - trait = esc(behavior) - N = S + E + M - quote - primitive type $T <: $mod.Microfloat{$S,$E,$M} 8 end - $mod.non_finite_behavior(::Type{$T}) = $trait - let lookup = Tuple($mod._to_bfloat16(reinterpret($T, i % UInt8)) for i in 0:$(2^N - 1)) - $mod.to_bfloat16(x::$T) = lookup[reinterpret(UInt8, x) + 0x0001] - end - end -end diff --git a/src/utils.jl b/src/utils.jl new file mode 100644 index 0000000..ceb6713 --- /dev/null +++ b/src/utils.jl @@ -0,0 +1,39 @@ +float_bits(::Type{Float64}) = (1, 11, 52) +float_bits(::Type{Float32}) = (1, 8, 23) +float_bits(::Type{Float16}) = (1, 5, 10) +float_bits(::Type{BFloat16}) = (1, 8, 7) + +""" + bitwidth(::Type{<:AbstractFloat}) + +Returns the number of utilized bits. + +```jldoctest +julia> Microfloats.bitwidth(Float4_E2M1FN) +4 +``` +""" +bitwidth(::Type{T}) where T<:AbstractFloat = sum(float_bits(T)) + +""" + sign_bits(::Type{<:AbstractFloat}) + +Return the number of sign bits (between 0 or 1). +""" +sign_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[1] + +""" + exponent_bits(::Type{<:AbstractFloat}) + +Return the number of exponent bits. +""" +exponent_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[2] + +""" + significand_bits(::Type{<:AbstractFloat}) + +Return the number of significand / mantissa bits. +""" +significand_bits(::Type{T}) where T<:AbstractFloat = float_bits(T)[3] + +exponent_bias(::Type{T}) where T<:AbstractFloat = 2^(exponent_bits(T) - 1) - 1 diff --git a/src/variants.jl b/src/variants.jl index c9b2ae9..70b9ec6 100644 --- a/src/variants.jl +++ b/src/variants.jl @@ -1,3 +1,9 @@ +# All shipped types default to `overflow=OVF` if they have any non-finite +# sentinel (IEEE or NanOnlyAllOnes), else forced `SAT` (FiniteOnly has none). +# Rule: `hasinf(T) || hasnan(T) ? OVF : SAT`. Matches cutile-python's +# `_convert_nonfinite` exactly. For PyTorch/Triton-style saturating E4M3FN, +# declare a twin with `overflow=SAT`. + # IEEE-like @microfloat Float8_E5M2 exponent=5 significand=2 @microfloat Float8_E4M3 exponent=4 significand=3 @@ -11,3 +17,22 @@ @microfloat Float6_E2M3FN exponent=2 significand=3 nonfinite=FiniteOnly @microfloat Float6_E3M2FN exponent=3 significand=2 nonfinite=FiniteOnly @microfloat Float4_E2M1FN exponent=2 significand=1 nonfinite=FiniteOnly + +for T in ( + :Float8_E5M2, :Float8_E4M3, :Float8_E3M4, + :Float8_E4M3FN, :Float8_E8M0FNU, + :Float6_E2M3FN, :Float6_E3M2FN, + :Float4_E2M1FN, +) + @eval @doc """ + $($T) + + ## Properties + - Bits: `$(sign_bits($T))` sign + `$(exponent_bits($T))` exponent + `$(significand_bits($T))` significand (`$(bitwidth($T))` total) + - Non-finite behavior: `$(non_finite_behavior($T))` + - Has Inf: `$(hasinf($T))` + - Has NaN: `$(hasnan($T))` + - Max normal: `$(Float64(floatmax($T)))` + - Min positive: `$(Float64(floatmin($T)))` + """ $T +end diff --git a/test/dlfp8_parity.jl b/test/dlfp8_parity.jl index 8818e29..9cd67ca 100644 --- a/test/dlfp8_parity.jl +++ b/test/dlfp8_parity.jl @@ -5,6 +5,8 @@ import DLFP8Types # non-finite semantics as ours. Every bit pattern must agree on the # Float32 value (NaN encodings and signed zeros are allowed to differ # between the two packages — the semantic float value is what matters). + # DLFP8Types uses strict OCP semantics (overflow → NaN for E4M3FN, + # overflow → Inf for E5M2), which matches our OVF defaults directly. @testset "$(M_T)" for (M_T, D_T) in ( (Microfloats.Float8_E4M3FN, DLFP8Types.Float8_E4M3FN), (Microfloats.Float8_E5M2, DLFP8Types.Float8_E5M2), @@ -17,12 +19,9 @@ import DLFP8Types @test isnan(mx) == isnan(dx) @test iszero(mx) == iszero(dx) end - # Float32 → narrow goes to the same semantic value. - # DLFP8Types uses strict spec semantics (overflow → NaN for E4M3FN, - # overflow → Inf for E5M2), which is our `OVF` policy. for f in Float32[0.0, -0.0, 1.0, -1.0, 3.5, -3.5, 448.0, 1000.0, 1.5f-5, 2.0f-7, NaN32, Inf32, -Inf32] - @test Float32(M_T(f, OVF)) ≡ Float32(D_T(f)) + @test Float32(M_T(f)) ≡ Float32(D_T(f)) end end end diff --git a/test/mx_properties.jl b/test/mx_properties.jl index f00c939..a4deb54 100644 --- a/test/mx_properties.jl +++ b/test/mx_properties.jl @@ -72,33 +72,46 @@ end end end -@testset "MX: saturation and NaN/Inf mapping from Float32" begin - for T in (MX_E5M2, MX_E4M3, MX_E3M2, MX_E2M3, MX_E2M1, MX_E8M0) - @testset "$T" begin - fmax = floatmax(T) - # +Inf maps to +floatmax under SAT. - @test T(Inf32, SAT) == fmax - # Negative input: signed types saturate to -fmax; unsigned throws. - if Microfloats.sign_bits(T) == 0 - @test_throws DomainError T(-Inf32, SAT) - else - @test T(-Inf32, SAT) == -fmax - end - # NaN input: types with NaN encoding return NaN; FiniteOnly throws. - if hasnan(T) - @test isnan(T(NaN32)) - else - @test_throws DomainError T(NaN32) - end - # Values just beyond floatmax saturate (positive side). - big = nextfloat(Float32(fmax)) - @test T(big, SAT) == fmax - if Microfloats.sign_bits(T) == 0 - @test_throws DomainError T(-big, SAT) - else - @test T(-big, SAT) == -fmax - end - end +@testset "MX: default overflow mapping from Float32" begin + # Each MX type's default policy is baked in. We test the shipped + # behavior per type; alternate semantics require a twin type. + @testset "E5M2 (IEEE, OVF)" begin + T = MX_E5M2 + @test T(+Inf32) == inf(T) + @test T(-Inf32) == -inf(T) + @test isnan(T(NaN32)) + big = nextfloat(BFloat16(floatmax(T))) + @test T(+big) == inf(T) + @test T(-big) == -inf(T) + end + + @testset "E4M3 (NanOnlyAllOnes, OVF)" begin + T = MX_E4M3 + @test isnan(T(+Inf32)) + @test isnan(T(-Inf32)) + @test isnan(T(NaN32)) + big = nextfloat(BFloat16(floatmax(T))) + @test isnan(T(+big)) + @test isnan(T(-big)) + end + + @testset "E8M0 (NanOnlyAllOnes, OVF)" begin + T = MX_E8M0 + @test isnan(T(+Inf32)) + @test_throws DomainError T(-Inf32) + @test isnan(T(NaN32)) + big = nextfloat(BFloat16(floatmax(T))) + @test isnan(T(big)) + end + + @testset "FiniteOnly $T" for T in (MX_E3M2, MX_E2M3, MX_E2M1) + fmax = floatmax(T) + @test T(+Inf32) == +fmax + @test T(-Inf32) == -fmax + @test_throws DomainError T(NaN32) + big = nextfloat(BFloat16(fmax)) + @test T(+big) == +fmax + @test T(-big) == -fmax end end diff --git a/test/overflow.jl b/test/overflow.jl index 513e265..324b4a3 100644 --- a/test/overflow.jl +++ b/test/overflow.jl @@ -1,70 +1,78 @@ -@testset "Overflow: IEEE types" begin - # Default: OVF → Inf for finite overflow, NaN input → NaN +@testset "Overflow: IEEE types (OVF)" begin + # overflow=OVF default: finite overflow → ±Inf, NaN input → NaN. @testset "$T" for T in (Float8_E3M4, Float8_E5M2) - @test default_overflow_policy(T) === OVF + @test overflow_policy(T) === OVF @test isnan(T(NaN)) @test T(+Inf) == +inf(T) @test T(-Inf) == -inf(T) big = nextfloat(BFloat16(floatmax(T))) - @test T(+big) == +inf(T) # default OVF + @test T(+big) == +inf(T) @test T(-big) == -inf(T) - @test T(+big, SAT) == +floatmax(T) - @test T(-big, SAT) == -floatmax(T) end end -@testset "Overflow: NanOnlyAllOnes types" begin - # Default: SAT → floatmax for finite overflow (matches PyTorch/Triton/Quartet) - @testset "$T" for T in (Float8_E4M3FN,) - @test default_overflow_policy(T) === SAT +@testset "Overflow: NanOnlyAllOnes types (OVF)" begin + # Default for NanOnlyAllOnes: OVF. Overflow → NaN. Matches cutile-python + # and the OCP strict reading. PyTorch-style saturation requires a + # twin type declared with `overflow=SAT`. + @testset "Float8_E4M3FN" begin + T = Float8_E4M3FN + @test overflow_policy(T) === OVF @test isnan(T(NaN)) - @test T(+Inf) == +floatmax(T) - @test T(-Inf) == -floatmax(T) + @test isnan(T(+Inf)) + @test isnan(T(-Inf)) big = nextfloat(BFloat16(floatmax(T))) - @test T(+big) == +floatmax(T) # default SAT - @test T(-big) == -floatmax(T) - @test isnan(T(+big, OVF)) - @test isnan(T(-big, OVF)) + @test isnan(T(+big)) + @test isnan(T(-big)) end - # Unsigned NanOnlyAllOnes (E8M0FNU): negative input throws, large positive saturates - T = Float8_E8M0FNU - @test default_overflow_policy(T) === SAT - @test T(+Inf) == floatmax(T) - @test_throws DomainError T(-Inf) - big = nextfloat(BFloat16(floatmax(T))) - @test T(big) == floatmax(T) - @test isnan(T(big, OVF)) - @test isnan(T(NaN)) + @testset "Float8_E8M0FNU" begin + # Unsigned NanOnlyAllOnes scale type; negative input throws regardless. + T = Float8_E8M0FNU + @test overflow_policy(T) === OVF + + @test isnan(T(NaN)) + @test isnan(T(+Inf)) + @test_throws DomainError T(-Inf) + + big = nextfloat(BFloat16(floatmax(T))) + @test isnan(T(big)) + end end -@testset "Overflow: FiniteOnly types" begin - # Default: SAT. NaN input always throws (no sentinel). OVF also throws on overflow. +@testset "Overflow: FiniteOnly types (SAT)" begin + # overflow=SAT forced — no sentinel encoding exists. NaN input throws. @testset "$T" for T in (Float4_E2M1FN, Float6_E2M3FN, Float6_E3M2FN) - @test default_overflow_policy(T) === SAT + @test overflow_policy(T) === SAT @test_throws DomainError T(NaN) - @test_throws DomainError T(NaN, OVF) - @test_throws DomainError T(NaN, SAT) @test T(+Inf) == +floatmax(T) @test T(-Inf) == -floatmax(T) - @test_throws DomainError T(+Inf, OVF) - @test_throws DomainError T(-Inf, OVF) big = nextfloat(BFloat16(floatmax(T))) @test T(+big) == +floatmax(T) @test T(-big) == -floatmax(T) - @test_throws DomainError T(+big, OVF) - @test_throws DomainError T(-big, OVF) end # Specific: Float4_E2M1FN (NVFP4 value type) saturates at ±6 @test Float32(Float4_E2M1FN(6.0)) == 6.0 - @test Float32(Float4_E2M1FN(7.0)) == 6.0 # SAT - @test_throws DomainError Float4_E2M1FN(7.0, OVF) + @test Float32(Float4_E2M1FN(7.0)) == 6.0 +end + +@testset "Alternate policy via twin type + reinterpret" begin + # _E4M3FN_SAT is declared at top of runtests.jl: same bit layout as + # Float8_E4M3FN but with overflow=SAT (PyTorch/Triton convention). + # Shows the documented escape hatch for the non-default policy. + @test overflow_policy(_E4M3FN_SAT) === SAT + big = nextfloat(BFloat16(floatmax(_E4M3FN_SAT))) + @test _E4M3FN_SAT(big) == floatmax(_E4M3FN_SAT) # SAT: overflow → floatmax + @test isnan(Float8_E4M3FN(big)) # OVF: overflow → NaN + # Bit layout is identical, so reinterpret is a free relabel. + x = Float8_E4M3FN(1.0) + @test reinterpret(UInt8, reinterpret(_E4M3FN_SAT, x)) == reinterpret(UInt8, x) end diff --git a/test/runtests.jl b/test/runtests.jl index ca940cc..6186188 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,9 @@ using Microfloats using Microfloats: non_finite_behavior, hasinf, hasnan, inf, nan, - sign_bits, bitwidth, default_overflow_policy + sign_bits, bitwidth, overflow_policy, + IEEE, NanOnlyAllOnes, FiniteOnly, + OverflowPolicy, SAT, OVF, + @microfloat using Test using Random @@ -18,6 +21,11 @@ using Random @microfloat UFloat5_E3M2 sign=0 exponent=3 significand=2 nonfinite=FiniteOnly @microfloat UFloat3_E2M1 sign=0 exponent=2 significand=1 nonfinite=FiniteOnly +# Twin of Float8_E4M3FN with the alternate (PyTorch/Triton) overflow policy. +# Demonstrates the documented "reinterpret between twin types" escape hatch; +# used in overflow.jl. +@microfloat _E4M3FN_SAT exponent=4 significand=3 nonfinite=NanOnlyAllOnes overflow=SAT + const SIGNED_TYPES = ( Float8_E3M4, Float8_E4M3, Float8_E5M2, Float8_E4M3FN, From e35598b581b985a2e928868e64826977dcdd15d9 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 21:20:49 +0200 Subject: [PATCH 4/8] update --- README.md | 40 ++++------------------------------------ src/Microfloats.jl | 4 +--- src/conversion.jl | 22 ++++++++++++++++++++-- 3 files changed, 25 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 1cceddc..ab6bba9 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ [![Build Status](https://github.com/MurrellGroup/Microfloats.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Microfloats.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/MurrellGroup/Microfloats.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Microfloats.jl) -Microfloats is a Julia package that implements types and arithmetic (through wider intermediates) for sub-8 bit floating points, supporting arbitrary combinations of sign, exponent, and mantissa (significand) bits. +Microfloats is a Julia package that implements types and arithmetic (through wider intermediates) for sub-byte floating points, supporting arbitrary combinations of sign, exponent, and significand (mantissa) bits. -Instances of a sub-8 bit floating point type are still 8 bits wide in memory; the goal of `Microfloat` is to serve as a base for arithmetic operations and method dispatch, lending downstream packages a good abstraction for doing bitpacking and hardware acceleration. +Instances of a sub-8 bit floating point type are still 8 bits wide in memory; Microfloats serves as a base and reference for arithmetic operations and method dispatch, lending downstream packages a good abstraction for bitpacking and hardware acceleration. ## Usage @@ -16,42 +16,10 @@ Define your own primitive type with the macro: ```julia using Microfloats -@microfloat MyE5M2 sign=1 exponent=5 significand=2 nonfinite=IEEE +@microfloat MyE5M2 sign=1 exponent=5 significand=2 nonfinite=Microfloats.IEEE ``` -Or the hand-written equivalent: - -```julia -primitive type MyE5M2 <: Microfloat{1,5,2} 8 end -Microfloats.non_finite_behavior(::Type{MyE5M2}) = IEEE -``` - -## Overflow policy - -Each `Microfloat` type carries its overflow policy as a compile-time trait, -baked in at declaration via `@microfloat`. There is no runtime override — -to convert with different semantics, declare a second type with the same -bit layout and the alternate policy, and `reinterpret` between them. - -- `SAT` saturates out-of-range finite values to `±floatmax(T)`. -- `OVF` maps them to the type's sentinel: `±Inf` for `IEEE`, `NaN` for - `NanOnlyAllOnes` (unavailable for `FiniteOnly`). - -Default rule: `OVF` if the type has any non-finite sentinel, else `SAT` -(forced for `FiniteOnly`). This matches cutile-python / OCP strict -semantics. The shipped policies: - -| Type | NonFiniteBehavior | Overflow | -| ------------------------------------------------- | ----------------- | -------- | -| `Float8_E5M2`, `Float8_E4M3`, `Float8_E3M4` | `IEEE` | `OVF` | -| `Float8_E4M3FN`, `Float8_E8M0FNU` | `NanOnlyAllOnes` | `OVF` | -| `Float6_E2M3FN`, `Float6_E3M2FN`, `Float4_E2M1FN` | `FiniteOnly` | `SAT` | - -For PyTorch/Triton-style saturating E4M3FN, declare a twin: - -```julia -@microfloat Float8_E4M3FN_SAT exponent=4 significand=3 nonfinite=NanOnlyAllOnes overflow=SAT -``` +or see the documentation for a list of predefined types. ## Installation diff --git a/src/Microfloats.jl b/src/Microfloats.jl index 82b09e9..afdc908 100644 --- a/src/Microfloats.jl +++ b/src/Microfloats.jl @@ -19,10 +19,8 @@ include("conversion.jl") @public SAT, OVF include("macro.jl") -@public @microfloat +export @microfloat -# Each `@microfloat` call builds a per-type BFloat16 lookup table, -# so conversion.jl must be loaded before this point. include("variants.jl") export Float8_E5M2, Float8_E4M3, Float8_E3M4 export Float8_E4M3FN, Float8_E8M0FNU diff --git a/src/conversion.jl b/src/conversion.jl index 8ba0d87..90e0c78 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -3,12 +3,21 @@ abstract type OverflowPolicy end """ OVF -Sentinel overflow: out-of-range finite inputs go to `±Inf` (IEEE) or `NaN` (NanOnlyAllOnes). +Overflow conversion: out-of-range finite inputs go to `±Inf` (IEEE) or `NaN` (NanOnlyAllOnes). | Input Condition | `T` has Inf+NaN | `T` has NaN | `T` is finite | | ---------------------- | --------------- | ------------- | ------------- | | `isnan(x)` | NaN | NaN | Error | | `abs(x) > floatmax(T)` | ±Inf | NaN | Error | + +## Examples + +```jldoctest +julia> @microfloat OverflowingFloat8 exponent=4 significand=3 overflow=Microfloats.OVF + +julia> OverflowingFloat8(10000) +OverflowingFloat8(Inf) +``` """ abstract type OVF <: OverflowPolicy end @@ -20,7 +29,16 @@ Saturating conversion: out-of-range finite inputs clamp to `±floatmax(T)`. | Input Condition | `T` has Inf+NaN | `T` has NaN | `T` is finite | | ---------------------- | --------------- | ------------- | ------------- | | `isnan(x)` | NaN | NaN | Error | -| `abs(x) > floatmax(T)` | ±floatmax | ±floatmax | ±floatmax | +| `abs(x) > floatmax(T)` | ±floatmax | ±floatmax | ±floatmax | + +## Examples + +```jldoctest +julia> @microfloat SaturatingFloat8 exponent=4 significand=3 overflow=Microfloats.SAT + +julia> SaturatingFloat8(10000) +SaturatingFloat8(240.0) +``` """ abstract type SAT <: OverflowPolicy end From 796a0ea038513b08413f5cfc610cacae13301e40 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 21:26:21 +0200 Subject: [PATCH 5/8] add BFloat16s to test env --- test/Project.toml | 1 + test/runtests.jl | 2 ++ 2 files changed, 3 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index ca75f0d..d05f338 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" Microfloats = "31c70f10-a750-4521-b13c-797315ae2933" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/test/runtests.jl b/test/runtests.jl index 6186188..3cbcbdf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,8 @@ using Microfloats: non_finite_behavior, hasinf, hasnan, inf, nan, using Test using Random +using BFloat16s: BFloat16 + # NaN-aware bit-identity for tests. Overrides `===` for Microfloat pairs so # round-trip tests can treat distinct NaN encodings as equivalent without # losing the bit-identity semantics for signed zeros. From a4649d5692826efb887e0bdf8e918e28766b6604 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 21:53:21 +0200 Subject: [PATCH 6/8] add hasinf and hasnan docstrings --- src/Microfloat.jl | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/Microfloat.jl b/src/Microfloat.jl index 594a985..79334be 100644 --- a/src/Microfloat.jl +++ b/src/Microfloat.jl @@ -55,7 +55,18 @@ Return `IEEE`, `NanOnlyAllOnes`, or `FiniteOnly` based on the assigned trait. non_finite_behavior(::Type{T}) where T<:Microfloat = error("$T must define `Microfloats.non_finite_behavior(::Type{$T})`") +""" + hasinf(::Type{<:Microfloat}) -> Bool + +Return `true` if the type can represent Inf, otherwise `false`. +""" hasinf(::Type{T}) where T<:Microfloat = hasinf(non_finite_behavior(T)) + +""" + hasnan(::Type{<:Microfloat}) -> Bool + +Return `true` if the type can represent NaN, otherwise `false`. +""" hasnan(::Type{T}) where T<:Microfloat = hasnan(non_finite_behavior(T)) # ───────────────────────── Inf / NaN / floatmax / inf / nan ────────────────────────── @@ -64,7 +75,6 @@ Base.isinf(x::T) where T<:Microfloat = _isinf(non_finite_behavior(T), x) Base.isnan(x::T) where T<:Microfloat = _isnan(non_finite_behavior(T), x) inf(::Type{T}) where T<:Microfloat = _inf(non_finite_behavior(T), T) - nan(::Type{T}) where T<:Microfloat = _nan(non_finite_behavior(T), T) Base.floatmax(::Type{T}) where T<:Microfloat = _floatmax(non_finite_behavior(T), T) From c928c834e6e831b2a1d462aa6ead457cbe81c977 Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Fri, 24 Apr 2026 23:52:28 +0200 Subject: [PATCH 7/8] add tests --- src/Microfloat.jl | 12 ++++++++++-- src/conversion.jl | 5 ++++- src/macro.jl | 2 +- test/basic.jl | 26 ++++++++++++++++++++++++++ 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/Microfloat.jl b/src/Microfloat.jl index 79334be..8a45bcf 100644 --- a/src/Microfloat.jl +++ b/src/Microfloat.jl @@ -1,4 +1,4 @@ -@republic import Base: signbit, exponent +import Base: signbit, exponent """ Microfloat{S,E,M} <: AbstractFloat @@ -18,7 +18,15 @@ significand_mask(::Type{T}) where T<:Microfloat = UInt8(0x01 << significand_bits Base.reinterpret(::Type{Unsigned}, x::Microfloat) = reinterpret(UInt8, x) signbit(x::Microfloat) = sign_bits(typeof(x)) > 0 && !iszero(reinterpret(Unsigned, x) & sign_mask(typeof(x))) -exponent(x::Microfloat) = Int((reinterpret(Unsigned, x) & exponent_mask(typeof(x))) >> significand_bits(typeof(x))) +function exponent(x::T) where T<:Microfloat + (isnan(x) || isinf(x)) && throw(DomainError(x, "Cannot be NaN or Inf.")) + iszero(x) && throw(DomainError(x, "Cannot be ±0.0.")) + raw = reinterpret(Unsigned, x) + biased = Int((raw & exponent_mask(T)) >> significand_bits(T)) + biased == 0 || return biased - exponent_bias(T) + sig = raw & significand_mask(T) + return 8 - leading_zeros(sig) - exponent_bias(T) - significand_bits(T) +end function Base.show(io::IO, x::T) where T<:Microfloat show_typeinfo = get(IOContext(io), :typeinfo, nothing) != T diff --git a/src/conversion.jl b/src/conversion.jl index 90e0c78..94156f0 100644 --- a/src/conversion.jl +++ b/src/conversion.jl @@ -200,7 +200,10 @@ function _to_bfloat16(x::T) where {T<:Microfloat} end end -to_bfloat16(x::T) where T<:Microfloat = _to_bfloat16(x) +# `@microfloat` adds a new method to `to_bfloat16` +function to_bfloat16 end + +# user can add specialized conversions to `BFloat16` itself BFloat16(x::T) where T<:Microfloat = to_bfloat16(x) (::Type{T})(x::Microfloat) where {T<:AbstractFloat} = T(BFloat16(x)) diff --git a/src/macro.jl b/src/macro.jl index 6255e42..6631c62 100644 --- a/src/macro.jl +++ b/src/macro.jl @@ -44,7 +44,7 @@ macro microfloat(name, kwargs...) elseif k == :overflow overflow = v else - error("@microfloat: unknown keyword `$k`") + throw(ArgumentError("@microfloat: unknown keyword `$k`")) end end diff --git a/test/basic.jl b/test/basic.jl index 58396d7..1dddc4a 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -4,6 +4,7 @@ @test precision(T) == Microfloats.significand_bits(T) + 1 @test floatmax(T) > zero(T) @test isfinite(floatmax(T)) + @test Microfloats.overflow_policy(T) <: Union{Microfloats.OVF, Microfloats.SAT} @test signbit(zero(T)) == false @@ -34,6 +35,13 @@ @test T(Float32(x)) ≡ x end end + + @test widen(Float8_E4M3) == BFloat16 + @test string(Float8_E4M3(1.0)) == "Float8_E4M3(1.0)" +end + +@testset "@microfloat" begin + @test_throws "abc" @eval @microfloat Name abc=1 end @testset "NonFiniteBehavior trait" begin @@ -126,6 +134,24 @@ end end end +@testset "eps / round / issubnormal" begin + @test eps(Float8_E4M3(1.0)) == Float8_E4M3(0.125) + @test eps(Float8_E4M3(2.0)) == Float8_E4M3(0.25) + @test eps(Float8_E4M3(0.5)) == Float8_E4M3(0.0625) + + @test round(Float8_E4M3(1.5), RoundDown) === Float8_E4M3(1.0) + @test round(Float8_E4M3(1.5), RoundUp) === Float8_E4M3(2.0) + @test round(Float8_E4M3(2.5), RoundNearest) === Float8_E4M3(2.0) + @test round(Float8_E4M3(0.5), RoundNearest) === Float8_E4M3(0.0) + + @test !issubnormal(zero(Float8_E4M3)) + @test !issubnormal(one(Float8_E4M3)) + @test issubnormal(reinterpret(Float8_E4M3, 0x01)) + @test issubnormal(reinterpret(Float8_E4M3, 0x07)) + @test !issubnormal(reinterpret(Float8_E4M3, 0x08)) + @test issubnormal(-reinterpret(Float8_E4M3, 0x01)) +end + @testset "Cross-microfloat arithmetic is unsupported" begin a = Float8_E4M3FN(1.0) b = Float8_E5M2(1.0) From 92a02e8fd51b5778b7b4bd6105b972bb34e2211b Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 25 Apr 2026 00:01:54 +0200 Subject: [PATCH 8/8] add more tests --- test/basic.jl | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/basic.jl b/test/basic.jl index 1dddc4a..da167a5 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -59,6 +59,15 @@ end @test_throws ErrorException Microfloats.non_finite_behavior(_BadFloat) end +@testset "overflow_policy trait" begin + @test Microfloats.overflow_policy(Float8_E4M3FN) === OVF + @test Microfloats.overflow_policy(_E4M3FN_SAT) === SAT + + # Forgetting `overflow_policy` on a custom type errors loudly. + primitive type _BadFloatOvf <: Microfloats.Microfloat{1,2,1} 8 end + @test_throws ErrorException Microfloats.overflow_policy(_BadFloatOvf) +end + @testset "IEEE types: Inf and NaN encodings" begin @testset "$T" for T in (Float8_E3M4, Float8_E5M2) @test isinf(inf(T)) @@ -152,6 +161,45 @@ end @test issubnormal(-reinterpret(Float8_E4M3, 0x01)) end +@testset "exponent" begin + # normals: matches Base.exponent on the round-tripped value + @testset "$T normals" for T in (Float8_E3M4, Float8_E4M3, Float8_E5M2, Float8_E4M3FN) + for u in 0x01:0xff + x = reinterpret(T, u) + (isnan(x) || isinf(x) || iszero(x) || issubnormal(x)) && continue + @test exponent(x) == exponent(Float64(x)) + end + end + + # subnormals: leading-1 position determines the unbiased exponent + @test exponent(reinterpret(Float8_E4M3, 0x01)) == -9 # 0.001 × 2^-6 + @test exponent(reinterpret(Float8_E4M3, 0x02)) == -8 # 0.010 × 2^-6 + @test exponent(reinterpret(Float8_E4M3, 0x04)) == -7 # 0.100 × 2^-6 + @test exponent(reinterpret(Float8_E5M2, 0x01)) == -16 # 0.01 × 2^-14 + @test exponent(reinterpret(Float8_E5M2, 0x02)) == -15 # 0.10 × 2^-14 + + # DomainError for zero / Inf / NaN — matches Base.exponent semantics + @test_throws DomainError exponent(zero(Float8_E4M3)) + @test_throws DomainError exponent(-zero(Float8_E4M3)) + @test_throws DomainError exponent(inf(Float8_E5M2)) + @test_throws DomainError exponent(-inf(Float8_E5M2)) + @test_throws DomainError exponent(nan(Float8_E4M3)) +end + +@testset "sign_bits / exponent_bits / significand_bits / bitwidth (Base floats)" begin + using Microfloats: sign_bits, exponent_bits, significand_bits, bitwidth + + @test (sign_bits(Float64), exponent_bits(Float64), significand_bits(Float64)) == (1, 11, 52) + @test (sign_bits(Float32), exponent_bits(Float32), significand_bits(Float32)) == (1, 8, 23) + @test (sign_bits(Float16), exponent_bits(Float16), significand_bits(Float16)) == (1, 5, 10) + @test (sign_bits(BFloat16), exponent_bits(BFloat16), significand_bits(BFloat16)) == (1, 8, 7) + + @test bitwidth(Float64) == 64 + @test bitwidth(Float32) == 32 + @test bitwidth(Float16) == 16 + @test bitwidth(BFloat16) == 16 +end + @testset "Cross-microfloat arithmetic is unsupported" begin a = Float8_E4M3FN(1.0) b = Float8_E5M2(1.0)