Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2475 +/- ##
==========================================
+ Coverage 25.47% 25.60% +0.12%
==========================================
Files 220 220
Lines 44609 44601 -8
==========================================
+ Hits 11366 11421 +55
+ Misses 33243 33180 -63 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@wsmoses could this get speedy approve |
wsmoses
left a comment
There was a problem hiding this comment.
so I'm super confused/question why anything other than the new op needs to be changed
|
the change to multifloat pass also refactored log op conversion, other two tests were broken by changes to |
|
yeah can we extract this to only do pow [the original code here was just an initial vibe code which went a bit wild] |
|
q, why not just write |
|
well the reason being that doing so might introduce more errors |
| // existing log/mul/exp patterns handle it without TUPLE-mode ordering | ||
| // issues. | ||
| { | ||
| IRRewriter rewriter(context); |
There was a problem hiding this comment.
we really shouldn't do this, it's going to likely be more numerically unstable. At minimum, put it within a pattern rather than here
| return rewriter.create<stablehlo::ConstantOp>(loc, tensorTy, attr); | ||
| }; | ||
|
|
||
| Value zero = makeConst(0.0); |
There was a problem hiding this comment.
how does this algorithmically compare to the multifloat.jl case
There was a problem hiding this comment.
constexpr int64_t kMaxUnrollExponent = 32; understandable that we would need a cap for IR size whereas multifloat.jl's runtime loop uncapped is fine. Not sure if it would be a bit too small for certain kinds of applications?
nvm i think it's fine 👍
| // * Constant 0.5 → stablehlo.sqrt | ||
| // (multifloat sqrt is much more accurate than exp(0.5*log(x))). |
There was a problem hiding this comment.
you likely want to handle y*2 is an int. i.e. x^3.5 => x^3*sqrt(x)
| // (multifloat sqrt is much more accurate than exp(0.5*log(x))). | ||
| // | ||
| // For exponents that don't match a fast path (non-constant y, or constants | ||
| // like 2/7 that have no closed form in elementary ops) we fall back to the |
There was a problem hiding this comment.
For rational y=a/b, you might want to use x^(a/b) => bth_root(x^a, b) where you compute b_th root using newton iteration with an initial value provided by the bth root of the highest limb.
| // | ||
| // For exponents that don't match a fast path (non-constant y, or constants | ||
| // like 2/7 that have no closed form in elementary ops) we fall back to the | ||
| // exp(y * log(|x|)) lowering. Restricted to ops whose lhs element type is |
There was a problem hiding this comment.
will Enzyme/LLVM be smart enough to constant fold log(x) when x is constant? If not, you definitely want that as a fast path.
| uint64_t k = static_cast<uint64_t>(std::abs(*n)); | ||
| Value base = x; | ||
| Value result; | ||
| while (k > 0) { |
There was a problem hiding this comment.
I think you want to do all the math in this loop in higher precision (because you will exponentially accumulate error)
| // like 2/7 that have no closed form in elementary ops) we fall back to the | ||
| // exp(y * log(|x|)) lowering. Restricted to ops whose lhs element type is | ||
| // the multifloat source type, and skipped inside enzyme.no_multifloat funcs. | ||
| struct PowOpLowerPattern : public OpRewritePattern<stablehlo::PowOp> { |
There was a problem hiding this comment.
I think you're missing a required check for 1^x (to preserve behavior like 1^Inf). Look at Julia's pow (in Base/math.jl). It has a pretty clear list of the required checks in the right order if you care about ieee semantics.
a quick vibe code, cc @dkytezab
doesn't seem like it did a good job tbh