Skip to content

[TorchToTosa] Add lowering for aten.atan#4517

Open
ahagro01 wants to merge 4 commits intollvm:mainfrom
ahagro01:add-atan-lowering
Open

[TorchToTosa] Add lowering for aten.atan#4517
ahagro01 wants to merge 4 commits intollvm:mainfrom
ahagro01:add-atan-lowering

Conversation

@ahagro01
Copy link
Copy Markdown

This PR adds TorchToTosa lowering support for aten.atan.

Changes:

  • add lowering for aten.atan in TorchToTosa.cpp
  • add tests in test/Conversion/TorchToTosa/basic.mlir

This enables conversion of models using aten.atan through the TorchToTosa path.

Testing:

  • added regression tests in test/Conversion/TorchToTosa/basic.mlir

Copy link
Copy Markdown
Member

@Lallapallooza Lallapallooza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for the patch, few comments

Comment on lines +9770 to +9783
// Offline-generated piecewise minimax fit for the reduced domain [0, 1].
// The split points were grid-searched, each interval was fit with an odd
// degree-9 polynomial x * Q(x^2), and the final float32-rounded tables were
// selected to minimize the max absolute error. The resulting max absolute
// error over [0, 1] is about 7.42e-8.
static constexpr float kAtanPieceSplit0 = 0.546f;
static constexpr float kAtanPieceSplit1 = 0.792f;
static constexpr float kHalfPi = 1.57079632679f;
static constexpr std::array<float, 5> kAtanLowCoefficients = {
0.99999887f, -0.33325633f, 0.19846700f, -0.13003899f, 0.06094434f};
static constexpr std::array<float, 5> kAtanMidCoefficients = {
0.99966830f, -0.32924002f, 0.17935193f, -0.08747230f, 0.02348170f};
static constexpr std::array<float, 5> kAtanHighCoefficients = {
0.99757016f, -0.31611255f, 0.14805676f, -0.05377132f, 0.00965516f};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we include the source for these numbers in the tree or attach it to the PR? Right now the code gives the result, but not the path that led to it.

Comment on lines +9919 to +9926
// Step 5. Restore the original sign.
Value isNonNegative = tosa::GreaterEqualOp::create(rewriter, op->getLoc(),
boolType, self, zero);
Value negMagnitude =
tosa::SubOp::create(rewriter, op->getLoc(), resultType, zero, magnitude);

rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, resultType, isNonNegative,
magnitude, negMagnitude);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we preserve zero lanes explicitly here?

The current self >= 0 check runs after abs(self), so atan(-0.0) ends up as +0.0. Passing self through unchanged on zero lanes, and only choosing between magnitude and -magnitude for nonzero lanes, would make the behavior match PyTorch.

Comment on lines +9770 to +9783
// Offline-generated piecewise minimax fit for the reduced domain [0, 1].
// The split points were grid-searched, each interval was fit with an odd
// degree-9 polynomial x * Q(x^2), and the final float32-rounded tables were
// selected to minimize the max absolute error. The resulting max absolute
// error over [0, 1] is about 7.42e-8.
static constexpr float kAtanPieceSplit0 = 0.546f;
static constexpr float kAtanPieceSplit1 = 0.792f;
static constexpr float kHalfPi = 1.57079632679f;
static constexpr std::array<float, 5> kAtanLowCoefficients = {
0.99999887f, -0.33325633f, 0.19846700f, -0.13003899f, 0.06094434f};
static constexpr std::array<float, 5> kAtanMidCoefficients = {
0.99966830f, -0.32924002f, 0.17935193f, -0.08747230f, 0.02348170f};
static constexpr std::array<float, 5> kAtanHighCoefficients = {
0.99757016f, -0.31611255f, 0.14805676f, -0.05377132f, 0.00965516f};
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we clarify whether this approximation is meant to be f32 only for now, or whether you have checked the lower-precision cases too?


// Materialize a scalar float constant that matches `like` and can be used in
// broadcastable elementwise TOSA ops with it.
auto createConst = [&](Value like, float value) -> FailureOr<Value> {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we either move this into shared utility code or inline the few uses?

// error over [0, 1] is about 7.42e-8.
static constexpr float kAtanPieceSplit0 = 0.546f;
static constexpr float kAtanPieceSplit1 = 0.792f;
static constexpr float kHalfPi = 1.57079632679f;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we compute it here? e.g.

constexpr float kHalfPi = llvm::numbers::pi / 2

"ElementwiseAtan2TensorIntModule_basic",
"ElementwiseAtan2TensorIntStaticModule_basic",
"ElementwiseAtanTensorFloatModule_basic",
"ElementwiseAtanTensorIntModule_basic",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a few deterministic atan e2e cases with special values? A small hand-picked set covering negative values, zero, +0.0, -0.0, near both split point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants