diff --git a/include/exec/when_any.hpp b/include/exec/when_any.hpp index 5151d1158..94d8e6918 100644 --- a/include/exec/when_any.hpp +++ b/include/exec/when_any.hpp @@ -83,8 +83,25 @@ namespace experimental::execution STDEXEC_IMMOVABLE(__opstate_base); + struct __forward_stop_request + { + constexpr void operator()() const noexcept + { + // Temporarily increment the count to avoid concurrent/recursive arrivals to + // pull the rug under our feet. Relaxed memory order is fine here. + __op_->__count_.fetch_add(1, __std::memory_order_relaxed); + + __op_->__stop_source_.request_stop(); + + // Arrive in order to decrement the count again and complete if needed. + __op_->__arrive(); + } + + __opstate_base* __op_; + }; + using __on_stop = - stop_callback_for_t&>, __forward_stop_request<>>; + stop_callback_for_t&>, __forward_stop_request>; inplace_stop_source __stop_source_{}; std::optional<__on_stop> __on_stop_{}; @@ -127,6 +144,11 @@ namespace experimental::execution // stop pending operations __stop_source_.request_stop(); } + __arrive(); + } + + constexpr void __arrive() noexcept + { // make __result_ emplacement visible when __count_ goes from one to zero // This relies on the fact that each sender will call notify() at most once if (__count_.fetch_sub(1, __std::memory_order_acq_rel) == 1) @@ -203,7 +225,7 @@ namespace experimental::execution void start() & noexcept { this->__on_stop_.emplace(get_stop_token(get_env(this->__rcvr_)), - __forward_stop_request{this->__stop_source_}); + typename __op_base_t::__forward_stop_request{this}); if (this->__stop_source_.stop_requested()) { STDEXEC::set_stopped(static_cast<_Receiver&&>(this->__rcvr_));