diff --git a/Sources/AsyncAlgorithms/FlatMapLatest/FlatMapLatestStateMachine.swift b/Sources/AsyncAlgorithms/FlatMapLatest/FlatMapLatestStateMachine.swift index 56bee75c..0c1aefe0 100644 --- a/Sources/AsyncAlgorithms/FlatMapLatest/FlatMapLatestStateMachine.swift +++ b/Sources/AsyncAlgorithms/FlatMapLatest/FlatMapLatestStateMachine.swift @@ -12,9 +12,10 @@ import DequeModule @available(AsyncAlgorithms 1.1, *) -struct FlatMapLatestStateMachine where Base.Element: Sendable, Inner.Element: Sendable { +struct FlatMapLatestStateMachine +where Base.Element: Sendable, Inner.Element: Sendable { typealias Element = Inner.Element - + private enum State { case initial(Base) case running( @@ -29,15 +30,15 @@ struct FlatMapLatestStateMachine Inner - + init(base: Base, transform: @escaping @Sendable (Base.Element) -> Inner) { self.state = .initial(base) self.transform = transform } - + enum NextAction { case returnElement(Element) case returnNil @@ -45,28 +46,47 @@ struct FlatMapLatestStateMachine?, previousContinuation: UnsafeContinuation?) + case startInnerTask( + Inner, + generation: Int, + previousTask: Task?, + previousContinuation: UnsafeContinuation? + ) case cancelInnerTask(Task, UnsafeContinuation?) case resumeDownstream(UnsafeContinuation, Result) case resumeOuterContinuation(UnsafeContinuation) - case cancelTasks(Task?, Task?, UnsafeContinuation?, UnsafeContinuation?) + case cancelTasks( + Task?, + Task?, + UnsafeContinuation?, + UnsafeContinuation? + ) case none } - + enum SuspendAction { case resumeOuterContinuation(UnsafeContinuation) case resumeInnerContinuation(UnsafeContinuation) case none } - + mutating func next() -> NextAction { switch state { case .initial(let base): return .startOuterTask(base) - - case .running(let outerTask, let outerCont, let innerTask, let innerCont, let downstreamCont, var buffer, let generation, let outerFinished): + + case .running( + let outerTask, + let outerCont, + let innerTask, + let innerCont, + let downstreamCont, + var buffer, + let generation, + let outerFinished + ): if let result = buffer.popFirst() { state = .running( outerTask: outerTask, @@ -88,21 +108,30 @@ struct FlatMapLatestStateMachine) -> SuspendAction { switch state { case .initial: fatalError("Should be started") - - case .running(let outerTask, let outerCont, let innerTask, let innerCont, let downstreamCont, let buffer, let generation, let outerFinished): + + case .running( + let outerTask, + let outerCont, + let innerTask, + let innerCont, + let downstreamCont, + let buffer, + let generation, + let outerFinished + ): precondition(downstreamCont == nil, "Already have downstream continuation") precondition(buffer.isEmpty, "Buffer should be empty if suspending") - + state = .running( outerTask: outerTask, outerContinuation: outerCont, @@ -113,23 +142,44 @@ struct FlatMapLatestStateMachine) { switch state { case .initial: @@ -148,15 +198,24 @@ struct FlatMapLatestStateMachine) -> SuspendAction { switch state { - case .running(let outerTask, _, let innerTask, let innerCont, let downstreamCont, let buffer, let generation, let outerFinished): + case .running( + let outerTask, + _, + let innerTask, + let innerCont, + let downstreamCont, + let buffer, + let generation, + let outerFinished + ): // If we have downstream demand, resume immediately if downstreamCont != nil { return .resumeOuterContinuation(continuation) } - + state = .running( outerTask: outerTask, outerContinuation: continuation, @@ -168,52 +227,70 @@ struct FlatMapLatestStateMachine Action { switch state { - case .running(let outerTask, _, let innerTask, let innerCont, let downstreamCont, let buffer, var generation, let outerFinished): + case .running( + let outerTask, + _, + let innerTask, + let innerCont, + let downstreamCont, + let buffer, + var generation, + let outerFinished + ): // New element from outer -> Cancel previous inner, start new inner let newInner = transform(element) generation += 1 - + state = .running( outerTask: outerTask, - outerContinuation: nil, // We just consumed the continuation by producing - innerTask: nil, // Will be set by innerTaskStarted + outerContinuation: nil, // We just consumed the continuation by producing + innerTask: nil, // Will be set by innerTaskStarted innerContinuation: nil, downstreamContinuation: downstreamCont, buffer: buffer, generation: generation, outerFinished: outerFinished ) - + return .startInnerTask(newInner, generation: generation, previousTask: innerTask, previousContinuation: innerCont) - + case .finished: return .none - + default: fatalError("Invalid state") } } - + mutating func innerTaskStarted(_ task: Task, generation: Int) { switch state { - case .running(let outerTask, let outerCont, _, let innerCont, let downstreamCont, let buffer, let currentGen, let outerFinished): + case .running( + let outerTask, + let outerCont, + _, + let innerCont, + let downstreamCont, + let buffer, + let currentGen, + let outerFinished + ): if generation != currentGen { // Stale task from previous generation, ignore it return } - + state = .running( outerTask: outerTask, outerContinuation: outerCont, @@ -224,26 +301,35 @@ struct FlatMapLatestStateMachine, generation: Int) -> SuspendAction { switch state { - case .running(let outerTask, let outerCont, let innerTask, _, let downstreamCont, let buffer, let currentGen, let outerFinished): + case .running( + let outerTask, + let outerCont, + let innerTask, + _, + let downstreamCont, + let buffer, + let currentGen, + let outerFinished + ): if generation != currentGen { // Stale generation continuation.resume(throwing: CancellationError()) return .none } - + if downstreamCont != nil { return .resumeInnerContinuation(continuation) } - + state = .running( outerTask: outerTask, outerContinuation: outerCont, @@ -255,36 +341,33 @@ struct FlatMapLatestStateMachine Action { switch state { - case .running(let outerTask, let outerCont, let innerTask, _, let downstreamCont, var buffer, let currentGen, let outerFinished): + case .running( + let outerTask, + let outerCont, + let innerTask, + _, + let downstreamCont, + var buffer, + let currentGen, + let outerFinished + ): if generation != currentGen { return .none } - - if let downstreamCont = downstreamCont { - state = .running( - outerTask: outerTask, - outerContinuation: outerCont, - innerTask: innerTask, - innerContinuation: nil, // Consumed - downstreamContinuation: nil, - buffer: buffer, - generation: currentGen, - outerFinished: outerFinished - ) - return .resumeDownstream(downstreamCont, .success(element)) - } else { + + guard let downstreamCont = downstreamCont else { buffer.append(.success(element)) state = .running( outerTask: outerTask, @@ -298,22 +381,33 @@ struct FlatMapLatestStateMachine Action { switch state { case .running(let outerTask, let outerCont, _, _, let downstreamCont, let buffer, let currentGen, let outerFinished): if generation != currentGen { return .none } - + // Inner finished. if outerFinished { // Both finished @@ -335,47 +429,46 @@ struct FlatMapLatestStateMachine Action { switch state { case .running(let outerTask, let outerCont, let innerTask, let innerCont, let downstreamCont, _, let currentGen, _): if generation != currentGen { return .none } - + state = .finished let action: Action = .cancelTasks(outerTask, innerTask, outerCont, innerCont) - - if let downstreamCont = downstreamCont { - return .resumeDownstream(downstreamCont, .failure(error)) - } else { + + guard let downstreamCont = downstreamCont else { return action } - + return .resumeDownstream(downstreamCont, .failure(error)) + default: return .none } } - + mutating func outerFinished() -> Action { switch state { case .running(let outerTask, _, let innerTask, let innerCont, let downstreamCont, let buffer, let generation, _): @@ -397,40 +490,39 @@ struct FlatMapLatestStateMachine Action { switch state { case .running(let outerTask, let outerCont, let innerTask, let innerCont, let downstreamCont, _, _, _): state = .finished let action: Action = .cancelTasks(outerTask, innerTask, outerCont, innerCont) - - if let downstreamCont = downstreamCont { - return .resumeDownstream(downstreamCont, .failure(error)) - } else { + + guard let downstreamCont = downstreamCont else { return action } - + return .resumeDownstream(downstreamCont, .failure(error)) + default: return .none } } - + mutating func cancelled() -> Action { switch state { case .running(let outerTask, let outerCont, let innerTask, let innerCont, let downstreamCont, _, _, _): state = .finished let action: Action = .cancelTasks(outerTask, innerTask, outerCont, innerCont) - + if let downstreamCont = downstreamCont { return .resumeDownstream(downstreamCont, .success(nil)) } return action - + default: state = .finished return .none diff --git a/Tests/AsyncAlgorithmsTests/TestFlatMapLatest.swift b/Tests/AsyncAlgorithmsTests/TestFlatMapLatest.swift index 8a7b9731..2cfab1cc 100644 --- a/Tests/AsyncAlgorithmsTests/TestFlatMapLatest.swift +++ b/Tests/AsyncAlgorithmsTests/TestFlatMapLatest.swift @@ -14,7 +14,7 @@ import AsyncAlgorithms @available(macOS 15.0, *) final class TestFlatMapLatest: XCTestCase { - + func test_simple_sequence() async throws { let source = [1, 2, 3].async let transformed = source.flatMapLatest { intValue in @@ -25,7 +25,7 @@ final class TestFlatMapLatest: XCTestCase { for try await element in transformed { results.append(element) } - + // With synchronous emission, we expect only the last inner sequence [3, 30] // However, depending on timing, we might see more intermediate values XCTAssertTrue(results.contains(3), "Should contain 3") @@ -38,17 +38,17 @@ final class TestFlatMapLatest: XCTestCase { // This test simulates a scenario where the inner sequence is slow. // In a naive implementation (without generation tracking), the inner task for '1' // might wake up and yield AFTER '2' has already started, causing interleaving. - + let source = [1, 2, 3].async let transformed = source.flatMapLatest { intValue -> AsyncStream in AsyncStream { continuation in Task { // Yield the value immediately continuation.yield(intValue) - + // Sleep for a bit to allow the outer sequence to move on - try? await Task.sleep(nanoseconds: 10_000_000) // 10ms - + try? await Task.sleep(nanoseconds: 10_000_000) // 10ms + // Yield a second value - this should be ignored if a new outer value has arrived continuation.yield(intValue * 10) continuation.finish() @@ -66,25 +66,23 @@ final class TestFlatMapLatest: XCTestCase { // However, without strict synchronization, we might see them. // The strict expectation for flatMapLatest is that once a new value arrives, // the old one produces NO MORE values. - - // Note: This test is probabilistic in the naive implementation. - // It might pass or fail depending on scheduling. + + // Note: This test is probabilistic in the naive implementation. + // It might pass or fail depending on scheduling. // But with a correct implementation, it should ALWAYS pass. - - var expected = [3, 30] // We only want the latest - + // We'll collect all results to see what happened var results: [Int] = [] - + for try await element in transformed { results.append(element) } - + // In the naive implementation, we might get [1, 2, 3, 10, 20, 30] or similar. // We want strictly [3, 30] (or [1, 2, 3, 30] depending on how fast the outer sequence is consumed vs produced) // Actually, if the outer sequence is consumed fast, we might see intermediate "first" values (1, 2). // But we should NEVER see "second" values (10, 20) from cancelled sequences. - + // Let's relax the check to: "Must not contain 10 or 20" XCTAssertFalse(results.contains(10), "Should not contain 10 (from cancelled sequence 1)") XCTAssertFalse(results.contains(20), "Should not contain 20 (from cancelled sequence 2)") @@ -99,59 +97,59 @@ final class TestFlatMapLatest: XCTestCase { return } continuation.yield(value) - try? await Task.sleep(nanoseconds: 5_000_000) // 5ms delay + try? await Task.sleep(nanoseconds: 5_000_000) // 5ms delay } continuation.finish() } } - + let transformed = source.flatMapLatest { intValue in return [intValue, intValue * 10].async } - + do { - for try await _ in transformed { } + for try await _ in transformed {} XCTFail("Should have thrown") } catch { XCTAssertEqual(error as? FlatMapLatestFailure, FlatMapLatestFailure()) } } - + func test_inner_throwing() async throws { let source = AsyncStream { continuation in Task { for value in [1, 2, 3] { continuation.yield(value) - try? await Task.sleep(nanoseconds: 5_000_000) // 5ms delay between outer values + try? await Task.sleep(nanoseconds: 5_000_000) // 5ms delay between outer values } continuation.finish() } } - + let transformed = source.flatMapLatest { intValue in return [intValue].async.map { try $0.throwIf(2) } } - + do { - for try await _ in transformed { } + for try await _ in transformed {} XCTFail("Should have thrown") } catch { XCTAssertEqual(error as? FlatMapLatestFailure, FlatMapLatestFailure()) } } - + func test_cancellation() async throws { let source = [1, 2, 3].async let transformed = source.flatMapLatest { intValue in return [intValue].async } - + let task = Task { - for try await _ in transformed { } + for try await _ in transformed {} } - + task.cancel() - + do { try await task.value } catch is CancellationError { @@ -160,38 +158,80 @@ final class TestFlatMapLatest: XCTestCase { XCTFail("Unexpected error: \(error)") } } - + func test_empty_outer() async throws { let source = [].async.map { $0 as Int } let transformed = source.flatMapLatest { intValue in return [intValue].async } - + var count = 0 for try await _ in transformed { count += 1 } XCTAssertEqual(count, 0) } - + func test_empty_inner() async throws { let source = [1, 2, 3].async let transformed = source.flatMapLatest { _ in return [].async.map { $0 as Int } } - + var count = 0 for try await _ in transformed { count += 1 } XCTAssertEqual(count, 0) } + func test_concurrency_crash_repro() async throws { + // This test simulates the exact race condition where an inner stream producer + // mimics the behavior of Firestore/Combine publisher wrappers that might + // yield values or finish asynchronously relative to cancellation. + + @Sendable func innerStream(for user: String) -> AsyncThrowingStream { + return AsyncThrowingStream { continuation in + let task = Task { + for i in 0...100 { + // Artificial delay to allow overlap/interleaving + // This is critical to hit the race condition window + try await Task.sleep(nanoseconds: 1_000_000) // 1ms + + continuation.yield(i) + } + continuation.finish() + } + + continuation.onTermination = { @Sendable termination in + task.cancel() + } + } + } + + let authStream = AsyncStream { continuation in + continuation.yield("A") + Task { + // Yielding again triggers the switch (and cancellation of the previous inner stream) + try? await Task.sleep(nanoseconds: 10_000_000) // 10ms + continuation.yield("B") + try? await Task.sleep(nanoseconds: 10_000_000) + continuation.finish() + } + } + + let combined = authStream.flatMapLatest { user in + innerStream(for: user) + } + + // Determine success by running without crashing + for try await _ in combined {} + } } private struct FlatMapLatestFailure: Error, Equatable {} -private extension Int { - func throwIf(_ value: Int) throws -> Int { +extension Int { + fileprivate func throwIf(_ value: Int) throws -> Int { if self == value { throw FlatMapLatestFailure() }