diff --git a/Sources/SwiftRetrier/Core/ExponentialBackoffPolicy/ExponentialBackoffRetryPolicy.swift b/Sources/SwiftRetrier/Core/ExponentialBackoffPolicy/ExponentialBackoffRetryPolicy.swift index fe79baf..2196ddf 100644 --- a/Sources/SwiftRetrier/Core/ExponentialBackoffPolicy/ExponentialBackoffRetryPolicy.swift +++ b/Sources/SwiftRetrier/Core/ExponentialBackoffPolicy/ExponentialBackoffRetryPolicy.swift @@ -1,6 +1,6 @@ import Foundation -public struct ExponentialBackoffRetryPolicy: RetryPolicy { +public struct ExponentialBackoffRetryPolicy { public enum Jitter: Sendable { case none @@ -23,23 +23,38 @@ public struct ExponentialBackoffRetryPolicy: RetryPolicy { self.previousDelay = previousDelay } - public func exponentiationBySquaring(_ base: T, _ multiplier: T, _ exponent: T) -> T { - precondition(exponent >= 0) - if exponent == 0 { - return base + private func safeMultiply(_ lhs: UInt, _ rhs: UInt) -> UInt { + if UInt.max / rhs < lhs || UInt.max / lhs < rhs { + UInt.max + } else { + lhs * rhs + } + } + + public func exponentiationBySquaring(base: UInt, multiplier: UInt, exponent: UInt) -> UInt { + if exponent == .zero { + base } else if exponent == 1 { - return base * multiplier + safeMultiply(base, multiplier) } else if exponent.isMultiple(of: 2) { - return exponentiationBySquaring(base, multiplier * multiplier, exponent / 2) + exponentiationBySquaring( + base: base, + multiplier: safeMultiply(multiplier, multiplier), + exponent: exponent / 2 + ) } else { // exponent is odd - return exponentiationBySquaring(base * multiplier, multiplier * multiplier, (exponent - 1) / 2) + exponentiationBySquaring( + base: safeMultiply(base, multiplier), + multiplier: safeMultiply(multiplier, multiplier), + exponent: (exponent - 1) / 2 + ) } } // swiftlint:disable:next line_length // See https://stackoverflow.com/questions/24196689/how-to-get-the-power-of-some-integer-in-swift-language/39021464#39021464 - public func pow(_ base: T, _ power: T) -> T { - return exponentiationBySquaring(1, base, power) + public func pow(_ base: UInt, _ power: UInt) -> UInt { + exponentiationBySquaring(base: 1, multiplier: base, exponent: power) } public func noJitterDelay(attemptIndex: UInt) -> TimeInterval { @@ -72,6 +87,9 @@ public struct ExponentialBackoffRetryPolicy: RetryPolicy { return decorrelatedJitterDelay(attemptIndex: attemptIndex, growthFactor: growthFactor) } } +} + +extension ExponentialBackoffRetryPolicy: RetryPolicy { public func retryDelay(for attemptFailure: AttemptFailure) -> TimeInterval { min(maxDelay, uncappedDelay(attemptIndex: attemptFailure.index)) diff --git a/Tests/SwiftRetrierTests/ExponentialBackoffTest.swift b/Tests/SwiftRetrierTests/ExponentialBackoffTest.swift new file mode 100644 index 0000000..3eb5017 --- /dev/null +++ b/Tests/SwiftRetrierTests/ExponentialBackoffTest.swift @@ -0,0 +1,11 @@ +import XCTest +@testable import SwiftRetrier +@preconcurrency import Combine + +final class ExponentialBackoffTest: XCTestCase { + + func test_When_exponentationGoesUp_Then_noOverflow() { + let policy = ExponentialBackoffRetryPolicy() + _ = policy.retryDelay(for: AttemptFailure(trialStart: Date(), index: .max, error: TestError())) + } +}