diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..ca8ae46 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: CI + +on: + pull_request: + branches: + - main + - master + push: + branches: + - main + - master + +permissions: + contents: read + +jobs: + build-and-test: + name: build-and-test + runs-on: ubuntu-latest + container: + image: archlinux:latest + options: --security-opt seccomp=unconfined --ulimit memlock=-1:-1 + + steps: + - name: Install dependencies + run: | + pacman -Syu --noconfirm + pacman -S --noconfirm base-devel cmake gcc liburing pkgconf git nodejs + + - name: Check out repository + uses: actions/checkout@v4 + + - name: Configure server tests + run: cmake -S server -B build/server -DCMAKE_BUILD_TYPE=Release + + - name: Build server tests + run: cmake --build build/server -j + + - name: Run tests + run: ctest --test-dir build/server --output-on-failure + + - name: Configure example + run: cmake -S example -B build/example -DCMAKE_BUILD_TYPE=Release -DENABLE_ASAN=OFF + + - name: Build example + run: cmake --build build/example -j diff --git a/README.md b/README.md index f0b219d..24b8d23 100644 --- a/README.md +++ b/README.md @@ -71,11 +71,29 @@ cd ../benchmarks DURATION=30s THREADS=4 CONNECTIONS=100 ./benchmark.sh ``` -Output (2026-04-02 16:42:45 MSK, non-ASAN Release build): - -- **GET /echo** - - **C++**: 427149 req/s - - **Tokio**: 500231 req/s -- **POST /echo** - - **C++**: 387572 req/s - - **Tokio**: 452589 req/s +Output (non-ASAN Release build): + +#### GET /echo + +| Server | Latency avg | Latency stdev | Latency max | +/- stdev | Requests/sec | Transfer/sec | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| C++ coroutine server | 132.16us | 66.78us | 4.97ms | 80.82% | 439792.17 | 29.78MB | +| Rust Tokio server | 150.12us | 119.34us | 5.53ms | 94.01% | 506996.06 | 40.61MB | + +#### POST /echo + +| Server | Latency avg | Latency stdev | Latency max | +/- stdev | Requests/sec | Transfer/sec | +| --- | ---: | ---: | ---: | ---: | ---: | ---: | +| C++ coroutine server | 149.98us | 77.14us | 5.27ms | 81.01% | 402624.34 | 29.57MB | +| Rust Tokio server | 168.98us | 138.11us | 5.80ms | 91.79% | 465844.80 | 39.98MB | + +## Pull Request Checks + +Pull requests targeting `main` or `master` run the `CI / build-and-test` +GitHub Actions check. Configure both branches in GitHub branch protection or a +repository ruleset with: + +- Require a pull request before merging +- Require status checks to pass before merging +- Required status check: `build-and-test` (shown in Actions as `CI / build-and-test`) +- Restrict direct pushes to matching branches diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index dd5a869..b741385 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -19,6 +19,7 @@ pkg_check_modules(LIBURING REQUIRED liburing) option(ENABLE_ASAN "Enable AddressSanitizer" OFF) add_library(coro_http_server src/io_uring.cpp + src/http_parser.cpp src/read_iterator.cpp src/http_error.cpp src/trie.cpp @@ -42,3 +43,31 @@ include(GNUInstallDirs) install(TARGETS coro_http_server EXPORT MyServerConfig DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(EXPORT MyServerConfig FILE coro_http_serverConfig.cmake NAMESPACE coro_http_server:: DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/coro_http_server) install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +if (CMAKE_SOURCE_DIR STREQUAL PROJECT_SOURCE_DIR) + set(CORO_HTTP_SERVER_BUILD_TESTS_DEFAULT ON) +else() + set(CORO_HTTP_SERVER_BUILD_TESTS_DEFAULT OFF) +endif() +option(CORO_HTTP_SERVER_BUILD_TESTS "Build coro_http_server tests" ${CORO_HTTP_SERVER_BUILD_TESTS_DEFAULT}) + +if (CORO_HTTP_SERVER_BUILD_TESTS) + include(CTest) + if (BUILD_TESTING) + find_package(Threads REQUIRED) + + add_executable(http_parser_test tests/http_parser_test.cpp) + target_link_libraries(http_parser_test PRIVATE coro_http_server::coro_http_server) + set_target_properties(http_parser_test PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests + ) + add_test(NAME http_parser_test COMMAND http_parser_test) + + add_executable(http_integration_test tests/http_integration_test.cpp) + target_link_libraries(http_integration_test PRIVATE coro_http_server::coro_http_server Threads::Threads) + set_target_properties(http_integration_test PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tests + ) + add_test(NAME http_integration_test COMMAND http_integration_test) + endif() +endif() diff --git a/server/include/http_error.h b/server/include/http_error.h index faaae94..3799b44 100644 --- a/server/include/http_error.h +++ b/server/include/http_error.h @@ -1,11 +1,11 @@ #pragma once #include -#include +#include namespace HTTP { class HTTPError : public std::exception { public: - std::string_view message; + std::string message; const int status; - HTTPError(int status, std::string_view message); + HTTPError(int status, std::string message); }; } // namespace HTTP diff --git a/server/include/http_parser.h b/server/include/http_parser.h new file mode 100644 index 0000000..e8d2be0 --- /dev/null +++ b/server/include/http_parser.h @@ -0,0 +1,88 @@ +#pragma once +#include "co_future.h" +#include "read_iterator.h" +#include "request_data.h" +#include +#include +#include + +namespace HTTP { + +enum class HttpParseStatus { NeedMoreData, NeedContinue, Complete, Error }; +enum class RequestReadStatus { Complete, NeedContinue, Closed }; + +struct HttpParseResult { + HttpParseStatus status{HttpParseStatus::NeedMoreData}; + int errorStatus{0}; + std::string errorMessage; +}; + +class HttpParserState { + enum class State { + StartLine, + HeaderLine, + FixedBody, + ChunkSize, + ChunkData, + ChunkDataCrlf, + TrailerLine, + Complete, + Error + }; + + State state_{State::StartLine}; + RequestData current_; + std::string line_; + std::string pending_; + bool sawCr_{false}; + bool continueSent_{false}; + bool waitingForContinue_{false}; + bool continueCandidate_{false}; + bool chunkNeedsLf_{false}; + bool started_{false}; + bool chunked_{false}; + bool expectsContinue_{false}; + bool hostEmpty_{false}; + bool contentLengthInvalid_{false}; + bool contentLengthConflict_{false}; + bool unsupportedTransferEncoding_{false}; + bool invalidExpect_{false}; + bool hasContentLength_{false}; + size_t hostCount_{0}; + size_t headerBytes_{0}; + size_t fixedRemaining_{0}; + size_t chunkRemaining_{0}; + size_t contentLength_{0}; + int errorStatus_{0}; + std::string errorMessage_; + + void ResetForNext(); + size_t ProcessBytes(std::string_view data); + void SetError(int status, std::string message); + void CompleteCurrent(); + void ProcessLine(); + void AppendLineData(std::string_view data); + void ProcessLineByte(char ch); + void FinishHeaders(); + void TrackHeader(std::string_view name, std::string_view value); + +public: + HttpParserState(); + void Append(std::string_view data); + size_t Consume(std::string_view data); + bool Empty() const; + void MarkContinueSent(); + HttpParseResult ParseNext(RequestData &request); +}; + +class HttpRequestParser { + ReadIterator iterator_; + HttpParserState state_; + +public: + HttpRequestParser(IOUring &ring, int fd); + CoFuture ReadRequest(RequestData &request); + void MarkContinueSent(); +}; + +} // namespace HTTP diff --git a/server/include/io_uring.h b/server/include/io_uring.h index 703e446..3d33d6a 100644 --- a/server/include/io_uring.h +++ b/server/include/io_uring.h @@ -2,6 +2,7 @@ #include "co_future.h" #include #include +#include #include #include #include @@ -12,6 +13,8 @@ #include #define QUEUE_DEPTH 1024 namespace HTTP { +inline constexpr size_t kReadBufferSize = 256; + class IOUring; class IOUring { @@ -44,8 +47,8 @@ class IOUring { ~IOUring(); IOUring(); IOUring &operator=(IOUring &&rhs); - void Read(int fileDescriptor, std::array &buffer, std::function complete); - CoFuture ReadAsync(int fileDescriptor, std::array &buffer); + void Read(int fileDescriptor, std::array &buffer, std::function complete); + CoFuture ReadAsync(int fileDescriptor, std::array &buffer); void Write(int fileDescriptor, std::string_view data, size_t offset, size_t len, std::function complete); CoFuture WriteAsync(int fileDescriptor, std::string_view data, size_t offset, diff --git a/server/include/read_iterator.h b/server/include/read_iterator.h index 02397e3..7194b82 100644 --- a/server/include/read_iterator.h +++ b/server/include/read_iterator.h @@ -1,27 +1,24 @@ #pragma once #include "co_future.h" #include "io_uring.h" -#include "request_data.h" namespace HTTP { + class ReadIterator { IOUring &ring_; - std::array buffer_; + std::array buffer_; size_t length_{0}; size_t position_{0}; int fd_; + bool eof_{false}; public: ReadIterator(IOUring &ring, int fd_); - CoFuture Ensure(); + CoFuture operator++(); + char operator*() const; + explicit operator bool() const; + bool Eof() const; size_t Available() const; const char *CurrentPtr() const; void Advance(size_t n); - CoFuture operator++(); - char operator*(); - operator bool(); - CoFuture ParseVariables(RequestData &data); - CoFuture ParseHeaders(RequestData &data); - CoFuture ParseMethod(RequestData &data); - CoFuture ParseBody(RequestData &data); -}; }; +} // namespace HTTP diff --git a/server/include/request_data.h b/server/include/request_data.h index 55764c1..c797ff1 100644 --- a/server/include/request_data.h +++ b/server/include/request_data.h @@ -1,19 +1,37 @@ #pragma once +#include +#include #include +#include #include +#include #include namespace HTTP { -enum Method { GET, PUT, POST, PATCH, DELETE }; +enum Method { GET, PUT, POST, PATCH, DELETE, HEAD, OPTIONS, UNKNOWN }; + +constexpr size_t kRoutableMethodCount = 7; + struct RequestData { std::unordered_map headers; std::unordered_map params; + std::unordered_map trailers; + std::vector> rawHeaders; std::vector urlVariables; - Method method; + Method method{UNKNOWN}; + std::string methodToken; + std::string target; + std::string path; + std::string query; + std::string version; std::string body; + bool connectionClose{false}; + bool connectionKeepAlive{false}; + + std::optional Header(std::string_view name) const; }; struct ResponseData { std::unordered_map headers; std::string body; - unsigned short status; + unsigned short status{200}; }; } // namespace HTTP diff --git a/server/include/server.h b/server/include/server.h index 2674d1c..cd081f1 100644 --- a/server/include/server.h +++ b/server/include/server.h @@ -1,11 +1,13 @@ #pragma once #include "co_future.h" +#include "http_parser.h" #include "io_uring.h" -#include "read_iterator.h" #include "request_data.h" #include "trie.h" #include #include +#include +#include #include namespace HTTP { class Server { @@ -20,9 +22,11 @@ class Server { void WorkerLoop(IOUring &ring); CoFuture AcceptAndProcess(IOUring &ring); - CoFuture GetHandler(RequestData &data, ReadIterator &iter, RespondType &handler); - CoFuture WriteResponse(IOUring &ring, int connectionFD, const ResponseData &data, - bool keepAlive); + CoFuture WriteRaw(IOUring &ring, int connectionFD, std::string_view data); + CoFuture WriteResponse(IOUring &ring, int connectionFD, + const ResponseData &data, + const RequestData &request, bool keepAlive, + std::string &buffer); CoFuture Process(IOUring &ring, int connectionFD); friend class ServerBuilder; diff --git a/server/include/trie.h b/server/include/trie.h index 503c8e7..82998c8 100644 --- a/server/include/trie.h +++ b/server/include/trie.h @@ -23,17 +23,28 @@ class Trie { StringEqual> children; std::unique_ptr wildcard; - std::optional handlers[5]; + std::optional handlers[kRoutableMethodCount]; Node() = default; Node &Move(std::string_view segment); }; std::unique_ptr root_ = std::make_unique(); public: + struct RouteResult { + bool pathFound{false}; + bool methodAllowed{false}; + bool automaticOptions{false}; + RespondType handler; + std::string allow; + }; + Trie() = default; Trie(Trie &&rhs); Trie &operator=(Trie &&rhs); void AddRequest(Method type, RespondType function, std::string_view path); + std::string AllAllowedMethods() const; + RouteResult Resolve(Method method, std::string_view path, + std::vector &urlVariables) const; RespondType Match(Method method, std::string_view path, std::vector &urlVariables) const; }; diff --git a/server/src/http_error.cpp b/server/src/http_error.cpp index 19d1811..0b0803a 100644 --- a/server/src/http_error.cpp +++ b/server/src/http_error.cpp @@ -1,5 +1,5 @@ #include "http_error.h" -#include +#include using namespace HTTP; -HTTPError::HTTPError(int status, std::string_view message) - : status(status), message(message) {} +HTTPError::HTTPError(int status, std::string message) + : message(std::move(message)), status(status) {} diff --git a/server/src/http_parser.cpp b/server/src/http_parser.cpp new file mode 100644 index 0000000..99d8287 --- /dev/null +++ b/server/src/http_parser.cpp @@ -0,0 +1,1079 @@ +#include "http_parser.h" +#include "http_error.h" +#include +#include +#include +#include +#include + +namespace HTTP { + +namespace { +constexpr size_t kMaxStartLineSize = 8192; +constexpr size_t kMaxHeaderSectionSize = 64 * 1024; +constexpr size_t kMaxBodySize = 10 * 1024 * 1024; + +HttpParseResult NeedMore() { + return {HttpParseStatus::NeedMoreData, 0, ""}; +} + +HttpParseResult NeedContinue() { + return {HttpParseStatus::NeedContinue, 0, ""}; +} + +HttpParseResult Complete() { + return {HttpParseStatus::Complete, 0, ""}; +} + +HttpParseResult Error(int status, std::string message) { + return {HttpParseStatus::Error, status, std::move(message)}; +} + +constexpr unsigned char ascii_lower(unsigned char ch) { + return ch >= 'A' && ch <= 'Z' ? static_cast(ch + ('a' - 'A')) + : ch; +} + +bool iequals(std::string_view lhs, std::string_view rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (ascii_lower(static_cast(lhs[i])) != + ascii_lower(static_cast(rhs[i]))) { + return false; + } + } + return true; +} + +bool starts_with_ci(std::string_view value, std::string_view prefix) { + return value.size() >= prefix.size() && + iequals(value.substr(0, prefix.size()), prefix); +} + +std::string_view trim_ows_view(std::string_view value) { + size_t start = 0; + while (start < value.size() && + (value[start] == ' ' || value[start] == '\t')) { + ++start; + } + size_t end = value.size(); + while (end > start && (value[end - 1] == ' ' || value[end - 1] == '\t')) { + --end; + } + return value.substr(start, end - start); +} + +std::string trim_ows(std::string_view value) { + return std::string(trim_ows_view(value)); +} + +bool ascii_is_alpha(unsigned char ch) { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); +} + +bool ascii_is_digit(unsigned char ch) { return ch >= '0' && ch <= '9'; } + +bool ascii_is_alnum(unsigned char ch) { + return ascii_is_alpha(ch) || ascii_is_digit(ch); +} + +bool ascii_is_xdigit(unsigned char ch) { + return ascii_is_digit(ch) || (ch >= 'a' && ch <= 'f') || + (ch >= 'A' && ch <= 'F'); +} + +bool is_tchar(unsigned char ch) { + if (ascii_is_alnum(ch)) { + return true; + } + switch (ch) { + case '!': + case '#': + case '$': + case '%': + case '&': + case '\'': + case '*': + case '+': + case '-': + case '.': + case '^': + case '_': + case '`': + case '{': + case '|': + case '}': + case '~': + return true; + default: + return false; + } +} + +bool is_token(std::string_view value) { + if (value.empty()) { + return false; + } + return std::all_of(value.begin(), value.end(), [](char ch) { + return is_tchar(static_cast(ch)); + }); +} + +bool valid_field_value(std::string_view value) { + for (unsigned char ch : value) { + if (ch == '\t') { + continue; + } + if (ch < 0x20 || ch == 0x7f) { + return false; + } + } + return true; +} + +Method method_from_token(std::string_view token) { + if (token == "GET") { + return GET; + } + if (token == "PUT") { + return PUT; + } + if (token == "POST") { + return POST; + } + if (token == "PATCH") { + return PATCH; + } + if (token == "DELETE") { + return DELETE; + } + if (token == "HEAD") { + return HEAD; + } + if (token == "OPTIONS") { + return OPTIONS; + } + return UNKNOWN; +} + +bool parse_http_version(std::string_view version, int &major, int &minor) { + if (version.size() != 8 || version.substr(0, 5) != "HTTP/" || + !ascii_is_digit(static_cast(version[5])) || + version[6] != '.' || + !ascii_is_digit(static_cast(version[7]))) { + return false; + } + major = version[5] - '0'; + minor = version[7] - '0'; + return true; +} + +std::vector header_values( + const std::vector> &headers, + std::string_view name) { + std::vector result; + for (const auto &[fieldName, value] : headers) { + if (iequals(fieldName, name)) { + result.push_back(value); + } + } + return result; +} + +std::vector split_comma_values( + const std::vector &values) { + std::vector result; + for (const std::string &value : values) { + size_t pos = 0; + while (pos <= value.size()) { + const size_t comma = value.find(',', pos); + const size_t end = comma == std::string::npos ? value.size() : comma; + result.push_back(trim_ows(std::string_view(value).substr(pos, end - pos))); + if (comma == std::string::npos) { + break; + } + pos = comma + 1; + } + } + return result; +} + +std::optional parse_decimal_size(std::string_view value) { + if (value.empty() || + !std::all_of(value.begin(), value.end(), [](char ch) { + return ascii_is_digit(static_cast(ch)); + })) { + return std::nullopt; + } + size_t parsed = 0; + const auto *first = value.data(); + const auto *last = value.data() + value.size(); + auto [ptr, ec] = std::from_chars(first, last, parsed, 10); + if (ec != std::errc() || ptr != last) { + return std::nullopt; + } + return parsed; +} + +std::optional parse_hex_size(std::string_view value) { + if (value.empty() || + !std::all_of(value.begin(), value.end(), [](char ch) { + return ascii_is_xdigit(static_cast(ch)); + })) { + return std::nullopt; + } + size_t parsed = 0; + const auto *first = value.data(); + const auto *last = value.data() + value.size(); + auto [ptr, ec] = std::from_chars(first, last, parsed, 16); + if (ec != std::errc() || ptr != last) { + return std::nullopt; + } + return parsed; +} + +HttpParseResult parse_field_line(std::string_view line, std::string_view &name, + std::string_view &value) { + if (line.empty()) { + return Error(400, "Empty header field"); + } + if (line.front() == ' ' || line.front() == '\t') { + return Error(400, "Obsolete folded headers are not supported"); + } + const size_t colon = line.find(':'); + if (colon == std::string_view::npos) { + return Error(400, "Invalid header field"); + } + name = line.substr(0, colon); + if (!is_token(name)) { + return Error(400, "Invalid header name"); + } + value = trim_ows_view(line.substr(colon + 1)); + if (!valid_field_value(value)) { + return Error(400, "Invalid header value"); + } + return Complete(); +} + +void add_header(RequestData &request, std::string_view name, + std::string_view value) { + std::string storedName(name); + std::string storedValue(value); + request.rawHeaders.emplace_back(storedName, storedValue); + auto existing = request.headers.find(storedName); + if (existing == request.headers.end()) { + request.headers.emplace(std::move(storedName), std::move(storedValue)); + } else { + existing->second += ", "; + existing->second += storedValue; + } +} + +void add_trailer(RequestData &request, std::string_view name, + std::string_view value) { + std::string storedName(name); + std::string storedValue(value); + auto existing = request.trailers.find(storedName); + if (existing == request.trailers.end()) { + request.trailers.emplace(std::move(storedName), std::move(storedValue)); + } else { + existing->second += ", "; + existing->second += storedValue; + } +} + +void parse_query_params(RequestData &request) { + request.params.reserve(4); + size_t pos = 0; + while (pos <= request.query.size()) { + const size_t amp = request.query.find('&', pos); + const size_t end = amp == std::string::npos ? request.query.size() : amp; + const std::string_view part = + std::string_view(request.query).substr(pos, end - pos); + if (!part.empty()) { + const size_t eq = part.find('='); + const std::string_view name = + eq == std::string_view::npos ? part : part.substr(0, eq); + const std::string_view value = + eq == std::string_view::npos ? std::string_view() + : part.substr(eq + 1); + if (!name.empty()) { + auto [it, inserted] = + request.params.emplace(std::string(name), std::string(value)); + if (!inserted) { + it->second.assign(value); + } + } + } + if (amp == std::string::npos) { + break; + } + pos = amp + 1; + } +} + +HttpParseResult parse_target(RequestData &request) { + if (request.target.empty()) { + return Error(400, "Empty request target"); + } + + if (request.target == "*") { + request.path = "*"; + return Complete(); + } + + if (starts_with_ci(request.target, "http://") || + starts_with_ci(request.target, "https://")) { + const size_t authorityStart = request.target.find("://") + 3; + const size_t targetStart = request.target.find_first_of("/?", authorityStart); + if (targetStart == authorityStart) { + return Error(400, "Invalid absolute request target"); + } + if (targetStart == std::string::npos) { + request.path = "/"; + return Complete(); + } + if (request.target[targetStart] == '?') { + request.path = "/"; + request.query = request.target.substr(targetStart + 1); + parse_query_params(request); + return Complete(); + } + const size_t queryStart = request.target.find('?', targetStart); + request.path = queryStart == std::string::npos + ? request.target.substr(targetStart) + : request.target.substr(targetStart, + queryStart - targetStart); + if (queryStart != std::string::npos) { + request.query = request.target.substr(queryStart + 1); + parse_query_params(request); + } + return Complete(); + } + + if (request.target.front() != '/') { + return Error(400, "Unsupported request target form"); + } + + const size_t queryStart = request.target.find('?'); + request.path = queryStart == std::string::npos + ? request.target + : request.target.substr(0, queryStart); + if (queryStart != std::string::npos) { + request.query = request.target.substr(queryStart + 1); + parse_query_params(request); + } + return Complete(); +} + +HttpParseResult parse_request_line(std::string_view line, + RequestData &request) { + const size_t firstSpace = line.find(' '); + if (firstSpace == std::string_view::npos || firstSpace == 0) { + return Error(400, "Invalid request line"); + } + const size_t secondSpace = line.find(' ', firstSpace + 1); + if (secondSpace == std::string_view::npos || secondSpace == firstSpace + 1 || + line.find(' ', secondSpace + 1) != std::string_view::npos) { + return Error(400, "Invalid request line"); + } + + request.methodToken = std::string(line.substr(0, firstSpace)); + request.target = + std::string(line.substr(firstSpace + 1, secondSpace - firstSpace - 1)); + request.version = std::string(line.substr(secondSpace + 1)); + + if (!is_token(request.methodToken)) { + return Error(400, "Invalid method token"); + } + request.method = method_from_token(request.methodToken); + + int major = 0; + int minor = 0; + if (!parse_http_version(request.version, major, minor)) { + return Error(400, "Invalid HTTP version"); + } + if (major != 1 || (minor != 0 && minor != 1)) { + return Error(505, "HTTP Version Not Supported"); + } + + return parse_target(request); +} + +HttpParseResult parse_content_length( + const std::vector> &headers, + std::optional &contentLength) { + const auto values = split_comma_values(header_values(headers, "Content-Length")); + if (values.empty()) { + return Complete(); + } + + std::optional expected; + for (const std::string &value : values) { + const auto parsed = parse_decimal_size(value); + if (!parsed) { + return Error(400, "Invalid Content-Length"); + } + if (!expected) { + expected = *parsed; + } else if (*expected != *parsed) { + return Error(400, "Conflicting Content-Length values"); + } + } + contentLength = expected; + if (contentLength && *contentLength > kMaxBodySize) { + return Error(413, "Content Too Large"); + } + return Complete(); +} + +HttpParseResult parse_transfer_encoding( + const std::vector> &headers, + bool &chunked) { + const auto values = + split_comma_values(header_values(headers, "Transfer-Encoding")); + if (values.empty()) { + return Complete(); + } + if (values.back().empty() || !iequals(values.back(), "chunked")) { + return Error(501, "Unsupported transfer encoding"); + } + for (const std::string &value : values) { + if (value.empty() || !iequals(value, "chunked")) { + return Error(501, "Unsupported transfer encoding"); + } + } + chunked = true; + return Complete(); +} + +HttpParseResult validate_host(const RequestData &request) { + const auto hosts = header_values(request.rawHeaders, "Host"); + if (request.version == "HTTP/1.1" && hosts.size() != 1) { + return Error(400, "HTTP/1.1 requires exactly one Host header"); + } + if (hosts.size() > 1) { + return Error(400, "Multiple Host headers are not allowed"); + } + if (!hosts.empty() && hosts.front().empty()) { + return Error(400, "Host header must not be empty"); + } + return Complete(); +} + +HttpParseResult validate_expect(const RequestData &request, + bool &expectsContinue) { + auto value = request.Header("Expect"); + if (!value) { + return Complete(); + } + const std::string trimmed = trim_ows(*value); + if (!iequals(trimmed, "100-continue")) { + return Error(417, "Expectation Failed"); + } + expectsContinue = true; + return Complete(); +} +} // namespace + +std::optional RequestData::Header(std::string_view name) const { + std::string result; + for (const auto &[fieldName, value] : rawHeaders) { + if (!iequals(fieldName, name)) { + continue; + } + if (!result.empty()) { + result += ", "; + } + result += value; + } + if (!result.empty()) { + return result; + } + + for (const auto &[fieldName, value] : headers) { + if (iequals(fieldName, name)) { + return value; + } + } + return std::nullopt; +} + +HttpParserState::HttpParserState() { + line_.reserve(256); + pending_.reserve(256); + current_.headers.reserve(8); + current_.rawHeaders.reserve(8); +} + +void HttpParserState::ResetForNext() { + state_ = State::StartLine; + current_ = RequestData{}; + current_.headers.reserve(8); + current_.rawHeaders.reserve(8); + line_.clear(); + sawCr_ = false; + continueSent_ = false; + waitingForContinue_ = false; + continueCandidate_ = false; + chunkNeedsLf_ = false; + started_ = false; + chunked_ = false; + expectsContinue_ = false; + hostEmpty_ = false; + contentLengthInvalid_ = false; + contentLengthConflict_ = false; + unsupportedTransferEncoding_ = false; + invalidExpect_ = false; + hasContentLength_ = false; + hostCount_ = 0; + headerBytes_ = 0; + fixedRemaining_ = 0; + chunkRemaining_ = 0; + contentLength_ = 0; + errorStatus_ = 0; + errorMessage_.clear(); +} + +void HttpParserState::SetError(int status, std::string message) { + state_ = State::Error; + errorStatus_ = status; + errorMessage_ = std::move(message); +} + +void HttpParserState::CompleteCurrent() { + state_ = State::Complete; +} + +void HttpParserState::TrackHeader(std::string_view name, + std::string_view value) { + if (iequals(name, "Host")) { + ++hostCount_; + hostEmpty_ = hostEmpty_ || value.empty(); + return; + } + + if (iequals(name, "Content-Length")) { + size_t pos = 0; + while (pos <= value.size()) { + const size_t comma = value.find(',', pos); + const size_t end = comma == std::string_view::npos ? value.size() : comma; + const std::string_view token = trim_ows_view(value.substr(pos, end - pos)); + const auto parsed = parse_decimal_size(token); + if (!parsed) { + contentLengthInvalid_ = true; + } else if (!hasContentLength_) { + hasContentLength_ = true; + contentLength_ = *parsed; + } else if (contentLength_ != *parsed) { + contentLengthConflict_ = true; + } + if (comma == std::string_view::npos) { + break; + } + pos = comma + 1; + } + return; + } + + if (iequals(name, "Transfer-Encoding")) { + size_t pos = 0; + bool sawToken = false; + bool lastWasChunked = false; + while (pos <= value.size()) { + const size_t comma = value.find(',', pos); + const size_t end = comma == std::string_view::npos ? value.size() : comma; + const std::string_view token = trim_ows_view(value.substr(pos, end - pos)); + if (token.empty()) { + unsupportedTransferEncoding_ = true; + } else { + sawToken = true; + lastWasChunked = iequals(token, "chunked"); + if (!lastWasChunked) { + unsupportedTransferEncoding_ = true; + } + } + if (comma == std::string_view::npos) { + break; + } + pos = comma + 1; + } + if (!sawToken || !lastWasChunked) { + unsupportedTransferEncoding_ = true; + } + chunked_ = chunked_ || lastWasChunked; + return; + } + + if (iequals(name, "Expect")) { + const std::string_view token = trim_ows_view(value); + if (!iequals(token, "100-continue")) { + invalidExpect_ = true; + } else { + expectsContinue_ = true; + } + return; + } + + if (iequals(name, "Connection")) { + size_t pos = 0; + while (pos <= value.size()) { + const size_t comma = value.find(',', pos); + const size_t end = comma == std::string_view::npos ? value.size() : comma; + const std::string_view token = trim_ows_view(value.substr(pos, end - pos)); + if (iequals(token, "close")) { + current_.connectionClose = true; + } else if (iequals(token, "keep-alive")) { + current_.connectionKeepAlive = true; + } + if (comma == std::string_view::npos) { + break; + } + pos = comma + 1; + } + } +} + +void HttpParserState::ProcessLine() { + if (state_ == State::StartLine) { + if (line_.empty() && !started_) { + line_.clear(); + return; + } + started_ = true; + HttpParseResult result = parse_request_line(line_, current_); + line_.clear(); + if (result.status != HttpParseStatus::Complete) { + SetError(result.errorStatus, std::move(result.errorMessage)); + return; + } + state_ = State::HeaderLine; + return; + } + + if (state_ == State::HeaderLine) { + headerBytes_ += line_.size() + 2; + if (headerBytes_ > kMaxHeaderSectionSize) { + line_.clear(); + SetError(431, "Header section too large"); + return; + } + if (line_.empty()) { + line_.clear(); + FinishHeaders(); + return; + } + std::string_view name; + std::string_view value; + HttpParseResult result = parse_field_line(line_, name, value); + if (result.status != HttpParseStatus::Complete) { + line_.clear(); + SetError(result.errorStatus, std::move(result.errorMessage)); + return; + } + TrackHeader(name, value); + add_header(current_, name, value); + line_.clear(); + return; + } + + if (state_ == State::ChunkSize) { + const size_t extension = line_.find(';'); + std::string sizeToken = + trim_ows(extension == std::string::npos + ? std::string_view(line_) + : std::string_view(line_).substr(0, extension)); + line_.clear(); + const auto chunkSize = parse_hex_size(sizeToken); + if (!chunkSize) { + SetError(400, "Invalid chunk size"); + return; + } + if (*chunkSize > kMaxBodySize || + current_.body.size() > kMaxBodySize - *chunkSize) { + SetError(413, "Content Too Large"); + return; + } + chunkRemaining_ = *chunkSize; + state_ = chunkRemaining_ == 0 ? State::TrailerLine : State::ChunkData; + return; + } + + if (state_ == State::TrailerLine) { + headerBytes_ += line_.size() + 2; + if (headerBytes_ > kMaxHeaderSectionSize) { + line_.clear(); + SetError(431, "Trailer section too large"); + return; + } + if (line_.empty()) { + line_.clear(); + CompleteCurrent(); + return; + } + std::string_view name; + std::string_view value; + HttpParseResult result = parse_field_line(line_, name, value); + if (result.status != HttpParseStatus::Complete) { + line_.clear(); + SetError(result.errorStatus, std::move(result.errorMessage)); + return; + } + add_trailer(current_, name, value); + line_.clear(); + } +} + +void HttpParserState::AppendLineData(std::string_view data) { + if (data.empty()) { + return; + } + + line_.append(data.data(), data.size()); + if (state_ == State::StartLine && line_.size() > kMaxStartLineSize) { + SetError(414, "URI Too Long"); + } else if ((state_ == State::HeaderLine || state_ == State::TrailerLine) && + headerBytes_ + line_.size() > kMaxHeaderSectionSize) { + SetError(431, state_ == State::HeaderLine ? "Header section too large" + : "Trailer section too large"); + } +} + +void HttpParserState::ProcessLineByte(char ch) { + if (sawCr_) { + sawCr_ = false; + if (ch != '\n') { + SetError(400, "Invalid line ending"); + return; + } + ProcessLine(); + return; + } + + if (ch == '\n') { + SetError(400, "Invalid line ending"); + return; + } + if (ch == '\r') { + sawCr_ = true; + return; + } + + line_.push_back(ch); + if (state_ == State::StartLine && line_.size() > kMaxStartLineSize) { + SetError(414, "URI Too Long"); + } else if ((state_ == State::HeaderLine || state_ == State::TrailerLine) && + headerBytes_ + line_.size() > kMaxHeaderSectionSize) { + SetError(state_ == State::HeaderLine ? 431 : 431, + state_ == State::HeaderLine ? "Header section too large" + : "Trailer section too large"); + } +} + +void HttpParserState::FinishHeaders() { + if (current_.version == "HTTP/1.1" && hostCount_ != 1) { + SetError(400, "HTTP/1.1 requires exactly one Host header"); + return; + } + if (hostCount_ > 1) { + SetError(400, "Multiple Host headers are not allowed"); + return; + } + if (hostEmpty_) { + SetError(400, "Host header must not be empty"); + return; + } + + if (contentLengthInvalid_) { + SetError(400, "Invalid Content-Length"); + return; + } + if (contentLengthConflict_) { + SetError(400, "Conflicting Content-Length values"); + return; + } + if (hasContentLength_ && contentLength_ > kMaxBodySize) { + SetError(413, "Content Too Large"); + return; + } + if (unsupportedTransferEncoding_) { + SetError(501, "Unsupported transfer encoding"); + return; + } + if (chunked_ && hasContentLength_) { + SetError(400, "Transfer-Encoding and Content-Length conflict"); + return; + } + if (invalidExpect_) { + SetError(417, "Expectation Failed"); + return; + } + + if (chunked_) { + state_ = State::ChunkSize; + continueCandidate_ = expectsContinue_ && !continueSent_; + return; + } + + fixedRemaining_ = hasContentLength_ ? contentLength_ : 0; + if (fixedRemaining_ == 0) { + CompleteCurrent(); + return; + } + current_.body.reserve(fixedRemaining_); + state_ = State::FixedBody; + continueCandidate_ = expectsContinue_ && !continueSent_; +} + +size_t HttpParserState::ProcessBytes(std::string_view data) { + size_t pos = 0; + while (pos < data.size()) { + if (state_ == State::Complete || state_ == State::Error || + waitingForContinue_) { + return pos; + } + + if (state_ == State::FixedBody) { + continueCandidate_ = false; + const size_t take = std::min(fixedRemaining_, data.size() - pos); + current_.body.append(data.data() + pos, take); + fixedRemaining_ -= take; + pos += take; + if (fixedRemaining_ == 0) { + CompleteCurrent(); + return pos; + } + continue; + } + + if (state_ == State::ChunkData) { + continueCandidate_ = false; + const size_t take = std::min(chunkRemaining_, data.size() - pos); + current_.body.append(data.data() + pos, take); + chunkRemaining_ -= take; + pos += take; + if (chunkRemaining_ == 0) { + state_ = State::ChunkDataCrlf; + chunkNeedsLf_ = false; + } + continue; + } + + if (state_ == State::StartLine || state_ == State::HeaderLine || + state_ == State::ChunkSize || state_ == State::TrailerLine) { + if (sawCr_) { + sawCr_ = false; + if (data[pos++] != '\n') { + SetError(400, "Invalid line ending"); + return pos; + } + ProcessLine(); + if (state_ == State::Complete || state_ == State::Error || + waitingForContinue_) { + return pos; + } + if (continueCandidate_ && pos == data.size() && + (state_ == State::FixedBody || state_ == State::ChunkSize)) { + waitingForContinue_ = true; + continueCandidate_ = false; + return pos; + } + if (continueCandidate_ && pos < data.size()) { + continueCandidate_ = false; + } + continue; + } + + const char *begin = data.data() + pos; + const size_t remaining = data.size() - pos; + const char *cr = static_cast(std::memchr(begin, '\r', remaining)); + const char *lf = static_cast(std::memchr(begin, '\n', remaining)); + const char *lineEnd = nullptr; + if (cr && lf) { + lineEnd = cr < lf ? cr : lf; + } else { + lineEnd = cr ? cr : lf; + } + + if (!lineEnd) { + AppendLineData(std::string_view(begin, remaining)); + return data.size(); + } + + const size_t prefixSize = static_cast(lineEnd - begin); + AppendLineData(std::string_view(begin, prefixSize)); + pos += prefixSize + 1; + if (state_ == State::Error) { + return pos; + } + + if (*lineEnd == '\n') { + SetError(400, "Invalid line ending"); + return pos; + } + + if (pos == data.size()) { + sawCr_ = true; + return pos; + } + + if (data[pos++] != '\n') { + SetError(400, "Invalid line ending"); + return pos; + } + + ProcessLine(); + if (state_ == State::Complete || state_ == State::Error || + waitingForContinue_) { + return pos; + } + if (continueCandidate_ && pos == data.size() && + (state_ == State::FixedBody || state_ == State::ChunkSize)) { + waitingForContinue_ = true; + continueCandidate_ = false; + return pos; + } + if (continueCandidate_ && pos < data.size()) { + continueCandidate_ = false; + } + continue; + } + + char ch = data[pos++]; + if (state_ == State::ChunkDataCrlf) { + if (!chunkNeedsLf_) { + if (ch != '\r') { + SetError(400, "Invalid chunk data terminator"); + return pos; + } + chunkNeedsLf_ = true; + continue; + } + if (ch != '\n') { + SetError(400, "Invalid chunk data terminator"); + return pos; + } + chunkNeedsLf_ = false; + state_ = State::ChunkSize; + continue; + } + + ProcessLineByte(ch); + if (state_ == State::Complete || state_ == State::Error || + waitingForContinue_) { + return pos; + } + if (continueCandidate_ && pos == data.size() && + (state_ == State::FixedBody || state_ == State::ChunkSize)) { + waitingForContinue_ = true; + continueCandidate_ = false; + return pos; + } + if (continueCandidate_ && pos < data.size()) { + continueCandidate_ = false; + } + } + return pos; +} + +void HttpParserState::Append(std::string_view data) { + const size_t consumed = ProcessBytes(data); + if (consumed < data.size()) { + pending_.append(data.substr(consumed)); + } +} + +size_t HttpParserState::Consume(std::string_view data) { + return ProcessBytes(data); +} + +bool HttpParserState::Empty() const { + return state_ == State::StartLine && line_.empty() && pending_.empty() && + !sawCr_; +} + +void HttpParserState::MarkContinueSent() { + if (waitingForContinue_) { + waitingForContinue_ = false; + } + continueSent_ = true; +} + +HttpParseResult HttpParserState::ParseNext(RequestData &request) { + if (state_ == State::Complete) { + request = std::move(current_); + ResetForNext(); + return Complete(); + } + + if (state_ == State::Error) { + return Error(errorStatus_, errorMessage_); + } + + if (waitingForContinue_) { + return NeedContinue(); + } + + if (!pending_.empty()) { + std::string data = std::move(pending_); + pending_.clear(); + const size_t consumed = ProcessBytes(data); + if (consumed < data.size()) { + pending_.append(std::string_view(data).substr(consumed)); + } + if (state_ == State::Complete) { + request = std::move(current_); + ResetForNext(); + return Complete(); + } + if (state_ == State::Error) { + return Error(errorStatus_, errorMessage_); + } + if (waitingForContinue_) { + return NeedContinue(); + } + } + + return NeedMore(); +} + +HttpRequestParser::HttpRequestParser(IOUring &ring, int fd) + : iterator_(ring, fd) {} + +void HttpRequestParser::MarkContinueSent() { state_.MarkContinueSent(); } + +CoFuture HttpRequestParser::ReadRequest( + RequestData &request) { + while (true) { + HttpParseResult parsed = state_.ParseNext(request); + if (parsed.status == HttpParseStatus::Complete) { + co_return RequestReadStatus::Complete; + } + if (parsed.status == HttpParseStatus::NeedContinue) { + co_return RequestReadStatus::NeedContinue; + } + if (parsed.status == HttpParseStatus::Error) { + throw HTTPError(parsed.errorStatus, parsed.errorMessage); + } + + if (!iterator_) { + co_await ++iterator_; + } + + if (!iterator_) { + if (iterator_.Eof() && state_.Empty()) { + co_return RequestReadStatus::Closed; + } + throw HTTPError(400, "Incomplete HTTP message"); + } + + const size_t available = iterator_.Available(); + const size_t consumed = + state_.Consume(std::string_view(iterator_.CurrentPtr(), available)); + iterator_.Advance(consumed); + } +} + +} // namespace HTTP diff --git a/server/src/io_uring.cpp b/server/src/io_uring.cpp index 0cd2e9e..6fc8821 100644 --- a/server/src/io_uring.cpp +++ b/server/src/io_uring.cpp @@ -35,7 +35,8 @@ void IOUring::AddEntries() { std::move(entry.complete), std::move(entry.writeData), entry.writeOffset, entry.writeLen}; if (entry.type == IOUring::READ) [[likely]] { - io_uring_prep_read(sqEntry, entry.fd, entry.toRead.value(), 256, 0); + io_uring_prep_read(sqEntry, entry.fd, entry.toRead.value(), + kReadBufferSize, 0); } else if (entry.type == IOUring::ACCEPT) { io_uring_prep_accept(sqEntry, entry.fd, nullptr, nullptr, 0); } else [[likely]] { @@ -89,7 +90,8 @@ void IOUring::Write(int fileDescriptor, std::string_view data, size_t offset, AddEntries(); } -void IOUring::Read(int fileDescriptor, std::array &buffer, +void IOUring::Read(int fileDescriptor, + std::array &buffer, std::function complete) { if (fileDescriptor < 0) { throw std::runtime_error("Invalid file descriptor"); @@ -102,7 +104,8 @@ void IOUring::Read(int fileDescriptor, std::array &buffer, queue_.push_back(entry); } -CoFuture IOUring::ReadAsync(int fileDescriptor, std::array &buffer) { +CoFuture IOUring::ReadAsync( + int fileDescriptor, std::array &buffer) { auto promise = std::make_shared>(); auto future = promise->GetFuture(); Read(fileDescriptor, buffer, [promise](int result) { diff --git a/server/src/read_iterator.cpp b/server/src/read_iterator.cpp index e507255..cc0d0c5 100644 --- a/server/src/read_iterator.cpp +++ b/server/src/read_iterator.cpp @@ -1,10 +1,8 @@ #include "read_iterator.h" -#include "http_error.h" #include #include -#include -#include #include + namespace HTTP { namespace { @@ -15,261 +13,62 @@ bool IsRetryableIoError(int result) { } } // namespace -ReadIterator::ReadIterator(IOUring &ring, int fd_) : ring_(ring), fd_(fd_), length_(0), position_(0) { -} - -CoFuture ReadIterator::Ensure() { - if (position_ >= length_) { - for (int attempt = 1; attempt <= kMaxIoAttempts; ++attempt) { - int result = co_await ring_.ReadAsync(fd_, buffer_); - if (result > 0) { - length_ = static_cast(result); - position_ = 0; - co_return; - } - if (result == 0) { - length_ = 0; - position_ = 0; - co_return; - } - if (!IsRetryableIoError(result) || attempt == kMaxIoAttempts) { - throw std::system_error(-result, std::generic_category(), "read failed"); - } - } - position_ = 0; - } - co_return; -} - -size_t ReadIterator::Available() const { - if (position_ >= length_) return 0; - return length_ - position_; -} - -const char *ReadIterator::CurrentPtr() const { - if (position_ >= length_) return nullptr; - return buffer_.data() + position_; -} - -void ReadIterator::Advance(size_t n) { - position_ += n; -} +ReadIterator::ReadIterator(IOUring &ring, int fd) : ring_(ring), fd_(fd) {} CoFuture ReadIterator::operator++() { - ++position_; - co_return; -} - -ReadIterator::operator bool() { - return position_ < length_ && **this != '\0'; -} - -char ReadIterator::operator*() { - if (position_ >= length_) { - return '\0'; + if (position_ < length_) { + ++position_; } - return buffer_.at(position_); -} -CoFuture ReadIterator::ParseMethod(RequestData &data) { - co_await Ensure(); - if (length_ == 0) { - throw HTTPError(400, "Invalid request"); - } - while (true) { - co_await Ensure(); - if (!*this) { - throw HTTPError(400, "Invalid request"); - } - if (**this != '\r' && **this != '\n') { - break; - } - co_await ++*this; - } - std::string methodString; - int count{0}; - while (count < 5) { - co_await Ensure(); - if (!*this) { - throw HTTPError(400, "Invalid request"); - } - if (**this == ' ') { - break; - } - methodString += **this; - count++; - co_await ++*this; - } - if (methodString == "PUT") { - data.method = PUT; - co_return; - } - if (methodString == "POST") { - data.method = POST; - co_return; - } - if (methodString == "DELETE") { - data.method = DELETE; - co_return; - } - if (methodString == "PATCH") { - data.method = PATCH; - co_return; - } - if (methodString == "GET") { - data.method = GET; - co_return; - } - throw HTTPError(400, "Invalid request"); -} - -CoFuture ReadIterator::ParseVariables(RequestData &data) { - co_await Ensure(); - if (**this != '?' && **this != ' ') { - throw HTTPError(400, "Invalid request"); - } - enum { Name, Value } current = Name; - std::string name; - std::string *value; - if (**this != '?') { - co_return; - } - co_await ++*this; - while (true) { - co_await Ensure(); - if (!*this) { - throw HTTPError(400, "Empty parameter name"); + int attempts = 0; + while (position_ >= length_ && !eof_) { + int result = co_await ring_.ReadAsync(fd_, buffer_); + if (result > 0) { + length_ = static_cast(result); + position_ = 0; + co_return; } - if (**this == ' ') { - break; + if (result == 0) { + length_ = 0; + position_ = 0; + eof_ = true; + co_return; } - if (current == Name) { - if (**this == '=') { - if (name == "") { - throw HTTPError(400, "Empty parameter name"); - } - current = Value; - data.params[name] = ""; - value = &data.params[name]; - } else { - name.push_back(**this); - } - } else { - if (**this == '&') { - name = ""; - value = nullptr; - current = Name; - } else { - value->push_back(**this); - } + ++attempts; + if (!IsRetryableIoError(result) || attempts >= kMaxIoAttempts) { + throw std::system_error(-result, std::generic_category(), "read failed"); } - co_await ++*this; } co_return; } -CoFuture ReadIterator::ParseHeaders(RequestData &data) { - enum { Name, Value } current = Name; - std::string name; - std::string *value; - char last = 'a'; - while (true) { - co_await Ensure(); - if (!*this) { - throw HTTPError(400, "Invalid message"); - } - if (**this == '\r') { - co_await ++*this; - continue; - } - if (last == **this && last == '\n') { - break; - } - if (current == Name) { - if (**this == ':') { - if (name == "") { - throw HTTPError(400, "Empty header name"); - } - current = Value; - data.headers[name] = ""; - value = &data.headers[name]; - } else { - name.push_back(**this); - } - } else { - if (**this == '\n') { - name = ""; - value = nullptr; - current = Name; - } else { - value->push_back(**this); - } - } - last = **this; - co_await ++*this; +char ReadIterator::operator*() const { + if (position_ >= length_) { + return '\0'; } - co_return; + return buffer_[position_]; } -CoFuture ReadIterator::ParseBody(RequestData &data) { - auto it = data.headers.find("Content-Length"); - if (it != data.headers.end()) { - size_t length; - try { - length = std::stoul(it->second); - } catch (const std::invalid_argument &) { - co_return; - } catch (const std::out_of_range &) { - co_return; - } - data.body.clear(); - data.body.reserve(length); - co_await Ensure(); - if (*this && (**this == '\n' || **this == '\r')) { - co_await ++*this; - co_await Ensure(); - if (*this && (**this == '\n' || **this == '\r')) { - co_await ++*this; - } - } - size_t remaining = length; - while (remaining > 0) { - co_await Ensure(); - if (!*this) break; - size_t avail = Available(); - if (avail == 0) continue; - size_t take = std::min(avail, remaining); - data.body.append(CurrentPtr(), take); - Advance(take); - remaining -= take; - } - co_return; - } - - auto it2 = data.headers.find("Transfer-Encoding"); - if (it2 != data.headers.end() && it2->second == "chunked") { - co_return; - } +ReadIterator::operator bool() const { return position_ < length_; } - if (data.method == GET || data.method == DELETE) { - co_return; - } +bool ReadIterator::Eof() const { return eof_; } - co_await Ensure(); - if (*this && (**this == '\n' || **this == '\r')) { - co_await ++*this; - co_await Ensure(); - if (*this && (**this == '\n' || **this == '\r')) { - co_await ++*this; - } +size_t ReadIterator::Available() const { + if (position_ >= length_) { + return 0; } + return length_ - position_; +} - while (true) { - co_await Ensure(); - if (!*this) break; - data.body.push_back(**this); - co_await ++*this; +const char *ReadIterator::CurrentPtr() const { + if (position_ >= length_) { + return nullptr; } - co_return; + return buffer_.data() + position_; } + +void ReadIterator::Advance(size_t n) { + position_ = std::min(length_, position_ + n); } + +} // namespace HTTP diff --git a/server/src/server.cpp b/server/src/server.cpp index d6f3a69..d6f2663 100644 --- a/server/src/server.cpp +++ b/server/src/server.cpp @@ -1,16 +1,17 @@ #include "server.h" #include "http_error.h" -#include "read_iterator.h" #include "request_data.h" #include "trie.h" +#include #include -#include +#include +#include #include +#include #include #include #include #include -#include #include #include #include @@ -20,49 +21,183 @@ namespace HTTP { namespace { -static bool iequals(std::string_view a, std::string_view b) { - if (a.size() != b.size()) +constexpr unsigned char ascii_lower(unsigned char ch) { + return ch >= 'A' && ch <= 'Z' ? static_cast(ch + ('a' - 'A')) + : ch; +} + +bool iequals(std::string_view a, std::string_view b) { + if (a.size() != b.size()) { return false; + } for (size_t i = 0; i < a.size(); ++i) { - if (std::tolower(static_cast(a[i])) != - std::tolower(static_cast(b[i]))) { + if (ascii_lower(static_cast(a[i])) != + ascii_lower(static_cast(b[i]))) { return false; } } return true; } -static std::string trim_copy(std::string_view s) { - size_t start = 0; - while (start < s.size() && - std::isspace(static_cast(s[start]))) { - ++start; +bool ShouldKeepAlive(const RequestData &request) { + if (request.version == "HTTP/1.0") { + return request.connectionKeepAlive; } - size_t end = s.size(); - while (end > start && std::isspace(static_cast(s[end - 1]))) { - --end; + if (request.connectionClose) { + return false; } - return std::string(s.substr(start, end - start)); + return request.version == "HTTP/1.1"; +} + +bool ascii_is_alpha(unsigned char ch) { + return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z'); } -static std::optional -find_header_ci(const std::unordered_map &headers, - std::string_view key) { - for (const auto &[k, v] : headers) { - if (iequals(k, key)) - return v; +bool ascii_is_digit(unsigned char ch) { return ch >= '0' && ch <= '9'; } + +bool ascii_is_alnum(unsigned char ch) { + return ascii_is_alpha(ch) || ascii_is_digit(ch); +} + +bool is_tchar(unsigned char ch) { + if (ascii_is_alnum(ch)) { + return true; + } + switch (ch) { + case '!': + case '#': + case '$': + case '%': + case '&': + case '\'': + case '*': + case '+': + case '-': + case '.': + case '^': + case '_': + case '`': + case '{': + case '|': + case '}': + case '~': + return true; + default: + return false; } - return std::nullopt; } -static bool wants_close(const RequestData &request) { - auto v = find_header_ci(request.headers, "Connection"); - if (!v) +bool valid_header_name(std::string_view name) { + if (name.empty()) { return false; - std::string value = trim_copy(*v); - for (auto &ch : value) - ch = static_cast(std::tolower(static_cast(ch))); - return value.find("close") != std::string::npos; + } + return std::all_of(name.begin(), name.end(), [](char ch) { + return is_tchar(static_cast(ch)); + }); +} + +bool valid_header_value(std::string_view value) { + for (unsigned char ch : value) { + if (ch == '\t') { + continue; + } + if (ch < 0x20 || ch == 0x7f) { + return false; + } + } + return true; +} + +bool reserved_response_header(std::string_view name) { + return iequals(name, "Content-Length") || + iequals(name, "Transfer-Encoding") || iequals(name, "Connection") || + iequals(name, "Date"); +} + +unsigned short NormalizeStatus(unsigned short status) { + if (status < 100 || status > 599) { + return 500; + } + return status; +} + +std::string_view ReasonPhrase(unsigned short status) { + switch (status) { + case 100: + return "Continue"; + case 101: + return "Switching Protocols"; + case 200: + return "OK"; + case 201: + return "Created"; + case 202: + return "Accepted"; + case 204: + return "No Content"; + case 301: + return "Moved Permanently"; + case 302: + return "Found"; + case 304: + return "Not Modified"; + case 400: + return "Bad Request"; + case 404: + return "Not Found"; + case 405: + return "Method Not Allowed"; + case 413: + return "Content Too Large"; + case 414: + return "URI Too Long"; + case 417: + return "Expectation Failed"; + case 431: + return "Request Header Fields Too Large"; + case 500: + return "Internal Server Error"; + case 501: + return "Not Implemented"; + case 505: + return "HTTP Version Not Supported"; + default: + if (status < 200) { + return "Informational"; + } + if (status < 300) { + return "Successful"; + } + if (status < 400) { + return "Redirection"; + } + if (status < 500) { + return "Client Error"; + } + return "Server Error"; + } +} + +std::string_view HttpDate() { + thread_local std::time_t cachedSecond = 0; + thread_local char buffer[64] = "Thu, 01 Jan 1970 00:00:00 GMT"; + + const std::time_t now = std::time(nullptr); + if (now != cachedSecond) { + cachedSecond = now; + std::tm tm{}; + gmtime_r(&now, &tm); + std::strftime(buffer, sizeof(buffer), "%a, %d %b %Y %H:%M:%S GMT", &tm); + } + return buffer; +} + +bool StatusAllowsBody(unsigned short status) { + return status >= 200 && status != 204 && status != 304; +} + +bool ShouldSendBody(unsigned short status, Method method) { + return method != HEAD && StatusAllowsBody(status); } constexpr int kMaxIoAttempts = 3; @@ -70,6 +205,14 @@ constexpr int kMaxIoAttempts = 3; bool IsRetryableIoError(int result) { return result == -EINTR || result == -EAGAIN || result == -EWOULDBLOCK; } + +void AppendUnsigned(std::string &out, size_t value) { + char buffer[32]; + auto [ptr, ec] = std::to_chars(buffer, buffer + sizeof(buffer), value); + if (ec == std::errc()) { + out.append(buffer, ptr); + } +} } // namespace Server::Server(Server &&rhs) { @@ -91,7 +234,16 @@ Server::~Server() { socketFD_ = -1; } for (auto &future : workerFutures_) { - future.Get(); + try { + if (future.valid()) { + future.Get(); + } + } catch (const std::exception &error) { + std::cerr << "[Server] Worker shutdown failed: " << error.what() + << std::endl; + } catch (...) { + std::cerr << "[Server] Worker shutdown failed" << std::endl; + } } } @@ -122,6 +274,14 @@ void Server::WorkerLoop(IOUring &ring) { acceptCoro = AcceptAndProcess(ring); } } + + for (int attempts = 0; !acceptCoro.isReady() && attempts < 1000; + ++attempts) { + ring.Poll(); + } + if (acceptCoro.isReady()) { + acceptCoro.Get(); + } } catch (const std::exception &e) { std::cerr << "[WorkerLoop] Exception: " << e.what() << std::endl; } catch (...) { @@ -129,27 +289,13 @@ void Server::WorkerLoop(IOUring &ring) { } } -CoFuture Server::WriteResponse(IOUring &ring, int connectionFD, const ResponseData &data, - bool keepAlive) { - std::stringstream text; - text << "HTTP/1.1 " << data.status << ' ' - << (data.status / 100 == 2 ? "OK" : "ERROR") << "\r\n"; - auto headers = data.headers; - if (!headers.contains("Content-Length")) { - headers["Content-Length"] = std::to_string(data.body.size()); - } - headers["Connection"] = keepAlive ? "keep-alive" : "close"; - for (const auto &[name, value] : headers) { - text << name << ": " << value << "\r\n"; - } - text << "\r\n"; - text << data.body; - std::string final = text.str(); +CoFuture Server::WriteRaw(IOUring &ring, int connectionFD, + std::string_view data) { size_t sent = 0; int attempts = 0; - while (sent < final.size()) { - int result = co_await ring.WriteAsync(connectionFD, final, sent, - final.size() - sent); + while (sent < data.size()) { + int result = + co_await ring.WriteAsync(connectionFD, data, sent, data.size() - sent); if (result > 0) { sent += static_cast(result); attempts = 0; @@ -166,50 +312,105 @@ CoFuture Server::WriteResponse(IOUring &ring, int connectionFD, const Resp co_return; } +CoFuture Server::WriteResponse(IOUring &ring, int connectionFD, + const ResponseData &data, + const RequestData &request, + bool keepAlive, std::string &buffer) { + const unsigned short status = NormalizeStatus(data.status); + const bool sendBody = ShouldSendBody(status, request.method); + buffer.clear(); + buffer.reserve(160 + data.body.size()); + + buffer += "HTTP/1.1 "; + AppendUnsigned(buffer, status); + buffer += ' '; + buffer += ReasonPhrase(status); + buffer += "\r\nDate: "; + buffer += HttpDate(); + buffer += "\r\n"; + + if (StatusAllowsBody(status)) { + buffer += "Content-Length: "; + AppendUnsigned(buffer, data.body.size()); + buffer += "\r\n"; + } + + if (!keepAlive) { + buffer += "Connection: close\r\n"; + } else if (request.version == "HTTP/1.0") { + buffer += "Connection: keep-alive\r\n"; + } + + for (const auto &[name, value] : data.headers) { + if (reserved_response_header(name) || !valid_header_name(name) || + !valid_header_value(value)) { + continue; + } + buffer += name; + buffer += ": "; + buffer += value; + buffer += "\r\n"; + } + buffer += "\r\n"; + if (sendBody) { + buffer += data.body; + } + + co_await WriteRaw(ring, connectionFD, buffer); + co_return; +} + CoFuture Server::Process(IOUring &ring, int connectionFD) { - ReadIterator iterator(ring, connectionFD); + HttpRequestParser parser(ring, connectionFD); + std::string writeBuffer; + writeBuffer.reserve(256); while (true) { + RequestData request; ResponseData response; bool keepAlive = true; bool mustClose = false; try { - RequestData request; - co_await iterator.Ensure(); - if (!iterator) { - mustClose = true; - keepAlive = false; - throw HTTPError(400, ""); - } - co_await iterator.ParseMethod(request); - RespondType handler; - co_await GetHandler(request, iterator, handler); - co_await iterator.ParseVariables(request); - co_await ++iterator; - std::string protocol; while (true) { - co_await iterator.Ensure(); - if (!iterator) { - throw HTTPError(400, "Invalid request"); + RequestReadStatus readStatus = co_await parser.ReadRequest(request); + if (readStatus == RequestReadStatus::Closed) { + (void)shutdown(connectionFD, SHUT_WR); + close(connectionFD); + co_return; } - char c = *iterator; - if (c == '\n') { + if (readStatus == RequestReadStatus::Complete) { break; } - if (c != '\r') { - protocol += c; - } - co_await ++iterator; + co_await WriteRaw(ring, connectionFD, + "HTTP/1.1 100 Continue\r\n\r\n"); + parser.MarkContinueSent(); } - if (protocol != "HTTP/1.1") { - throw HTTPError(400, "Invalid request"); + + keepAlive = ShouldKeepAlive(request); + if (request.method == UNKNOWN) { + response.status = 501; + response.body = "Not Implemented"; + } else if (request.method == OPTIONS && request.path == "*") { + response.status = 200; + response.headers["Allow"] = trie_.AllAllowedMethods(); + } else { + Trie::RouteResult route = + trie_.Resolve(request.method, request.path, request.urlVariables); + if (!route.pathFound) { + response.status = 404; + response.body = "Not Found"; + } else if (request.method == OPTIONS && route.automaticOptions) { + response.status = 200; + response.headers["Allow"] = route.allow; + } else if (!route.methodAllowed) { + response.status = 405; + response.body = "Method Not Allowed"; + response.headers["Allow"] = route.allow; + } else { + response = route.handler(request); + } } - co_await ++iterator; - co_await iterator.ParseHeaders(request); - co_await iterator.ParseBody(request); - keepAlive = !wants_close(request); - response = handler(request); } catch (HTTPError &error) { response.status = error.status; response.body = error.message; @@ -234,7 +435,8 @@ CoFuture Server::Process(IOUring &ring, int connectionFD) { if (!mustClose || response.status != 400 || !response.body.empty()) { try { - co_await WriteResponse(ring, connectionFD, response, keepAlive); + co_await WriteResponse(ring, connectionFD, response, request, keepAlive, + writeBuffer); } catch (const std::system_error &) { mustClose = true; keepAlive = false; @@ -249,34 +451,6 @@ CoFuture Server::Process(IOUring &ring, int connectionFD) { co_return; } -CoFuture Server::GetHandler(RequestData &data, ReadIterator &iter, RespondType &handler) { - co_await iter.Ensure(); - if (*iter != ' ') { - throw HTTPError(400, "Invalid request"); - } - co_await ++iter; - co_await iter.Ensure(); - if (*iter != '/') { - throw HTTPError(400, "Invalid request"); - } - std::string path; - path.reserve(64); - while (true) { - co_await iter.Ensure(); - if (!iter) { - throw HTTPError(400, "Invalid request"); - } - char c = *iter; - if (c == ' ' || c == '?') { - break; - } - path.push_back(c); - co_await ++iter; - } - handler = trie_.Match(data.method, path, data.urlVariables); - co_return; -} - void ServerBuilder::SetPort(int port) { server_.port_ = port; } void ServerBuilder::SetThreads(int numThreads) { diff --git a/server/src/trie.cpp b/server/src/trie.cpp index db17164..c9779d5 100644 --- a/server/src/trie.cpp +++ b/server/src/trie.cpp @@ -1,7 +1,44 @@ #include "trie.h" #include "http_error.h" #include "request_data.h" +#include namespace HTTP { +namespace { +bool IsRoutable(Method method) { + return method >= GET && method <= OPTIONS; +} + +size_t MethodIndex(Method method) { return static_cast(method); } + +const char *MethodName(Method method) { + switch (method) { + case GET: + return "GET"; + case PUT: + return "PUT"; + case POST: + return "POST"; + case PATCH: + return "PATCH"; + case DELETE: + return "DELETE"; + case HEAD: + return "HEAD"; + case OPTIONS: + return "OPTIONS"; + case UNKNOWN: + return ""; + } + return ""; +} + +void AppendAllow(std::string &allow, Method method) { + if (!allow.empty()) { + allow += ", "; + } + allow += MethodName(method); +} +} // namespace size_t Trie::StringHash::operator()(std::string_view value) const noexcept { return std::hash{}(value); @@ -31,6 +68,9 @@ Trie::Node &Trie::Node::Move(std::string_view segment) { void Trie::AddRequest(Method method, RespondType respond, std::string_view path) { + if (!IsRoutable(method)) { + return; + } Node *current = root_.get(); size_t position = path.starts_with('/') ? 1 : 0; while (position < path.size()) { @@ -45,14 +85,91 @@ void Trie::AddRequest(Method method, RespondType respond, } position = next + 1; } - current->handlers[method] = respond; + current->handlers[MethodIndex(method)] = respond; } -RespondType Trie::Match(Method method, std::string_view path, - std::vector &urlVariables) const { - const Node *current = root_.get(); - urlVariables.clear(); +std::string Trie::AllAllowedMethods() const { + std::array methods{}; + auto visit = [&](const Node *node, const auto &self) -> void { + for (size_t i = 0; i < kRoutableMethodCount; ++i) { + if (node->handlers[i]) { + methods[i] = true; + } + } + for (const auto &[_, child] : node->children) { + self(child.get(), self); + } + if (node->wildcard) { + self(node->wildcard.get(), self); + } + }; + visit(root_.get(), visit); + + std::string allow; + if (methods[MethodIndex(GET)]) { + AppendAllow(allow, GET); + } + if (methods[MethodIndex(HEAD)] || methods[MethodIndex(GET)]) { + AppendAllow(allow, HEAD); + } + AppendAllow(allow, OPTIONS); + if (methods[MethodIndex(POST)]) { + AppendAllow(allow, POST); + } + if (methods[MethodIndex(PUT)]) { + AppendAllow(allow, PUT); + } + if (methods[MethodIndex(PATCH)]) { + AppendAllow(allow, PATCH); + } + if (methods[MethodIndex(DELETE)]) { + AppendAllow(allow, DELETE); + } + return allow; +} +Trie::RouteResult Trie::Resolve( + Method method, std::string_view path, + std::vector &urlVariables) const { + auto hasHandler = [](const Node &node, Method candidate) { + return IsRoutable(candidate) && + node.handlers[MethodIndex(candidate)].has_value(); + }; + auto hasAnyHandler = [](const Node &node) { + for (const auto &handler : node.handlers) { + if (handler) { + return true; + } + } + return false; + }; + auto buildAllow = [&](const Node &node) { + std::string allow; + if (hasHandler(node, GET)) { + AppendAllow(allow, GET); + } + if (hasHandler(node, HEAD) || hasHandler(node, GET)) { + AppendAllow(allow, HEAD); + } + AppendAllow(allow, OPTIONS); + if (hasHandler(node, POST)) { + AppendAllow(allow, POST); + } + if (hasHandler(node, PUT)) { + AppendAllow(allow, PUT); + } + if (hasHandler(node, PATCH)) { + AppendAllow(allow, PATCH); + } + if (hasHandler(node, DELETE)) { + AppendAllow(allow, DELETE); + } + return allow; + }; + + RouteResult result; + const Node *node = root_.get(); + urlVariables.clear(); size_t position = path.starts_with('/') ? 1 : 0; while (position < path.size()) { const size_t next = path.find('/', position); @@ -61,14 +178,15 @@ RespondType Trie::Match(Method method, std::string_view path, ? path.substr(position) : path.substr(position, next - position); - const auto child = current->children.find(segment); - if (child != current->children.end()) { - current = child->second.get(); - } else if (current->wildcard) { + const auto child = node->children.find(segment); + if (child != node->children.end()) { + node = child->second.get(); + } else if (node->wildcard) { urlVariables.emplace_back(segment); - current = current->wildcard.get(); + node = node->wildcard.get(); } else { - throw HTTPError(404, "Not found"); + node = nullptr; + break; } if (next == std::string_view::npos) { @@ -77,10 +195,42 @@ RespondType Trie::Match(Method method, std::string_view path, position = next + 1; } - if (!current->handlers[method]) { + if (!node || !hasAnyHandler(*node)) { + urlVariables.clear(); + return result; + } + + result.pathFound = true; + result.allow = buildAllow(*node); + if (!IsRoutable(method)) { + return result; + } + + if (hasHandler(*node, method)) { + result.methodAllowed = true; + result.handler = *node->handlers[MethodIndex(method)]; + return result; + } + if (method == HEAD && hasHandler(*node, GET)) { + result.methodAllowed = true; + result.handler = *node->handlers[MethodIndex(GET)]; + return result; + } + if (method == OPTIONS) { + result.methodAllowed = true; + result.automaticOptions = true; + return result; + } + return result; +} + +RespondType Trie::Match(Method method, std::string_view path, + std::vector &urlVariables) const { + RouteResult result = Resolve(method, path, urlVariables); + if (!result.pathFound || !result.methodAllowed || result.automaticOptions) { throw HTTPError(404, "Not found"); } - return *current->handlers[method]; + return result.handler; } Trie::Trie(Trie &&rhs) { root_ = std::move(rhs.root_); } diff --git a/server/tests/http_integration_test b/server/tests/http_integration_test new file mode 100755 index 0000000..a31ab43 Binary files /dev/null and b/server/tests/http_integration_test differ diff --git a/server/tests/http_integration_test.cpp b/server/tests/http_integration_test.cpp new file mode 100644 index 0000000..2490670 --- /dev/null +++ b/server/tests/http_integration_test.cpp @@ -0,0 +1,306 @@ +#include "server.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { +void Check(bool condition, const char *message) { + if (!condition) { + std::cerr << "FAIL: " << message << '\n'; + std::exit(1); + } +} + +int PickPort() { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + throw std::runtime_error("socket failed"); + } + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + address.sin_port = 0; + if (bind(fd, reinterpret_cast(&address), sizeof(address)) < 0) { + close(fd); + throw std::runtime_error("bind failed"); + } + socklen_t length = sizeof(address); + if (getsockname(fd, reinterpret_cast(&address), &length) < 0) { + close(fd); + throw std::runtime_error("getsockname failed"); + } + int port = ntohs(address.sin_port); + close(fd); + return port; +} + +int Connect(int port) { + for (int attempt = 0; attempt < 100; ++attempt) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { + throw std::runtime_error("socket failed"); + } + + timeval timeout{}; + timeout.tv_sec = 2; + setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); + + sockaddr_in address{}; + address.sin_family = AF_INET; + address.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + address.sin_port = htons(port); + if (connect(fd, reinterpret_cast(&address), sizeof(address)) == + 0) { + return fd; + } + close(fd); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + throw std::runtime_error("connect failed"); +} + +void SendAll(int fd, std::string_view data) { + size_t sent = 0; + while (sent < data.size()) { + ssize_t result = send(fd, data.data() + sent, data.size() - sent, 0); + if (result <= 0) { + throw std::runtime_error("send failed"); + } + sent += static_cast(result); + } +} + +std::string ReadUntilClose(int fd) { + std::string data; + char buffer[2048]; + while (true) { + ssize_t result = recv(fd, buffer, sizeof(buffer), 0); + if (result > 0) { + data.append(buffer, static_cast(result)); + continue; + } + if (result == 0) { + return data; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + throw std::runtime_error("recv timed out"); + } + throw std::runtime_error("recv failed"); + } +} + +std::string ReadUntilContains(int fd, std::string_view needle) { + std::string data; + char buffer[2048]; + while (data.find(needle) == std::string::npos) { + ssize_t result = recv(fd, buffer, sizeof(buffer), 0); + if (result > 0) { + data.append(buffer, static_cast(result)); + continue; + } + if (result == 0) { + break; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + throw std::runtime_error("recv timed out"); + } + throw std::runtime_error("recv failed"); + } + return data; +} + +std::string RequestOnce(int port, std::string_view request) { + int fd = Connect(port); + SendAll(fd, request); + std::string response = ReadUntilClose(fd); + close(fd); + return response; +} + +size_t Count(std::string_view haystack, std::string_view needle) { + size_t count = 0; + size_t pos = 0; + while ((pos = haystack.find(needle, pos)) != std::string_view::npos) { + ++count; + pos += needle.size(); + } + return count; +} +} // namespace + +int main() { + try { + const int port = PickPort(); + HTTP::ServerBuilder builder; + builder.SetPort(port); + builder.SetThreads(1); + builder.AddRequest(HTTP::GET, "/echo", + [](const HTTP::RequestData &request) { + HTTP::ResponseData response; + response.status = 200; + auto it = request.params.find("msg"); + if (it != request.params.end()) { + response.body = it->second; + } + return response; + }); + builder.AddRequest(HTTP::POST, "/echo", + [](const HTTP::RequestData &request) { + HTTP::ResponseData response; + response.status = 200; + response.body = request.body; + if (auto trailer = request.trailers.find("Trace"); + trailer != request.trailers.end()) { + response.headers["X-Trailer"] = trailer->second; + } + return response; + }); + + auto server = builder.Build(); + auto serverFuture = server.Start(); + + { + std::string response = + RequestOnce(port, "GET /echo?msg=hello HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 200 OK") != std::string::npos, + "GET returns 200"); + Check(response.ends_with("hello"), "GET response body is returned"); + } + + { + std::string response = RequestOnce( + port, + "POST /echo HTTP/1.1\r\nHost: localhost\r\n" + "Transfer-Encoding: chunked\r\nConnection: close\r\n\r\n" + "5\r\nhello\r\n0\r\nTrace: yes\r\n\r\n"); + Check(response.find("HTTP/1.1 200 OK") != std::string::npos, + "chunked POST returns 200"); + Check(response.find("X-Trailer: yes") != std::string::npos, + "chunked trailer reaches handler"); + Check(response.ends_with("hello"), "chunked body is decoded"); + } + + { + int fd = Connect(port); + SendAll(fd, "GET /echo?msg=one HTTP/1.1\r\nHost: localhost\r\n\r\n" + "GET /echo?msg=two HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + std::string response = ReadUntilClose(fd); + close(fd); + Check(Count(response, "HTTP/1.1 200 OK") == 2, + "pipelined requests receive two responses"); + Check(response.find("one") != std::string::npos && + response.find("two") != std::string::npos, + "pipelined response bodies are present"); + } + + { + std::string response = + RequestOnce(port, "HEAD /echo?msg=head HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 200 OK") != std::string::npos, + "HEAD falls back to GET handler"); + Check(response.find("Content-Length: 4") != std::string::npos, + "HEAD preserves representation length"); + const size_t body = response.find("\r\n\r\n"); + Check(body != std::string::npos && body + 4 == response.size(), + "HEAD response has no body"); + } + + { + std::string response = + RequestOnce(port, "OPTIONS /echo HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 200 OK") != std::string::npos, + "OPTIONS is automatic"); + Check(response.find("Allow: GET, HEAD, OPTIONS, POST") != + std::string::npos, + "OPTIONS includes Allow"); + } + + { + std::string response = + RequestOnce(port, "OPTIONS * HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 200 OK") != std::string::npos, + "OPTIONS * is automatic"); + Check(response.find("Allow: GET, HEAD, OPTIONS, POST") != + std::string::npos, + "OPTIONS * includes server-wide Allow"); + } + + { + std::string response = + RequestOnce(port, "PUT /echo HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 405 Method Not Allowed") != + std::string::npos, + "unsupported method on existing path returns 405"); + Check(response.find("Allow: GET, HEAD, OPTIONS, POST") != + std::string::npos, + "405 includes Allow"); + } + + { + std::string response = + RequestOnce(port, "FOO /echo HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 501 Not Implemented") != std::string::npos, + "unknown valid method returns 501"); + } + + { + std::string response = + RequestOnce(port, "GET /missing HTTP/1.1\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 404 Not Found") != std::string::npos, + "missing path returns 404"); + } + + { + std::string response = + RequestOnce(port, "GET / HTTP/2.0\r\nHost: localhost\r\n" + "Connection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 505 HTTP Version Not Supported") != + std::string::npos, + "unsupported version returns 505"); + } + + { + std::string response = + RequestOnce(port, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n"); + Check(response.find("HTTP/1.1 400 Bad Request") != std::string::npos, + "missing Host returns 400"); + } + + { + int fd = Connect(port); + SendAll(fd, "POST /echo HTTP/1.1\r\nHost: localhost\r\n" + "Expect: 100-continue\r\nContent-Length: 5\r\n" + "Connection: close\r\n\r\n"); + std::string interim = ReadUntilContains(fd, "HTTP/1.1 100 Continue"); + Check(interim.find("HTTP/1.1 100 Continue") != std::string::npos, + "100 Continue is sent before body"); + SendAll(fd, "hello"); + std::string final = ReadUntilClose(fd); + close(fd); + Check(final.find("HTTP/1.1 200 OK") != std::string::npos, + "continued request gets final 200"); + Check(final.ends_with("hello"), "continued body reaches handler"); + } + } catch (const std::exception &error) { + std::cerr << "FAIL: " << error.what() << '\n'; + return 1; + } + return 0; +} diff --git a/server/tests/http_parser_test b/server/tests/http_parser_test new file mode 100755 index 0000000..2b0c5bf Binary files /dev/null and b/server/tests/http_parser_test differ diff --git a/server/tests/http_parser_test.cpp b/server/tests/http_parser_test.cpp new file mode 100644 index 0000000..fc51997 --- /dev/null +++ b/server/tests/http_parser_test.cpp @@ -0,0 +1,146 @@ +#include "http_parser.h" +#include +#include +#include +#include + +namespace { +void Check(bool condition, const char *message) { + if (!condition) { + std::cerr << "FAIL: " << message << '\n'; + std::exit(1); + } +} + +HTTP::HttpParseResult Parse(HTTP::HttpParserState &parser, + HTTP::RequestData &request) { + return parser.ParseNext(request); +} +} // namespace + +int main() { + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("GET /echo?msg=hello HTTP/1.1\r\nHost: example.com\r\n\r\n"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Complete, "valid GET parses"); + Check(request.method == HTTP::GET, "GET method is recognized"); + Check(request.path == "/echo", "origin path is extracted"); + Check(request.query == "msg=hello", "query is extracted"); + Check(request.params["msg"] == "hello", "query parameter is populated"); + Check(request.Header("host") && *request.Header("host") == "example.com", + "Header lookup is case-insensitive"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData first; + HTTP::RequestData second; + parser.Append("GET /one HTTP/1.1\r\nHost: a\r\n\r\n" + "GET /two HTTP/1.1\r\nHost: a\r\n\r\n"); + Check(Parse(parser, first).status == HTTP::HttpParseStatus::Complete, + "first pipelined request parses"); + Check(first.path == "/one", "first pipelined path"); + Check(Parse(parser, second).status == HTTP::HttpParseStatus::Complete, + "second pipelined request parses from buffered bytes"); + Check(second.path == "/two", "second pipelined path"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + const std::string raw = + "POST /echo HTTP/1.1\r\nHost: a\r\nContent-Length: 5\r\n\r\nhello"; + for (char ch : raw) { + parser.Append(std::string_view(&ch, 1)); + } + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Complete, + "one-byte-at-a-time request parses"); + Check(request.body == "hello", "one-byte-at-a-time body is read"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("POST /echo HTTP/1.1\r\nHost: a\r\n" + "Content-Length: 5\r\nContent-Length: 5\r\n\r\nhello"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Complete, + "duplicate equal Content-Length is accepted"); + Check(request.body == "hello", "fixed body is read"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("POST /echo HTTP/1.1\r\nHost: a\r\n" + "Content-Length: 5\r\nContent-Length: 7\r\n\r\nhello"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Error && + result.errorStatus == 400, + "conflicting Content-Length is rejected"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("POST /echo HTTP/1.1\r\nHost: a\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + "5;ignored=true\r\nhello\r\n0\r\nTrace: yes\r\n\r\n"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Complete, + "chunked body parses"); + Check(request.body == "hello", "chunked body is decoded"); + Check(request.trailers["Trace"] == "yes", "trailer is parsed"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("POST /echo HTTP/1.1\r\nHost: a\r\n" + "Expect: 100-continue\r\nContent-Length: 5\r\n\r\n"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::NeedContinue, + "Expect 100-continue is surfaced before body read"); + parser.MarkContinueSent(); + parser.Append("hello"); + result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Complete, + "request completes after continue and body"); + Check(request.body == "hello", "continued body is read"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("GET / HTTP/1.1\r\n\r\n"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Error && + result.errorStatus == 400, + "HTTP/1.1 without Host is rejected"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("GET / HTTP/2.0\r\nHost: a\r\n\r\n"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Error && + result.errorStatus == 505, + "unsupported HTTP version returns 505 parse error"); + } + + { + HTTP::HttpParserState parser; + HTTP::RequestData request; + parser.Append("GET / HTTP/1.1\r\nHost: a\r\n folded: no\r\n\r\n"); + auto result = Parse(parser, request); + Check(result.status == HTTP::HttpParseStatus::Error && + result.errorStatus == 400, + "obsolete folded header is rejected"); + } + + return 0; +}