Skip to content

multifloat pow#2475

Open
wsmoses wants to merge 14 commits into
mainfrom
mfpow
Open

multifloat pow#2475
wsmoses wants to merge 14 commits into
mainfrom
mfpow

Conversation

@wsmoses
Copy link
Copy Markdown
Member

@wsmoses wsmoses commented May 11, 2026

a quick vibe code, cc @dkytezab

doesn't seem like it did a good job tbh

@codecov
Copy link
Copy Markdown

codecov Bot commented May 11, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 25.60%. Comparing base (003d2cd) to head (9737b22).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2475      +/-   ##
==========================================
+ Coverage   25.47%   25.60%   +0.12%     
==========================================
  Files         220      220              
  Lines       44609    44601       -8     
==========================================
+ Hits        11366    11421      +55     
+ Misses      33243    33180      -63     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@dkytezab
Copy link
Copy Markdown
Collaborator

@wsmoses could this get speedy approve

Copy link
Copy Markdown
Member Author

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I'm super confused/question why anything other than the new op needs to be changed

@dkytezab
Copy link
Copy Markdown
Collaborator

the change to multifloat pass also refactored log op conversion, other two tests were broken by changes to IsResultOrOperandTypeLegal i believe

@wsmoses
Copy link
Copy Markdown
Member Author

wsmoses commented May 14, 2026

yeah can we extract this to only do pow [the original code here was just an initial vibe code which went a bit wild]

@dkytezab
Copy link
Copy Markdown
Collaborator

dkytezab commented May 14, 2026

q, why not just write pow(x, y) as exp(y * log(x)), with case handling, seeing as we have exponential, log, multiply already. also easier down the line as we don't have to write diff versions of pow for more limbs

@wsmoses
Copy link
Copy Markdown
Member Author

wsmoses commented May 15, 2026

well the reason being that doing so might introduce more errors

Comment thread src/enzyme_ad/jax/Passes/MultiFloatConversion.cpp
Comment thread src/enzyme_ad/jax/Passes/MultiFloatConversion.cpp
// existing log/mul/exp patterns handle it without TUPLE-mode ordering
// issues.
{
IRRewriter rewriter(context);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we really shouldn't do this, it's going to likely be more numerically unstable. At minimum, put it within a pattern rather than here

Comment thread src/enzyme_ad/jax/Passes/MultiFloatConversion.cpp Outdated
return rewriter.create<stablehlo::ConstantOp>(loc, tensorTy, attr);
};

Value zero = makeConst(0.0);
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @vimarsh6739 @sbrantq

how does this algorithmically compare to the multifloat.jl case

Copy link
Copy Markdown
Member

@sbrantq sbrantq May 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constexpr int64_t kMaxUnrollExponent = 32; understandable that we would need a cap for IR size whereas multifloat.jl's runtime loop uncapped is fine. Not sure if it would be a bit too small for certain kinds of applications?

nvm i think it's fine 👍

@wsmoses
Copy link
Copy Markdown
Member Author

wsmoses commented May 19, 2026

@OscarSmith

Comment on lines +4223 to +4224
// * Constant 0.5 → stablehlo.sqrt
// (multifloat sqrt is much more accurate than exp(0.5*log(x))).
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you likely want to handle y*2 is an int. i.e. x^3.5 => x^3*sqrt(x)

// (multifloat sqrt is much more accurate than exp(0.5*log(x))).
//
// For exponents that don't match a fast path (non-constant y, or constants
// like 2/7 that have no closed form in elementary ops) we fall back to the
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For rational y=a/b, you might want to use x^(a/b) => bth_root(x^a, b) where you compute b_th root using newton iteration with an initial value provided by the bth root of the highest limb.

//
// For exponents that don't match a fast path (non-constant y, or constants
// like 2/7 that have no closed form in elementary ops) we fall back to the
// exp(y * log(|x|)) lowering. Restricted to ops whose lhs element type is
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will Enzyme/LLVM be smart enough to constant fold log(x) when x is constant? If not, you definitely want that as a fast path.

uint64_t k = static_cast<uint64_t>(std::abs(*n));
Value base = x;
Value result;
while (k > 0) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want to do all the math in this loop in higher precision (because you will exponentially accumulate error)

// like 2/7 that have no closed form in elementary ops) we fall back to the
// exp(y * log(|x|)) lowering. Restricted to ops whose lhs element type is
// the multifloat source type, and skipped inside enzyme.no_multifloat funcs.
struct PowOpLowerPattern : public OpRewritePattern<stablehlo::PowOp> {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're missing a required check for 1^x (to preserve behavior like 1^Inf). Look at Julia's pow (in Base/math.jl). It has a pretty clear list of the required checks in the right order if you care about ieee semantics.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants