diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index b741385..6bc7649 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -63,6 +63,13 @@ if (CORO_HTTP_SERVER_BUILD_TESTS) ) add_test(NAME http_parser_test COMMAND http_parser_test) + add_executable(stack_future_test tests/stack_future_test.cpp) + target_link_libraries(stack_future_test PRIVATE coro_http_server::coro_http_server) + set_target_properties(stack_future_test PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests + ) + add_test(NAME stack_future_test COMMAND stack_future_test) + add_executable(http_integration_test tests/http_integration_test.cpp) target_link_libraries(http_integration_test PRIVATE coro_http_server::coro_http_server Threads::Threads) set_target_properties(http_integration_test PROPERTIES diff --git a/server/include/http_parser.h b/server/include/http_parser.h index e8d2be0..3ccb11d 100644 --- a/server/include/http_parser.h +++ b/server/include/http_parser.h @@ -1,5 +1,5 @@ #pragma once -#include "co_future.h" +#include "non_owning_co_future.h" #include "read_iterator.h" #include "request_data.h" #include @@ -81,7 +81,7 @@ class HttpRequestParser { public: HttpRequestParser(IOUring &ring, int fd); - CoFuture ReadRequest(RequestData &request); + NonOwningCoFuture ReadRequest(RequestData &request); void MarkContinueSent(); }; diff --git a/server/include/io_uring.h b/server/include/io_uring.h index 3d33d6a..d8db2e5 100644 --- a/server/include/io_uring.h +++ b/server/include/io_uring.h @@ -1,5 +1,5 @@ #pragma once -#include "co_future.h" +#include "owning_co_future.h" #include #include #include @@ -48,12 +48,12 @@ class IOUring { IOUring(); IOUring &operator=(IOUring &&rhs); void Read(int fileDescriptor, std::array &buffer, std::function complete); - CoFuture ReadAsync(int fileDescriptor, std::array &buffer); + OwningCoFuture ReadAsync(int fileDescriptor, std::array &buffer); void Write(int fileDescriptor, std::string_view data, size_t offset, size_t len, std::function complete); - CoFuture WriteAsync(int fileDescriptor, std::string_view data, size_t offset, + OwningCoFuture WriteAsync(int fileDescriptor, std::string_view data, size_t offset, size_t len); void Accept(int fileDescriptor, std::function complete); - CoFuture AcceptAsync(int fileDescriptor); + OwningCoFuture AcceptAsync(int fileDescriptor); }; } diff --git a/server/include/non_owning_co_future.h b/server/include/non_owning_co_future.h new file mode 100644 index 0000000..b34bd49 --- /dev/null +++ b/server/include/non_owning_co_future.h @@ -0,0 +1,316 @@ +#pragma once + +#include "owning_co_future.h" +#include +#include +#include +#include +#include +#include +#include + +namespace HTTP { + +template +class NonOwningCoFuture; + +template +class NonOwningCoFutureCoroutinePromise { + using ControlBlock = detail::StackFutureControl; + + ControlBlock control_; + std::unique_lock producerLock_{control_.mutex}; + bool completed_{false}; + + template + friend class NonOwningCoFuture; + + template + void StoreCompletion(Fn &&store) { + if (completed_) { + throw std::runtime_error("Future already satisfied"); + } + std::forward(store)(); + completed_ = true; + } + + void PublishCompletion() noexcept { + if (!completed_) { + StoreCompletion([this] { + control_.exception = + std::make_exception_ptr(std::runtime_error("Coroutine did not complete")); + }); + } + + control_.ready.store(true, std::memory_order_release); + auto continuation = + control_.continuation.exchange(nullptr, std::memory_order_acq_rel); + producerLock_.unlock(); + + if (continuation) { + (*continuation)(); + } + } + +public: + NonOwningCoFuture get_return_object() { + return NonOwningCoFuture( + std::coroutine_handle::from_promise( + *this)); + } + + std::suspend_never initial_suspend() noexcept { return {}; } + + struct FinalAwaiter { + bool await_ready() noexcept { return false; } + void await_resume() noexcept {} + void await_suspend( + std::coroutine_handle handle) noexcept { + handle.promise().PublishCompletion(); + } + }; + + FinalAwaiter final_suspend() noexcept { return {}; } + + template + requires(!std::is_void_v) + void return_value(U &&value) { + StoreCompletion([this, value = std::forward(value)]() mutable { + control_.value = std::move(value); + }); + } + + void unhandled_exception() { + StoreCompletion([this] { control_.exception = std::current_exception(); }); + } +}; + +template <> +class NonOwningCoFutureCoroutinePromise { + using ControlBlock = detail::StackFutureControl; + + ControlBlock control_; + std::unique_lock producerLock_{control_.mutex}; + bool completed_{false}; + + template + friend class NonOwningCoFuture; + + template + void StoreCompletion(Fn &&store) { + if (completed_) { + throw std::runtime_error("Future already satisfied"); + } + std::forward(store)(); + completed_ = true; + } + + void PublishCompletion() noexcept { + if (!completed_) { + StoreCompletion([this] { + control_.exception = + std::make_exception_ptr(std::runtime_error("Coroutine did not complete")); + }); + } + + control_.ready.store(true, std::memory_order_release); + auto continuation = + control_.continuation.exchange(nullptr, std::memory_order_acq_rel); + producerLock_.unlock(); + + if (continuation) { + (*continuation)(); + } + } + +public: + NonOwningCoFuture get_return_object(); + + std::suspend_never initial_suspend() noexcept { return {}; } + + struct FinalAwaiter { + bool await_ready() noexcept { return false; } + void await_resume() noexcept {} + void await_suspend( + std::coroutine_handle handle) noexcept { + handle.promise().PublishCompletion(); + } + }; + + FinalAwaiter final_suspend() noexcept { return {}; } + + void return_void() { StoreCompletion([this] { control_.value = true; }); } + + void unhandled_exception() { + StoreCompletion([this] { control_.exception = std::current_exception(); }); + } +}; + +template +class NonOwningCoFuture { + using Promise = NonOwningCoFutureCoroutinePromise; + using Handle = std::coroutine_handle; + using ControlBlock = typename Promise::ControlBlock; + + Handle handle_{}; + + explicit NonOwningCoFuture(Handle handle) : handle_(handle) {} + + template + friend class NonOwningCoFutureCoroutinePromise; + + Promise *promise() const { + if (!handle_) { + throw std::logic_error("Invalid future"); + } + return &handle_.promise(); + } + +public: + NonOwningCoFuture() = default; + NonOwningCoFuture(const NonOwningCoFuture &) = delete; + NonOwningCoFuture &operator=(const NonOwningCoFuture &) = delete; + + NonOwningCoFuture(NonOwningCoFuture &&rhs) noexcept + : handle_(std::exchange(rhs.handle_, {})) {} + + NonOwningCoFuture &operator=(NonOwningCoFuture &&rhs) noexcept { + if (this != &rhs) { + if (handle_) { + handle_.destroy(); + } + handle_ = std::exchange(rhs.handle_, {}); + } + return *this; + } + + ~NonOwningCoFuture() { + if (handle_) { + handle_.destroy(); + } + } + + bool valid() const noexcept { return static_cast(handle_); } + + bool isReady() const noexcept { + if (!handle_) { + return false; + } + return handle_.promise().control_.ready.load(std::memory_order_acquire); + } + + T get() { + Promise *p = promise(); + ControlBlock &control = p->control_; + if (!control.ready.load(std::memory_order_acquire)) { + control.mutex.lock(); + control.mutex.unlock(); + } + + if (control.exception) { + std::rethrow_exception(control.exception); + } + if constexpr (!std::is_void_v) { + if (!control.value) { + throw std::runtime_error("Future has no value"); + } + return *control.value; + } + } + + T Get() { return get(); } + + bool await_ready() const noexcept { return isReady(); } + + bool await_suspend(std::coroutine_handle<> awaiting) { + ControlBlock &control = promise()->control_; + auto callback = std::make_shared( + [awaiting]() mutable { + if (awaiting && !awaiting.done()) { + awaiting.resume(); + } + }); + + if (control.ready.load(std::memory_order_acquire)) { + return false; + } + + std::shared_ptr expected = nullptr; + if (!control.continuation.compare_exchange_strong( + expected, callback, std::memory_order_acq_rel, + std::memory_order_acquire)) { + throw std::runtime_error("Tried to set the subscriber second time"); + } + + if (control.ready.load(std::memory_order_acquire)) { + expected = callback; + if (control.continuation.compare_exchange_strong( + expected, nullptr, std::memory_order_acq_rel, + std::memory_order_acquire)) { + return false; + } + } + return true; + } + + T await_resume() { return get(); } + + template >> + OwningCoFuture Then(std::function &&then) { + return OwningCoFuture( + [this, thenFn = std::move(then)](OwningCoFuture &next) mutable { + ControlBlock &control = promise()->control_; + auto run = [this, &next, thenFn = std::move(thenFn)]() mutable { + try { + if constexpr (std::is_void_v) { + thenFn(get()); + next.Set(); + } else { + next.Set(thenFn(get())); + } + } catch (...) { + next.SetException(std::current_exception()); + } + }; + detail::RegisterContinuation(control, std::move(run)); + }); + } + + template >, typename = void> + OwningCoFuture Then(std::function &&then) { + return OwningCoFuture( + [this, thenFn = std::move(then)](OwningCoFuture &next) mutable { + ControlBlock &control = promise()->control_; + auto run = [this, &next, thenFn = std::move(thenFn)]() mutable { + try { + get(); + if constexpr (std::is_void_v) { + thenFn(); + next.Set(); + } else { + next.Set(thenFn()); + } + } catch (...) { + next.SetException(std::current_exception()); + } + }; + detail::RegisterContinuation(control, std::move(run)); + }); + } +}; + +inline NonOwningCoFuture +NonOwningCoFutureCoroutinePromise::get_return_object() { + return NonOwningCoFuture( + std::coroutine_handle::from_promise( + *this)); +} + +} // namespace HTTP + +template +struct std::coroutine_traits, Args...> { + using promise_type = HTTP::NonOwningCoFutureCoroutinePromise; +}; diff --git a/server/include/owning_co_future.h b/server/include/owning_co_future.h new file mode 100644 index 0000000..cb07a35 --- /dev/null +++ b/server/include/owning_co_future.h @@ -0,0 +1,230 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace HTTP { +namespace detail { + +template +struct StackFutureControl { + std::mutex mutex; + std::atomic_bool ready{false}; + std::conditional_t, bool, std::optional> value; + std::exception_ptr exception; + using Callback = std::function; + std::atomic> continuation{nullptr}; +}; + +template +void RegisterContinuation(ControlBlock &control, Fn &&fn) { + auto callback = std::make_shared( + std::forward(fn)); + + if (control.ready.load(std::memory_order_acquire)) { + (*callback)(); + return; + } + + std::shared_ptr expected = nullptr; + if (!control.continuation.compare_exchange_strong( + expected, callback, std::memory_order_acq_rel, + std::memory_order_acquire)) { + throw std::runtime_error("Tried to set the subscriber second time"); + } + + if (control.ready.load(std::memory_order_acquire)) { + expected = callback; + if (control.continuation.compare_exchange_strong( + expected, nullptr, std::memory_order_acq_rel, + std::memory_order_acquire)) { + (*callback)(); + } + } +} + +} // namespace detail + +template +class OwningCoFuture { + using ControlBlock = detail::StackFutureControl; + + ControlBlock control_; + std::unique_lock producerLock_{control_.mutex}; + + template + void Complete(Fn &&store) { + if (!producerLock_.owns_lock()) { + throw std::runtime_error("Future already satisfied"); + } + + std::forward(store)(); + control_.ready.store(true, std::memory_order_release); + auto continuation = + control_.continuation.exchange(nullptr, std::memory_order_acq_rel); + producerLock_.unlock(); + + if (continuation) { + (*continuation)(); + } + } + +public: + OwningCoFuture() = default; + + template , OwningCoFuture>>> + explicit OwningCoFuture(Starter &&starter) { + std::forward(starter)(*this); + } + + OwningCoFuture(const OwningCoFuture &) = delete; + OwningCoFuture &operator=(const OwningCoFuture &) = delete; + OwningCoFuture(OwningCoFuture &&) = delete; + OwningCoFuture &operator=(OwningCoFuture &&) = delete; + + bool valid() const noexcept { return true; } + + bool isReady() const noexcept { + return control_.ready.load(std::memory_order_acquire); + } + + T get() { + if (!control_.ready.load(std::memory_order_acquire)) { + control_.mutex.lock(); + control_.mutex.unlock(); + } + + if (control_.exception) { + std::rethrow_exception(control_.exception); + } + if constexpr (!std::is_void_v) { + if (!control_.value) { + throw std::runtime_error("Future has no value"); + } + return *control_.value; + } + } + + T Get() { return get(); } + + template + std::enable_if_t, void> Set(U value) { + Complete([this, value = std::move(value)]() mutable { + control_.value = std::move(value); + }); + } + + template + std::enable_if_t, void> set(U value) { + Set(std::move(value)); + } + + template + std::enable_if_t, void> Set() { + Complete([this] { control_.value = true; }); + } + + template + std::enable_if_t, void> set() { + Set(); + } + + void SetException(std::exception_ptr exception) { + Complete([this, exception = std::move(exception)]() mutable { + control_.exception = std::move(exception); + }); + } + + void setException(std::exception_ptr exception) { + SetException(std::move(exception)); + } + + bool await_ready() const noexcept { return isReady(); } + + bool await_suspend(std::coroutine_handle<> awaiting) { + auto callback = std::make_shared( + [awaiting]() mutable { + if (awaiting && !awaiting.done()) { + awaiting.resume(); + } + }); + + if (control_.ready.load(std::memory_order_acquire)) { + return false; + } + + std::shared_ptr expected = nullptr; + if (!control_.continuation.compare_exchange_strong( + expected, callback, std::memory_order_acq_rel, + std::memory_order_acquire)) { + throw std::runtime_error("Tried to set the subscriber second time"); + } + + if (control_.ready.load(std::memory_order_acquire)) { + expected = callback; + if (control_.continuation.compare_exchange_strong( + expected, nullptr, std::memory_order_acq_rel, + std::memory_order_acquire)) { + return false; + } + } + return true; + } + + T await_resume() { return get(); } + + template >> + OwningCoFuture Then(std::function &&then) { + return OwningCoFuture( + [this, thenFn = std::move(then)](OwningCoFuture &next) mutable { + auto run = [this, &next, thenFn = std::move(thenFn)]() mutable { + try { + if constexpr (std::is_void_v) { + thenFn(get()); + next.Set(); + } else { + next.Set(thenFn(get())); + } + } catch (...) { + next.SetException(std::current_exception()); + } + }; + detail::RegisterContinuation(control_, std::move(run)); + }); + } + + template >, typename = void> + OwningCoFuture Then(std::function &&then) { + return OwningCoFuture( + [this, thenFn = std::move(then)](OwningCoFuture &next) mutable { + auto run = [this, &next, thenFn = std::move(thenFn)]() mutable { + try { + get(); + if constexpr (std::is_void_v) { + thenFn(); + next.Set(); + } else { + next.Set(thenFn()); + } + } catch (...) { + next.SetException(std::current_exception()); + } + }; + detail::RegisterContinuation(control_, std::move(run)); + }); + } +}; + +} // namespace HTTP diff --git a/server/include/read_iterator.h b/server/include/read_iterator.h index 7194b82..e72c6e5 100644 --- a/server/include/read_iterator.h +++ b/server/include/read_iterator.h @@ -1,6 +1,6 @@ #pragma once -#include "co_future.h" #include "io_uring.h" +#include "non_owning_co_future.h" namespace HTTP { class ReadIterator { @@ -13,7 +13,7 @@ class ReadIterator { public: ReadIterator(IOUring &ring, int fd_); - CoFuture operator++(); + NonOwningCoFuture operator++(); char operator*() const; explicit operator bool() const; bool Eof() const; diff --git a/server/include/server.h b/server/include/server.h index cd081f1..63e9d6d 100644 --- a/server/include/server.h +++ b/server/include/server.h @@ -2,6 +2,7 @@ #include "co_future.h" #include "http_parser.h" #include "io_uring.h" +#include "non_owning_co_future.h" #include "request_data.h" #include "trie.h" #include @@ -21,13 +22,16 @@ class Server { std::shared_ptr> serverLoop_; void WorkerLoop(IOUring &ring); - CoFuture AcceptAndProcess(IOUring &ring); - CoFuture WriteRaw(IOUring &ring, int connectionFD, std::string_view data); - CoFuture WriteResponse(IOUring &ring, int connectionFD, - const ResponseData &data, - const RequestData &request, bool keepAlive, - std::string &buffer); - CoFuture Process(IOUring &ring, int connectionFD); + NonOwningCoFuture + AcceptAndProcess(IOUring &ring, + std::vector> &connections); + NonOwningCoFuture WriteRaw(IOUring &ring, int connectionFD, + std::string_view data); + NonOwningCoFuture WriteResponse(IOUring &ring, int connectionFD, + const ResponseData &data, + const RequestData &request, + bool keepAlive, std::string &buffer); + NonOwningCoFuture Process(IOUring &ring, int connectionFD); friend class ServerBuilder; public: diff --git a/server/src/http_parser.cpp b/server/src/http_parser.cpp index 99d8287..0452534 100644 --- a/server/src/http_parser.cpp +++ b/server/src/http_parser.cpp @@ -1044,7 +1044,7 @@ HttpRequestParser::HttpRequestParser(IOUring &ring, int fd) void HttpRequestParser::MarkContinueSent() { state_.MarkContinueSent(); } -CoFuture HttpRequestParser::ReadRequest( +NonOwningCoFuture HttpRequestParser::ReadRequest( RequestData &request) { while (true) { HttpParseResult parsed = state_.ParseNext(request); diff --git a/server/src/io_uring.cpp b/server/src/io_uring.cpp index 6fc8821..d8418cf 100644 --- a/server/src/io_uring.cpp +++ b/server/src/io_uring.cpp @@ -104,14 +104,14 @@ void IOUring::Read(int fileDescriptor, queue_.push_back(entry); } -CoFuture IOUring::ReadAsync( +OwningCoFuture IOUring::ReadAsync( int fileDescriptor, std::array &buffer) { - auto promise = std::make_shared>(); - auto future = promise->GetFuture(); - Read(fileDescriptor, buffer, [promise](int result) { - promise->Set(result); - }); - return future; + return OwningCoFuture( + [this, fileDescriptor, &buffer](OwningCoFuture &future) { + Read(fileDescriptor, buffer, [&future](int result) { + future.Set(result); + }); + }); } void IOUring::Accept(int fileDescriptor, std::function complete) { @@ -126,23 +126,23 @@ void IOUring::Accept(int fileDescriptor, std::function complete) { AddEntries(); } -CoFuture IOUring::AcceptAsync(int fileDescriptor) { - auto promise = std::make_shared>(); - auto future = promise->GetFuture(); - Accept(fileDescriptor, [promise](int result) { - promise->Set(result); - }); - return future; +OwningCoFuture IOUring::AcceptAsync(int fileDescriptor) { + return OwningCoFuture( + [this, fileDescriptor](OwningCoFuture &future) { + Accept(fileDescriptor, [&future](int result) { + future.Set(result); + }); + }); } -CoFuture IOUring::WriteAsync(int fileDescriptor, std::string_view data, +OwningCoFuture IOUring::WriteAsync(int fileDescriptor, std::string_view data, size_t offset, size_t len) { - auto promise = std::make_shared>(); - auto future = promise->GetFuture(); - Write(fileDescriptor, data, offset, len, [promise](int result) { - promise->Set(result); - }); - return future; + return OwningCoFuture( + [this, fileDescriptor, data, offset, len](OwningCoFuture &future) { + Write(fileDescriptor, data, offset, len, [&future](int result) { + future.Set(result); + }); + }); } void IOUring::ProcessCalls() { diff --git a/server/src/read_iterator.cpp b/server/src/read_iterator.cpp index cc0d0c5..a187ae9 100644 --- a/server/src/read_iterator.cpp +++ b/server/src/read_iterator.cpp @@ -15,7 +15,7 @@ bool IsRetryableIoError(int result) { ReadIterator::ReadIterator(IOUring &ring, int fd) : ring_(ring), fd_(fd) {} -CoFuture ReadIterator::operator++() { +NonOwningCoFuture ReadIterator::operator++() { if (position_ < length_) { ++position_; } diff --git a/server/src/server.cpp b/server/src/server.cpp index d6f2663..d1076fe 100644 --- a/server/src/server.cpp +++ b/server/src/server.cpp @@ -247,7 +247,9 @@ Server::~Server() { } } -CoFuture Server::AcceptAndProcess(IOUring &ring) { +NonOwningCoFuture +Server::AcceptAndProcess(IOUring &ring, + std::vector> &connections) { while (!stopFlag_.load()) { int connectionFD = co_await ring.AcceptAsync(socketFD_); @@ -258,20 +260,38 @@ CoFuture Server::AcceptAndProcess(IOUring &ring) { continue; } - Process(ring, connectionFD); + connections.push_back(Process(ring, connectionFD)); } co_return; } void Server::WorkerLoop(IOUring &ring) { try { - auto acceptCoro = AcceptAndProcess(ring); + std::vector> connections; + connections.reserve(1024); + auto acceptCoro = AcceptAndProcess(ring, connections); + size_t cleanupPolls = 0; while (!stopFlag_.load()) { ring.Poll(); + + if (++cleanupPolls >= 256) { + cleanupPolls = 0; + connections.erase( + std::remove_if(connections.begin(), connections.end(), + [](NonOwningCoFuture &connection) { + if (!connection.isReady()) { + return false; + } + connection.Get(); + return true; + }), + connections.end()); + } if (acceptCoro.isReady()) { - acceptCoro = AcceptAndProcess(ring); + acceptCoro.Get(); + acceptCoro = AcceptAndProcess(ring, connections); } } @@ -282,6 +302,11 @@ void Server::WorkerLoop(IOUring &ring) { if (acceptCoro.isReady()) { acceptCoro.Get(); } + for (NonOwningCoFuture &connection : connections) { + if (connection.isReady()) { + connection.Get(); + } + } } catch (const std::exception &e) { std::cerr << "[WorkerLoop] Exception: " << e.what() << std::endl; } catch (...) { @@ -289,8 +314,8 @@ void Server::WorkerLoop(IOUring &ring) { } } -CoFuture Server::WriteRaw(IOUring &ring, int connectionFD, - std::string_view data) { +NonOwningCoFuture Server::WriteRaw(IOUring &ring, int connectionFD, + std::string_view data) { size_t sent = 0; int attempts = 0; while (sent < data.size()) { @@ -312,10 +337,11 @@ CoFuture Server::WriteRaw(IOUring &ring, int connectionFD, co_return; } -CoFuture Server::WriteResponse(IOUring &ring, int connectionFD, - const ResponseData &data, - const RequestData &request, - bool keepAlive, std::string &buffer) { +NonOwningCoFuture Server::WriteResponse(IOUring &ring, int connectionFD, + const ResponseData &data, + const RequestData &request, + bool keepAlive, + std::string &buffer) { const unsigned short status = NormalizeStatus(data.status); const bool sendBody = ShouldSendBody(status, request.method); buffer.clear(); @@ -360,7 +386,7 @@ CoFuture Server::WriteResponse(IOUring &ring, int connectionFD, co_return; } -CoFuture Server::Process(IOUring &ring, int connectionFD) { +NonOwningCoFuture Server::Process(IOUring &ring, int connectionFD) { HttpRequestParser parser(ring, connectionFD); std::string writeBuffer; writeBuffer.reserve(256); diff --git a/server/tests/stack_future_test.cpp b/server/tests/stack_future_test.cpp new file mode 100644 index 0000000..f9b80be --- /dev/null +++ b/server/tests/stack_future_test.cpp @@ -0,0 +1,209 @@ +#include "non_owning_co_future.h" +#include "owning_co_future.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { +void Check(bool condition, const char *message) { + if (!condition) { + std::cerr << "FAIL: " << message << '\n'; + std::exit(1); + } +} + +HTTP::NonOwningCoFuture +AwaitOwning(std::function()> &makeFuture) { + int value = co_await makeFuture(); + co_return value + 1; +} + +HTTP::NonOwningCoFuture +AwaitOwningVoid(std::function()> &makeFuture) { + co_await makeFuture(); + co_return; +} + +HTTP::NonOwningCoFuture ReturnValue() { co_return 7; } + +HTTP::NonOwningCoFuture ThrowValue() { + throw std::runtime_error("boom"); + co_return 0; +} + +HTTP::NonOwningCoFuture AwaitOwningReference(HTTP::OwningCoFuture &future) { + int value = co_await future; + co_return value + 2; +} + +HTTP::NonOwningCoFuture +AwaitOwningVoidReference(HTTP::OwningCoFuture &future) { + co_await future; + co_return; +} +} // namespace + +int main() { + { + std::function complete; + std::function()> makeFuture = [&] { + return HTTP::OwningCoFuture( + [&](HTTP::OwningCoFuture &future) { + complete = [&future](int value) { future.Set(value); }; + }); + }; + + auto task = AwaitOwning(makeFuture); + Check(!task.isReady(), "owning int await suspends"); + complete(41); + Check(task.Get() == 42, "owning int await resumes with value"); + } + + { + std::function complete; + std::function()> makeFuture = [&] { + return HTTP::OwningCoFuture( + [&](HTTP::OwningCoFuture &future) { + complete = [&future] { future.Set(); }; + }); + }; + + auto task = AwaitOwningVoid(makeFuture); + Check(!task.isReady(), "owning void await suspends"); + complete(); + task.Get(); + Check(task.isReady(), "owning void await resumes"); + } + + { + auto task = ReturnValue(); + Check(task.Get() == 7, "non-owning coroutine returns value"); + } + + { + auto task = ThrowValue(); + bool threw = false; + try { + (void)task.Get(); + } catch (const std::runtime_error &) { + threw = true; + } + Check(threw, "non-owning coroutine propagates exception"); + } + + { + HTTP::OwningCoFuture root; + auto next = root.Then(std::function( + [](int value) { return value + 10; })); + root.Set(5); + Check(next.Get() == 15, "owning Then maps value"); + } + + { + HTTP::OwningCoFuture root; + auto next = root.Then(std::function([] { return 3; })); + root.Set(); + Check(next.Get() == 3, "owning void Then maps value"); + } + + { + std::function complete; + std::function()> makeFuture = [&] { + return HTTP::OwningCoFuture( + [&](HTTP::OwningCoFuture &future) { + complete = [&future](int value) { future.Set(value); }; + }); + }; + + auto task = AwaitOwning(makeFuture); + std::noop_coroutine_handle noop = std::noop_coroutine(); + Check(task.await_suspend(noop), "non-owning future accepts subscriber"); + bool threw = false; + try { + (void)task.await_suspend(noop); + } catch (const std::runtime_error &) { + threw = true; + } + Check(threw, "non-owning future rejects second subscriber"); + } + + { + HTTP::OwningCoFuture future; + std::noop_coroutine_handle noop = std::noop_coroutine(); + Check(future.await_suspend(noop), "owning future accepts subscriber"); + bool threw = false; + try { + (void)future.await_suspend(noop); + } catch (const std::runtime_error &) { + threw = true; + } + Check(threw, "owning future rejects second subscriber"); + } + + { + HTTP::OwningCoFuture future; + std::thread producer([&future] { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + future.Set(19); + }); + Check(future.Get() == 19, "owning int Get blocks across threads"); + producer.join(); + } + + { + HTTP::OwningCoFuture future; + std::thread producer([&future] { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + future.Set(); + }); + future.Get(); + Check(future.isReady(), "owning void Get blocks across threads"); + producer.join(); + } + + { + HTTP::OwningCoFuture future; + std::thread producer([&future] { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + future.SetException(std::make_exception_ptr(std::runtime_error("thread"))); + }); + bool threw = false; + try { + (void)future.Get(); + } catch (const std::runtime_error &) { + threw = true; + } + producer.join(); + Check(threw, "owning exception wakes Get across threads"); + } + + { + HTTP::OwningCoFuture root; + auto task = AwaitOwningReference(root); + std::thread producer([&root] { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + root.Set(40); + }); + Check(task.Get() == 42, "non-owning int Get blocks through awaited future"); + producer.join(); + } + + { + HTTP::OwningCoFuture root; + auto task = AwaitOwningVoidReference(root); + std::thread producer([&root] { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + root.Set(); + }); + task.Get(); + producer.join(); + Check(task.isReady(), "non-owning void Get blocks through awaited future"); + } + + return 0; +}