diff --git a/Cargo.lock b/Cargo.lock index 2c7d917..3690bc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,7 +1,16 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "hteapot" version = "0.5.0" +dependencies = [ + "libc", +] + +[[package]] +name = "libc" +version = "0.2.171" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" diff --git a/Cargo.toml b/Cargo.toml index 495bbc9..4cef2b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,16 +1,17 @@ [package] -edition = "2024" name = "hteapot" version = "0.5.0" -exclude = ["config.toml", "demo/", "readme.md"] -license = "MIT" -keywords = ["HTTP", "HTTP-SERVER"] +edition = "2024" +authors = ["Alb Ruiz G. "] description = "HTeaPot is a lightweight HTTP server library designed to be easy to use and extend." +license = "MIT" +readme = "README.md" +documentation = "https://docs.rs/hteapot/" homepage = "https://github.com/az107/HTeaPot" repository = "https://github.com/az107/HTeaPot" -readme = "readme.md" -categories = ["network-programming", "web-programming"] -authors = ["Alb Ruiz G. "] +keywords = ["http", "server", "web", "lightweight", "rust"] +categories = ["network-programming", "web-programming", "command-line-utilities"] +exclude = ["config.toml", "demo/"] [lib] name = "hteapot" @@ -18,3 +19,6 @@ path = "src/hteapot/mod.rs" [[bin]] name = "hteapot" + +[dependencies] +libc = "0.2" \ No newline at end of file diff --git a/README.md b/README.md index f9ee1e7..d1bcc20 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,19 @@ A high-performance, lightweight HTTP server and library built in Rust. HTeaPot is designed to deliver exceptional performance for modern web applications while maintaining a simple and intuitive API. -## Features +## 📚 Table of Contents + +- [Features](#--features) +- [Getting Started](#-getting-started) + - [Standalone Server](#standalone-server) + - [As a Library](#as-a-library) +- [Performance](#-performance) +- [Roadmap](#-roadmap) +- [Contributing](#-contributing) +- [License](#-license) +- [Acknowledgments](#-acknowledgments) + +## ✨ Features ### Exceptional Performance - **Threaded Architecture**: Powered by a custom-designed thread system that handles **70,000+ requests per second** @@ -34,25 +46,25 @@ A high-performance, lightweight HTTP server and library built in Rust. HTeaPot i - **Extensible Design**: Easily customize behavior for specific use cases - **Lightweight Footprint**: Zero dependencies and efficient resource usage -## Getting Started +## 🚀 Getting Started -### Installation +### 🔧 Installation ```bash # Install from crates.io cargo install hteapot # Or build from source -git clone https://github.com/yourusername/hteapot.git +git clone https://github.com/Az107/hteapot.git cd hteapot cargo build --release ``` -### Standalone Server +### 🖥️ Running the Server -#### Using a configuration file: +#### Option 1: With Config -Create a `config.toml` file: +1. Create a `config.toml` file: ```toml [HTEAPOT] @@ -61,27 +73,27 @@ host = "localhost" # The host address to bind to root = "public" # The root directory to serve files from ``` -Run the server: +2. Run the server: ```bash hteapot ./config.toml ``` -#### Quick serve a directory: +#### Option 2: Quick Serve ```bash hteapot -s ./public/ ``` -### As a Library +### 🦀 Using as a Library -1. Add HTeaPot to your project: +1. Add HTeaPot to your ```Cargo.toml``` project: ```bash cargo add hteapot ``` -2. Implement in your code: +2. Implement in your code: ```example``` ```rust use hteapot::{HttpStatus, HttpResponse, Hteapot, HttpRequest}; @@ -97,16 +109,16 @@ fn main() { } ``` -## Performance +## 📊 Performance HTeaPot has been benchmarked against other popular HTTP servers, consistently demonstrating excellent metrics: -| Metric | HTeaPot | Industry Average | -|---------------|---------|------------------| -| Requests/sec | 70,000+ | 30,000-50,000 | -| Error rate | <0.1% | 0.5-2% | -| Latency (p99) | 5ms | 15-30ms | -| Memory usage | Low | Moderate | +| Metric | HTeaPot | Industry Average | +|---------------|---------------|------------------------| +| Requests/sec | 70,000+ req/s | 30,000 - 50,000 req/s | +| Error rate | < 0.1% | 0.5% - 2% | +| Latency (p99) | 5ms | 15ms - 30ms | +| Memory usage | Low | Moderate | ## Roadmap @@ -122,7 +134,19 @@ HTeaPot has been benchmarked against other popular HTTP servers, consistently de ## Contributing -We welcome contributions from the community! See our [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to get involved. +We welcome contributions from the community! To get started: + +```sh +# Format the code +cargo fmt + +# Lint for warnings +cargo clippy --all-targets --all-features + +# Run tests +cargo test +``` +See [CONTRIBUTING.md](https://github.com/Az107/HTeaPot/wiki/Contributing) for more details. ## License diff --git a/src/hteapot/mod.rs b/src/hteapot/mod.rs index 588d832..ea2ea77 100644 --- a/src/hteapot/mod.rs +++ b/src/hteapot/mod.rs @@ -11,6 +11,7 @@ mod status; use self::response::EmptyHttpResponse; use self::response::HttpResponseCommon; use self::response::IterError; +use std::sync::atomic::{AtomicBool, Ordering}; pub use self::methods::HttpMethod; pub use self::request::HttpRequest; @@ -45,6 +46,8 @@ pub struct Hteapot { port: u16, address: String, threads: u16, + shutdown_signal: Option>, + shutdown_hooks: Vec>, } struct SocketStatus { @@ -62,12 +65,25 @@ struct SocketData { } impl Hteapot { + pub fn set_shutdown_signal(&mut self, signal: Arc) { + self.shutdown_signal = Some(signal); + } + + pub fn add_shutdown_hook(&mut self, hook: F) + where + F: Fn() + Send + Sync + 'static, + { + self.shutdown_hooks.push(Arc::new(hook)); + } + // Constructor pub fn new(address: &str, port: u16) -> Self { Hteapot { port, address: address.to_string(), threads: 1, + shutdown_signal: None, + shutdown_hooks: Vec::new(), } } @@ -76,6 +92,8 @@ impl Hteapot { port, address: address.to_string(), threads: if threads == 0 { 1 } else { threads }, + shutdown_signal: None, + shutdown_hooks: Vec::new(), } } @@ -85,11 +103,10 @@ impl Hteapot { action: impl Fn(HttpRequest) -> Box + Send + Sync + 'static, ) { let addr = format!("{}:{}", self.address, self.port); - let listener = TcpListener::bind(addr); - let listener = match listener { + let listener = match TcpListener::bind(addr) { Ok(listener) => listener, Err(e) => { - eprintln!("Error L: {}", e); + eprintln!("Error binding to address: {}", e); return; } }; @@ -100,10 +117,16 @@ impl Hteapot { Arc::new(Mutex::new(vec![0; self.threads as usize])); let arc_action = Arc::new(action); + // Clone shutdown_signal and share the shutdown_hooks via Arc + let shutdown_signal = self.shutdown_signal.clone(); + let shutdown_hooks = Arc::new(self.shutdown_hooks.clone()); + for thread_index in 0..self.threads { let pool_clone = pool.clone(); let action_clone = arc_action.clone(); let priority_list_clone = priority_list.clone(); + let shutdown_signal_clone = shutdown_signal.clone(); + let shutdown_hooks_clone = shutdown_hooks.clone(); thread::spawn(move || { let mut streams_to_handle = Vec::new(); @@ -113,11 +136,20 @@ impl Hteapot { let mut pool = lock.lock().expect("Error locking pool"); if streams_to_handle.is_empty() { - pool = cvar - .wait_while(pool, |pool| pool.is_empty()) + // Store the returned guard back into pool + pool = cvar.wait_while(pool, |pool| pool.is_empty()) .expect("Error waiting on cvar"); } + if let Some(signal) = &shutdown_signal_clone { + if !signal.load(Ordering::SeqCst) { + for hook in shutdown_hooks_clone.iter() { + hook(); + } + break; // Exit the server loop + } + } + while let Some(stream) = pool.pop_back() { let socket_status = SocketStatus { ttl: Instant::now(), @@ -153,15 +185,19 @@ impl Hteapot { } loop { - let stream = listener.accept(); - if stream.is_err() { + let stream = match listener.accept() { + Ok((stream, _)) => stream, + Err(_) => continue, + }; + + if stream.set_nonblocking(true).is_err() { + eprintln!("Error setting non-blocking mode on stream"); + continue; + } + if stream.set_nodelay(true).is_err() { + eprintln!("Error setting no delay on stream"); continue; } - let (stream, _) = stream.unwrap(); - stream - .set_nonblocking(true) - .expect("Error setting non-blocking"); - stream.set_nodelay(true).expect("Error setting no delay"); { let (lock, cvar) = &*pool; @@ -180,27 +216,20 @@ impl Hteapot { ) -> Option<()> { let status = socket_data.status.as_mut()?; - // Fix by miky-rola 2025-04-08 // Check if the TTL (time-to-live) for the connection has expired. - // If the connection is idle for longer than `KEEP_ALIVE_TTL` and no data is being written, - // the connection is gracefully shut down to free resources. if Instant::now().duration_since(status.ttl) > KEEP_ALIVE_TTL && !status.write { let _ = socket_data.stream.shutdown(Shutdown::Both); return None; } - // If the request is not yet complete, read data from the stream into a buffer. // This ensures that the server can handle partial or chunked requests. + if !status.request.done { let mut buffer = [0; BUFFER_SIZE]; match socket_data.stream.read(&mut buffer) { Err(e) => match e.kind() { - io::ErrorKind::WouldBlock => { - return Some(()); - } - io::ErrorKind::ConnectionReset => { - return None; - } + io::ErrorKind::WouldBlock => return Some(()), + io::ErrorKind::ConnectionReset => return None, _ => { eprintln!("Read error: {:?}", e); return None; @@ -211,23 +240,27 @@ impl Hteapot { return None; } status.ttl = Instant::now(); - let _ = status.request.append(buffer[..m].to_vec()); + let r = status.request.append(buffer[..m].to_vec()); + if r.is_err() { + // Early return response if not valid request is sended + let error_msg = r.err().unwrap(); + let response = + HttpResponse::new(HttpStatus::BadRequest, error_msg, None).to_bytes(); + let _ = socket_data.stream.write(&response); + let _ = socket_data.stream.flush(); + let _ = socket_data.stream.shutdown(Shutdown::Both); + return None; + } } } } - let request = status.request.get(); - if request.is_none() { - return Some(()); - } - let request = request.unwrap(); - + let request = status.request.get()?; let keep_alive = request .headers - .get("Connection") - .map(|v| v == "keep-alive") + .get("connection") //all headers are turn lowercase in the builder + .map(|v| v.to_lowercase() == "keep-alive") .unwrap_or(false); - if !status.write { let mut response = action(request); if keep_alive { @@ -250,8 +283,7 @@ impl Hteapot { status.response = response; } - // Write the response to the client in chunks using the `peek` and `next` methods. - // This ensures that large responses are sent incrementally without blocking the server. + // Write the response to the client in chunks loop { match status.response.peek() { Ok(n) => match socket_data.stream.write(&n) { @@ -259,9 +291,7 @@ impl Hteapot { status.ttl = Instant::now(); let _ = status.response.next(); } - Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { - return Some(()); - } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return Some(()), Err(e) => { eprintln!("Write error: {:?}", e); return None; @@ -289,68 +319,71 @@ impl Hteapot { } #[cfg(test)] -#[test] -fn test_http_response_maker() { - let mut response = HttpResponse::new(HttpStatus::IAmATeapot, "Hello, World!", None); - let response = String::from_utf8(response.to_bytes()).unwrap(); - let expected_response = format!( - "HTTP/1.1 418 I'm a teapot\r\nContent-Length: 13\r\nServer: HTeaPot/{}\r\n\r\nHello, World!\r\n", - VERSION - ); - let expected_response_list = expected_response.split("\r\n"); - for item in expected_response_list.into_iter() { - assert!(response.contains(item)); +mod tests { + use super::*; + + #[test] + fn test_http_response_maker() { + let mut response = HttpResponse::new(HttpStatus::IAmATeapot, "Hello, World!", None); + let response = String::from_utf8(response.to_bytes()).unwrap(); + let expected_response = format!( + "HTTP/1.1 418 I'm a teapot\r\nContent-Length: 13\r\nServer: HTeaPot/{}\r\n\r\nHello, World!\r\n", + VERSION + ); + let expected_response_list = expected_response.split("\r\n"); + for item in expected_response_list { + assert!(response.contains(item)); + } } -} -#[cfg(test)] -#[test] -fn test_keep_alive_connection() { - let mut response = HttpResponse::new( - HttpStatus::OK, - "Keep-Alive Test", - headers! { - "Connection" => "keep-alive", - "Content-Length" => "15" - }, - ); - - response.base().headers.insert( - "Keep-Alive".to_string(), - format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), - ); - - let response_bytes = response.to_bytes(); - let response_str = String::from_utf8(response_bytes.clone()).unwrap(); - - assert!(response_str.contains("HTTP/1.1 200 OK")); - assert!(response_str.contains("Content-Length: 15")); - assert!(response_str.contains("Connection: keep-alive")); - assert!(response_str.contains("Keep-Alive: timeout=10")); - assert!(response_str.contains("Server: HTeaPot/")); - assert!(response_str.contains("Keep-Alive Test")); - - let mut second_response = HttpResponse::new( - HttpStatus::OK, - "Second Request", - headers! { - "Connection" => "keep-alive", - "Content-Length" => "14" // Length for "Second Request" - }, - ); - - second_response.base().headers.insert( - "Keep-Alive".to_string(), - format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), - ); - - let second_response_bytes = second_response.to_bytes(); - let second_response_str = String::from_utf8(second_response_bytes.clone()).unwrap(); - - assert!(second_response_str.contains("HTTP/1.1 200 OK")); - assert!(second_response_str.contains("Content-Length: 14")); - assert!(second_response_str.contains("Connection: keep-alive")); - assert!(second_response_str.contains("Keep-Alive: timeout=10")); - assert!(second_response_str.contains("Server: HTeaPot/")); - assert!(second_response_str.contains("Second Request")); -} + #[test] + fn test_keep_alive_connection() { + let mut response = HttpResponse::new( + HttpStatus::OK, + "Keep-Alive Test", + headers! { + "Connection" => "keep-alive", + "Content-Length" => "15" + }, + ); + + response.base().headers.insert( + "Keep-Alive".to_string(), + format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), + ); + + let response_bytes = response.to_bytes(); + let response_str = String::from_utf8(response_bytes.clone()).unwrap(); + + assert!(response_str.contains("HTTP/1.1 200 OK")); + assert!(response_str.contains("Content-Length: 15")); + assert!(response_str.contains("Connection: keep-alive")); + assert!(response_str.contains("Keep-Alive: timeout=10")); + assert!(response_str.contains("Server: HTeaPot/")); + assert!(response_str.contains("Keep-Alive Test")); + + let mut second_response = HttpResponse::new( + HttpStatus::OK, + "Second Request", + headers! { + "Connection" => "keep-alive", + "Content-Length" => "14" // Length for "Second Request" + }, + ); + + second_response.base().headers.insert( + "Keep-Alive".to_string(), + format!("timeout={}", KEEP_ALIVE_TTL.as_secs()), + ); + + let second_response_bytes = second_response.to_bytes(); + let second_response_str = String::from_utf8(second_response_bytes.clone()).unwrap(); + + assert!(second_response_str.contains("HTTP/1.1 200 OK")); + assert!(second_response_str.contains("Content-Length: 14")); + assert!(response_str.contains("Connection: keep-alive")); + assert!(response_str.contains("Keep-Alive: timeout=10")); + assert!(response_str.contains("Server: HTeaPot/")); + assert!(second_response_str.contains("Second Request")); + } +} \ No newline at end of file diff --git a/src/hteapot/request.rs b/src/hteapot/request.rs index f00262f..4374ea8 100644 --- a/src/hteapot/request.rs +++ b/src/hteapot/request.rs @@ -1,6 +1,15 @@ +// Written by Alberto Ruiz 2025-01-01 +// This module handles the request struct and a builder for it +// This implementation has some issues and fixes are required for security +// Refactor is recomended, but for now can work with the fixes + use super::HttpMethod; -use std::{collections::HashMap, net::TcpStream, str}; +use std::{cmp::min, collections::HashMap, net::TcpStream, str}; + +const MAX_HEADER_SIZE: usize = 1024 * 16; +const MAX_HEADER_COUNT: usize = 100; +#[derive(Debug)] pub struct HttpRequest { pub method: HttpMethod, pub path: String, @@ -24,7 +33,7 @@ impl HttpRequest { pub fn default() -> Self { HttpRequest { - method: HttpMethod::GET, + method: HttpMethod::Other(String::new()), path: String::new(), args: HashMap::new(), headers: HashMap::new(), @@ -44,39 +53,6 @@ impl HttpRequest { }; } - // pub fn body(&mut self) -> Option> { - // if self.has_body() { - // let mut stream = self.stream.as_ref().unwrap(); - // let content_length = self.headers.get("Content-Length")?; - // let content_length: usize = content_length.parse().unwrap(); - // if content_length > self.body.len() { - // let _ = stream.flush(); - // let mut total_read = 0; - // self.body.resize(content_length, 0); - // while total_read < content_length { - // match stream.read(&mut self.body[total_read..]) { - // Ok(0) => { - // break; - // } - // Ok(n) => { - // total_read += n; - // } - // Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { - // continue; - // } - // Err(_e) => { - // break; - // } - // } - // } - // } - - // Some(self.body.clone()) - // } else { - // None - // } - // } - pub fn set_stream(&mut self, stream: TcpStream) { self.stream = Some(stream); } @@ -96,6 +72,8 @@ impl HttpRequest { pub struct HttpRequestBuilder { request: HttpRequest, buffer: Vec, + header_done: bool, + header_size: usize, body_size: usize, pub done: bool, } @@ -111,6 +89,8 @@ impl HttpRequestBuilder { body: Vec::new(), stream: None, }, + header_size: 0, + header_done: false, body_size: 0, buffer: Vec::new(), done: false, @@ -125,21 +105,63 @@ impl HttpRequestBuilder { } } - pub fn append(&mut self, buffer: Vec) -> Option { - self.buffer.extend(buffer); - self.buffer.retain(|&b| b != 0); + fn read_body_len(&mut self) -> Option<()> { + let body_left = self.body_size.saturating_sub(self.request.body.len()); + let to_take = min(body_left, self.buffer.len()); + let to_append = &self.buffer[..to_take]; + self.request.body.extend_from_slice(to_append); + self.buffer.drain(..to_take); + + if body_left > 0 { + return None; + } else { + self.done = true; + return Some(()); + } + } + + fn _read_body_chunk(&mut self) -> Option<()> { + //TODO: this will support chunked body in the future + todo!() + } + + fn read_body(&mut self) -> Option<()> { + return self.read_body_len(); + } + + pub fn append(&mut self, chunk: Vec) -> Result<(), &'static str> { + if !self.header_done && self.buffer.len() > MAX_HEADER_SIZE { + return Err("Entity Too large"); + } + let chunk_size = chunk.len(); + self.buffer.extend(chunk); + if self.header_done { + self.read_body(); + return Ok(()); + } else { + self.header_size += chunk_size; + if self.header_size > MAX_HEADER_SIZE { + return Err("Entity Too large"); + } + } while let Some(pos) = self.buffer.windows(2).position(|w| w == b"\r\n") { let line = self.buffer.drain(..pos).collect::>(); self.buffer.drain(..2); - let line_str = String::from_utf8_lossy(&line); + let line_str = match str::from_utf8(line.as_slice()) { + Ok(v) => v.to_string(), + Err(_e) => return Err("No utf-8"), + }; if self.request.path.is_empty() { let parts: Vec<&str> = line_str.split_whitespace().collect(); if parts.len() < 2 { - return None; + return Ok(()); } + if parts.len() != 3 { + return Err("Invalid method + path + version request"); + } self.request.method = HttpMethod::from_str(parts[0]); let path_parts: Vec<&str> = parts[1].split('?').collect(); self.request.path = path_parts[0].to_string(); @@ -158,21 +180,41 @@ impl HttpRequestBuilder { .collect(); } } else if !line_str.is_empty() { - if let Some((key, value)) = line_str.split_once(": ") { - if key.to_lowercase() == "content-length" { + if let Some((key, value)) = line_str.split_once(":") { + //Check the number of headers, if the actual headers exceed that number + //drop the connection + if self.request.headers.len() > MAX_HEADER_COUNT { + return Err("Header number exceed allowed"); + } + let key = key.trim().to_lowercase(); + let value = value.trim(); + if key == "content-length" { + if self.request.headers.get("content-length").is_some() + || self + .request + .headers + .get("transfer-encoding") + .map(|te| te == "chunked") + .unwrap_or(false) + { + return Err("Duplicated content-length"); + } self.body_size = value.parse().unwrap_or(0); } self.request .headers .insert(key.to_string(), value.to_string()); } + } else { + self.header_done = true; + self.read_body(); + return Ok(()); } } - self.request.body.append(&mut self.buffer.clone()); - if self.request.body.len() == self.body_size { - self.done = true; - return Some(self.request.clone()); - } - None + Ok(()) } } + +#[cfg(test)] +#[test] +fn basic_request() {} diff --git a/src/logger.rs b/src/logger.rs index f61b19f..26eab6c 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -103,6 +103,16 @@ impl SimpleTime { } } +impl Clone for Logger { + fn clone(&self) -> Self { + Logger { + core: Arc::clone(&self.core), + component: Arc::clone(&self.component), + min_level: self.min_level, + } + } +} + // Log message with metadata struct LogMessage { timestamp: String, diff --git a/src/main.rs b/src/main.rs index 7aa096a..fb7599e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ mod utils; use std::{fs, io, path::PathBuf}; use std::path::Path; use std::sync::Mutex; +mod shutdown; use cache::Cache; use config::Config; @@ -131,11 +132,12 @@ fn main() { }; let cache: Mutex = Mutex::new(Cache::new(config.cache_ttl as u64)); - let server = Hteapot::new_threaded(config.host.as_str(), config.port, config.threads); + let mut server = Hteapot::new_threaded(config.host.as_str(), config.port, config.threads); logger.info(format!( "Server started at http://{}:{}", config.host, config.port )); + let _running = shutdown::setup_graceful_shutdown(&mut server, logger.clone()); if config.cache { logger.info("Cache Enabled".to_string()); } diff --git a/src/shutdown.rs b/src/shutdown.rs new file mode 100644 index 0000000..c78230a --- /dev/null +++ b/src/shutdown.rs @@ -0,0 +1,175 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::thread; +use std::time::Duration; + +use crate::hteapot::Hteapot; +use crate::logger::Logger; + +pub fn setup_graceful_shutdown(server: &mut Hteapot, logger: Logger) -> Arc { + let running = Arc::new(AtomicBool::new(true)); + let r_server = running.clone(); + let shutdown_logger = logger.with_component("shutdown"); + + #[cfg(unix)] + { + let r_unix = running.clone(); + let unix_logger = shutdown_logger.clone(); + unix_signal::register_signal_handler(r_unix, unix_logger); + } + + #[cfg(windows)] + { + let r_win = running.clone(); + let win_logger = shutdown_logger.clone(); + + if !win_console::set_handler(r_win, win_logger.clone()) { + win_logger.error("Failed to set Windows Ctrl+C handler".to_string()); + } else { + win_logger.info("Windows Ctrl+C handler registered".to_string()); + } + } + + let r_enter = running.clone(); + let enter_logger = shutdown_logger.clone(); + + thread::spawn(move || { + println!(" Ctrl+C to shutdown the server gracefully..."); + let mut buffer = String::new(); + let _ = std::io::stdin().read_line(&mut buffer); + enter_logger.info("Enter pressed, initiating graceful shutdown...".to_string()); + r_enter.store(false, Ordering::SeqCst); + }); + + // Set up server with shutdown signal + server.set_shutdown_signal(r_server); + + // Add shutdown hook for cleanup + let shutdown_logger_clone = shutdown_logger.clone(); + server.add_shutdown_hook(move || { + shutdown_logger_clone.info("Waiting for ongoing requests to complete...".to_string()); + + thread::sleep(Duration::from_secs(3)); + + shutdown_logger_clone.info("Server shutdown complete.".to_string()); + + std::process::exit(0); + }); + + // Return the running flag so the main thread can also check it + running +} + +#[cfg(unix)] +pub mod unix_signal { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::ptr; + use std::mem; + + use libc::{c_int, sigaction, sighandler_t, sigset_t}; + use libc::{SA_RESTART, SIGINT}; + + use crate::logger::Logger; + + // Thread-safe flag to indicate signal received + static mut SIGNAL_RECEIVED: bool = false; + + // Signal handler function - minimal to avoid UB + extern "C" fn handle_signal(_: c_int) { + unsafe { + SIGNAL_RECEIVED = true; + } + } + + pub fn register_signal_handler(running: Arc, logger: Logger) { + unsafe { + // Set up the sigaction struct + let mut sigact: sigaction = mem::zeroed(); + sigact.sa_sigaction = handle_signal as sighandler_t; + sigact.sa_flags = SA_RESTART; + + // Empty the signal mask + sigemptyset(&mut sigact.sa_mask); + + // Register our signal handler for SIGINT + if sigaction(SIGINT, &sigact, ptr::null_mut()) < 0 { + logger.error("Failed to set SIGINT handler".to_string()); + return; + } else { + logger.info("SIGINT handler registered".to_string()); + } + } + + // Start a monitoring thread that periodically checks the signal flag + // and updates the running atomic + let monitor_logger = logger.clone(); + std::thread::spawn(move || { + while running.load(Ordering::SeqCst) { + unsafe { + if SIGNAL_RECEIVED { + monitor_logger.info("SIGINT received, initiating graceful shutdown...".to_string()); + running.store(false, Ordering::SeqCst); + break; + } + } + std::thread::sleep(std::time::Duration::from_millis(50)); + } + }); + } + + // Helper function to create an empty signal set + unsafe fn sigemptyset(set: *mut sigset_t) { + ptr::write_bytes(set, 0, 1); + } +} + +#[cfg(windows)] +pub mod win_console { + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + use std::sync::Mutex; + + // Define the external Windows API function in an unsafe extern block + unsafe extern "system" { + pub fn SetConsoleCtrlHandler( + handler: Option i32>, + add: i32, + ) -> i32; + } + + pub const CTRL_C_EVENT: u32 = 0; + + struct StaticData { + running: Option>, + logger: Option, + } + + // We need to ensure thread safety, so use a Mutex for it + static HANDLER_DATA: Mutex = Mutex::new(StaticData { + running: None, + logger: None, + }); + + pub fn set_handler(running: Arc, logger: crate::logger::Logger) -> bool { + // Store references in the mutex-protected static + let mut data = HANDLER_DATA.lock().unwrap(); + data.running = Some(running); + data.logger = Some(logger); + + unsafe extern "system" fn handler_func(ctrl_type: u32) -> i32 { + if ctrl_type == CTRL_C_EVENT { + if let Ok(data) = HANDLER_DATA.lock() { + if let (Some(r), Some(l)) = (&data.running, &data.logger) { + l.info("initiating graceful shutdown...".to_string()); + r.store(false, Ordering::SeqCst); + std::process::exit(0); + } + } + } + 0 + } + + unsafe { SetConsoleCtrlHandler(Some(handler_func), 1) != 0 } + } +} \ No newline at end of file