From 3765bfd8c2379b1709277b6fd90e1bd914c6cd47 Mon Sep 17 00:00:00 2001 From: Guest0x0 Date: Wed, 1 Apr 2026 13:46:07 +0800 Subject: [PATCH] make socket `bind` async --- src/http/client_test.mbt | 2 +- src/http/pkg.generated.mbti | 4 +-- src/http/server.mbt | 8 ++--- src/internal/event_loop/network.mbt | 10 ++++++ src/internal/event_loop/pkg.generated.mbti | 1 + src/internal/event_loop/thread_pool.c | 36 ++++++++++++++++++++++ src/internal/event_loop/thread_pool.mbt | 4 +++ src/socket/ffi.mbt | 4 --- src/socket/pkg.generated.mbti | 8 ++--- src/socket/reuse_addr_test.mbt | 2 +- src/socket/socket.c | 5 --- src/socket/tcp.mbt | 23 +++++++------- src/socket/udp.mbt | 21 ++++++------- 13 files changed, 84 insertions(+), 44 deletions(-) diff --git a/src/http/client_test.mbt b/src/http/client_test.mbt index bf41068e..c0c8675d 100644 --- a/src/http/client_test.mbt +++ b/src/http/client_test.mbt @@ -14,7 +14,7 @@ ///| #cfg(target="native") -fn test_server(group : @async.TaskGroup[Unit], log : &Logger) -> Int raise { +async fn test_server(group : @async.TaskGroup[Unit], log : &Logger) -> Int { let server = @http.Server(@socket.Addr::parse("127.0.0.1:0")) group.spawn_bg(no_wait=true, () => { server.run_forever((request, body, conn) => { diff --git a/src/http/pkg.generated.mbti b/src/http/pkg.generated.mbti index e0da13b8..4ff6b61b 100644 --- a/src/http/pkg.generated.mbti +++ b/src/http/pkg.generated.mbti @@ -88,13 +88,13 @@ pub struct Server { addr : @socket.Addr // private fields - fn new(@socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String]) -> Server raise + async fn new(@socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String]) -> Server } pub async fn Server::accept(Self) -> (ServerConnection, @socket.Addr) #deprecated pub fn Server::addr(Self) -> @socket.Addr pub fn Server::close(Self) -> Unit -pub fn Server::new(@socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String]) -> Self raise +pub async fn Server::new(@socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String]) -> Self pub async fn Server::run_forever(Self, async (Request, &@io.Reader, ServerConnection) -> Unit, allow_failure? : Bool, max_connections? : Int) -> Unit type ServerConnection diff --git a/src/http/server.mbt b/src/http/server.mbt index a766ca8c..d9ee02ab 100644 --- a/src/http/server.mbt +++ b/src/http/server.mbt @@ -152,12 +152,12 @@ pub struct Server { priv server : @socket.TcpServer priv headers : Map[String, String] - fn new( + async fn new( addr : @socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String], - ) -> Server raise + ) -> Server } ///| @@ -168,12 +168,12 @@ pub struct Server { /// /// `headers` can be used to specify common headers /// shared by all responses sent by this server. -pub fn Server::new( +pub async fn Server::new( addr : @socket.Addr, dual_stack? : Bool, reuse_addr? : Bool, headers? : Map[String, String] = {}, -) -> Server raise { +) -> Server { let server = @socket.TcpServer::new(addr, dual_stack?, reuse_addr?) { addr: server.addr, server, headers } } diff --git a/src/internal/event_loop/network.mbt b/src/internal/event_loop/network.mbt index 818b294e..b989689a 100644 --- a/src/internal/event_loop/network.mbt +++ b/src/internal/event_loop/network.mbt @@ -50,3 +50,13 @@ pub async fn getaddrinfo( Ok(job.get_getaddrinfo_result()) } } + +///| +pub async fn IoHandle::bind( + handle : IoHandle, + addr : Bytes, + context~ : String, +) -> Unit { + @coroutine.check_cancellation() + ignore(perform_job_in_worker(Job::bind(handle.fd, addr), context~)) +} diff --git a/src/internal/event_loop/pkg.generated.mbti b/src/internal/event_loop/pkg.generated.mbti index 72a67e0d..5aac8b4a 100644 --- a/src/internal/event_loop/pkg.generated.mbti +++ b/src/internal/event_loop/pkg.generated.mbti @@ -73,6 +73,7 @@ pub fn Directory::close(Self) -> Int type IoHandle pub async fn IoHandle::accept(Self, Bytes, context~ : String) -> Self +pub async fn IoHandle::bind(Self, Bytes, context~ : String) -> Unit pub fn IoHandle::close(Self) -> Unit pub async fn IoHandle::connect(Self, Bytes, context~ : String) -> Unit pub fn IoHandle::detach_from_event_loop(Self) -> Int diff --git a/src/internal/event_loop/thread_pool.c b/src/internal/event_loop/thread_pool.c index 1c2de3ec..5391d970 100644 --- a/src/internal/event_loop/thread_pool.c +++ b/src/internal/event_loop/thread_pool.c @@ -50,6 +50,7 @@ #include typedef int HANDLE; +typedef int SOCKET; #endif @@ -2020,6 +2021,41 @@ struct wait_for_process_job *moonbitlang_async_make_wait_for_process_job( #endif +// ===== bind job, bind socket to specific address ===== +struct bind_job { + struct job job; + HANDLE socket; + struct sockaddr *addr; +}; + +static +void free_bind_job(void *obj) { + struct bind_job *job = (struct bind_job*)obj; + moonbit_decref(job->addr); +} + +static +void bind_job_worker(struct job *job) { + struct bind_job *bind_job = (struct bind_job*)job; + + job->ret = bind((SOCKET)bind_job->socket, bind_job->addr, Moonbit_array_length(bind_job->addr)); + + if (job->ret < 0) +#ifdef _WIN32 + job->err = GetLastError(); +#else + job->err = errno; +#endif +} + +MOONBIT_FFI_EXPORT +struct bind_job *moonbitlang_async_make_bind_job(HANDLE socket, struct sockaddr *addr) { + struct bind_job *job = MAKE_JOB(bind); + job->socket = socket; + job->addr = addr; + return job; +} + // ===== getaddrinfo job, resolve host name via `getaddrinfo` ===== #ifdef _WIN32 diff --git a/src/internal/event_loop/thread_pool.mbt b/src/internal/event_loop/thread_pool.mbt index bd3cfe41..994ace91 100644 --- a/src/internal/event_loop/thread_pool.mbt +++ b/src/internal/event_loop/thread_pool.mbt @@ -252,6 +252,10 @@ extern "C" fn Job::wait_for_process(pid : Int) -> Job = "moonbitlang_async_make_ #borrow(job) extern "C" fn Job::cancel_process_waiter(job : Job) = "moonbitlang_async_cancel_wait_for_process_job" +///| +#owned(addr) +extern "C" fn Job::bind(socket : @fd_util.Fd, addr : Bytes) -> Job = "moonbitlang_async_make_bind_job" + ///| #owned(host) extern "C" fn Job::getaddrinfo(host : @os_string.OsString) -> Job = "moonbitlang_async_make_getaddrinfo_job" diff --git a/src/socket/ffi.mbt b/src/socket/ffi.mbt index 7f2069af..42dfa1de 100644 --- a/src/socket/ffi.mbt +++ b/src/socket/ffi.mbt @@ -60,10 +60,6 @@ fn make_udp_socket(family : AddrFamily, context~ : String) -> @fd_util.Fd raise sock } -///| -#borrow(addr) -extern "C" fn bind_ffi(sock : @fd_util.Fd, addr : Addr) -> Int = "moonbitlang_async_bind" - ///| extern "C" fn set_ipv6_only(sock : @fd_util.Fd, dual_stack : Bool) -> Int = "moonbitlang_async_set_ipv6_only" diff --git a/src/socket/pkg.generated.mbti b/src/socket/pkg.generated.mbti index 1fb08cd8..9b0a1752 100644 --- a/src/socket/pkg.generated.mbti +++ b/src/socket/pkg.generated.mbti @@ -49,14 +49,14 @@ pub struct TcpServer { addr : Addr // private fields - fn new(Addr, dual_stack? : Bool, reuse_addr? : Bool) -> TcpServer raise + async fn new(Addr, dual_stack? : Bool, reuse_addr? : Bool) -> TcpServer } pub async fn TcpServer::accept(Self) -> (Tcp, Addr) #deprecated pub fn TcpServer::addr(Self) -> Addr pub fn TcpServer::close(Self) -> Unit pub fn TcpServer::fd(Self) -> Int -pub fn TcpServer::new(Addr, dual_stack? : Bool, reuse_addr? : Bool) -> Self raise +pub async fn TcpServer::new(Addr, dual_stack? : Bool, reuse_addr? : Bool) -> Self pub async fn TcpServer::run_forever(Self, async (Tcp, Addr) -> Unit, allow_failure? : Bool, max_connections? : Int) -> Unit pub struct UdpClient { @@ -77,13 +77,13 @@ pub struct UdpServer { addr : Addr // private fields - fn new(Addr, dual_stack? : Bool) -> UdpServer raise + async fn new(Addr, dual_stack? : Bool) -> UdpServer } #deprecated pub fn UdpServer::addr(Self) -> Addr pub fn UdpServer::close(Self) -> Unit pub fn UdpServer::fd(UdpClient) -> Int -pub fn UdpServer::new(Addr, dual_stack? : Bool) -> Self raise +pub async fn UdpServer::new(Addr, dual_stack? : Bool) -> Self pub async fn UdpServer::recvfrom(Self, FixedArray[Byte], offset? : Int, max_len? : Int) -> (Int, Addr) pub async fn UdpServer::sendto(Self, Bytes, Addr, offset? : Int, len? : Int) -> Unit diff --git a/src/socket/reuse_addr_test.mbt b/src/socket/reuse_addr_test.mbt index 2a71c2f5..29c0baef 100644 --- a/src/socket/reuse_addr_test.mbt +++ b/src/socket/reuse_addr_test.mbt @@ -28,8 +28,8 @@ async test "reuse addr" { conn.write("abcd") } }) - @async.sleep(100) for _ in 0..<2 { + @async.sleep(100) let conn = @socket.Tcp::connect(@socket.Addr::parse("127.0.0.1:\{port}")) defer conn.close() inspect(conn.read_all().text(), content="abcd") diff --git a/src/socket/socket.c b/src/socket/socket.c index ebc60e5d..e80bb8d3 100644 --- a/src/socket/socket.c +++ b/src/socket/socket.c @@ -91,11 +91,6 @@ HANDLE moonbitlang_async_make_udp_socket(int family) { #endif } -MOONBIT_FFI_EXPORT -int moonbitlang_async_bind(HANDLE sockfd, struct sockaddr *addr) { - return bind((SOCKET)sockfd, (struct sockaddr*)addr, Moonbit_array_length(addr)); -} - MOONBIT_FFI_EXPORT int moonbitlang_async_set_ipv6_only(HANDLE sockfd, int ipv6_only) { return setsockopt(sockfd, IPPROTO_IPV6, IPV6_V6ONLY, &ipv6_only, sizeof(int)); diff --git a/src/socket/tcp.mbt b/src/socket/tcp.mbt index c2530308..91e80add 100644 --- a/src/socket/tcp.mbt +++ b/src/socket/tcp.mbt @@ -37,7 +37,7 @@ pub struct TcpServer { addr : Addr priv io : @event_loop.IoHandle - fn new(addr : Addr, dual_stack? : Bool, reuse_addr? : Bool) -> TcpServer raise + async fn new(addr : Addr, dual_stack? : Bool, reuse_addr? : Bool) -> TcpServer } ///| @@ -63,14 +63,15 @@ pub struct TcpServer { /// This is useful for avoiding "address already in use" error while testing. /// WARNING: enabling `reuse_addr` on production is unsafe, /// packets from the previous listener of the address may be accidentally received. -pub fn TcpServer::new( +pub async fn TcpServer::new( addr : Addr, dual_stack? : Bool = true, reuse_addr? : Bool = false, -) -> TcpServer raise { +) -> TcpServer { let context = "@socket.TcpServer::new()" let family = addr.family() let sock = make_tcp_socket(family, context~) + let io = @event_loop.IoHandle::from_fd(sock, kind=Socket) try { if addr.is_ipv6() && addr.is_ipv6_wildcard() { if 0 != set_ipv6_only(sock, !dual_stack) { @@ -80,23 +81,21 @@ pub fn TcpServer::new( if reuse_addr { guard allow_reuse_addr(sock) >= 0 else { @os_error.check_errno(context) } } - if bind_ffi(sock, addr) != 0 { - @os_error.check_errno(context) - } + io.bind(addr.0, context~) if 0 != listen_ffi(sock) { @os_error.check_errno(context) } + // If `addr` specifies zero as the listen port, + // the OS will assign a random port for us, + // in this case, we need to retrieve the actual port via `getsockname`. + let addr = getsockname(sock, family, context~) + { io, addr } } catch { err => { - @fd_util.close(sock, kind=Socket, context~) + io.close() raise err } } - // If `addr` specifies zero as the listen port, - // the OS will assign a random port for us, - // in this case, we need to retrieve the actual port via `getsockname`. - let addr = getsockname(sock, family, context~) - { io: @event_loop.IoHandle::from_fd(sock, kind=Socket), addr } } ///| diff --git a/src/socket/udp.mbt b/src/socket/udp.mbt index fef375c2..0b1bee3b 100644 --- a/src/socket/udp.mbt +++ b/src/socket/udp.mbt @@ -71,7 +71,7 @@ pub struct UdpServer { addr : Addr priv io : @event_loop.IoHandle - fn new(addr : Addr, dual_stack? : Bool) -> UdpServer raise + async fn new(addr : Addr, dual_stack? : Bool) -> UdpServer } ///| @@ -98,33 +98,32 @@ pub fn UdpServer::fd(self : UdpClient) -> @fd_util.Fd { /// If the port of `addr` is zero, the server will be bound to a random port, /// assigned by the operating system. /// The actual listen address can be retrieved via `.addr()`. -pub fn UdpServer::new( +pub async fn UdpServer::new( addr : Addr, dual_stack? : Bool = true, -) -> UdpServer raise { +) -> UdpServer { let context = "@socket.UdpServer::new()" let family = addr.family() let sock = make_udp_socket(family, context~) + let io = @event_loop.IoHandle::from_fd(sock, kind=Socket) try { if addr.is_ipv6() && addr.is_ipv6_wildcard() { if 0 != set_ipv6_only(sock, !dual_stack) { @os_error.check_errno(context) } } - if bind_ffi(sock, addr) != 0 { - @os_error.check_errno(context) - } + io.bind(addr.0, context~) + // If `addr` specifies zero as the listen port, + // the OS will assign a random port for us, + // in this case, we need to retrieve the actual port via `getsockname`. + let addr = getsockname(sock, family, context~) + { io, addr } } catch { err => { @fd_util.close(sock, kind=Socket, context~) raise err } } - // If `addr` specifies zero as the listen port, - // the OS will assign a random port for us, - // in this case, we need to retrieve the actual port via `getsockname`. - let addr = getsockname(sock, family, context~) - { io: @event_loop.IoHandle::from_fd(sock, kind=Socket), addr } } ///|