From 6152a04c63cde2ec93ed5357a8d07b02cda8b30a Mon Sep 17 00:00:00 2001 From: Haoxiang Fei Date: Sat, 28 Feb 2026 22:32:50 +0800 Subject: [PATCH 1/2] feat: add HTTP/2 support with transparent ALPN negotiation - Add HPACK header compression (RFC 7541) with Huffman coding - Add HTTP/2 framing, flow control, settings, and stream management - Add H2ClientConnection and H2ServerConnection with background read loops - Integrate HTTP/2 into existing HTTP client/server via ALPN ("h2"/"http/1.1") - Add ALPN support for both OpenSSL and Windows SChannel - Add HTTP/2 trailer support and stream accessors - Fix use-after-free in TLS read loop after connection close --- src/hpack/dynamic_table.mbt | 120 ++++++++ src/hpack/hpack.mbt | 404 +++++++++++++++++++++++++ src/hpack/hpack_wbtest.mbt | 344 +++++++++++++++++++++ src/hpack/huffman.mbt | 464 ++++++++++++++++++++++++++++ src/hpack/moon.pkg | 3 + src/hpack/pkg.generated.mbti | 20 ++ src/hpack/static_table.mbt | 117 ++++++++ src/http/client.mbt | 177 ++++++++++- src/http/client_test.mbt | 18 +- src/http/moon.pkg | 3 + src/http/pkg.generated.mbti | 3 +- src/http/proxy_test.mbt | 10 +- src/http/request_test.mbt | 12 +- src/http/server.mbt | 238 +++++++++++++-- src/http/unimplemented.mbt | 3 + src/http2/client.mbt | 126 ++++++++ src/http2/connection.mbt | 566 +++++++++++++++++++++++++++++++++++ src/http2/flow_control.mbt | 51 ++++ src/http2/frame.mbt | 118 ++++++++ src/http2/moon.pkg | 22 ++ src/http2/pkg.generated.mbti | 131 ++++++++ src/http2/server.mbt | 141 +++++++++ src/http2/settings.mbt | 112 +++++++ src/http2/stream.mbt | 229 ++++++++++++++ src/http2/types.mbt | 198 ++++++++++++ src/http2/unimplemented.mbt | 25 ++ src/tls/moon.pkg | 1 + src/tls/openssl.c | 50 +++- src/tls/openssl.mbt | 85 ++++++ src/tls/pkg.generated.mbti | 9 +- src/tls/schannel.c | 101 ++++++- src/tls/schannel.mbt | 81 ++++- src/tls/tls.mbt | 3 +- src/tls/tls_test.mbt | 62 ++++ 34 files changed, 3976 insertions(+), 71 deletions(-) create mode 100644 src/hpack/dynamic_table.mbt create mode 100644 src/hpack/hpack.mbt create mode 100644 src/hpack/hpack_wbtest.mbt create mode 100644 src/hpack/huffman.mbt create mode 100644 src/hpack/moon.pkg create mode 100644 src/hpack/pkg.generated.mbti create mode 100644 src/hpack/static_table.mbt create mode 100644 src/http2/client.mbt create mode 100644 src/http2/connection.mbt create mode 100644 src/http2/flow_control.mbt create mode 100644 src/http2/frame.mbt create mode 100644 src/http2/moon.pkg create mode 100644 src/http2/pkg.generated.mbti create mode 100644 src/http2/server.mbt create mode 100644 src/http2/settings.mbt create mode 100644 src/http2/stream.mbt create mode 100644 src/http2/types.mbt create mode 100644 src/http2/unimplemented.mbt diff --git a/src/hpack/dynamic_table.mbt b/src/hpack/dynamic_table.mbt new file mode 100644 index 00000000..9bfba224 --- /dev/null +++ b/src/hpack/dynamic_table.mbt @@ -0,0 +1,120 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HPACK dynamic table implemented as a ring buffer. +/// New entries are added at the front (head), old entries are evicted from the back. +priv struct DynamicTable { + mut entries : Array[(String, String)] + mut head : Int + mut len : Int + mut size : Int + mut max_size : Int +} + +///| +/// Calculate the size of a header entry per RFC 7541 Section 4.1. +/// Size = 32 + length of name + length of value. +fn entry_size(name : String, value : String) -> Int { + 32 + name.length() + value.length() +} + +///| +/// Create a new dynamic table with the given maximum size. +fn DynamicTable::new(max_size? : Int = 4096) -> DynamicTable { + let capacity = 64 + { + entries: Array::make(capacity, ("", "")), + head: 0, + len: 0, + size: 0, + max_size, + } +} + +///| +/// Evict entries from the back of the table until size <= max_size. +fn DynamicTable::evict(self : DynamicTable) -> Unit { + while self.size > self.max_size && self.len > 0 { + // The oldest entry is at index (head + len - 1) % capacity + let capacity = self.entries.length() + let tail = (self.head + self.len - 1) % capacity + let (name, value) = self.entries[tail] + self.size = self.size - entry_size(name, value) + self.entries[tail] = ("", "") + self.len = self.len - 1 + } +} + +///| +/// Add a new entry at the front of the dynamic table. +/// Evicts entries from the back if the table would exceed max_size. +fn DynamicTable::add( + self : DynamicTable, + name : String, + value : String, +) -> Unit { + let es = entry_size(name, value) + // If the entry itself is larger than max_size, clear the table + if es > self.max_size { + self.len = 0 + self.size = 0 + self.head = 0 + return + } + // Evict until there's room + self.size = self.size + es + self.evict() + // Grow the backing array if needed + if self.len >= self.entries.length() { + let old_cap = self.entries.length() + let new_cap = old_cap * 2 + let new_entries : Array[(String, String)] = Array::make(new_cap, ("", "")) + for i = 0; i < self.len; i = i + 1 { + new_entries[i] = self.entries[(self.head + i) % old_cap] + } + self.entries = new_entries + self.head = 0 + } + // Insert at front: move head back by 1 + let capacity = self.entries.length() + self.head = (self.head - 1 + capacity) % capacity + self.entries[self.head] = (name, value) + self.len = self.len + 1 +} + +///| +/// Get an entry from the dynamic table by 0-based index. +/// Index 0 is the newest entry. +fn DynamicTable::get(self : DynamicTable, index : Int) -> (String, String)? { + if index < 0 || index >= self.len { + return None + } + let capacity = self.entries.length() + let actual = (self.head + index) % capacity + Some(self.entries[actual]) +} + +///| +/// Set a new maximum size for the dynamic table, evicting entries as needed. +fn DynamicTable::set_max_size(self : DynamicTable, new_max : Int) -> Unit { + self.max_size = new_max + self.evict() +} + +///| +/// Return the number of entries in the dynamic table. +fn DynamicTable::length(self : DynamicTable) -> Int { + self.len +} diff --git a/src/hpack/hpack.mbt b/src/hpack/hpack.mbt new file mode 100644 index 00000000..87753c40 --- /dev/null +++ b/src/hpack/hpack.mbt @@ -0,0 +1,404 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HPACK encoder. +struct Encoder { + dynamic_table : DynamicTable + huffman : Bool +} + +///| +/// HPACK decoder. +struct Decoder { + dynamic_table : DynamicTable +} + +///| +/// Headers that should use literal without indexing (sensitive). +let sensitive_headers : Array[String] = [ + "cookie", "authorization", "set-cookie", +] + +///| +/// Check if a header name is sensitive. +fn is_sensitive(name : String) -> Bool { + for i = 0; i < sensitive_headers.length(); i = i + 1 { + if sensitive_headers[i] == name { + return true + } + } + false +} + +// --- Integer encoding/decoding (RFC 7541 Section 5.1) --- + +///| +/// Encode an integer with the given prefix bits. +/// prefix_byte contains the non-prefix bits already set in the high bits. +fn encode_integer( + buf : @buffer.Buffer, + value : Int, + prefix_bits : Int, + prefix_byte : Byte, +) -> Unit { + let max_prefix = (1 << prefix_bits) - 1 + if value < max_prefix { + buf.write_byte((prefix_byte.to_int() | value).to_byte()) + } else { + buf.write_byte((prefix_byte.to_int() | max_prefix).to_byte()) + let mut v = value - max_prefix + while v >= 128 { + buf.write_byte((v % 128 + 128).to_byte()) + v = v / 128 + } + buf.write_byte(v.to_byte()) + } +} + +///| +/// Decode an integer with the given prefix bits from data at offset. +/// Returns (value, new_offset). +fn decode_integer( + data : Bytes, + offset : Int, + prefix_bits : Int, +) -> (Int, Int) raise { + guard offset < data.length() else { + raise Failure::Failure("decode_integer: unexpected end of data") + } + let max_prefix = (1 << prefix_bits) - 1 + let mut value = data[offset].to_int() & max_prefix + let mut pos = offset + 1 + if value < max_prefix { + return (value, pos) + } + let mut m = 0 + for { + guard pos < data.length() else { + raise Failure::Failure( + "decode_integer: unexpected end of data in continuation", + ) + } + let b = data[pos].to_int() + pos = pos + 1 + value = value + ((b & 127) << m) + m = m + 7 + if (b & 128) == 0 { + break + } + } + (value, pos) +} + +// --- String encoding/decoding (RFC 7541 Section 5.2) --- + +///| +/// Encode a string literal. If huffman is true, use Huffman encoding. +fn encode_string(buf : @buffer.Buffer, s : String, huffman : Bool) -> Unit { + let raw = string_to_bytes(s) + if huffman { + let encoded = huffman_encode(raw) + encode_integer(buf, encoded.length(), 7, b'\x80') + buf.write_bytes(encoded) + } else { + encode_integer(buf, raw.length(), 7, b'\x00') + buf.write_bytes(raw) + } +} + +///| +/// Decode a string literal from data at offset. +/// Returns (string, new_offset). +fn decode_string(data : Bytes, offset : Int) -> (String, Int) raise { + guard offset < data.length() else { + raise Failure::Failure("decode_string: unexpected end of data") + } + let is_huffman = (data[offset].to_int() & 0x80) != 0 + let (str_len, pos) = decode_integer(data, offset, 7) + guard pos + str_len <= data.length() else { + raise Failure::Failure("decode_string: string length exceeds data") + } + let raw = bytes_slice(data, pos, str_len) + let s = if is_huffman { + let decoded = huffman_decode(raw) + bytes_to_string(decoded) + } else { + bytes_to_string(raw) + } + (s, pos + str_len) +} + +// --- Helper functions for string <-> bytes conversion --- + +///| +/// Convert a String to Bytes (UTF-8 encoding). +fn string_to_bytes(s : String) -> Bytes { + let buf = @buffer.new() + for i = 0; i < s.length(); i = i + 1 { + let c = s[i] + let code = c.to_int() + if code < 0x80 { + buf.write_byte(code.to_byte()) + } else if code < 0x800 { + buf.write_byte((0xC0 | (code >> 6)).to_byte()) + buf.write_byte((0x80 | (code & 0x3F)).to_byte()) + } else if code < 0x10000 { + buf.write_byte((0xE0 | (code >> 12)).to_byte()) + buf.write_byte((0x80 | ((code >> 6) & 0x3F)).to_byte()) + buf.write_byte((0x80 | (code & 0x3F)).to_byte()) + } else { + buf.write_byte((0xF0 | (code >> 18)).to_byte()) + buf.write_byte((0x80 | ((code >> 12) & 0x3F)).to_byte()) + buf.write_byte((0x80 | ((code >> 6) & 0x3F)).to_byte()) + buf.write_byte((0x80 | (code & 0x3F)).to_byte()) + } + } + buf.contents() +} + +///| +/// Convert Bytes (UTF-8) to a String. +fn bytes_to_string(b : Bytes) -> String { + let arr : Array[Char] = [] + let mut i = 0 + while i < b.length() { + let byte0 = b[i].to_int() + if byte0 < 0x80 { + arr.push(byte0.unsafe_to_char()) + i = i + 1 + } else if (byte0 & 0xE0) == 0xC0 { + let code = ((byte0 & 0x1F) << 6) | (b[i + 1].to_int() & 0x3F) + arr.push(code.unsafe_to_char()) + i = i + 2 + } else if (byte0 & 0xF0) == 0xE0 { + let code = ((byte0 & 0x0F) << 12) | + ((b[i + 1].to_int() & 0x3F) << 6) | + (b[i + 2].to_int() & 0x3F) + arr.push(code.unsafe_to_char()) + i = i + 3 + } else { + let code = ((byte0 & 0x07) << 18) | + ((b[i + 1].to_int() & 0x3F) << 12) | + ((b[i + 2].to_int() & 0x3F) << 6) | + (b[i + 3].to_int() & 0x3F) + arr.push(code.unsafe_to_char()) + i = i + 4 + } + } + String::from_array(arr) +} + +///| +/// Extract a slice of bytes. +fn bytes_slice(data : Bytes, offset : Int, length : Int) -> Bytes { + data[offset:offset + length].to_bytes() +} + +// --- Encoder --- + +///| +/// Create a new HPACK encoder. +pub fn Encoder::new( + max_table_size? : Int = 4096, + huffman? : Bool = true, +) -> Encoder { + { dynamic_table: DynamicTable::new(max_size=max_table_size), huffman } +} + +///| +/// Look up a header in the static and dynamic tables. +/// Returns (index, has_value_match). +fn Encoder::find_header( + self : Encoder, + name : String, + value : String, +) -> (Int, Bool) { + // Check static table first for exact (name, value) match + match static_table_by_name_value.get(name) { + Some(value_map) => + match value_map.get(value) { + Some(index) => return (index, true) + None => () + } + None => () + } + // Check dynamic table for exact match + for i = 0; i < self.dynamic_table.length(); i = i + 1 { + match self.dynamic_table.get(i) { + Some((n, v)) => + if n == name && v == value { + return (static_table.length() + 1 + i, true) + } + None => () + } + } + // Check static table for name-only match + match static_table_by_name.get(name) { + Some(index) => return (index, false) + None => () + } + // Check dynamic table for name-only match + for i = 0; i < self.dynamic_table.length(); i = i + 1 { + match self.dynamic_table.get(i) { + Some((n, _)) => + if n == name { + return (static_table.length() + 1 + i, false) + } + None => () + } + } + (0, false) +} + +///| +/// Encode a list of headers into an HPACK header block. +pub fn Encoder::encode( + self : Encoder, + headers : Array[(String, String)], +) -> Bytes { + let buf = @buffer.new() + for i = 0; i < headers.length(); i = i + 1 { + let (name, value) = headers[i] + let (index, value_match) = self.find_header(name, value) + if index > 0 && value_match { + // Indexed header field representation (Section 6.1) + encode_integer(buf, index, 7, b'\x80') + } else if is_sensitive(name) { + // Literal header field without indexing (Section 6.2.2) + if index > 0 { + encode_integer(buf, index, 4, b'\x00') + } else { + buf.write_byte(b'\x00') + encode_string(buf, name, self.huffman) + } + encode_string(buf, value, self.huffman) + } else { + // Literal header field with incremental indexing (Section 6.2.1) + if index > 0 { + encode_integer(buf, index, 6, b'\x40') + } else { + buf.write_byte(b'\x40') + encode_string(buf, name, self.huffman) + } + encode_string(buf, value, self.huffman) + self.dynamic_table.add(name, value) + } + } + buf.contents() +} + +// --- Decoder --- + +///| +/// Create a new HPACK decoder. +pub fn Decoder::new(max_table_size? : Int = 4096) -> Decoder { + { dynamic_table: DynamicTable::new(max_size=max_table_size) } +} + +///| +/// Look up a header by HPACK index (1-based). +fn Decoder::lookup(self : Decoder, index : Int) -> (String, String) raise { + if index < 1 { + raise Failure::Failure("hpack: invalid index 0") + } + if index <= static_table.length() { + return static_table[index - 1] + } + let dyn_index = index - static_table.length() - 1 + match self.dynamic_table.get(dyn_index) { + Some(entry) => entry + None => + raise Failure::Failure( + "hpack: dynamic table index out of range: \{index}", + ) + } +} + +///| +/// Decode an HPACK header block into a list of headers. +pub fn Decoder::decode( + self : Decoder, + data : Bytes, +) -> Array[(String, String)] raise { + let headers : Array[(String, String)] = [] + let mut pos = 0 + while pos < data.length() { + let byte0 = data[pos].to_int() + if (byte0 & 0x80) != 0 { + // Indexed header field representation (Section 6.1) + let (index, new_pos) = decode_integer(data, pos, 7) + pos = new_pos + if index == 0 { + raise Failure::Failure("hpack: indexed header with index 0") + } + let entry = self.lookup(index) + headers.push(entry) + } else if (byte0 & 0xC0) == 0x40 { + // Literal header field with incremental indexing (Section 6.2.1) + let (index, new_pos) = decode_integer(data, pos, 6) + pos = new_pos + let name = if index > 0 { + let (n, _) = self.lookup(index) + n + } else { + let (n, new_pos2) = decode_string(data, pos) + pos = new_pos2 + n + } + let (value, new_pos3) = decode_string(data, pos) + pos = new_pos3 + self.dynamic_table.add(name, value) + headers.push((name, value)) + } else if (byte0 & 0xF0) == 0x00 { + // Literal header field without indexing (Section 6.2.2) + let (index, new_pos) = decode_integer(data, pos, 4) + pos = new_pos + let name = if index > 0 { + let (n, _) = self.lookup(index) + n + } else { + let (n, new_pos2) = decode_string(data, pos) + pos = new_pos2 + n + } + let (value, new_pos3) = decode_string(data, pos) + pos = new_pos3 + headers.push((name, value)) + } else if (byte0 & 0xF0) == 0x10 { + // Literal header field never indexed (Section 6.2.3) + let (index, new_pos) = decode_integer(data, pos, 4) + pos = new_pos + let name = if index > 0 { + let (n, _) = self.lookup(index) + n + } else { + let (n, new_pos2) = decode_string(data, pos) + pos = new_pos2 + n + } + let (value, new_pos3) = decode_string(data, pos) + pos = new_pos3 + headers.push((name, value)) + } else if (byte0 & 0xE0) == 0x20 { + // Dynamic table size update (Section 6.3) + let (new_size, new_pos) = decode_integer(data, pos, 5) + pos = new_pos + self.dynamic_table.set_max_size(new_size) + } else { + raise Failure::Failure("hpack: unknown header field representation") + } + } + headers +} diff --git a/src/hpack/hpack_wbtest.mbt b/src/hpack/hpack_wbtest.mbt new file mode 100644 index 00000000..301f89bc --- /dev/null +++ b/src/hpack/hpack_wbtest.mbt @@ -0,0 +1,344 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// --- Integer encoding/decoding tests --- + +///| +test "integer encode/decode round-trip small value" { + let buf = @buffer.new() + encode_integer(buf, 10, 5, b'\x00') + let data = buf.contents() + let (value, offset) = decode_integer(data, 0, 5) + inspect(value, content="10") + inspect(offset, content="1") +} + +///| +test "integer encode/decode round-trip large value" { + let buf = @buffer.new() + encode_integer(buf, 1337, 5, b'\x00') + let data = buf.contents() + let (value, offset) = decode_integer(data, 0, 5) + inspect(value, content="1337") + inspect(offset, content="3") +} + +///| +test "RFC 7541 C.1.1 - integer 10 with 5-bit prefix" { + let buf = @buffer.new() + encode_integer(buf, 10, 5, b'\x00') + let data = buf.contents() + // Value 10 fits in 5 bits (max_prefix = 31), so it's a single byte: 0x0a + inspect(data.length(), content="1") + inspect(data[0].to_int(), content="10") +} + +///| +test "RFC 7541 C.1.2 - integer 1337 with 5-bit prefix" { + let buf = @buffer.new() + encode_integer(buf, 1337, 5, b'\x00') + let data = buf.contents() + // 1337 with 5-bit prefix: 0x1f (31), 0x9a (154), 0x0a (10) + // 1337 - 31 = 1306, 1306 = 128*10 + 26, so: 0x1f, (26+128)=0x9a, 0x0a + inspect(data.length(), content="3") + inspect(data[0].to_int(), content="31") + inspect(data[1].to_int(), content="154") + inspect(data[2].to_int(), content="10") +} + +///| +test "integer encode/decode with prefix byte bits set" { + let buf = @buffer.new() + // Encode with top bit set (like indexed header 0x80) + encode_integer(buf, 5, 7, b'\x80') + let data = buf.contents() + // The byte should be 0x80 | 5 = 0x85 + inspect(data[0].to_int(), content="133") + let (value, _offset) = decode_integer(data, 0, 7) + inspect(value, content="5") +} + +// --- Huffman encoding/decoding tests --- + +///| +test "huffman encode/decode round-trip for 'www.example.com'" { + let input = string_to_bytes("www.example.com") + let encoded = huffman_encode(input) + let decoded = huffman_decode(encoded) + let result = bytes_to_string(decoded) + inspect(result, content="www.example.com") +} + +///| +test "huffman encode/decode round-trip for empty string" { + let input = string_to_bytes("") + let encoded = huffman_encode(input) + inspect(encoded.length(), content="0") + let decoded = huffman_decode(encoded) + inspect(decoded.length(), content="0") +} + +///| +test "huffman encode/decode round-trip for 'no-cache'" { + let input = string_to_bytes("no-cache") + let encoded = huffman_encode(input) + let decoded = huffman_decode(encoded) + let result = bytes_to_string(decoded) + inspect(result, content="no-cache") +} + +///| +test "huffman encode/decode round-trip for 'custom-key'" { + let input = string_to_bytes("custom-key") + let encoded = huffman_encode(input) + let decoded = huffman_decode(encoded) + let result = bytes_to_string(decoded) + inspect(result, content="custom-key") +} + +///| +test "huffman table key values" { + // 'w' (119) = (0x78, 7) + let (code_w, bits_w) = huffman_table[119] + inspect(code_w, content="120") + inspect(bits_w, content="7") + // 'o' (111) = (0x7, 5) + let (code_o, bits_o) = huffman_table[111] + inspect(code_o, content="7") + inspect(bits_o, content="5") + // '0' (48) = (0x0, 5) + let (code_0, bits_0) = huffman_table[48] + inspect(code_0, content="0") + inspect(bits_0, content="5") + // '1' (49) = (0x1, 5) + let (code_1, bits_1) = huffman_table[49] + inspect(code_1, content="1") + inspect(bits_1, content="5") + // EOS (256) = (0x3fffffff, 30) + let (code_eos, bits_eos) = huffman_table[256] + inspect(code_eos, content="1073741823") + inspect(bits_eos, content="30") +} + +// --- Dynamic table tests --- + +///| +test "dynamic table add and get" { + let dt = DynamicTable::new() + dt.add("custom-header", "custom-value") + inspect(dt.length(), content="1") + match dt.get(0) { + Some((name, value)) => { + inspect(name, content="custom-header") + inspect(value, content="custom-value") + } + None => inspect("should not be None", content="found entry") + } +} + +///| +test "dynamic table ordering - newest first" { + let dt = DynamicTable::new() + dt.add("first", "value1") + dt.add("second", "value2") + dt.add("third", "value3") + inspect(dt.length(), content="3") + // Index 0 = newest + match dt.get(0) { + Some((name, _)) => inspect(name, content="third") + None => inspect("unexpected", content="found") + } + match dt.get(1) { + Some((name, _)) => inspect(name, content="second") + None => inspect("unexpected", content="found") + } + match dt.get(2) { + Some((name, _)) => inspect(name, content="first") + None => inspect("unexpected", content="found") + } +} + +///| +test "dynamic table eviction" { + // Entry size = 32 + name.length() + value.length() + // "a" + "b" = 32 + 1 + 1 = 34 bytes + let dt = DynamicTable::new(max_size=70) + dt.add("a", "b") // 34 bytes, total = 34 + dt.add("c", "d") // 34 bytes, total = 68 + inspect(dt.length(), content="2") + dt.add("e", "f") // 34 bytes, would be 102, evict oldest until fits + // After eviction: should have 2 entries (evicted "a"/"b") + inspect(dt.length(), content="2") + match dt.get(0) { + Some((name, _)) => inspect(name, content="e") + None => inspect("unexpected", content="found") + } + match dt.get(1) { + Some((name, _)) => inspect(name, content="c") + None => inspect("unexpected", content="found") + } +} + +///| +test "dynamic table set_max_size eviction" { + let dt = DynamicTable::new() + dt.add("header1", "value1") + dt.add("header2", "value2") + inspect(dt.length(), content="2") + dt.set_max_size(0) + inspect(dt.length(), content="0") +} + +///| +test "dynamic table entry too large" { + let dt = DynamicTable::new(max_size=10) + dt.add("a", "b") + inspect(dt.length(), content="0") +} + +// --- Full HPACK encode/decode round-trip tests --- + +///| +test "hpack encode/decode simple headers" { + let encoder = Encoder::new(huffman=false) + let decoder = Decoder::new() + let headers : Array[(String, String)] = [ + (":method", "GET"), + (":path", "/"), + (":scheme", "https"), + ("host", "example.com"), + ] + let encoded = encoder.encode(headers) + let decoded = decoder.decode(encoded) + inspect(decoded.length(), content="4") + let (n0, v0) = decoded[0] + inspect(n0, content=":method") + inspect(v0, content="GET") + let (n1, v1) = decoded[1] + inspect(n1, content=":path") + inspect(v1, content="/") + let (n2, v2) = decoded[2] + inspect(n2, content=":scheme") + inspect(v2, content="https") + let (n3, v3) = decoded[3] + inspect(n3, content="host") + inspect(v3, content="example.com") +} + +///| +test "hpack encode/decode with huffman encoding" { + let encoder = Encoder::new(huffman=true) + let decoder = Decoder::new() + let headers : Array[(String, String)] = [ + (":method", "GET"), + (":path", "/index.html"), + ("custom-key", "custom-value"), + ] + let encoded = encoder.encode(headers) + let decoded = decoder.decode(encoded) + inspect(decoded.length(), content="3") + let (n0, v0) = decoded[0] + inspect(n0, content=":method") + inspect(v0, content="GET") + let (n1, v1) = decoded[1] + inspect(n1, content=":path") + inspect(v1, content="/index.html") + let (n2, v2) = decoded[2] + inspect(n2, content="custom-key") + inspect(v2, content="custom-value") +} + +///| +test "hpack encode/decode sensitive headers" { + let encoder = Encoder::new(huffman=false) + let decoder = Decoder::new() + let headers : Array[(String, String)] = [ + ("authorization", "Bearer token123"), + ("cookie", "session=abc"), + ] + let encoded = encoder.encode(headers) + let decoded = decoder.decode(encoded) + inspect(decoded.length(), content="2") + let (n0, v0) = decoded[0] + inspect(n0, content="authorization") + inspect(v0, content="Bearer token123") + let (n1, v1) = decoded[1] + inspect(n1, content="cookie") + inspect(v1, content="session=abc") + // Sensitive headers should NOT be added to the dynamic table + inspect(encoder.dynamic_table.length(), content="0") +} + +///| +test "hpack encode/decode multiple requests with dynamic table" { + let encoder = Encoder::new(huffman=false) + let decoder = Decoder::new() + // First request + let headers1 : Array[(String, String)] = [ + (":method", "GET"), + (":path", "/"), + ("custom-header", "value1"), + ] + let encoded1 = encoder.encode(headers1) + let decoded1 = decoder.decode(encoded1) + inspect(decoded1.length(), content="3") + // Second request - "custom-header" should be findable in dynamic table + let headers2 : Array[(String, String)] = [ + (":method", "GET"), + (":path", "/other"), + ("custom-header", "value1"), + ] + let encoded2 = encoder.encode(headers2) + let decoded2 = decoder.decode(encoded2) + inspect(decoded2.length(), content="3") + let (n2, v2) = decoded2[2] + inspect(n2, content="custom-header") + inspect(v2, content="value1") +} + +///| +test "static table lookups" { + match static_table_by_name.get(":method") { + Some(idx) => inspect(idx, content="2") + None => inspect("not found", content="found") + } + match static_table_by_name.get(":path") { + Some(idx) => inspect(idx, content="4") + None => inspect("not found", content="found") + } + match static_table_by_name_value.get(":method") { + Some(value_map) => + match value_map.get("GET") { + Some(idx) => inspect(idx, content="2") + None => inspect("not found", content="found") + } + None => inspect("not found", content="found") + } + match static_table_by_name_value.get(":status") { + Some(value_map) => + match value_map.get("200") { + Some(idx) => inspect(idx, content="8") + None => inspect("not found", content="found") + } + None => inspect("not found", content="found") + } +} + +///| +test "string to bytes and back round-trip" { + let original = "Hello, World!" + let bytes = string_to_bytes(original) + let result = bytes_to_string(bytes) + inspect(result, content="Hello, World!") +} diff --git a/src/hpack/huffman.mbt b/src/hpack/huffman.mbt new file mode 100644 index 00000000..c8fc4e98 --- /dev/null +++ b/src/hpack/huffman.mbt @@ -0,0 +1,464 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HPACK Huffman encoding table per RFC 7541 Appendix B. +/// 257 entries: symbols 0-255 plus EOS (256). +/// Each entry is (code, bit_length). +let huffman_table : Array[(Int, Int)] = [ + // 0-3 + (0x1ff8, 13), + (0x7fffd8, 23), + (0xfffffe2, 28), + (0xfffffe3, 28), + // 4-7 + (0xfffffe4, 28), + (0xfffffe5, 28), + (0xfffffe6, 28), + (0xfffffe7, 28), + // 8-11 + (0xfffffe8, 28), + (0xffffea, 24), + (0x3ffffffc, 30), + (0xfffffe9, 28), + // 12-15 + (0xfffffea, 28), + (0x3ffffffd, 30), + (0xfffffeb, 28), + (0xfffffec, 28), + // 16-19 + (0xfffffed, 28), + (0xfffffee, 28), + (0xfffffef, 28), + (0xffffff0, 28), + // 20-23 + (0xffffff1, 28), + (0xffffff2, 28), + (0x3ffffffe, 30), + (0xffffff3, 28), + // 24-27 + (0xffffff4, 28), + (0xffffff5, 28), + (0xffffff6, 28), + (0xffffff7, 28), + // 28-31 + (0xffffff8, 28), + (0xffffff9, 28), + (0xffffffa, 28), + (0xffffffb, 28), + // 32 ' ' - 35 '#' + (0x14, 6), + (0x3f8, 10), + (0x3f9, 10), + (0xffa, 12), + // 36 '$' - 39 '\'' + (0x1ff9, 13), + (0x15, 6), + (0xf8, 8), + (0x7fa, 11), + // 40 '(' - 43 '+' + (0x3fa, 10), + (0x3fb, 10), + (0xf9, 8), + (0x7fb, 11), + // 44 ',' - 47 '/' + (0xfa, 8), + (0x16, 6), + (0x17, 6), + (0x18, 6), + // 48 '0' - 51 '3' + (0x0, 5), + (0x1, 5), + (0x2, 5), + (0x19, 6), + // 52 '4' - 55 '7' + (0x1a, 6), + (0x1b, 6), + (0x1c, 6), + (0x1d, 6), + // 56 '8' - 59 ';' + (0x1e, 6), + (0x1f, 6), + (0x5c, 7), + (0xfb, 8), + // 60 '<' - 63 '?' + (0x7ffc, 15), + (0x20, 6), + (0xffb, 12), + (0x3fc, 10), + // 64 '@' - 67 'C' + (0x1ffa, 13), + (0x21, 6), + (0x5d, 7), + (0x5e, 7), + // 68 'D' - 71 'G' + (0x5f, 7), + (0x60, 7), + (0x61, 7), + (0x62, 7), + // 72 'H' - 75 'K' + (0x63, 7), + (0x64, 7), + (0x65, 7), + (0x66, 7), + // 76 'L' - 79 'O' + (0x67, 7), + (0x68, 7), + (0x69, 7), + (0x6a, 7), + // 80 'P' - 83 'S' + (0x6b, 7), + (0x6c, 7), + (0x6d, 7), + (0x6e, 7), + // 84 'T' - 87 'W' + (0x6f, 7), + (0x70, 7), + (0x71, 7), + (0x72, 7), + // 88 'X' - 91 '[' + (0xfc, 8), + (0x73, 7), + (0xfd, 8), + (0x1ffb, 13), + // 92 '\\' - 95 '_' + (0x7fff0, 19), + (0x1ffc, 13), + (0x3ffc, 14), + (0x22, 6), + // 96 '`' - 99 'c' + (0x7ffd, 15), + (0x3, 5), + (0x23, 6), + (0x4, 5), + // 100 'd' - 103 'g' + (0x24, 6), + (0x5, 5), + (0x25, 6), + (0x26, 6), + // 104 'h' - 107 'k' + (0x27, 6), + (0x6, 5), + (0x74, 7), + (0x75, 7), + // 108 'l' - 111 'o' + (0x28, 6), + (0x29, 6), + (0x2a, 6), + (0x7, 5), + // 112 'p' - 115 's' + (0x2b, 6), + (0x76, 7), + (0x2c, 6), + (0x8, 5), + // 116 't' - 119 'w' + (0x9, 5), + (0x2d, 6), + (0x77, 7), + (0x78, 7), + // 120 'x' - 123 '{' + (0x79, 7), + (0x7a, 7), + (0x7b, 7), + (0x7ffe, 15), + // 124 '|' - 127 + (0x7fc, 11), + (0x3ffd, 14), + (0x1ffd, 13), + (0xffffffc, 28), + // 128-131 + (0xfffe6, 20), + (0x3fffd2, 22), + (0xfffe7, 20), + (0xfffe8, 20), + // 132-135 + (0x3fffd3, 22), + (0x3fffd4, 22), + (0x3fffd5, 22), + (0x7fffd9, 23), + // 136-139 + (0x3fffd6, 22), + (0x7fffda, 23), + (0x7fffdb, 23), + (0x7fffdc, 23), + // 140-143 + (0x7fffdd, 23), + (0x7fffde, 23), + (0xffffeb, 24), + (0x7fffdf, 23), + // 144-147 + (0xffffec, 24), + (0xffffed, 24), + (0x3fffd7, 22), + (0x7fffe0, 23), + // 148-151 + (0xffffee, 24), + (0x7fffe1, 23), + (0x7fffe2, 23), + (0x7fffe3, 23), + // 152-155 + (0x7fffe4, 23), + (0x1fffdc, 21), + (0x3fffd8, 22), + (0x7fffe5, 23), + // 156-159 + (0x3fffd9, 22), + (0x7fffe6, 23), + (0x7fffe7, 23), + (0xffffef, 24), + // 160-163 + (0x3fffda, 22), + (0x1fffdd, 21), + (0xfffe9, 20), + (0x3fffdb, 22), + // 164-167 + (0x3fffdc, 22), + (0x7fffe8, 23), + (0x7fffe9, 23), + (0x1fffde, 21), + // 168-171 + (0x7fffea, 23), + (0x3fffdd, 22), + (0x3fffde, 22), + (0xfffff0, 24), + // 172-175 + (0x1fffdf, 21), + (0x3fffdf, 22), + (0x7fffeb, 23), + (0x7fffec, 23), + // 176-179 + (0x1fffe0, 21), + (0x1fffe1, 21), + (0x3fffe0, 22), + (0x1fffe2, 21), + // 180-183 + (0x7fffed, 23), + (0x3fffe1, 22), + (0x7fffee, 23), + (0x7fffef, 23), + // 184-187 + (0xfffea, 20), + (0x3fffe2, 22), + (0x3fffe3, 22), + (0x3fffe4, 22), + // 188-191 + (0x7ffff0, 23), + (0x3fffe5, 22), + (0x3fffe6, 22), + (0x7ffff1, 23), + // 192-195 + (0x3ffffe0, 26), + (0x3ffffe1, 26), + (0xfffeb, 20), + (0x7fff1, 19), + // 196-199 + (0x3fffe7, 22), + (0x7ffff2, 23), + (0x3fffe8, 22), + (0x1ffffec, 25), + // 200-203 + (0x3ffffe2, 26), + (0x3ffffe3, 26), + (0x3ffffe4, 26), + (0x7ffffde, 27), + // 204-207 + (0x7ffffdf, 27), + (0x3ffffe5, 26), + (0xfffff1, 24), + (0x1ffffed, 25), + // 208-211 + (0x7fff2, 19), + (0x1fffe3, 21), + (0x3ffffe6, 26), + (0x7ffffe0, 27), + // 212-215 + (0x7ffffe1, 27), + (0x3ffffe7, 26), + (0x7ffffe2, 27), + (0xfffff2, 24), + // 216-219 + (0x1fffe4, 21), + (0x1fffe5, 21), + (0x3ffffe8, 26), + (0x3ffffe9, 26), + // 220-223 + (0xffffffd, 28), + (0x7ffffe3, 27), + (0x7ffffe4, 27), + (0x7ffffe5, 27), + // 224-227 + (0xfffec, 20), + (0xfffff3, 24), + (0xfffed, 20), + (0x1fffe6, 21), + // 228-231 + (0x3fffe9, 22), + (0x1fffe7, 21), + (0x1fffe8, 21), + (0x7ffff3, 23), + // 232-235 + (0x3fffea, 22), + (0x3fffeb, 22), + (0x1ffffee, 25), + (0x1ffffef, 25), + // 236-239 + (0xfffff4, 24), + (0xfffff5, 24), + (0x3ffffea, 26), + (0x7ffff4, 23), + // 240-243 + (0x3ffffeb, 26), + (0x7ffffe6, 27), + (0x3ffffec, 26), + (0x3ffffed, 26), + // 244-247 + (0x7ffffe7, 27), + (0x7ffffe8, 27), + (0x7ffffe9, 27), + (0x7ffffea, 27), + // 248-251 + (0x7ffffeb, 27), + (0xffffffe, 28), + (0x7ffffec, 27), + (0x7ffffed, 27), + // 252-255 + (0x7ffffee, 27), + (0x7ffffef, 27), + (0x7fffff0, 27), + (0x3ffffee, 26), + // 256 EOS + (0x3fffffff, 30), +] + +///| +/// Huffman decode tree node. +priv enum HuffmanNode { + Empty + Leaf(Int) + Internal(HuffmanNode, HuffmanNode) +} + +///| +/// Build the Huffman decode tree from the encoding table. +fn build_huffman_tree() -> HuffmanNode { + fn insert( + node : HuffmanNode, + code : Int, + bits : Int, + symbol : Int, + ) -> HuffmanNode { + if bits == 0 { + return Leaf(symbol) + } + let bit = (code >> (bits - 1)) & 1 + let (left, right) = match node { + Empty => (Empty, Empty) + Internal(left, right) => (left, right) + Leaf(_) => abort("huffman tree: inserting into leaf node") + } + if bit == 0 { + Internal(insert(left, code, bits - 1, symbol), right) + } else { + Internal(left, insert(right, code, bits - 1, symbol)) + } + } + let mut root : HuffmanNode = Empty + for i = 0; i < huffman_table.length(); i = i + 1 { + let (code, bit_length) = huffman_table[i] + root = insert(root, code, bit_length, i) + } + root +} + +///| +/// Lazily built Huffman decode tree. +let huffman_decode_root : HuffmanNode = build_huffman_tree() + +///| +/// Encode bytes using HPACK Huffman coding. +/// Writes bits MSB first. Pads the final byte with 1-bits (EOS prefix). +fn huffman_encode(src : Bytes) -> Bytes { + let buf = @buffer.new() + let mut current_byte : Int = 0 + let mut bits_left : Int = 8 + for i = 0; i < src.length(); i = i + 1 { + let sym = src[i].to_int() + let (code, code_len) = huffman_table[sym] + let mut remaining = code_len + while remaining > 0 { + if remaining >= bits_left { + current_byte = current_byte | + ((code >> (remaining - bits_left)) & ((1 << bits_left) - 1)) + remaining = remaining - bits_left + buf.write_byte(current_byte.to_byte()) + current_byte = 0 + bits_left = 8 + } else { + current_byte = current_byte | + ((code & ((1 << remaining) - 1)) << (bits_left - remaining)) + bits_left = bits_left - remaining + remaining = 0 + } + } + } + if bits_left < 8 { + current_byte = current_byte | ((1 << bits_left) - 1) + buf.write_byte(current_byte.to_byte()) + } + buf.contents() +} + +///| +/// Decode HPACK Huffman-encoded bytes back to original. +fn huffman_decode(src : Bytes) -> Bytes raise { + let buf = @buffer.new() + let mut node = huffman_decode_root + let mut bits_remaining = 0 + let mut current_byte = 0 + let mut byte_index = 0 + let mut padding_bits = 0 + while byte_index < src.length() || bits_remaining > 0 { + if bits_remaining == 0 { + if byte_index >= src.length() { + break + } + current_byte = src[byte_index].to_int() + byte_index = byte_index + 1 + bits_remaining = 8 + } + let bit = (current_byte >> (bits_remaining - 1)) & 1 + bits_remaining = bits_remaining - 1 + node = match node { + Empty => abort("huffman_decode: unexpected empty node during traversal") + Leaf(_) => abort("huffman_decode: unexpected leaf during traversal") + Internal(left, right) => if bit == 0 { left } else { right } + } + match node { + Leaf(sym) => { + if sym == 256 { + raise Failure::Failure("huffman_decode: EOS symbol encountered") + } + buf.write_byte(sym.to_byte()) + padding_bits = 0 + node = huffman_decode_root + } + Internal(_, _) => padding_bits = padding_bits + 1 + Empty => abort("huffman_decode: reached empty node") + } + } + if padding_bits > 7 { + raise Failure::Failure("huffman_decode: padding exceeds 7 bits") + } + buf.contents() +} diff --git a/src/hpack/moon.pkg b/src/hpack/moon.pkg new file mode 100644 index 00000000..2412b4aa --- /dev/null +++ b/src/hpack/moon.pkg @@ -0,0 +1,3 @@ +import { + "moonbitlang/core/buffer", +} diff --git a/src/hpack/pkg.generated.mbti b/src/hpack/pkg.generated.mbti new file mode 100644 index 00000000..bd156c34 --- /dev/null +++ b/src/hpack/pkg.generated.mbti @@ -0,0 +1,20 @@ +// Generated using `moon info`, DON'T EDIT IT +package "moonbitlang/async/hpack" + +// Values + +// Errors + +// Types and methods +type Decoder +pub fn Decoder::decode(Self, Bytes) -> Array[(String, String)] raise +pub fn Decoder::new(max_table_size? : Int) -> Self + +type Encoder +pub fn Encoder::encode(Self, Array[(String, String)]) -> Bytes +pub fn Encoder::new(max_table_size? : Int, huffman? : Bool) -> Self + +// Type aliases + +// Traits + diff --git a/src/hpack/static_table.mbt b/src/hpack/static_table.mbt new file mode 100644 index 00000000..3ec23f44 --- /dev/null +++ b/src/hpack/static_table.mbt @@ -0,0 +1,117 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HPACK static table per RFC 7541 Appendix A. +/// 61 entries, 0-indexed in the array but HPACK uses 1-based indexing. +let static_table : Array[(String, String)] = [ + (":authority", ""), + (":method", "GET"), + (":method", "POST"), + (":path", "/"), + (":path", "/index.html"), + (":scheme", "http"), + (":scheme", "https"), + (":status", "200"), + (":status", "204"), + (":status", "206"), + (":status", "304"), + (":status", "400"), + (":status", "404"), + (":status", "500"), + ("accept-charset", ""), + ("accept-encoding", "gzip, deflate"), + ("accept-language", ""), + ("accept-ranges", ""), + ("accept", ""), + ("access-control-allow-origin", ""), + ("age", ""), + ("allow", ""), + ("authorization", ""), + ("cache-control", ""), + ("content-disposition", ""), + ("content-encoding", ""), + ("content-language", ""), + ("content-length", ""), + ("content-location", ""), + ("content-range", ""), + ("content-type", ""), + ("cookie", ""), + ("date", ""), + ("etag", ""), + ("expect", ""), + ("expires", ""), + ("from", ""), + ("host", ""), + ("if-match", ""), + ("if-modified-since", ""), + ("if-none-match", ""), + ("if-range", ""), + ("if-unmodified-since", ""), + ("last-modified", ""), + ("link", ""), + ("location", ""), + ("max-forwards", ""), + ("proxy-authenticate", ""), + ("proxy-authorization", ""), + ("range", ""), + ("referer", ""), + ("refresh", ""), + ("retry-after", ""), + ("server", ""), + ("set-cookie", ""), + ("strict-transport-security", ""), + ("transfer-encoding", ""), + ("user-agent", ""), + ("vary", ""), + ("via", ""), + ("www-authenticate", ""), +] + +///| +/// Reverse lookup: name -> (value -> 1-based index). +let static_table_by_name_value : Map[String, Map[String, Int]] = { + let m : Map[String, Map[String, Int]] = {} + for i = 0; i < static_table.length(); i = i + 1 { + let (name, value) = static_table[i] + let index = i + 1 + match m.get(name) { + Some(inner) => + // Only store the first occurrence for each (name, value) pair + if not(inner.contains(value)) { + inner[value] = index + } + None => { + let inner : Map[String, Int] = {} + inner[value] = index + m[name] = inner + } + } + } + m +} + +///| +/// Reverse lookup: name -> first 1-based index with that name. +let static_table_by_name : Map[String, Int] = { + let m : Map[String, Int] = {} + for i = 0; i < static_table.length(); i = i + 1 { + let (name, _) = static_table[i] + let index = i + 1 + if not(m.contains(name)) { + m[name] = index + } + } + m +} diff --git a/src/http/client.mbt b/src/http/client.mbt index 22fc3133..6b716a52 100644 --- a/src/http/client.mbt +++ b/src/http/client.mbt @@ -18,6 +18,17 @@ priv enum ClientTransport { Proxy(Client) } +///| +/// HTTP/2 client state, created when ALPN negotiation selects h2. +priv struct H2ClientState { + conn : @http2.H2ClientConnection + read_loop_coro : @coroutine.Coroutine + host : String + scheme : String + persistent_headers : Map[String, String] + mut stream : @http2.H2Stream? +} + ///| /// Simple HTTP client which connect to a remote host via TCP struct Client { @@ -25,6 +36,7 @@ struct Client { transport : ClientTransport tls : @tls.Tls? sender : Sender + h2 : H2ClientState? } ///| @@ -86,7 +98,7 @@ pub async fn Client::connect( (Http, Plain(conn)) => (None, Reader::new(conn), Sender::new(conn, headers~)) (Https, Plain(conn)) => { - let tls = @tls.Tls::client(conn, host~) catch { + let tls = @tls.Tls::client(conn, host~, alpn=["h2", "http/1.1"]) catch { err => { transport.close() raise err @@ -97,7 +109,7 @@ pub async fn Client::connect( (Http, Proxy(proxy)) => (None, Reader::new(proxy), Sender::new(proxy, headers~)) (Https, Proxy(proxy)) => { - let tls = @tls.Tls::client(proxy, host~) catch { + let tls = @tls.Tls::client(proxy, host~, alpn=["h2", "http/1.1"]) catch { err => { transport.close() raise err @@ -106,7 +118,42 @@ pub async fn Client::connect( (Some(tls), Reader::new(tls), Sender::new(tls, headers~)) } } - { transport, tls, reader, sender } + // Check if HTTP/2 was negotiated via ALPN + let h2 : H2ClientState? = match tls { + Some(t) => + if t.selected_alpn() is Some("h2") { + let h2_conn = @http2.H2ClientConnection::new(t, t) catch { + err => { + t.close() + transport.close() + raise err + } + } + let read_loop_coro = @coroutine.spawn(async fn() { + h2_conn.run_read_loop() catch { + _ => () + } + }) + let h2_headers : Map[String, String] = {} + for k, v in headers { + if k != "Host" { + h2_headers[k] = v + } + } + Some({ + conn: h2_conn, + read_loop_coro, + host, + scheme: "https", + persistent_headers: h2_headers, + stream: None, + }) + } else { + None + } + None => None + } + { transport, tls, reader, sender, h2 } } ///| @@ -158,6 +205,9 @@ fn ClientTransport::close(self : ClientTransport) -> Unit { /// This function is idempotent: it is safe to call `.close()` multiple times, /// only the first `.close()` call takes effect. pub fn Client::close(self : Client) -> Unit { + if self.h2 is Some(h2) { + h2.read_loop_coro.cancel() + } if self.tls is Some(tls) { tls.close() } @@ -166,12 +216,26 @@ pub fn Client::close(self : Client) -> Unit { ///| pub impl @io.Reader for Client with _direct_read(self, buf, offset~, max_len~) { - self.reader._direct_read(buf, offset~, max_len~) + if self.h2 is Some(h2) { + match h2.stream { + Some(stream) => stream._direct_read(buf, offset~, max_len~) + None => 0 + } + } else { + self.reader._direct_read(buf, offset~, max_len~) + } } ///| pub impl @io.Reader for Client with _get_internal_buffer(self) { - self.reader._get_internal_buffer() + if self.h2 is Some(h2) { + match h2.stream { + Some(stream) => stream._get_internal_buffer() + None => self.reader._get_internal_buffer() + } + } else { + self.reader._get_internal_buffer() + } } ///| @@ -182,20 +246,44 @@ pub impl @io.Reader for Client with _get_internal_buffer(self) { /// Writing to `@http.Client` MAY be buffered, /// call `flush` manually to ensure data is delivered to the remote peer. pub impl @io.Writer for Client with write_once(self, buf, offset~, len~) { - guard not(self.sender.mode is SendingHeader) - self.sender.write_once(buf, offset~, len~) + if self.h2 is Some(h2) { + guard h2.stream is Some(stream) + let data = buf[offset:offset + len].to_bytes() + stream.send_data(data, end_stream=false) + len + } else { + guard not(self.sender.mode is SendingHeader) + self.sender.write_once(buf, offset~, len~) + } } ///| pub impl @io.Writer for Client with write_reader(self, reader) { - guard not(self.sender.mode is SendingHeader) - self.sender.write_reader(reader) + if self.h2 is Some(_) { + // For H2, read from reader and write via self.write which calls write_once + let buf : FixedArray[Byte] = FixedArray::make(16384, b'\x00') + for { + let n = reader._direct_read(buf, offset=0, max_len=buf.length()) + if n == 0 { + break + } + self.write(buf.unsafe_reinterpret_as_bytes()[:n]) + } + } else { + guard not(self.sender.mode is SendingHeader) + self.sender.write_reader(reader) + } } ///| /// Flush buffered data in the request body being sent, if any. pub async fn Client::flush(self : Client) -> Unit { - self.sender.flush() + if self.h2 is Some(_) { + // H2 sends DATA frames immediately, no buffering + () + } else { + self.sender.flush() + } } ///| @@ -209,9 +297,29 @@ pub async fn Client::flush(self : Client) -> Unit { /// If the body of the last response is still not consumed, /// it will be discarded. pub async fn Client::end_request(self : Client) -> Response { - self.sender.end_body() - self.reader.skip_body() - self.reader.read_response() + if self.h2 is Some(h2) { + guard h2.stream is Some(stream) + // Send END_STREAM (empty DATA frame with END_STREAM flag) + stream.send_data(Bytes::make(0, b'\x00'), end_stream=true) + // Wait for response headers from the server + let headers = stream.wait_headers() + let code = match headers.get(":status") { + Some(status) => @strconv.parse_int(status, base=10) catch { _ => 200 } + None => 200 + } + // Build response headers excluding pseudo-headers + let response_headers : Map[String, String] = {} + for k, v in headers { + if k.length() > 0 && k[0] != ':' { + response_headers[k] = v + } + } + { code, reason: "", headers: response_headers } + } else { + self.sender.end_body() + self.reader.skip_body() + self.reader.read_response() + } } ///| @@ -237,14 +345,50 @@ pub async fn Client::request( path : String, extra_headers? : Map[String, String] = {}, ) -> Unit { - self.sender.send_request(meth, path, extra_headers~) + if self.h2 is Some(h2) { + let meth_str = match meth { + Get => "GET" + Head => "HEAD" + Post => "POST" + Put => "PUT" + Delete => "DELETE" + Connect => "CONNECT" + Options => "OPTIONS" + Trace => "TRACE" + Patch => "PATCH" + } + let all_headers : Map[String, String] = {} + for k, v in h2.persistent_headers { + all_headers[k] = v + } + for k, v in extra_headers { + all_headers[k] = v + } + let stream = h2.conn.send_request( + meth_str, + path, + scheme=h2.scheme, + headers=all_headers, + authority=h2.host, + ) + h2.stream = Some(stream) + } else { + self.sender.send_request(meth, path, extra_headers~) + } } ///| /// Skip the body of the response currently being produced, /// so that the next request can be made. pub async fn Client::skip_response_body(self : Client) -> Unit { - self.reader.skip_body() + if self.h2 is Some(_) { + // Drain the H2 stream's remaining data + while self.drop(1024) == 1024 { + + } + } else { + self.reader.skip_body() + } } ///| @@ -298,6 +442,9 @@ pub async fn Client::post( /// (i.e. not in the middle of sending a request). /// Unread data from the body of the last response will be discarded. pub async fn Client::enter_passthrough_mode(self : Client) -> Unit { + guard self.h2 is None else { + abort("enter_passthrough_mode is not supported for HTTP/2") + } guard self.sender.mode is SendingHeader self.reader.enter_passthrough_mode() self.sender.enter_passthrough_mode() diff --git a/src/http/client_test.mbt b/src/http/client_test.mbt index eccf28c1..80ee96a0 100644 --- a/src/http/client_test.mbt +++ b/src/http/client_test.mbt @@ -176,12 +176,14 @@ async test "request streaming" { ///| async test "multiple request" { - let client = @http.Client::new("https://www.moonbitlang.com") - defer client.close() - let response1 = client.get("/") - inspect(response1.code, content="200") - assert_true(client.read_all().text().has_prefix("")) - let response2 = client.get("/download") - inspect(response2.code, content="200") - assert_true(client.read_all().text().has_prefix("")) + @async.with_timeout(30000, () => { + let client = @http.Client::new("https://www.moonbitlang.com") + defer client.close() + let response1 = client.get("/") + inspect(response1.code, content="200") + assert_true(client.read_all().text().has_prefix("")) + let response2 = client.get("/download") + inspect(response2.code, content="200") + assert_true(client.read_all().text().has_prefix("")) + }) } diff --git a/src/http/moon.pkg b/src/http/moon.pkg index d733ccb4..70e255ab 100644 --- a/src/http/moon.pkg +++ b/src/http/moon.pkg @@ -5,8 +5,11 @@ import { "moonbitlang/async/io", "moonbitlang/async/socket", "moonbitlang/async/tls", + "moonbitlang/async/http2", "moonbitlang/async/internal/io_buffer", + "moonbitlang/async/internal/coroutine", "moonbitlang/async/js_async", + "moonbitlang/async", } import { diff --git a/src/http/pkg.generated.mbti b/src/http/pkg.generated.mbti index 3ebffb06..6e1e478f 100644 --- a/src/http/pkg.generated.mbti +++ b/src/http/pkg.generated.mbti @@ -4,6 +4,7 @@ package "moonbitlang/async/http" import { "moonbitlang/async/io", "moonbitlang/async/socket", + "moonbitlang/async/tls", } // Values @@ -115,7 +116,7 @@ type Server pub async fn Server::accept(Self) -> (ServerConnection, @socket.Addr) pub fn Server::addr(Self) -> @socket.Addr pub fn Server::close(Self) -> Unit -pub fn Server::new(@socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String]) -> Self raise +pub fn Server::new(@socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String], tls_certificate_file? : String, tls_certificate_type? : @tls.X509FileType, tls_private_key_file? : String, tls_private_key_type? : @tls.X509FileType) -> Self raise pub async fn Server::run_forever(Self, async (Request, &@io.Reader, ServerConnection) -> Unit, allow_failure? : Bool, max_connections? : Int) -> Unit type ServerConnection diff --git a/src/http/proxy_test.mbt b/src/http/proxy_test.mbt index eeffef35..194ec00e 100644 --- a/src/http/proxy_test.mbt +++ b/src/http/proxy_test.mbt @@ -144,10 +144,12 @@ async test "proxied http request" { handle_connect(log, req, body, conn) }) }) - let (_, result) = @http.get( - "http://www.example.org", - proxy=@http.Client::new("http://localhost:\{port}"), - ) + let (_, result) = @async.with_timeout(30000, () => { + @http.get( + "http://www.example.org", + proxy=@http.Client::new("http://localhost:\{port}"), + ) + }) assert_true(result.text().has_prefix("")) }) json_inspect(log, content=[ diff --git a/src/http/request_test.mbt b/src/http/request_test.mbt index b5dff04e..bf2bb5d8 100644 --- a/src/http/request_test.mbt +++ b/src/http/request_test.mbt @@ -14,14 +14,18 @@ ///| async test "https request" { - let (_, result) = @http.get("https://www.moonbitlang.com") - assert_true(result.text().has_prefix("")) + @async.with_timeout(30000, () => { + let (_, result) = @http.get("https://www.moonbitlang.com") + assert_true(result.text().has_prefix("")) + }) } ///| async test "http request" { - let (_, result) = @http.get("http://www.example.org") - assert_true(result.text().has_prefix("")) + @async.with_timeout(30000, () => { + let (_, result) = @http.get("http://www.example.org") + assert_true(result.text().has_prefix("")) + }) } ///| diff --git a/src/http/server.mbt b/src/http/server.mbt index e6b8abf6..bf4142a4 100644 --- a/src/http/server.mbt +++ b/src/http/server.mbt @@ -12,6 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +///| +priv struct H2ServerState { + server_conn : @http2.H2ServerConnection + stream : @http2.H2Stream + mut response_sent : Bool +} + ///| /// A single HTTP server connection struct ServerConnection { @@ -19,6 +26,7 @@ struct ServerConnection { conn : @socket.Tcp sender : Sender mut closed : Bool + h2 : H2ServerState? } ///| @@ -39,6 +47,7 @@ pub fn ServerConnection::new( conn, sender: Sender::new(conn, headers~), closed: false, + h2: None, } } @@ -49,7 +58,10 @@ pub fn ServerConnection::new( pub fn ServerConnection::close(self : ServerConnection) -> Unit { if !self.closed { self.closed = true - self.conn.close() + // For H2, don't close the TCP connection (shared across streams) + if self.h2 is None { + self.conn.close() + } } } @@ -66,12 +78,20 @@ pub impl @io.Reader for ServerConnection with _direct_read( offset~, max_len~, ) { - self.reader._direct_read(buf, offset~, max_len~) + if self.h2 is Some(h2) { + h2.stream._direct_read(buf, offset~, max_len~) + } else { + self.reader._direct_read(buf, offset~, max_len~) + } } ///| pub impl @io.Reader for ServerConnection with _get_internal_buffer(self) { - self.reader._get_internal_buffer() + if self.h2 is Some(h2) { + h2.stream._get_internal_buffer() + } else { + self.reader._get_internal_buffer() + } } ///| @@ -83,8 +103,12 @@ pub impl @io.Reader for ServerConnection with _get_internal_buffer(self) { /// the body of the request can be obtained by using `ServerConnection` /// as a `@io.Reader`. pub async fn ServerConnection::read_request(self : ServerConnection) -> Request { - self.reader.skip_body() - self.reader.read_request() + if self.h2 is Some(h2) { + build_h2_request(h2.stream.wait_headers()) + } else { + self.reader.skip_body() + self.reader.read_request() + } } ///| @@ -92,7 +116,13 @@ pub async fn ServerConnection::read_request(self : ServerConnection) -> Request pub async fn ServerConnection::skip_request_body( self : ServerConnection, ) -> Unit { - self.reader.skip_body() + if self.h2 is Some(_) { + while self.drop(1024) == 1024 { + + } + } else { + self.reader.skip_body() + } } ///| @@ -108,27 +138,53 @@ pub impl @io.Writer for ServerConnection with write_once( offset~, len~, ) { - guard not(self.sender.mode is SendingHeader) - self.sender.write_once(buf, offset~, len~) + if self.h2 is Some(h2) { + let data = buf[offset:offset + len].to_bytes() + h2.stream.send_data(data, end_stream=false) + len + } else { + guard not(self.sender.mode is SendingHeader) + self.sender.write_once(buf, offset~, len~) + } } ///| pub impl @io.Writer for ServerConnection with write_reader(self, reader) { - guard not(self.sender.mode is SendingHeader) - self.sender.write_reader(reader) + if self.h2 is Some(_) { + let buf : FixedArray[Byte] = FixedArray::make(16384, b'\x00') + for { + let n = reader._direct_read(buf, offset=0, max_len=buf.length()) + if n == 0 { + break + } + self.write(buf.unsafe_reinterpret_as_bytes()[:n]) + } + } else { + guard not(self.sender.mode is SendingHeader) + self.sender.write_reader(reader) + } } ///| /// Flush buffered data in the response body being sent, if any. pub async fn ServerConnection::flush(self : ServerConnection) -> Unit { - self.sender.flush() + if self.h2 is Some(_) { + () // H2 sends DATA frames immediately + } else { + self.sender.flush() + } } ///| /// End the body of the response currently being sent. /// Should be called immediately after response body is fully sent. pub async fn ServerConnection::end_response(self : ServerConnection) -> Unit { - self.sender.end_body() + if self.h2 is Some(h2) { + // Send empty DATA frame with END_STREAM + h2.stream.send_data(Bytes::make(0, b'\x00'), end_stream=true) + } else { + self.sender.end_body() + } } ///| @@ -142,7 +198,13 @@ pub async fn ServerConnection::send_response( reason : String, extra_headers? : Map[String, String] = {}, ) -> Unit { - self.sender.send_response(code, reason, extra_headers~) + if self.h2 is Some(h2) { + ignore(reason) + h2.server_conn.send_response(h2.stream, code, headers=extra_headers) + h2.response_sent = true + } else { + self.sender.send_response(code, reason, extra_headers~) + } } ///| @@ -185,11 +247,21 @@ pub async fn run_server( ) } +///| +/// TLS configuration for a HTTP server +priv struct TlsConfig { + certificate_file : String + certificate_type : @tls.X509FileType + private_key_file : String + private_key_type : @tls.X509FileType +} + ///| /// A HTTP server struct Server { server : @socket.TcpServer headers : Map[String, String] + tls : TlsConfig? } ///| @@ -205,8 +277,30 @@ pub fn Server::new( dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String] = {}, + tls_certificate_file? : String, + tls_certificate_type? : @tls.X509FileType = PEM, + tls_private_key_file? : String, + tls_private_key_type? : @tls.X509FileType = PEM, ) -> Server raise { - { server: @socket.TcpServer::new(addr, dual_stack?, reuse_addr?), headers } + let tls : TlsConfig? = match tls_certificate_file { + Some(cert) => + match tls_private_key_file { + Some(key) => + Some({ + certificate_file: cert, + certificate_type: tls_certificate_type, + private_key_file: key, + private_key_type: tls_private_key_type, + }) + None => None + } + None => None + } + { + server: @socket.TcpServer::new(addr, dual_stack?, reuse_addr?), + headers, + tls, + } } ///| @@ -225,6 +319,7 @@ pub fn Server::addr(self : Server) -> @socket.Addr { /// Return the new HTTP connection and the address of peer. pub async fn Server::accept(self : Server) -> (ServerConnection, @socket.Addr) { let (conn, addr) = self.server.accept() + // TLS + H2 negotiation is handled in run_forever; accept returns H1 only (ServerConnection::new(conn, headers=self.headers), addr) } @@ -265,16 +360,81 @@ pub async fn Server::run_forever( allow_failure? : Bool, max_connections? : Int, ) -> Unit { - self.server.run_forever(allow_failure?, max_connections?, (conn, _) => { - let conn = ServerConnection::new(conn, headers=self.headers) - defer conn.close() - for { - let request = conn.read_request() - f(request, conn, conn) - if conn.closed { - break - } else if conn.sender.mode is (WaitingBody | SendingBody) { - conn.end_response() + self.server.run_forever(allow_failure?, max_connections?, (tcp_conn, _) => { + match self.tls { + Some(tls_config) => { + let tls = @tls.Tls::server( + tcp_conn, + private_key_file=tls_config.private_key_file, + private_key_type=tls_config.private_key_type, + certificate_file=tls_config.certificate_file, + certificate_type=tls_config.certificate_type, + alpn=["h2", "http/1.1"], + ) + defer tls.close() + if tls.selected_alpn() is Some("h2") { + // HTTP/2 mode + let h2_conn = @http2.H2ServerConnection::new(tls, tls) + @async.with_task_group(async fn(group) { + group.spawn_bg(no_wait=true, allow_failure=true, async fn() { + h2_conn.run_read_loop() + }) + group.add_defer(async fn() { h2_conn.close() }) + for { + let stream = h2_conn.accept_stream() catch { _ => break } + let headers = stream.wait_headers() catch { _ => continue } + let request = build_h2_request(headers) + let sc : ServerConnection = { + reader: Reader::new(tls), + conn: tcp_conn, + sender: Sender::new(tls, headers=self.headers), + closed: false, + h2: Some({ server_conn: h2_conn, stream, response_sent: false }), + } + group.spawn_bg(allow_failure?, async fn() { + f(request, sc, sc) + if not(sc.closed) { + if sc.h2 is Some(h2) && h2.response_sent { + sc.end_response() + } + } + }) + } + }) + } else { + // HTTP/1.1 over TLS + let conn : ServerConnection = { + reader: Reader::new(tls), + conn: tcp_conn, + sender: Sender::new(tls, headers=self.headers), + closed: false, + h2: None, + } + defer conn.close() + for { + let request = conn.read_request() + f(request, conn, conn) + if conn.closed { + break + } else if conn.sender.mode is (WaitingBody | SendingBody) { + conn.end_response() + } + } + } + } + None => { + // Plain HTTP/1.1 (no TLS) + let conn = ServerConnection::new(tcp_conn, headers=self.headers) + defer conn.close() + for { + let request = conn.read_request() + f(request, conn, conn) + if conn.closed { + break + } else if conn.sender.mode is (WaitingBody | SendingBody) { + conn.end_response() + } + } } } }) @@ -294,7 +454,37 @@ pub async fn Server::run_forever( /// (i.e. not in the middle of sending a response). /// Unread data from the body of the last request will be discarded. pub async fn ServerConnection::enter_passthrough_mode(self : Self) -> Unit { + guard self.h2 is None else { + abort("enter_passthrough_mode is not supported for HTTP/2") + } guard self.sender.mode is SendingHeader self.reader.enter_passthrough_mode() self.sender.enter_passthrough_mode() } + +///| +fn build_h2_request(headers : Map[String, String]) -> Request { + let meth = match headers.get(":method") { + Some("GET") => Get + Some("HEAD") => Head + Some("POST") => Post + Some("PUT") => Put + Some("DELETE") => Delete + Some("CONNECT") => Connect + Some("OPTIONS") => Options + Some("TRACE") => Trace + Some("PATCH") => Patch + _ => Get + } + let path = match headers.get(":path") { + Some(p) => p + None => "/" + } + let request_headers : Map[String, String] = {} + for k, v in headers { + if k.length() > 0 && k[0] != ':' { + request_headers[k] = v + } + } + { meth, path, headers: request_headers } +} diff --git a/src/http/unimplemented.mbt b/src/http/unimplemented.mbt index 7cb12f3e..7c59c10e 100644 --- a/src/http/unimplemented.mbt +++ b/src/http/unimplemented.mbt @@ -21,6 +21,9 @@ pub let unimplemented : Unit = { ignore(@socket.unimplemented) ignore(@tls.unimplemented) ignore(@io_buffer.new) + ignore(@http2.unimplemented) + ignore(@coroutine.is_being_cancelled) + ignore(@async.sleep) } ///| diff --git a/src/http2/client.mbt b/src/http2/client.mbt new file mode 100644 index 00000000..c48c5b23 --- /dev/null +++ b/src/http2/client.mbt @@ -0,0 +1,126 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HTTP/2 client connection. +/// Wraps an H2Connection with client-specific behavior. +pub struct H2ClientConnection { + conn : H2Connection +} + +///| +/// Create a new HTTP/2 client connection. +/// Sends the client preface (magic + SETTINGS), +/// reads the server's SETTINGS, and sends SETTINGS ACK. +pub async fn[R : @io.Reader, W : @io.Writer] H2ClientConnection::new( + r : R, + w : W, +) -> H2ClientConnection { + let conn = H2Connection::new(r, w, is_client=true) + conn.local_settings.enable_push = false + conn.writer.write(client_preface) + conn.send_settings() + let frame = read_frame(conn.reader) + guard frame.header.frame_type is Settings else { + raise H2Error( + ProtocolError, + "expected SETTINGS frame, got \{frame.header.frame_type}", + ) + } + conn.remote_settings.apply(frame.payload) + conn.send_settings_ack() + let window_diff = conn.remote_settings.initial_window_size - + default_initial_window_size + if window_diff != 0 { + conn.send_window.update(window_diff) + } + H2ClientConnection::{ conn, } +} + +///| +/// Start the background read loop for this client connection. +/// Must be called after new() to begin processing incoming frames. +/// Typically called inside a task group's spawn_bg. +pub async fn H2ClientConnection::run_read_loop( + self : H2ClientConnection, +) -> Unit { + self.conn.read_loop() +} + +///| +/// Send a request on a new stream. +/// `meth` should be a method string like "GET", "POST", etc. +/// Returns the stream that can be used to read the response. +pub async fn H2ClientConnection::send_request( + self : H2ClientConnection, + meth : String, + path : String, + scheme? : String = "https", + headers? : Map[String, String] = {}, + authority? : String = "", +) -> H2Stream { + let stream = self.conn.alloc_stream() + let header_list : Array[(String, String)] = [] + header_list.push((":method", meth)) + header_list.push((":path", path)) + header_list.push((":scheme", scheme)) + if authority.length() > 0 { + header_list.push((":authority", authority)) + } + for name, value in headers { + header_list.push((name, value)) + } + self.conn.send_headers(stream.id, header_list) + stream +} + +///| +/// Send a request with body on a new stream. +pub async fn H2ClientConnection::send_request_with_body( + self : H2ClientConnection, + meth : String, + path : String, + body : Bytes, + scheme? : String = "https", + headers? : Map[String, String] = {}, + authority? : String = "", +) -> H2Stream { + let stream = self.conn.alloc_stream() + let header_list : Array[(String, String)] = [] + header_list.push((":method", meth)) + header_list.push((":path", path)) + header_list.push((":scheme", scheme)) + if authority.length() > 0 { + header_list.push((":authority", authority)) + } + for name, value in headers { + header_list.push((name, value)) + } + let end_stream = body.length() == 0 + self.conn.send_headers(stream.id, header_list, end_stream~) + if body.length() > 0 { + stream.send_data(body, end_stream=true) + } + stream +} + +///| +/// Close the HTTP/2 client connection gracefully. +pub async fn H2ClientConnection::close(self : H2ClientConnection) -> Unit { + if not(self.conn.goaway_sent) { + self.conn.send_goaway(NoError) + } + self.conn.closed = true + self.conn.incoming_streams.close(clear=true) +} diff --git a/src/http2/connection.mbt b/src/http2/connection.mbt new file mode 100644 index 00000000..f512d74d --- /dev/null +++ b/src/http2/connection.mbt @@ -0,0 +1,566 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HTTP/2 connection. +struct H2Connection { + reader : &@io.Reader + writer : &@io.Writer + write_lock : @semaphore.Semaphore + encoder : @hpack.Encoder + decoder : @hpack.Decoder + streams : Map[Int, H2Stream] + mut next_stream_id : Int + send_window : FlowWindow + mut recv_window : Int + local_settings : H2Settings + remote_settings : H2Settings + incoming_streams : @aqueue.Queue[H2Stream] + mut goaway_sent : Bool + mut last_stream_id : Int + mut closed : Bool + mut continuation_stream_id : Int + continuation_buf : @buffer.Buffer +} + +///| +/// Create a new H2Connection (base constructor). +fn H2Connection::new( + reader : &@io.Reader, + writer : &@io.Writer, + is_client~ : Bool, +) -> H2Connection { + { + reader, + writer, + write_lock: @semaphore.Semaphore::new(1), + encoder: @hpack.Encoder::new(), + decoder: @hpack.Decoder::new(), + streams: {}, + next_stream_id: if is_client { + 1 + } else { + 2 + }, + send_window: FlowWindow::new(default_initial_window_size), + recv_window: default_initial_window_size, + local_settings: H2Settings::default(), + remote_settings: H2Settings::default(), + incoming_streams: @aqueue.Queue::new(kind=Blocking(256)), + goaway_sent: false, + last_stream_id: 0, + closed: false, + continuation_stream_id: 0, + continuation_buf: @buffer.new(), + } +} + +///| +/// Write a frame while holding the write lock. +async fn H2Connection::locked_write_frame( + self : H2Connection, + frame_type : FrameType, + flags : Byte, + stream_id : Int, + payload : Bytes, +) -> Unit { + self.write_lock.acquire() + defer self.write_lock.release() + write_frame(self.writer, frame_type, flags, stream_id, payload) +} + +///| +/// Send a SETTINGS frame. +async fn H2Connection::send_settings(self : H2Connection) -> Unit { + let payload = self.local_settings.encode() + self.locked_write_frame(Settings, b'\x00', 0, payload) +} + +///| +/// Send a SETTINGS ACK frame. +async fn H2Connection::send_settings_ack(self : H2Connection) -> Unit { + self.locked_write_frame(Settings, flag_ack, 0, Bytes::make(0, b'\x00')) +} + +///| +/// Send a PING frame. +async fn H2Connection::send_ping( + self : H2Connection, + data : Bytes, + ack? : Bool = false, +) -> Unit { + let flags = if ack { flag_ack } else { b'\x00' } + self.locked_write_frame(Ping, flags, 0, data) +} + +///| +/// Send a GOAWAY frame. +async fn H2Connection::send_goaway( + self : H2Connection, + error_code : H2ErrorCode, + debug_data? : String = "", +) -> Unit { + self.goaway_sent = true + let buf = @buffer.new() + let last_id = self.last_stream_id + buf.write_int_be(last_id & 0x7fffffff) + buf.write_int_be(error_code.to_int()) + for i = 0; i < debug_data.length(); i = i + 1 { + buf.write_byte(debug_data[i].to_int().to_byte()) + } + self.locked_write_frame(GoAway, b'\x00', 0, buf.to_bytes()) +} + +///| +/// Send a RST_STREAM frame. +async fn H2Connection::send_rst_stream( + self : H2Connection, + stream_id : Int, + error_code : H2ErrorCode, +) -> Unit { + let payload = encode_u32(error_code.to_int()) + self.locked_write_frame(RstStream, b'\x00', stream_id, payload) +} + +///| +/// Send a WINDOW_UPDATE frame. +async fn H2Connection::send_window_update( + self : H2Connection, + stream_id : Int, + increment : Int, +) -> Unit { + let payload = encode_u32(increment & 0x7fffffff) + self.locked_write_frame(WindowUpdate, b'\x00', stream_id, payload) +} + +///| +/// Send HEADERS frame(s) for a stream. +/// HPACK encodes the headers, splits across HEADERS+CONTINUATION +/// if exceeding max_frame_size. Holds write_lock for entire sequence. +async fn H2Connection::send_headers( + self : H2Connection, + stream_id : Int, + headers : Array[(String, String)], + end_stream? : Bool = false, +) -> Unit { + let encoded = self.encoder.encode(headers) + let max_size = self.remote_settings.max_frame_size + self.write_lock.acquire() + defer self.write_lock.release() + if encoded.length() <= max_size { + let mut flags = flag_end_headers + if end_stream { + flags = flags.to_int().lor(flag_end_stream.to_int()).to_byte() + } + write_frame(self.writer, Headers, flags, stream_id, encoded) + } else { + let first_chunk = encoded[:max_size].to_bytes() + let mut es_flags : Byte = b'\x00' + if end_stream { + es_flags = flag_end_stream + } + write_frame(self.writer, Headers, es_flags, stream_id, first_chunk) + let mut offset = max_size + while offset < encoded.length() { + let remaining = encoded.length() - offset + let chunk_size = @cmp.minimum(remaining, max_size) + let is_last = offset + chunk_size >= encoded.length() + let cont_flags = if is_last { flag_end_headers } else { b'\x00' } + write_frame_from_slice( + self.writer, + Continuation, + cont_flags, + stream_id, + encoded, + offset, + chunk_size, + ) + offset = offset + chunk_size + } + } +} + +///| +/// Process a received HEADERS/CONTINUATION payload through HPACK. +fn H2Connection::decode_headers( + self : H2Connection, + payload : Bytes, +) -> Array[(String, String)] raise { + self.decoder.decode(payload) +} + +///| +/// Allocate a new stream (client-side). +fn H2Connection::alloc_stream(self : H2Connection) -> H2Stream { + let id = self.next_stream_id + self.next_stream_id = self.next_stream_id + 2 + let stream = H2Stream::new( + id, + Open, + self.remote_settings.initial_window_size, + self.local_settings.initial_window_size, + self, + ) + self.streams[id] = stream + if id > self.last_stream_id { + self.last_stream_id = id + } + stream +} + +///| +/// Get or create a stream for server-side incoming streams. +fn H2Connection::get_or_create_stream( + self : H2Connection, + stream_id : Int, +) -> H2Stream { + match self.streams.get(stream_id) { + Some(stream) => stream + None => { + let stream = H2Stream::new( + stream_id, + Open, + self.remote_settings.initial_window_size, + self.local_settings.initial_window_size, + self, + ) + self.streams[stream_id] = stream + if stream_id > self.last_stream_id { + self.last_stream_id = stream_id + } + stream + } + } +} + +///| +/// Background frame reading loop. +/// Reads frames and dispatches them to the appropriate stream queues. +async fn H2Connection::read_loop(self : H2Connection) -> Unit { + defer self.streams.clear() + defer self.incoming_streams.close(clear=true) + while not(self.closed) { + let frame = read_frame(self.reader) catch { + _ => { + self.closed = true + for _, stream in self.streams { + if stream.error is None { + stream.error = Some(NoError) + } + stream.inbound_data.close() + stream.headers_ready.broadcast() + stream.trailers_ready.broadcast() + } + self.incoming_streams.close(clear=true) + break + } + } + self.dispatch_frame(frame) + } +} + +///| +/// Dispatch a single received frame. +async fn H2Connection::dispatch_frame( + self : H2Connection, + frame : Frame, +) -> Unit { + let ft = frame.header.frame_type + let sid = frame.header.stream_id + let flags = frame.header.flags + let payload = frame.payload + // Handle CONTINUATION accumulation + if self.continuation_stream_id != 0 { + if not(ft is Continuation) || sid != self.continuation_stream_id { + self.send_goaway(ProtocolError, debug_data="expected CONTINUATION") catch { + _ => () + } + self.closed = true + return + } + self.continuation_buf.write_bytes(payload) + if has_flag(flags, flag_end_headers) { + let full_payload = self.continuation_buf.to_bytes() + self.continuation_buf.reset() + let csid = self.continuation_stream_id + self.continuation_stream_id = 0 + self.process_headers(csid, full_payload, has_flag(flags, flag_end_stream)) + } + return + } + match ft { + Data => self.handle_data(sid, flags, payload) + Headers => self.handle_headers(sid, flags, payload) + Priority => () // RFC 9113 deprecates priority. Accept and ignore. + RstStream => self.handle_rst_stream(sid, payload) + Settings => self.handle_settings(sid, flags, payload) + PushPromise => + self.send_goaway(ProtocolError, debug_data="PUSH_PROMISE not supported") catch { + _ => () + } + Ping => self.handle_ping(sid, flags, payload) + GoAway => () + // Could extract last_stream_id and error_code from payload + WindowUpdate => self.handle_window_update(sid, payload) + Continuation => + self.send_goaway(ProtocolError, debug_data="unexpected CONTINUATION") catch { + _ => () + } + Unknown(_) => () // Unknown frame types MUST be ignored + } +} + +///| +async fn H2Connection::handle_data( + self : H2Connection, + sid : Int, + flags : Byte, + payload : Bytes, +) -> Unit { + if sid == 0 { + self.send_goaway(ProtocolError, debug_data="DATA on stream 0") catch { + _ => () + } + return + } + match self.streams.get(sid) { + Some(stream) => { + stream.inbound_data.put(payload) catch { + _ => () + } + if has_flag(flags, flag_end_stream) { + stream.state = match stream.state { + Open => HalfClosedRemote + HalfClosedLocal => Closed + _ => stream.state + } + stream.inbound_data.put(Bytes::make(0, b'\x00')) catch { + _ => () + } + } + self.recv_window = self.recv_window - payload.length() + if self.recv_window < default_initial_window_size / 2 { + let increment = default_initial_window_size - self.recv_window + self.recv_window = self.recv_window + increment + self.send_window_update(0, increment) catch { + _ => () + } + } + } + None => self.send_rst_stream(sid, StreamClosed) catch { _ => () } + } +} + +///| +async fn H2Connection::handle_headers( + self : H2Connection, + sid : Int, + flags : Byte, + payload : Bytes, +) -> Unit { + if sid == 0 { + self.send_goaway(ProtocolError, debug_data="HEADERS on stream 0") catch { + _ => () + } + return + } + let mut header_payload = payload + let mut offset = 0 + if has_flag(flags, flag_padded) { + let pad_length = payload[0].to_int() + offset = 1 + header_payload = payload[offset:payload.length() - pad_length].to_bytes() + offset = 0 + } + if has_flag(flags, flag_priority) { + offset = offset + 5 + header_payload = header_payload[offset:].to_bytes() + } else if offset > 0 { + header_payload = header_payload[offset:].to_bytes() + } + if has_flag(flags, flag_end_headers) { + self.process_headers(sid, header_payload, has_flag(flags, flag_end_stream)) + } else { + self.continuation_stream_id = sid + self.continuation_buf.reset() + self.continuation_buf.write_bytes(header_payload) + } +} + +///| +async fn H2Connection::handle_rst_stream( + self : H2Connection, + sid : Int, + payload : Bytes, +) -> Unit { + if payload.length() != 4 { + self.send_goaway(FrameSizeError, debug_data="RST_STREAM bad size") catch { + _ => () + } + return + } + let error_code = H2ErrorCode::from_int(decode_u32(payload, 0)) + match self.streams.get(sid) { + Some(stream) => { + stream.error = Some(error_code) + stream.state = Closed + stream.inbound_data.close() + stream.headers_ready.broadcast() + stream.trailers_ready.broadcast() + } + None => () + } +} + +///| +async fn H2Connection::handle_settings( + self : H2Connection, + sid : Int, + flags : Byte, + payload : Bytes, +) -> Unit { + if sid != 0 { + self.send_goaway(ProtocolError, debug_data="SETTINGS on non-zero stream") catch { + _ => () + } + return + } + if has_flag(flags, flag_ack) { + return + } + self.remote_settings.apply(payload) + self.send_settings_ack() catch { + _ => () + } +} + +///| +async fn H2Connection::handle_ping( + self : H2Connection, + sid : Int, + flags : Byte, + payload : Bytes, +) -> Unit { + if sid != 0 { + self.send_goaway(ProtocolError, debug_data="PING on non-zero stream") catch { + _ => () + } + return + } + if payload.length() != 8 { + self.send_goaway(FrameSizeError, debug_data="PING bad size") catch { + _ => () + } + return + } + if not(has_flag(flags, flag_ack)) { + self.send_ping(payload, ack=true) catch { + _ => () + } + } +} + +///| +async fn H2Connection::handle_window_update( + self : H2Connection, + sid : Int, + payload : Bytes, +) -> Unit { + if payload.length() != 4 { + self.send_goaway(FrameSizeError, debug_data="WINDOW_UPDATE bad size") catch { + _ => () + } + return + } + let increment = decode_u31(payload, 0) + if increment == 0 { + if sid == 0 { + self.send_goaway(ProtocolError, debug_data="WINDOW_UPDATE increment 0") catch { + _ => () + } + } else { + self.send_rst_stream(sid, ProtocolError) catch { + _ => () + } + } + return + } + if sid == 0 { + self.send_window.update(increment) + } else { + match self.streams.get(sid) { + Some(stream) => stream.send_window.update(increment) + None => () + } + } +} + +///| +/// Process decoded headers for a stream. +async fn H2Connection::process_headers( + self : H2Connection, + stream_id : Int, + payload : Bytes, + end_stream : Bool, +) -> Unit { + let headers = self.decode_headers(payload) catch { + _ => { + self.send_rst_stream(stream_id, CompressionError) catch { + _ => () + } + return + } + } + let header_map : Map[String, String] = {} + for i = 0; i < headers.length(); i = i + 1 { + let (name, value) = headers[i] + header_map[name] = value + } + let stream = self.get_or_create_stream(stream_id) + let is_new = stream.headers is None + if is_new { + // Initial headers + stream.headers = Some(header_map) + stream.headers_ready.broadcast() + } else { + // Trailing headers (HEADERS after DATA) + stream.trailers = Some(header_map) + stream.trailers_ready.broadcast() + } + if end_stream { + stream.state = match stream.state { + Open => HalfClosedRemote + HalfClosedLocal => Closed + _ => stream.state + } + if is_new { + // Headers-only request (no DATA), signal end of data + stream.inbound_data.put(Bytes::make(0, b'\x00')) catch { + _ => () + } + } else { + // Trailers signal end of data too + stream.inbound_data.put(Bytes::make(0, b'\x00')) catch { + _ => () + } + } + // Wake trailers_ready in case someone is waiting and no trailers came + stream.trailers_ready.broadcast() + } + if is_new && stream_id % 2 == 1 { + self.incoming_streams.put(stream) catch { + _ => () + } + } +} diff --git a/src/http2/flow_control.mbt b/src/http2/flow_control.mbt new file mode 100644 index 00000000..953fab03 --- /dev/null +++ b/src/http2/flow_control.mbt @@ -0,0 +1,51 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// Flow control window for HTTP/2. +/// Maps flow control directly to coroutine suspension — no busy waiting. +struct FlowWindow { + mut size : Int + cond : @cond_var.Cond +} + +///| +/// Create a new flow control window. +fn FlowWindow::new(initial_size : Int) -> FlowWindow { + { size: initial_size, cond: @cond_var.Cond::new() } +} + +///| +/// Consume bytes from the flow window. +/// If the window is exhausted (size <= 0), suspends the coroutine +/// until a WINDOW_UPDATE arrives. +/// Returns the number of bytes actually consumed (min of n and available). +async fn FlowWindow::consume(self : FlowWindow, n : Int) -> Int { + while self.size <= 0 { + self.cond.wait() + } + let consumed = @cmp.minimum(n, self.size) + self.size = self.size - consumed + consumed +} + +///| +/// Update the flow window by incrementing its size. +/// Wakes any blocked senders. +fn FlowWindow::update(self : FlowWindow, increment : Int) -> Unit { + self.size = self.size + increment + if self.size > 0 { + self.cond.broadcast() + } +} diff --git a/src/http2/frame.mbt b/src/http2/frame.mbt new file mode 100644 index 00000000..9f176020 --- /dev/null +++ b/src/http2/frame.mbt @@ -0,0 +1,118 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// Read a single HTTP/2 frame from a reader. +/// Returns the frame header and payload. +async fn read_frame(reader : &@io.Reader) -> Frame { + let header_bytes = reader.read_exactly(9) + let length = (header_bytes[0].to_int() << 16) | + (header_bytes[1].to_int() << 8) | + header_bytes[2].to_int() + let frame_type = FrameType::from_byte(header_bytes[3]) + let flags = header_bytes[4] + let stream_id = ((header_bytes[5].to_int() & 0x7f) << 24) | + (header_bytes[6].to_int() << 16) | + (header_bytes[7].to_int() << 8) | + header_bytes[8].to_int() + let payload = if length > 0 { + reader.read_exactly(length) + } else { + Bytes::make(0, b'\x00') + } + { header: { length, frame_type, flags, stream_id }, payload } +} + +///| +/// Write a frame to a writer. +async fn write_frame( + writer : &@io.Writer, + frame_type : FrameType, + flags : Byte, + stream_id : Int, + payload : Bytes, +) -> Unit { + let length = payload.length() + let header = FixedArray::make(9, b'\x00') + header[0] = ((length >> 16) & 0xff).to_byte() + header[1] = ((length >> 8) & 0xff).to_byte() + header[2] = length.land(0xff).to_byte() + header[3] = frame_type.to_byte() + header[4] = flags + header[5] = ((stream_id >> 24) & 0x7f).to_byte() + header[6] = ((stream_id >> 16) & 0xff).to_byte() + header[7] = ((stream_id >> 8) & 0xff).to_byte() + header[8] = stream_id.land(0xff).to_byte() + writer.write(header.unsafe_reinterpret_as_bytes()) + if length > 0 { + writer.write(payload) + } +} + +///| +/// Write a frame header followed by data from a slice of bytes. +/// Avoids copying for DATA frames. +async fn write_frame_from_slice( + writer : &@io.Writer, + frame_type : FrameType, + flags : Byte, + stream_id : Int, + data : Bytes, + offset : Int, + length : Int, +) -> Unit { + let header = FixedArray::make(9, b'\x00') + header[0] = ((length >> 16) & 0xff).to_byte() + header[1] = ((length >> 8) & 0xff).to_byte() + header[2] = length.land(0xff).to_byte() + header[3] = frame_type.to_byte() + header[4] = flags + header[5] = ((stream_id >> 24) & 0x7f).to_byte() + header[6] = ((stream_id >> 16) & 0xff).to_byte() + header[7] = ((stream_id >> 8) & 0xff).to_byte() + header[8] = stream_id.land(0xff).to_byte() + writer.write(header.unsafe_reinterpret_as_bytes()) + if length > 0 { + writer.write(data[offset:offset + length]) + } +} + +///| +/// Encode a 32-bit unsigned integer into 4 bytes (big-endian). +fn encode_u32(value : Int) -> Bytes { + let buf = FixedArray::make(4, b'\x00') + buf[0] = ((value >> 24) & 0xff).to_byte() + buf[1] = ((value >> 16) & 0xff).to_byte() + buf[2] = ((value >> 8) & 0xff).to_byte() + buf[3] = value.land(0xff).to_byte() + buf.unsafe_reinterpret_as_bytes() +} + +///| +/// Decode a 32-bit unsigned integer from 4 bytes (big-endian). +fn decode_u32(data : Bytes, offset : Int) -> Int { + (data[offset].to_int() << 24) | + (data[offset + 1].to_int() << 16) | + (data[offset + 2].to_int() << 8) | + data[offset + 3].to_int() +} + +///| +/// Decode a 31-bit unsigned integer from 4 bytes (big-endian), masking top bit. +fn decode_u31(data : Bytes, offset : Int) -> Int { + ((data[offset].to_int() & 0x7f) << 24) | + (data[offset + 1].to_int() << 16) | + (data[offset + 2].to_int() << 8) | + data[offset + 3].to_int() +} diff --git a/src/http2/moon.pkg b/src/http2/moon.pkg new file mode 100644 index 00000000..3cd1cfa8 --- /dev/null +++ b/src/http2/moon.pkg @@ -0,0 +1,22 @@ +import { + "moonbitlang/core/buffer", + "moonbitlang/core/cmp", + "moonbitlang/async/io", + "moonbitlang/async/hpack", + "moonbitlang/async/semaphore", + "moonbitlang/async/cond_var", + "moonbitlang/async/aqueue", +} + +options( + targets: { + "client.mbt": [ "native" ], + "connection.mbt": [ "native" ], + "flow_control.mbt": [ "native" ], + "frame.mbt": [ "native" ], + "server.mbt": [ "native" ], + "settings.mbt": [ "native" ], + "stream.mbt": [ "native" ], + "types.mbt": [ "native" ], + }, +) diff --git a/src/http2/pkg.generated.mbti b/src/http2/pkg.generated.mbti new file mode 100644 index 00000000..c1b45eb0 --- /dev/null +++ b/src/http2/pkg.generated.mbti @@ -0,0 +1,131 @@ +// Generated using `moon info`, DON'T EDIT IT +package "moonbitlang/async/http2" + +import { + "moonbitlang/async/aqueue", + "moonbitlang/async/cond_var", + "moonbitlang/async/io", +} + +// Values + +// Errors +pub suberror H2Error { + H2Error(H2ErrorCode, String) +} +pub impl Show for H2Error +pub impl ToJson for H2Error + +// Types and methods +type FlowWindow + +pub(all) struct Frame { + header : FrameHeader + payload : Bytes +} +pub impl Show for Frame + +pub(all) struct FrameHeader { + length : Int + frame_type : FrameType + flags : Byte + stream_id : Int +} +pub impl Show for FrameHeader + +pub(all) enum FrameType { + Data + Headers + Priority + RstStream + Settings + PushPromise + Ping + GoAway + WindowUpdate + Continuation + Unknown(Byte) +} +pub impl Eq for FrameType +pub impl Show for FrameType + +pub struct H2ClientConnection { + conn : H2Connection +} +pub async fn H2ClientConnection::close(Self) -> Unit +pub async fn[R : @io.Reader, W : @io.Writer] H2ClientConnection::new(R, W) -> Self +pub async fn H2ClientConnection::run_read_loop(Self) -> Unit +pub async fn H2ClientConnection::send_request(Self, String, String, scheme? : String, headers? : Map[String, String], authority? : String) -> H2Stream +pub async fn H2ClientConnection::send_request_with_body(Self, String, String, Bytes, scheme? : String, headers? : Map[String, String], authority? : String) -> H2Stream + +type H2Connection + +pub(all) enum H2ErrorCode { + NoError + ProtocolError + InternalError + FlowControlError + SettingsTimeout + StreamClosed + FrameSizeError + RefusedStream + Cancel + CompressionError + ConnectError + EnhanceYourCalm + InadequateSecurity + Http11Required +} +pub impl Eq for H2ErrorCode +pub impl Show for H2ErrorCode +pub impl ToJson for H2ErrorCode + +pub struct H2ServerConnection { + conn : H2Connection +} +pub async fn H2ServerConnection::accept_stream(Self) -> H2Stream +pub async fn H2ServerConnection::close(Self) -> Unit +pub async fn[R : @io.Reader, W : @io.Writer] H2ServerConnection::new(R, W) -> Self +pub async fn H2ServerConnection::run_read_loop(Self) -> Unit +pub async fn H2ServerConnection::send_data(Self, H2Stream, Bytes, end_stream? : Bool) -> Unit +pub async fn H2ServerConnection::send_response(Self, H2Stream, Int, headers? : Map[String, String], end_stream? : Bool) -> Unit +pub async fn H2ServerConnection::send_trailers(Self, H2Stream, Map[String, String]) -> Unit + +pub struct H2Stream { + id : Int + mut state : StreamState + send_window : FlowWindow + mut recv_window : Int + inbound_data : @aqueue.Queue[Bytes] + mut headers : Map[String, String]? + headers_ready : @cond_var.Cond + mut trailers : Map[String, String]? + trailers_ready : @cond_var.Cond + mut error : H2ErrorCode? + conn : H2Connection + read_buf : @io.ReaderBuffer +} +pub fn H2Stream::get_error(Self) -> H2ErrorCode? +pub async fn H2Stream::send_data(Self, Bytes, end_stream? : Bool) -> Unit +pub async fn H2Stream::send_trailers(Self, Map[String, String]) -> Unit +pub fn H2Stream::stream_id(Self) -> Int +pub async fn H2Stream::wait_headers(Self) -> Map[String, String] +pub async fn H2Stream::wait_trailers(Self) -> Map[String, String]? +pub impl @io.Reader for H2Stream + +pub(all) enum StreamState { + Idle + ReservedLocal + ReservedRemote + Open + HalfClosedLocal + HalfClosedRemote + Closed +} +pub impl Eq for StreamState +pub impl Show for StreamState + +// Type aliases + +// Traits + diff --git a/src/http2/server.mbt b/src/http2/server.mbt new file mode 100644 index 00000000..3eaceabe --- /dev/null +++ b/src/http2/server.mbt @@ -0,0 +1,141 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HTTP/2 server connection. +/// Wraps an H2Connection with server-specific behavior. +pub struct H2ServerConnection { + conn : H2Connection +} + +///| +/// Create a new HTTP/2 server connection. +/// Reads and verifies the client preface magic, +/// reads client SETTINGS, sends server SETTINGS, +/// sends SETTINGS ACK, and spawns the background read loop. +pub async fn[R : @io.Reader, W : @io.Writer] H2ServerConnection::new( + r : R, + w : W, +) -> H2ServerConnection { + let conn = H2Connection::new(r, w, is_client=false) + // Disable server push + conn.local_settings.enable_push = false + // Read and verify client preface magic (24 bytes) + let magic = conn.reader.read_exactly(client_preface.length()) + if magic != client_preface { + raise H2Error(ProtocolError, "invalid client preface") + } + // Read client's SETTINGS + let frame = read_frame(conn.reader) + guard frame.header.frame_type is Settings else { + raise H2Error( + ProtocolError, + "expected SETTINGS frame after preface, got \{frame.header.frame_type}", + ) + } + conn.remote_settings.apply(frame.payload) + // Send server SETTINGS + conn.send_settings() + // Send SETTINGS ACK for client's settings + conn.send_settings_ack() + // Update send window based on remote settings + let window_diff = conn.remote_settings.initial_window_size - + default_initial_window_size + if window_diff != 0 { + conn.send_window.update(window_diff) + } + H2ServerConnection::{ conn, } +} + +///| +/// Start the background read loop for this server connection. +pub async fn H2ServerConnection::run_read_loop( + self : H2ServerConnection, +) -> Unit { + self.conn.read_loop() +} + +///| +/// Accept the next incoming stream from a client. +/// Blocks until a new stream with headers arrives. +pub async fn H2ServerConnection::accept_stream( + self : H2ServerConnection, +) -> H2Stream { + self.conn.incoming_streams.get() +} + +///| +/// Send response headers on a stream. +pub async fn H2ServerConnection::send_response( + self : H2ServerConnection, + stream : H2Stream, + status : Int, + headers? : Map[String, String] = {}, + end_stream? : Bool = false, +) -> Unit { + let header_list : Array[(String, String)] = [] + header_list.push((":status", status.to_string())) + for name, value in headers { + header_list.push((name, value)) + } + self.conn.send_headers(stream.id, header_list, end_stream~) + if end_stream { + stream.state = match stream.state { + Open => HalfClosedLocal + HalfClosedRemote => Closed + _ => stream.state + } + } +} + +///| +/// Send response data on a stream. +pub async fn H2ServerConnection::send_data( + self : H2ServerConnection, + stream : H2Stream, + data : Bytes, + end_stream? : Bool = false, +) -> Unit { + ignore(self) + stream.send_data(data, end_stream~) + if end_stream { + stream.state = match stream.state { + Open => HalfClosedLocal + HalfClosedRemote => Closed + _ => stream.state + } + } +} + +///| +/// Send trailing headers on a stream with END_STREAM. +/// Must be called after all DATA frames have been sent. +pub async fn H2ServerConnection::send_trailers( + self : H2ServerConnection, + stream : H2Stream, + trailers : Map[String, String], +) -> Unit { + ignore(self) + stream.send_trailers(trailers) +} + +///| +/// Close the HTTP/2 server connection gracefully. +pub async fn H2ServerConnection::close(self : H2ServerConnection) -> Unit { + if not(self.conn.goaway_sent) { + self.conn.send_goaway(NoError) + } + self.conn.closed = true + self.conn.incoming_streams.close(clear=true) +} diff --git a/src/http2/settings.mbt b/src/http2/settings.mbt new file mode 100644 index 00000000..af55bf4a --- /dev/null +++ b/src/http2/settings.mbt @@ -0,0 +1,112 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HTTP/2 settings per RFC 7540 Section 6.5. +priv struct H2Settings { + mut header_table_size : Int + mut enable_push : Bool + mut max_concurrent_streams : Int + mut initial_window_size : Int + mut max_frame_size : Int + mut max_header_list_size : Int +} + +// Settings identifiers + +///| +let settings_header_table_size : Int = 0x1 + +///| +let settings_enable_push : Int = 0x2 + +///| +let settings_max_concurrent_streams : Int = 0x3 + +///| +let settings_initial_window_size : Int = 0x4 + +///| +let settings_max_frame_size : Int = 0x5 + +///| +let settings_max_header_list_size : Int = 0x6 + +///| +/// Create default HTTP/2 settings. +fn H2Settings::default() -> H2Settings { + { + header_table_size: default_header_table_size, + enable_push: true, + max_concurrent_streams: 100, + initial_window_size: default_initial_window_size, + max_frame_size: default_max_frame_size, + max_header_list_size: default_max_header_list_size, + } +} + +///| +/// Encode settings into payload bytes for a SETTINGS frame. +fn H2Settings::encode(self : H2Settings) -> Bytes { + let buf = @buffer.new() + fn write_setting(id : Int, value : Int) { + buf.write_int16_be(Int16::from_int(id)) + buf.write_int_be(value) + } + write_setting(settings_header_table_size, self.header_table_size) + write_setting(settings_enable_push, if self.enable_push { 1 } else { 0 }) + write_setting(settings_max_concurrent_streams, self.max_concurrent_streams) + write_setting(settings_initial_window_size, self.initial_window_size) + write_setting(settings_max_frame_size, self.max_frame_size) + write_setting(settings_max_header_list_size, self.max_header_list_size) + buf.to_bytes() +} + +///| +/// Apply settings from a SETTINGS frame payload. +fn H2Settings::apply(self : H2Settings, payload : Bytes) -> Unit raise { + if payload.length() % 6 != 0 { + raise H2Error(FrameSizeError, "SETTINGS payload not a multiple of 6") + } + let mut pos = 0 + while pos + 6 <= payload.length() { + let id = (payload[pos].to_int() << 8) | payload[pos + 1].to_int() + let value = (payload[pos + 2].to_int() << 24) | + (payload[pos + 3].to_int() << 16) | + (payload[pos + 4].to_int() << 8) | + payload[pos + 5].to_int() + match id { + _ if id == settings_header_table_size => self.header_table_size = value + _ if id == settings_enable_push => self.enable_push = value != 0 + _ if id == settings_max_concurrent_streams => + self.max_concurrent_streams = value + _ if id == settings_initial_window_size => { + if value > 0x7fffffff { + raise H2Error(FlowControlError, "INITIAL_WINDOW_SIZE exceeds maximum") + } + self.initial_window_size = value + } + _ if id == settings_max_frame_size => { + if value < 16384 || value > 16777215 { + raise H2Error(ProtocolError, "MAX_FRAME_SIZE out of range") + } + self.max_frame_size = value + } + _ if id == settings_max_header_list_size => + self.max_header_list_size = value + _ => () // Unknown settings identifiers MUST be ignored (RFC 7540 Section 6.5.2) + } + pos = pos + 6 + } +} diff --git a/src/http2/stream.mbt b/src/http2/stream.mbt new file mode 100644 index 00000000..2b0793f0 --- /dev/null +++ b/src/http2/stream.mbt @@ -0,0 +1,229 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// A single HTTP/2 stream. +pub struct H2Stream { + id : Int + mut state : StreamState + send_window : FlowWindow + mut recv_window : Int + inbound_data : @aqueue.Queue[Bytes] + mut headers : Map[String, String]? + headers_ready : @cond_var.Cond + mut trailers : Map[String, String]? + trailers_ready : @cond_var.Cond + mut error : H2ErrorCode? + conn : H2Connection + read_buf : @io.ReaderBuffer +} + +///| +/// Create a new stream. +fn H2Stream::new( + id : Int, + state : StreamState, + initial_send_window : Int, + initial_recv_window : Int, + conn : H2Connection, +) -> H2Stream { + { + id, + state, + send_window: FlowWindow::new(initial_send_window), + recv_window: initial_recv_window, + inbound_data: @aqueue.Queue::new(kind=Blocking(64)), + headers: None, + headers_ready: @cond_var.Cond::new(), + trailers: None, + trailers_ready: @cond_var.Cond::new(), + error: None, + conn, + read_buf: @io.ReaderBuffer::new(), + } +} + +///| +/// Get the stream ID. +pub fn H2Stream::stream_id(self : H2Stream) -> Int { + self.id +} + +///| +/// Get the stream error, if any. +pub fn H2Stream::get_error(self : H2Stream) -> H2ErrorCode? { + self.error +} + +///| +/// Wait for response headers to be available. +pub async fn H2Stream::wait_headers(self : H2Stream) -> Map[String, String] { + while self.headers is None { + if self.error is Some(code) { + raise H2Error(code, "stream \{self.id} error") + } + self.headers_ready.wait() + } + // Check for error again after waking + if self.error is Some(code) { + raise H2Error(code, "stream \{self.id} error") + } + self.headers.unwrap() +} + +///| +/// Wait for trailers to be available. +/// Trailers are HEADERS frames received after DATA frames (with END_STREAM). +/// Returns `None` if the stream ended without trailers. +pub async fn H2Stream::wait_trailers(self : H2Stream) -> Map[String, String]? { + // Wait until the remote side is done sending (HalfClosedRemote or Closed) + while self.trailers is None { + if self.error is Some(code) { + raise H2Error(code, "stream \{self.id} error") + } + if self.state is (HalfClosedRemote | Closed) { + return None + } + self.trailers_ready.wait() + } + if self.error is Some(code) { + raise H2Error(code, "stream \{self.id} error") + } + self.trailers +} + +///| +/// Send trailing headers on this stream with END_STREAM. +/// Must be called after all DATA frames have been sent. +pub async fn H2Stream::send_trailers( + self : H2Stream, + trailers : Map[String, String], +) -> Unit { + let header_list : Array[(String, String)] = [] + for name, value in trailers { + header_list.push((name, value)) + } + self.conn.send_headers(self.id, header_list, end_stream=true) + self.state = match self.state { + Open => HalfClosedLocal + HalfClosedRemote => Closed + _ => self.state + } +} + +///| +/// Read from the stream's inbound data queue. +/// Implements the Reader interface for consuming DATA payloads. +/// Sends WINDOW_UPDATE when recv window drops below threshold. +pub impl @io.Reader for H2Stream with _direct_read(self, buf, offset~, max_len~) { + // Check for errors, but allow draining buffered data first + if self.error is Some(code) { + // If there's still buffered data, serve it before erroring + match self.inbound_data.try_get() { + Some(data) if data.length() > 0 => { + let n = @cmp.minimum(data.length(), max_len) + buf.blit_from_bytes(offset, data, 0, n) + if n < data.length() { + self.inbound_data.put(data[n:].to_bytes()) catch { + _ => () + } + } + return n + } + _ => () + } + // NoError means the stream was gracefully reset — treat as EOF + if code is NoError { + return 0 + } + raise H2Error(code, "stream \{self.id} error") + } + let data = if self.state is (HalfClosedRemote | Closed) { + match self.inbound_data.try_get() { + Some(data) => data + None => return 0 + } + } else { + self.inbound_data.get() catch { + _ => return 0 + } + } + if data.length() == 0 { + return 0 + } + let n = @cmp.minimum(data.length(), max_len) + buf.blit_from_bytes(offset, data, 0, n) + // If we didn't consume all the data, put remainder back + if n < data.length() { + self.inbound_data.put(data[n:].to_bytes()) catch { + _ => () + } + } + // Send WINDOW_UPDATE if recv window drops below half of initial + self.recv_window = self.recv_window - n + let threshold = default_initial_window_size / 2 + if self.recv_window < threshold { + let increment = default_initial_window_size - self.recv_window + self.recv_window = self.recv_window + increment + // Send stream-level WINDOW_UPDATE + self.conn.send_window_update(self.id, increment) catch { + _ => () + } + } + n +} + +///| +pub impl @io.Reader for H2Stream with _get_internal_buffer(self) { + self.read_buf +} + +///| +/// Send data on this stream, respecting flow control. +/// Fragments into max_frame_size chunks. +pub async fn H2Stream::send_data( + self : H2Stream, + data : Bytes, + end_stream? : Bool = false, +) -> Unit { + let mut offset = 0 + let remaining = data.length() + while offset < remaining { + let chunk_size = @cmp.minimum( + remaining - offset, + self.conn.remote_settings.max_frame_size, + ) + // Consume from both stream-level and connection-level flow windows + let allowed_stream = self.send_window.consume(chunk_size) + let allowed_conn = self.conn.send_window.consume(allowed_stream) + let send_size = allowed_conn + let is_last = offset + send_size >= remaining && end_stream + let flags = if is_last { flag_end_stream } else { b'\x00' } + self.conn.locked_write_frame( + Data, + flags, + self.id, + data[offset:offset + send_size].to_bytes(), + ) + offset = offset + send_size + } + if remaining == 0 && end_stream { + self.conn.locked_write_frame( + Data, + flag_end_stream, + self.id, + Bytes::make(0, b'\x00'), + ) + } +} diff --git a/src/http2/types.mbt b/src/http2/types.mbt new file mode 100644 index 00000000..72f0b7cb --- /dev/null +++ b/src/http2/types.mbt @@ -0,0 +1,198 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// HTTP/2 frame types per RFC 7540 Section 6. +pub(all) enum FrameType { + Data + Headers + Priority + RstStream + Settings + PushPromise + Ping + GoAway + WindowUpdate + Continuation + Unknown(Byte) +} derive(Show, Eq) + +///| +fn FrameType::from_byte(b : Byte) -> FrameType { + match b.to_int() { + 0x0 => Data + 0x1 => Headers + 0x2 => Priority + 0x3 => RstStream + 0x4 => Settings + 0x5 => PushPromise + 0x6 => Ping + 0x7 => GoAway + 0x8 => WindowUpdate + 0x9 => Continuation + _ => Unknown(b) + } +} + +///| +fn FrameType::to_byte(self : FrameType) -> Byte { + match self { + Data => b'\x00' + Headers => b'\x01' + Priority => b'\x02' + RstStream => b'\x03' + Settings => b'\x04' + PushPromise => b'\x05' + Ping => b'\x06' + GoAway => b'\x07' + WindowUpdate => b'\x08' + Continuation => b'\x09' + Unknown(b) => b + } +} + +///| +/// HTTP/2 error codes per RFC 7540 Section 7. +pub(all) enum H2ErrorCode { + NoError + ProtocolError + InternalError + FlowControlError + SettingsTimeout + StreamClosed + FrameSizeError + RefusedStream + Cancel + CompressionError + ConnectError + EnhanceYourCalm + InadequateSecurity + Http11Required +} derive(Show, Eq, ToJson) + +///| +fn H2ErrorCode::from_int(code : Int) -> H2ErrorCode { + match code { + 0x0 => NoError + 0x1 => ProtocolError + 0x2 => InternalError + 0x3 => FlowControlError + 0x4 => SettingsTimeout + 0x5 => StreamClosed + 0x6 => FrameSizeError + 0x7 => RefusedStream + 0x8 => Cancel + 0x9 => CompressionError + 0xa => ConnectError + 0xb => EnhanceYourCalm + 0xc => InadequateSecurity + 0xd => Http11Required + _ => ProtocolError + } +} + +///| +fn H2ErrorCode::to_int(self : H2ErrorCode) -> Int { + match self { + NoError => 0x0 + ProtocolError => 0x1 + InternalError => 0x2 + FlowControlError => 0x3 + SettingsTimeout => 0x4 + StreamClosed => 0x5 + FrameSizeError => 0x6 + RefusedStream => 0x7 + Cancel => 0x8 + CompressionError => 0x9 + ConnectError => 0xa + EnhanceYourCalm => 0xb + InadequateSecurity => 0xc + Http11Required => 0xd + } +} + +///| +/// HTTP/2 stream states per RFC 7540 Section 5.1. +pub(all) enum StreamState { + Idle + ReservedLocal + ReservedRemote + Open + HalfClosedLocal + HalfClosedRemote + Closed +} derive(Show, Eq) + +///| +/// HTTP/2 frame header (9 bytes on the wire). +pub(all) struct FrameHeader { + length : Int + frame_type : FrameType + flags : Byte + stream_id : Int +} derive(Show) + +///| +/// HTTP/2 frame = header + payload. +pub(all) struct Frame { + header : FrameHeader + payload : Bytes +} derive(Show) + +// Frame flag constants + +///| +let flag_end_stream : Byte = b'\x01' + +///| +let flag_ack : Byte = b'\x01' + +///| +let flag_end_headers : Byte = b'\x04' + +///| +let flag_padded : Byte = b'\x08' + +///| +let flag_priority : Byte = b'\x20' + +///| +/// HTTP/2 client connection preface magic string. +let client_preface : Bytes = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + +// Default settings constants + +///| +let default_header_table_size : Int = 4096 + +///| +let default_initial_window_size : Int = 65535 + +///| +let default_max_frame_size : Int = 16384 + +///| +let default_max_header_list_size : Int = 0x7fffffff + +///| +/// HTTP/2 connection/stream error. +pub suberror H2Error { + H2Error(H2ErrorCode, String) +} derive(Show, ToJson) + +///| +/// Check if a flag is set. +fn has_flag(flags : Byte, flag : Byte) -> Bool { + (flags.to_int() & flag.to_int()) != 0 +} diff --git a/src/http2/unimplemented.mbt b/src/http2/unimplemented.mbt new file mode 100644 index 00000000..494c35e3 --- /dev/null +++ b/src/http2/unimplemented.mbt @@ -0,0 +1,25 @@ +// Copyright 2025 International Digital Economy Academy +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +///| +/// This package currently does not support JavaScript or WebAssembly backends +#cfg(not(target="native")) +#coverage.skip +pub let unimplemented : Unit = { + ignore(@io.pipe) + ignore(@hpack.Encoder::encode) + ignore(@semaphore.Semaphore::release) + ignore(@cond_var.Cond::new) + ignore(@aqueue.Kind::Unbounded) +} diff --git a/src/tls/moon.pkg b/src/tls/moon.pkg index 75b7fdb5..30af0dec 100644 --- a/src/tls/moon.pkg +++ b/src/tls/moon.pkg @@ -1,5 +1,6 @@ import { "moonbitlang/core/encoding/utf8", + "moonbitlang/core/buffer", "moonbitlang/core/cmp", "moonbitlang/async/io", "moonbitlang/async/semaphore", diff --git a/src/tls/openssl.c b/src/tls/openssl.c index 2a6d2a02..62b34b2a 100644 --- a/src/tls/openssl.c +++ b/src/tls/openssl.c @@ -72,7 +72,10 @@ typedef struct SSL_METHOD SSL_METHOD; IMPORT_FUNC(unsigned long, ERR_get_error, (void))\ IMPORT_FUNC(char *, ERR_error_string, (unsigned long e, char *buf))\ IMPORT_FUNC(int, RAND_bytes, (unsigned char *buf, int num))\ - IMPORT_FUNC(unsigned char *, SHA1, (const unsigned char *d, size_t n, unsigned char *md)) + IMPORT_FUNC(unsigned char *, SHA1, (const unsigned char *d, size_t n, unsigned char *md))\ + IMPORT_FUNC(int, SSL_set_alpn_protos, (SSL *ssl, const unsigned char *protos, unsigned int protos_len))\ + IMPORT_FUNC(void, SSL_get0_alpn_selected, (const SSL *ssl, const unsigned char **data, unsigned int *len))\ + IMPORT_FUNC(void, SSL_CTX_set_alpn_select_cb, (SSL_CTX *ctx, int (*cb)(SSL *ssl, const unsigned char **out, unsigned char *outlen, const unsigned char *in, unsigned int inlen, void *arg), void *arg)) #define IMPORT_FUNC(ret, name, params) static ret (*name) params; IMPORTED_OPEN_SSL_FUNCTIONS @@ -269,4 +272,49 @@ void moonbitlang_async_tls_SHA1( SHA1(src, len, dst); } +int moonbitlang_async_tls_ssl_set_alpn_protos(SSL *ssl, const char *protos, int len) { + return SSL_set_alpn_protos(ssl, (const unsigned char *)protos, (unsigned int)len); +} + +int moonbitlang_async_tls_ssl_get_alpn_selected(SSL *ssl, char *out_buf) { + const unsigned char *data = 0; + unsigned int len = 0; + SSL_get0_alpn_selected(ssl, &data, &len); + if (data && len > 0) { + memcpy(out_buf, data, len); + } + return (int)len; +} + +static int alpn_select_cb(SSL *ssl, const unsigned char **out, unsigned char *outlen, + const unsigned char *in, unsigned int inlen, void *arg) { + // arg points to the protocol list (wire format) that the server supports + // We iterate through client's list (in) and try to match with server's list (arg) + const unsigned char *server_protos = (const unsigned char *)arg; + // We stored server proto list length in the first 4 bytes + unsigned int server_protos_len = *(unsigned int *)server_protos; + server_protos += 4; + + const unsigned char *client = in; + while (client < in + inlen) { + unsigned char client_len = *client++; + const unsigned char *server = server_protos; + while (server < server_protos + server_protos_len) { + unsigned char server_len = *server++; + if (client_len == server_len && memcmp(client, server, client_len) == 0) { + *out = client; + *outlen = client_len; + return 0; // SSL_TLSEXT_ERR_OK + } + server += server_len; + } + client += client_len; + } + return 3; // SSL_TLSEXT_ERR_NOACK +} + +void moonbitlang_async_tls_ssl_ctx_set_alpn_select(SSL_CTX *ctx, void *protos) { + SSL_CTX_set_alpn_select_cb(ctx, alpn_select_cb, protos); +} + #endif diff --git a/src/tls/openssl.mbt b/src/tls/openssl.mbt index 54da696c..a582a1a0 100644 --- a/src/tls/openssl.mbt +++ b/src/tls/openssl.mbt @@ -138,6 +138,25 @@ extern "C" fn SSL::free(self : SSL) = "moonbitlang_async_tls_ssl_free" #cfg(not(platform="windows")) extern "C" fn SSL::get_error(self : SSL, ret : Int) -> Int = "moonbitlang_async_tls_ssl_get_error" +///| +#cfg(not(platform="windows")) +#borrow(protos) +extern "C" fn SSL::set_alpn_protos( + self : SSL, + protos : Bytes, + len : Int, +) -> Int = "moonbitlang_async_tls_ssl_set_alpn_protos" + +///| +#cfg(not(platform="windows")) +#borrow(buf) +extern "C" fn SSL::get_alpn_selected(self : SSL, buf : FixedArray[Byte]) -> Int = "moonbitlang_async_tls_ssl_get_alpn_selected" + +///| +#cfg(not(platform="windows")) +#owned(protos) +extern "C" fn SSL_CTX::set_alpn_select_cb(ctx : SSL_CTX, protos : Bytes) = "moonbitlang_async_tls_ssl_ctx_set_alpn_select" + ///| #cfg(not(platform="windows")) const SSL_ERROR_SSL = 1 @@ -171,6 +190,31 @@ fn err_get_error() -> String { @bytes_util.ascii_to_string(buf[:len]) } +///| +#cfg(not(platform="windows")) +fn encode_alpn_protos(protocols : Array[String]) -> Bytes { + let buf = @buffer.new() + for proto in protocols { + let bytes = @utf8.encode(proto) + buf.write_byte(bytes.length().to_byte()) + buf.write_bytes(bytes) + } + buf.contents() +} + +// The server-side ALPN callback (`alpn_select_cb` in openssl.c) reads the +// first 4 bytes as a little-endian uint32 length, then the wire bytes follow. + +///| +#cfg(not(platform="windows")) +fn encode_alpn_protos_for_cb(protocols : Array[String]) -> Bytes { + let wire = encode_alpn_protos(protocols) + let buf = @buffer.new() + buf.write_int_le(wire.length()) + buf.write_bytes(wire) + buf.contents() +} + ///| #cfg(not(platform="windows")) #external @@ -284,10 +328,22 @@ async fn Tls::handle_error(self : Tls, err : Int) -> Unit { SSL_ERROR_WANT_READ => { self.transport.flush_write() self.transport.read_more() + // If tls.close() was called while we were suspended in read_more(), + // ssl has been freed — raise an error to prevent use-after-free. + // Normal TCP EOF from the peer does not set self.closed; that path + // is handled by the existing `transport.state is Closed` check + // inside _direct_read after ssl.read() returns. + if self.closed { + raise ConnectionClosed + } } SSL_ERROR_WANT_WRITE => { self.transport.flush_write() self.transport.write_buf.enlarge_to(1) + // Same protection for the write path. + if self.closed { + raise ConnectionClosed + } } SSL_ERROR_SSL | SSL_ERROR_SYSCALL => raise TlsError(err_get_error()) SSL_ERROR_ZERO_RETURN => raise ConnectionClosed @@ -314,6 +370,7 @@ pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair( verify? : Bool = true, host? : String, sni? : Bool = true, + alpn? : Array[String], ) -> Tls { let self = Tls::from_pair(client_ctx, r, w, host~) try { @@ -329,6 +386,12 @@ pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair( } } } + if alpn is Some(protocols) { + let wire = encode_alpn_protos(protocols) + if self.ssl.set_alpn_protos(wire, wire.length()) != 0 { + raise TlsError("failed to set ALPN protocols") + } + } while self.ssl.connect() is ret && ret <= 0 { self.handle_error(self.ssl.get_error(ret)) } @@ -362,9 +425,14 @@ pub async fn[R : @io.Reader, W : @io.Writer] Tls::server_from_pair( private_key_type~ : X509FileType, certificate_file~ : String, certificate_type~ : X509FileType, + alpn? : Array[String], ) -> Tls { let private_key_file = @utf8.encode(private_key_file) let certificate_file = @utf8.encode(certificate_file) + if alpn is Some(protocols) { + let encoded = encode_alpn_protos_for_cb(protocols) + SSL_CTX::set_alpn_select_cb(server_ctx, encoded) + } let self = Tls::from_pair(server_ctx, r, w, host=None) try { if self.ssl.use_certificate_file(certificate_file, certificate_type) != 1 { @@ -405,6 +473,7 @@ pub async fn[Inner : @io.Reader + @io.Writer] Tls::server( private_key_type~ : X509FileType, certificate_file~ : String, certificate_type~ : X509FileType, + alpn? : Array[String], ) -> Tls { Tls::server_from_pair( inner, @@ -413,12 +482,16 @@ pub async fn[Inner : @io.Reader + @io.Writer] Tls::server( private_key_type~, certificate_file~, certificate_type~, + alpn?, ) } ///| #cfg(not(platform="windows")) pub impl @io.Reader for Tls with _direct_read(self, buf, offset~, max_len~) { + if self.closed { + return 0 + } let n = while self.ssl.read(buf, offset, max_len) is ret { if ret > 0 { break ret @@ -521,6 +594,18 @@ pub async fn Tls::shutdown(self : Tls) -> Unit { } } +///| +#cfg(not(platform="windows")) +pub fn Tls::selected_alpn(self : Tls) -> String? { + let buf = FixedArray::make(256, b'\x00') + let len = self.ssl.get_alpn_selected(buf) + if len > 0 { + Some(@bytes_util.ascii_to_string(buf.unsafe_reinterpret_as_bytes()[:len])) + } else { + None + } +} + ///| // mute unused warning #cfg(not(platform="windows")) diff --git a/src/tls/pkg.generated.mbti b/src/tls/pkg.generated.mbti index c56ca669..fe1d71f5 100644 --- a/src/tls/pkg.generated.mbti +++ b/src/tls/pkg.generated.mbti @@ -24,11 +24,12 @@ pub impl ToJson for TlsError // Types and methods #alias(TLS, deprecated) type Tls -pub async fn[Inner : @io.Reader + @io.Writer] Tls::client(Inner, verify? : Bool, host? : String, sni? : Bool) -> Self -pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair(R, W, verify? : Bool, host? : String, sni? : Bool) -> Self +pub async fn[Inner : @io.Reader + @io.Writer] Tls::client(Inner, verify? : Bool, host? : String, sni? : Bool, alpn? : Array[String]) -> Self +pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair(R, W, verify? : Bool, host? : String, sni? : Bool, alpn? : Array[String]) -> Self pub fn Tls::close(Self) -> Unit -pub async fn[Inner : @io.Reader + @io.Writer] Tls::server(Inner, private_key_file~ : String, private_key_type~ : X509FileType, certificate_file~ : String, certificate_type~ : X509FileType) -> Self -pub async fn[R : @io.Reader, W : @io.Writer] Tls::server_from_pair(R, W, private_key_file~ : String, private_key_type~ : X509FileType, certificate_file~ : String, certificate_type~ : X509FileType) -> Self +pub fn Tls::selected_alpn(Self) -> String? +pub async fn[Inner : @io.Reader + @io.Writer] Tls::server(Inner, private_key_file~ : String, private_key_type~ : X509FileType, certificate_file~ : String, certificate_type~ : X509FileType, alpn? : Array[String]) -> Self +pub async fn[R : @io.Reader, W : @io.Writer] Tls::server_from_pair(R, W, private_key_file~ : String, private_key_type~ : X509FileType, certificate_file~ : String, certificate_type~ : X509FileType, alpn? : Array[String]) -> Self pub async fn Tls::shutdown(Self) -> Unit pub impl @io.Reader for Tls pub impl @io.Writer for Tls diff --git a/src/tls/schannel.c b/src/tls/schannel.c index 37b43bc1..c1dbe663 100644 --- a/src/tls/schannel.c +++ b/src/tls/schannel.c @@ -47,6 +47,8 @@ struct Context { int32_t bytes_to_write; int32_t msg_trailer; SecPkgContext_StreamSizes stream_sizes; + char *alpn_protos; // ALPN wire-format bytes (length-prefixed list) + int32_t alpn_protos_len; }; MOONBIT_FFI_EXPORT @@ -59,6 +61,8 @@ void moonbitlang_async_schannel_free(struct Context *ctx) { case Uninitialized: break; } + if (ctx->alpn_protos) + free(ctx->alpn_protos); free(ctx); } @@ -68,6 +72,8 @@ struct Context *moonbitlang_async_schannel_new() { result->state = Uninitialized; result->context.dwUpper = 0; result->context.dwLower = 0; + result->alpn_protos = NULL; + result->alpn_protos_len = 0; return result; } @@ -198,6 +204,50 @@ int32_t moonbitlang_async_schannel_init_server(struct Context *ctx) { } } +MOONBIT_FFI_EXPORT +void moonbitlang_async_schannel_set_alpn( + struct Context *ctx, + const char *wire, + int32_t len +) { + ctx->alpn_protos = (char*)malloc(len); + memcpy(ctx->alpn_protos, wire, len); + ctx->alpn_protos_len = len; +} + +// Allocates and returns a SEC_APPLICATION_PROTOCOLS blob wrapping `wire` in ALPN wire format. +// Caller must free() the returned pointer. +static void *build_sec_app_protos(const char *wire, int32_t wire_len, DWORD *out_size) { + DWORD list_size = (DWORD)(sizeof(SEC_APPLICATION_PROTOCOL_NEGOTIATION_EXT) + + sizeof(unsigned short) + + (DWORD)wire_len); + DWORD total = sizeof(DWORD) + list_size; + char *buf = (char*)malloc(total); + *(DWORD*)buf = list_size; + SEC_APPLICATION_PROTOCOL_LIST *lst = (SEC_APPLICATION_PROTOCOL_LIST*)(buf + sizeof(DWORD)); + lst->ProtoNegoExt = SecApplicationProtocolNegotiationExt_ALPN; + lst->ProtocolListSize = (unsigned short)wire_len; + memcpy(lst->ProtocolList, wire, wire_len); + *out_size = total; + return buf; +} + +MOONBIT_FFI_EXPORT +int32_t moonbitlang_async_schannel_get_alpn_selected( + struct Context *ctx, + char *out_buf +) { + SecPkgContext_ApplicationProtocol app_proto; + SECURITY_STATUS ret = QueryContextAttributes( + &ctx->context, SECPKG_ATTR_APPLICATION_PROTOCOL, &app_proto); + if (ret != SEC_E_OK + || app_proto.ProtoNegoStatus != SecApplicationProtocolNegotiationStatus_Success + || app_proto.ProtoNegoExt != SecApplicationProtocolNegotiationExt_ALPN) + return 0; + memcpy(out_buf, app_proto.ProtocolId, app_proto.ProtocolIdSize); + return (int32_t)app_proto.ProtocolIdSize; +} + enum TlsState { Completed = 0, WantRead = 1, @@ -219,7 +269,13 @@ int32_t moonbitlang_async_schannel_connect( int32_t out_buffer_len ) { SecBufferDesc input_desc, output_desc; - SecBuffer input[2], output[1]; + SecBuffer input[3], output[1]; + + // Build ALPN extension buffer if protocols were provided. + void *sec_app_protos = NULL; + DWORD sec_app_protos_size = 0; + if (ctx->alpn_protos) + sec_app_protos = build_sec_app_protos(ctx->alpn_protos, ctx->alpn_protos_len, &sec_app_protos_size); if (ctx->state == ContextInitialized) { input[0].BufferType = SECBUFFER_TOKEN; @@ -233,6 +289,22 @@ int32_t moonbitlang_async_schannel_connect( input_desc.ulVersion = SECBUFFER_VERSION; input_desc.cBuffers = 2; input_desc.pBuffers = input; + + if (sec_app_protos) { + input[2].BufferType = SECBUFFER_APPLICATION_PROTOCOLS; + input[2].cbBuffer = sec_app_protos_size; + input[2].pvBuffer = sec_app_protos; + input_desc.cBuffers = 3; + } + } else if (sec_app_protos) { + // First call: pInput is normally NULL, but we need it for ALPN. + input[0].BufferType = SECBUFFER_APPLICATION_PROTOCOLS; + input[0].cbBuffer = sec_app_protos_size; + input[0].pvBuffer = sec_app_protos; + + input_desc.ulVersion = SECBUFFER_VERSION; + input_desc.cBuffers = 1; + input_desc.pBuffers = input; } output[0].BufferType = SECBUFFER_TOKEN; @@ -245,6 +317,10 @@ int32_t moonbitlang_async_schannel_connect( ctx->bytes_read = ctx->bytes_to_write = 0; + // Determine pInput: NULL on first call unless ALPN is set. + SecBufferDesc *p_input = + (ctx->state == ContextInitialized || sec_app_protos) ? &input_desc : NULL; + int32_t ret = InitializeSecurityContextW( &ctx->handle, ctx->state == ContextInitialized ? &ctx->context : NULL, @@ -252,7 +328,7 @@ int32_t moonbitlang_async_schannel_connect( ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY, // `fContextReq` 0, // `Reserved1` 0, // `TargetDataRep`, unused by schannel - ctx->state == ContextInitialized ? &input_desc : NULL, // `pInput` + p_input, // `pInput` 0, // `Reserved2` &ctx->context, // `phNewContext` &output_desc, // `pOutput` @@ -260,8 +336,11 @@ int32_t moonbitlang_async_schannel_connect( NULL // `ptsExpiry` ); + if (sec_app_protos) + free(sec_app_protos); + ctx->bytes_read = in_buffer_len; - if (input[1].BufferType == SECBUFFER_EXTRA) + if (ctx->state == ContextInitialized && input[1].BufferType == SECBUFFER_EXTRA) ctx->bytes_read -= input[1].cbBuffer; switch (ret) { @@ -304,7 +383,7 @@ int32_t moonbitlang_async_schannel_accept( int32_t out_buffer_len ) { SecBufferDesc input_desc, output_desc; - SecBuffer input[2], output[1]; + SecBuffer input[3], output[1]; input[0].BufferType = SECBUFFER_TOKEN; input[0].cbBuffer = in_buffer_len; @@ -318,6 +397,17 @@ int32_t moonbitlang_async_schannel_accept( input_desc.cBuffers = 2; input_desc.pBuffers = input; + // Build ALPN extension buffer if protocols were provided. + void *sec_app_protos = NULL; + DWORD sec_app_protos_size = 0; + if (ctx->alpn_protos) { + sec_app_protos = build_sec_app_protos(ctx->alpn_protos, ctx->alpn_protos_len, &sec_app_protos_size); + input[2].BufferType = SECBUFFER_APPLICATION_PROTOCOLS; + input[2].cbBuffer = sec_app_protos_size; + input[2].pvBuffer = sec_app_protos; + input_desc.cBuffers = 3; + } + output[0].BufferType = SECBUFFER_TOKEN; output[0].cbBuffer = out_buffer_len; output[0].pvBuffer = out_buffer + out_buffer_offset; @@ -340,6 +430,9 @@ int32_t moonbitlang_async_schannel_accept( NULL // `ptsExpiry` ); + if (sec_app_protos) + free(sec_app_protos); + ctx->bytes_read = in_buffer_len; if (input[1].BufferType == SECBUFFER_EXTRA) ctx->bytes_read -= input[1].cbBuffer; diff --git a/src/tls/schannel.mbt b/src/tls/schannel.mbt index c5e6ecea..9246ad0e 100644 --- a/src/tls/schannel.mbt +++ b/src/tls/schannel.mbt @@ -160,6 +160,18 @@ async fn Tls::connect(self : Tls) -> Unit { } } +///| +#cfg(platform="windows") +fn encode_alpn_protos(protocols : Array[String]) -> Bytes { + let buf = @buffer.new() + for proto in protocols { + let bytes = @utf8.encode(proto) + buf.write_byte(bytes.length().to_byte()) + buf.write_bytes(bytes) + } + buf.contents() +} + ///| #cfg(platform="windows") pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair( @@ -168,6 +180,7 @@ pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair( verify? : Bool = true, host? : String, sni? : Bool = true, + alpn? : Array[String], ) -> Tls { ignore(sni) let context = Schannel::new() @@ -175,6 +188,10 @@ pub async fn[R : @io.Reader, W : @io.Writer] Tls::client_from_pair( context.free() raise TlsError(@os_error.errno_to_string(err)) } + if alpn is Some(protocols) { + let wire = encode_alpn_protos(protocols) + context.set_alpn(wire, wire.length()) + } let transport = Transport::new(reader, writer) let tls = { context, @@ -234,12 +251,17 @@ async fn Tls::accept(self : Tls) -> Unit { pub async fn[R : @io.Reader, W : @io.Writer] Tls::server_from_pair( reader : R, writer : W, + alpn? : Array[String], ) -> Tls { let context = Schannel::new() if context.init_server() is err && err != 0 { context.free() raise TlsError(@os_error.errno_to_string(err)) } + if alpn is Some(protocols) { + let wire = encode_alpn_protos(protocols) + context.set_alpn(wire, wire.length()) + } let transport = Transport::new(reader, writer) let tls = { context, @@ -261,6 +283,27 @@ pub async fn[R : @io.Reader, W : @io.Writer] Tls::server_from_pair( tls } +///| +/// WARNING: this API is currently for testing only. +/// On Windows, `private_key_file`, `private_key_type`, `certificate_file`, +/// and `certificate_type` are ignored; SChannel uses the Windows certificate store. +#internal(internal, "do not use, for internal testing only") +#cfg(platform="windows") +pub async fn[Inner : @io.Reader + @io.Writer] Tls::server( + inner : Inner, + private_key_file~ : String, + private_key_type~ : X509FileType, + certificate_file~ : String, + certificate_type~ : X509FileType, + alpn? : Array[String], +) -> Tls { + ignore(private_key_file) + ignore(private_key_type) + ignore(certificate_file) + ignore(certificate_type) + Tls::server_from_pair(inner, inner, alpn?) +} + ///| #cfg(platform="windows") pub impl @io.Reader for Tls with _get_internal_buffer(self) { @@ -289,6 +332,9 @@ pub impl @io.Reader for Tls with _direct_read(self, buf, offset~, max_len~) { } let read_buf = self.transport.reader._get_internal_buffer().repr() while self.curr_msg_remaining == 0 { + if self.closed { + return 0 + } read_buf.enlarge_to(self.context.record_overhead() + 1) let ret = self.context.read( read_buf.buf, @@ -372,6 +418,23 @@ pub impl @io.Writer for Tls with write_once(self, buf, offset~, len~) { } } +///| +#cfg(platform="windows") +#borrow(ch, wire) +extern "C" fn Schannel::set_alpn( + ch : Schannel, + wire : Bytes, + len : Int, +) -> Unit = "moonbitlang_async_schannel_set_alpn" + +///| +#cfg(platform="windows") +#borrow(ch, buf) +extern "C" fn Schannel::get_alpn_selected( + ch : Schannel, + buf : FixedArray[Byte], +) -> Int = "moonbitlang_async_schannel_get_alpn_selected" + ///| #cfg(platform="windows") #borrow(ch) @@ -398,10 +461,18 @@ pub async fn Tls::shutdown(self : Tls) -> Unit { } ///| -// mute unused warning #cfg(platform="windows") -let _unused : Unit = { - ignore(@c_buffer.Buffer::strlen) - ignore(@bytes_util.ascii_to_string) - ignore(@utf8.encode("")) +pub fn Tls::selected_alpn(self : Tls) -> String? { + let buf = FixedArray::make(256, b'\x00') + let len = self.context.get_alpn_selected(buf) + if len > 0 { + Some(@bytes_util.ascii_to_string(buf.unsafe_reinterpret_as_bytes()[:len])) + } else { + None + } } + +///| +// mute unused warning +#cfg(platform="windows") +let _unused : Unit = ignore(@c_buffer.Buffer::strlen) diff --git a/src/tls/tls.mbt b/src/tls/tls.mbt index 8a401af6..a79c9372 100644 --- a/src/tls/tls.mbt +++ b/src/tls/tls.mbt @@ -37,6 +37,7 @@ pub async fn[Inner : @io.Reader + @io.Writer] Tls::client( verify? : Bool = true, host? : String, sni? : Bool = true, + alpn? : Array[String], ) -> Tls { - Tls::client_from_pair(inner, inner, verify~, host?, sni~) + Tls::client_from_pair(inner, inner, verify~, host?, sni~, alpn?) } diff --git a/src/tls/tls_test.mbt b/src/tls/tls_test.mbt index 8ae2d1f0..654d82d4 100644 --- a/src/tls/tls_test.mbt +++ b/src/tls/tls_test.mbt @@ -31,6 +31,26 @@ async fn test_server(r : &@io.Reader, w : &@io.Writer) -> @tls.Tls { @tls.Tls::server_from_pair(r, w) } +///| +#cfg(not(platform="windows")) +async fn test_server_alpn(r : &@io.Reader, w : &@io.Writer) -> @tls.Tls { + @tls.Tls::server_from_pair( + r, + w, + private_key_file="test_keys/key.pem", + private_key_type=PEM, + certificate_file="test_keys/cert.pem", + certificate_type=PEM, + alpn=["h2", "http/1.1"], + ) +} + +///| +#cfg(platform="windows") +async fn test_server_alpn(r : &@io.Reader, w : &@io.Writer) -> @tls.Tls { + @tls.Tls::server_from_pair(r, w, alpn=["h2", "http/1.1"]) +} + ///| async test "one way" { let log = StringBuilder::new() @@ -203,3 +223,45 @@ async test "`read` already closed" { }) }) } + +///| +async test "ALPN negotiation" { + @async.with_task_group(root => { + let (client_read_from_server, server_write_to_client) = @pipe.pipe() + let (server_read_from_client, client_write_to_server) = @pipe.pipe() + // server + root.spawn_bg(() => { + defer { + server_read_from_client.close() + server_write_to_client.close() + } + let server = test_server_alpn( + server_read_from_client, server_write_to_client, + ) + defer server.close() + inspect(server.selected_alpn(), content="Some(\"h2\")") + while server.read_some() is Some(_) { + + } nobreak { + server.shutdown() + } + }) + // client + root.spawn_bg(() => { + defer { + client_read_from_server.close() + client_write_to_server.close() + } + let client = @tls.Tls::client_from_pair( + client_read_from_server, + client_write_to_server, + verify=false, + alpn=["h2", "http/1.1"], + ) + defer client.close() + inspect(client.selected_alpn(), content="Some(\"h2\")") + client.shutdown() + let _ = client.read_all() + }) + }) +} From 43326323f615e58c6017fd737aad6e02762fa709 Mon Sep 17 00:00:00 2001 From: Haoxiang Fei Date: Sat, 28 Feb 2026 22:33:01 +0800 Subject: [PATCH 2/2] fix: process IOCP completions for closed sockets to prevent coroutine hang When a socket was closed via IoHandle::close(), the fd was removed from evloop.fds before closesocket() was called. The subsequent IOCP completion for cancelled I/O was silently dropped by poll() because the fd was no longer in the fds map, leaving coroutines in protect_from_cancel(suspend()) waiting forever. --- src/internal/event_loop/event_loop.mbt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/internal/event_loop/event_loop.mbt b/src/internal/event_loop/event_loop.mbt index 42588953..b274ebb8 100644 --- a/src/internal/event_loop/event_loop.mbt +++ b/src/internal/event_loop/event_loop.mbt @@ -265,7 +265,12 @@ fn EventLoop::poll(self : Self) -> Unit raise { let fd = event.fd() if @fd_util.fd_is_valid(fd) { // completed IO operation on a file descriptor - guard self.fds.get(fd) is Some(_) else { } + // Note: we intentionally do NOT guard on self.fds.get(fd) here. + // When a socket is closed (IoHandle::close), the fd is removed from + // self.fds before closesocket() is called. closesocket() posts IOCP + // completions for any pending I/O. We must still process these + // completions to wake coroutines waiting in protect_from_cancel(suspend()), + // otherwise they hang forever. let result = event.io_result() let job_id = result.job_id() if self.jobs.get(job_id) is Some(coro) {