[TorchToTosa] Add lowering for aten.atan#4517
Conversation
Lallapallooza
left a comment
There was a problem hiding this comment.
Hi, thanks for the patch, few comments
| // Offline-generated piecewise minimax fit for the reduced domain [0, 1]. | ||
| // The split points were grid-searched, each interval was fit with an odd | ||
| // degree-9 polynomial x * Q(x^2), and the final float32-rounded tables were | ||
| // selected to minimize the max absolute error. The resulting max absolute | ||
| // error over [0, 1] is about 7.42e-8. | ||
| static constexpr float kAtanPieceSplit0 = 0.546f; | ||
| static constexpr float kAtanPieceSplit1 = 0.792f; | ||
| static constexpr float kHalfPi = 1.57079632679f; | ||
| static constexpr std::array<float, 5> kAtanLowCoefficients = { | ||
| 0.99999887f, -0.33325633f, 0.19846700f, -0.13003899f, 0.06094434f}; | ||
| static constexpr std::array<float, 5> kAtanMidCoefficients = { | ||
| 0.99966830f, -0.32924002f, 0.17935193f, -0.08747230f, 0.02348170f}; | ||
| static constexpr std::array<float, 5> kAtanHighCoefficients = { | ||
| 0.99757016f, -0.31611255f, 0.14805676f, -0.05377132f, 0.00965516f}; |
There was a problem hiding this comment.
Could we include the source for these numbers in the tree or attach it to the PR? Right now the code gives the result, but not the path that led to it.
| // Step 5. Restore the original sign. | ||
| Value isNonNegative = tosa::GreaterEqualOp::create(rewriter, op->getLoc(), | ||
| boolType, self, zero); | ||
| Value negMagnitude = | ||
| tosa::SubOp::create(rewriter, op->getLoc(), resultType, zero, magnitude); | ||
|
|
||
| rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, resultType, isNonNegative, | ||
| magnitude, negMagnitude); |
There was a problem hiding this comment.
Could we preserve zero lanes explicitly here?
The current self >= 0 check runs after abs(self), so atan(-0.0) ends up as +0.0. Passing self through unchanged on zero lanes, and only choosing between magnitude and -magnitude for nonzero lanes, would make the behavior match PyTorch.
| // Offline-generated piecewise minimax fit for the reduced domain [0, 1]. | ||
| // The split points were grid-searched, each interval was fit with an odd | ||
| // degree-9 polynomial x * Q(x^2), and the final float32-rounded tables were | ||
| // selected to minimize the max absolute error. The resulting max absolute | ||
| // error over [0, 1] is about 7.42e-8. | ||
| static constexpr float kAtanPieceSplit0 = 0.546f; | ||
| static constexpr float kAtanPieceSplit1 = 0.792f; | ||
| static constexpr float kHalfPi = 1.57079632679f; | ||
| static constexpr std::array<float, 5> kAtanLowCoefficients = { | ||
| 0.99999887f, -0.33325633f, 0.19846700f, -0.13003899f, 0.06094434f}; | ||
| static constexpr std::array<float, 5> kAtanMidCoefficients = { | ||
| 0.99966830f, -0.32924002f, 0.17935193f, -0.08747230f, 0.02348170f}; | ||
| static constexpr std::array<float, 5> kAtanHighCoefficients = { | ||
| 0.99757016f, -0.31611255f, 0.14805676f, -0.05377132f, 0.00965516f}; |
There was a problem hiding this comment.
Could we clarify whether this approximation is meant to be f32 only for now, or whether you have checked the lower-precision cases too?
|
|
||
| // Materialize a scalar float constant that matches `like` and can be used in | ||
| // broadcastable elementwise TOSA ops with it. | ||
| auto createConst = [&](Value like, float value) -> FailureOr<Value> { |
There was a problem hiding this comment.
Could we either move this into shared utility code or inline the few uses?
| // error over [0, 1] is about 7.42e-8. | ||
| static constexpr float kAtanPieceSplit0 = 0.546f; | ||
| static constexpr float kAtanPieceSplit1 = 0.792f; | ||
| static constexpr float kHalfPi = 1.57079632679f; |
There was a problem hiding this comment.
Could we compute it here? e.g.
constexpr float kHalfPi = llvm::numbers::pi / 2
| "ElementwiseAtan2TensorIntModule_basic", | ||
| "ElementwiseAtan2TensorIntStaticModule_basic", | ||
| "ElementwiseAtanTensorFloatModule_basic", | ||
| "ElementwiseAtanTensorIntModule_basic", |
There was a problem hiding this comment.
Could we add a few deterministic atan e2e cases with special values? A small hand-picked set covering negative values, zero, +0.0, -0.0, near both split point.
This PR adds TorchToTosa lowering support for
aten.atan.Changes:
aten.ataninTorchToTosa.cpptest/Conversion/TorchToTosa/basic.mlirThis enables conversion of models using
aten.atanthrough the TorchToTosa path.Testing:
test/Conversion/TorchToTosa/basic.mlir