diff --git a/Cargo.lock b/Cargo.lock index 2c7d917..da585f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,7 +1,16 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "hteapot" version = "0.5.0" +dependencies = [ + "libc", +] + +[[package]] +name = "libc" +version = "0.2.172" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" diff --git a/Cargo.toml b/Cargo.toml index 9c789b1..7bedc6b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,11 @@ documentation = "https://docs.rs/hteapot/" homepage = "https://github.com/az107/HTeaPot" repository = "https://github.com/az107/HTeaPot" keywords = ["http", "server", "web", "lightweight", "rust"] -categories = ["network-programming", "web-programming", "command-line-utilities"] +categories = [ + "network-programming", + "web-programming", + "command-line-utilities", +] exclude = ["config.toml", "demo/", "README.md"] [lib] @@ -20,5 +24,9 @@ path = "src/hteapot/mod.rs" [[bin]] name = "hteapot" +[dependencies] +libc = "0.2.172" + + [package.metadata.docs.rs] -no-readme = true \ No newline at end of file +no-readme = true diff --git a/src/hteapot/mod.rs b/src/hteapot/mod.rs index c03acd3..5b1f478 100644 --- a/src/hteapot/mod.rs +++ b/src/hteapot/mod.rs @@ -1,5 +1,5 @@ // Written by Alberto Ruiz 2024-03-08 -// +// // This is the HTTP server module, it will handle the requests and responses // Also provides utilities to parse the requests and build the response @@ -17,11 +17,11 @@ //! ``` /// Submodules for HTTP functionality. -pub mod brew; // HTTP client implementation -mod methods; // HTTP method and status enums -mod request; // Request parsing and builder -mod response; // Response types and streaming -mod status; // Status code mapping +pub mod brew; // HTTP client implementation +mod methods; // HTTP method and status enums +mod request; // Request parsing and builder +mod response; // Response types and streaming +mod status; // Status code mapping // Internal types used for connection management use self::response::{EmptyHttpResponse, HttpResponseCommon, IterError}; @@ -37,6 +37,7 @@ pub use self::status::HttpStatus; use std::collections::VecDeque; use std::io::{self, Read, Write}; use std::net::{Shutdown, TcpListener, TcpStream}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Condvar, Mutex}; use std::thread; use std::time::{Duration, Instant}; @@ -76,6 +77,8 @@ pub struct Hteapot { port: u16, address: String, threads: u16, + shutdown_signal: Option>, + shutdown_hooks: Vec>, } /// Represents the state of a connection's lifecycle. @@ -95,12 +98,33 @@ struct SocketData { } impl Hteapot { + pub fn set_shutdown_signal(&mut self, signal: Arc) { + self.shutdown_signal = Some(signal); + } + + pub fn get_shutdown_signal(&self) -> Option> { + self.shutdown_signal.clone() + } + + pub fn add_shutdown_hook(&mut self, hook: F) + where + F: Fn() + Send + Sync + 'static, + { + self.shutdown_hooks.push(Arc::new(hook)); + } + + pub fn get_addr(&self) -> (String, u16) { + return (self.address.clone(), self.port); + } + // Constructor pub fn new(address: &str, port: u16) -> Self { Hteapot { port, address: address.to_string(), threads: 1, + shutdown_signal: None, + shutdown_hooks: Vec::new(), } } @@ -109,6 +133,8 @@ impl Hteapot { port, address: address.to_string(), threads: if threads == 0 { 1 } else { threads }, + shutdown_signal: None, + shutdown_hooks: Vec::new(), } } @@ -132,10 +158,15 @@ impl Hteapot { Arc::new(Mutex::new(vec![0; self.threads as usize])); let arc_action = Arc::new(action); + // Clone shutdown_signal and share the shutdown_hooks via Arc + let shutdown_signal = self.shutdown_signal.clone(); + let shutdown_hooks = Arc::new(self.shutdown_hooks.clone()); + for thread_index in 0..self.threads { let pool_clone = pool.clone(); let action_clone = arc_action.clone(); let priority_list_clone = priority_list.clone(); + let shutdown_signal_clone = shutdown_signal.clone(); thread::spawn(move || { let mut streams_to_handle = Vec::new(); @@ -143,12 +174,18 @@ impl Hteapot { { let (lock, cvar) = &*pool_clone; let mut pool = lock.lock().expect("Error locking pool"); - if streams_to_handle.is_empty() { // Store the returned guard back into pool - pool = cvar.wait_while(pool, |pool| pool.is_empty()) + pool = cvar + .wait_while(pool, |pool| pool.is_empty()) .expect("Error waiting on cvar"); } + //TODO: move this to allow process the last request + if let Some(signal) = &shutdown_signal_clone { + if !signal.load(Ordering::SeqCst) { + break; // Exit the server loop + } + } while let Some(stream) = pool.pop_back() { let socket_status = SocketStatus { @@ -185,6 +222,17 @@ impl Hteapot { } loop { + if let Some(signal) = &shutdown_signal { + if !signal.load(Ordering::SeqCst) { + let (lock, cvar) = &*pool; + let _guard = lock.lock().unwrap(); + cvar.notify_all(); + for hook in shutdown_hooks.iter() { + hook(); + } + break; + } + } let stream = match listener.accept() { Ok((stream, _)) => stream, Err(_) => continue, @@ -216,17 +264,14 @@ impl Hteapot { ) -> Option<()> { let status = socket_data.status.as_mut()?; - // Fix by miky-rola 2025-04-08 // Check if the TTL (time-to-live) for the connection has expired. - // If the connection is idle for longer than `KEEP_ALIVE_TTL` and no data is being written, - // the connection is gracefully shut down to free resources. if Instant::now().duration_since(status.ttl) > KEEP_ALIVE_TTL && !status.write { let _ = socket_data.stream.shutdown(Shutdown::Both); return None; } + // If the request is not yet complete, read data from the stream into a buffer. // This ensures that the server can handle partial or chunked requests. - if !status.request.done { let mut buffer = [0; BUFFER_SIZE]; match socket_data.stream.read(&mut buffer) { @@ -286,8 +331,7 @@ impl Hteapot { status.response = response; } - // Write the response to the client in chunks using the `peek` and `next` methods. - // This ensures that large responses are sent incrementally without blocking the server. + // Write the response to the client in chunks loop { match status.response.peek() { Ok(n) => match socket_data.stream.write(&n) { @@ -390,4 +434,4 @@ mod tests { assert!(response_str.contains("Server: HTeaPot/")); assert!(second_response_str.contains("Second Request")); } -} \ No newline at end of file +} diff --git a/src/hteapot/request.rs b/src/hteapot/request.rs index 4b8f939..fafadb3 100644 --- a/src/hteapot/request.rs +++ b/src/hteapot/request.rs @@ -190,7 +190,6 @@ impl HttpRequestBuilder { if parts.len() != 3 { return Err("Invalid method + path + version request"); } - self.request.method = HttpMethod::from_str(parts[0]); let path_parts: Vec<&str> = parts[1].split('?').collect(); self.request.path = path_parts[0].to_string(); diff --git a/src/logger.rs b/src/logger.rs index c7799ef..c18dfdb 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -1,9 +1,9 @@ +use std::fmt; use std::io::Write; -use std::sync::mpsc::{channel, Sender}; +use std::sync::Arc; +use std::sync::mpsc::{Sender, channel}; use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; -use std::fmt; -use std::sync::Arc; /// Differnt log levels #[derive(Debug, Clone, PartialEq, Eq, Ord, PartialOrd, Copy)] @@ -86,7 +86,15 @@ impl SimpleTime { // calculate millisecs from nanosecs let millis = nanos / 1_000_000; - (year, month as u32 + 1, day as u32, hour, minute, second, millis) + ( + year, + month as u32 + 1, + day as u32, + hour, + minute, + second, + millis, + ) } /// Returns a formatted timestamp string for the current system time. @@ -133,7 +141,7 @@ impl Logger { pub fn new( mut writer: W, min_level: LogLevel, - component: &str + component: &str, ) -> Logger { let (tx, rx) = channel::(); let thread = thread::spawn(move || { @@ -151,7 +159,7 @@ impl Logger { msg.timestamp, msg.level, msg.component, msg.content ); buff.push(formatted); - }, + } Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {} Err(_) => break, } @@ -238,8 +246,8 @@ impl Logger { pub fn fatal(&self, content: String) { self.log(LogLevel::FATAL, content); } - /// Log a message with TRACE level - #[allow(dead_code)] + /// Log a message with TRACE level + #[allow(dead_code)] pub fn trace(&self, content: String) { self.log(LogLevel::TRACE, content); } @@ -255,20 +263,19 @@ impl Clone for Logger { } } - #[cfg(test)] mod tests { use super::*; use std::io::stdout; - + #[test] fn test_basic() { let logs = Logger::new(stdout(), LogLevel::DEBUG, "test"); logs.info("test message".to_string()); logs.debug("debug info".to_string()); - + // Create a sub-logger with a different component let sub_logger = logs.with_component("sub-component"); sub_logger.warn("warning from sub-component".to_string()); } -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 4b41751..f328571 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,18 +38,19 @@ mod cache; mod config; pub mod hteapot; mod logger; +mod shutdown; mod utils; -use std::{fs, io, path::PathBuf}; use std::path::Path; use std::sync::Mutex; +use std::{fs, io, path::PathBuf}; use cache::Cache; use config::Config; use hteapot::{Hteapot, HttpRequest, HttpResponse, HttpStatus}; use utils::get_mime_tipe; -use logger::{Logger, LogLevel}; +use logger::{LogLevel, Logger}; use std::time::Instant; const VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -79,13 +80,13 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); fn safe_join_paths(root: &str, requested_path: &str) -> Option { let root_path = Path::new(root).canonicalize().ok()?; let requested_full_path = root_path.join(requested_path.trim_start_matches("/")); - + if !requested_full_path.exists() { return None; } - + let canonical_path = requested_full_path.canonicalize().ok()?; - + if canonical_path.starts_with(&root_path) { Some(canonical_path) } else { @@ -143,14 +144,13 @@ fn serve_file(path: &PathBuf) -> Option> { let r = fs::read(path); if r.is_ok() { Some(r.unwrap()) } else { None } } -// +// // Suggest to use .ok()? instead of manual unwrap/if is_ok for more idiomatic error handling: // fn serve_file(path: &PathBuf) -> Option> { - // fs::read(path).ok() +// fs::read(path).ok() // } -// -// - +// +// /// Main entry point of the Hteapot server. /// @@ -177,7 +177,7 @@ fn main() { println!("usage: {} ", args[0]); return; } - + // Initialize logger based on config or default to stdout let config = match args[1].as_str() { "--help" | "-h" => { @@ -223,36 +223,40 @@ fn main() { // Initialize the logger based on the config or default to stdout if the log file can't be created let logger = match config.log_file.clone() { Some(file_name) => { - let file = fs::File::create(file_name.clone()); // Attempt to create the log file - match file { // If creating the file fails, log to stdout instead - Ok(file) => Logger::new(file, LogLevel::INFO, "main"), // If successful, use the file + let file = fs::File::create(file_name.clone()); // Attempt to create the log file + match file { + // If creating the file fails, log to stdout instead + Ok(file) => Logger::new(file, LogLevel::INFO, "main"), // If successful, use the file Err(e) => { println!("Failed to create log file: {:?}. Using stdout instead.", e); - Logger::new(io::stdout(), LogLevel::INFO, "main") // Log to stdout + Logger::new(io::stdout(), LogLevel::INFO, "main") // Log to stdout } } } - None => Logger::new(io::stdout(), LogLevel::INFO, "main"), // If no log file is specified, use stdout + None => Logger::new(io::stdout(), LogLevel::INFO, "main"), // If no log file is specified, use stdout }; // Set up the cache with thread-safe locking // The Mutex ensures that only one thread can access the cache at a time, // preventing race conditions when reading and writing to the cache. - let cache: Mutex = Mutex::new(Cache::new(config.cache_ttl as u64)); // Initialize the cache with TTL + let cache: Mutex = Mutex::new(Cache::new(config.cache_ttl as u64)); // Initialize the cache with TTL // Create a new threaded HTTP server with the provided host, port, and number of threads - let server = Hteapot::new_threaded(config.host.as_str(), config.port, config.threads); + let mut server = Hteapot::new_threaded(config.host.as_str(), config.port, config.threads); + + //Configure graceful shutdown from ctrl+c + shutdown::setup_graceful_shutdown(&mut server, logger.clone()); logger.info(format!( "Server started at http://{}:{}", config.host, config.port - )); // Log that the server has started + )); // Log that the server has started // Log whether the cache is enabled based on the config setting if config.cache { logger.info("Cache Enabled".to_string()); } - + // If proxy-only mode is enabled, issue a warning that local paths won't be used if proxy_only { logger @@ -268,9 +272,9 @@ fn main() { // Start listening for HTTP requests server.listen(move |req| { // SERVER CORE: For each incoming request, we handle it in this closure - let start_time = Instant::now(); // Track request processing time - let req_method = req.method.to_str(); // Get the HTTP method (e.g., GET, POST) - let req_path = req.path.clone(); // Get the requested path + let start_time = Instant::now(); // Track request processing time + let req_method = req.method.to_str(); // Get the HTTP method (e.g., GET, POST) + let req_path = req.path.clone(); // Get the requested path // Log the incoming request method and path http_logger.info(format!("Request {} {}", req_method, req.path)); @@ -279,22 +283,21 @@ fn main() { let is_proxy = is_proxy(&config, req.clone()); if proxy_only || is_proxy.is_some() { // If proxying is enabled or this request matches a proxy rule, handle it - let (host, proxy_req) = is_proxy.unwrap(); // Get the target host and modified request + let (host, proxy_req) = is_proxy.unwrap(); // Get the target host and modified request proxy_logger.info(format!( "Proxying request {} {} to {}", req_method, req_path, host )); - // Perform the proxy request (forward the request to the target server) let res = proxy_req.brew(host.as_str()); - let elapsed = start_time.elapsed(); // Measure the time taken to process the proxy request + let elapsed = start_time.elapsed(); // Measure the time taken to process the proxy request if res.is_ok() { // If the proxy request is successful, log the time taken and return the response let response = res.unwrap(); proxy_logger.info(format!( "Proxy request processed in {:.6}ms", - elapsed.as_secs_f64() * 1000.0 // Log the time taken in milliseconds + elapsed.as_secs_f64() * 1000.0 // Log the time taken in milliseconds )); return response; } else { @@ -316,12 +319,12 @@ fn main() { // If the root path exists and is valid, try to join the index file let index_path = root_path.unwrap().join(&config.index); if index_path.exists() { - Some(index_path) // If index exists, return its path + Some(index_path) // If index exists, return its path } else { - None // If no index exists, return None + None // If no index exists, return None } } else { - None // If the root path is invalid, return None + None // If the root path is invalid, return None } } else { // For any other path, resolve it safely using the `safe_join_paths` function @@ -335,16 +338,17 @@ fn main() { // If it's a directory, check for the index file in that directory let index_path = path.join(&config.index); if index_path.exists() { - index_path // If index exists, return its path + index_path // If index exists, return its path } else { // If no index file exists, log a warning and return a 404 response - http_logger.warn(format!("Index file not found in directory: {}", req.path)); + http_logger + .warn(format!("Index file not found in directory: {}", req.path)); return HttpResponse::new(HttpStatus::NotFound, "Index not found", None); } } else { - path // If it's not a directory, just return the path + path // If it's not a directory, just return the path } - }, + } None => { // If the path is invalid or access is denied, return a 404 response http_logger.warn(format!("Path not found or access denied: {}", req.path)); @@ -359,9 +363,9 @@ fn main() { let content: Option> = if config.cache { // Lock the cache to ensure thread-safe access let mut cachee = cache.lock().expect("Error locking cache"); - let cache_start = Instant::now(); // Track cache operation time - let cache_key = req.path.clone(); // Use the request path as the cache key - let mut r = cachee.get(cache_key.clone()); // Try to get the content from cache + let cache_start = Instant::now(); // Track cache operation time + let cache_key = req.path.clone(); // Use the request path as the cache key + let mut r = cachee.get(cache_key.clone()); // Try to get the content from cache if r.is_none() { // If cache miss, read the file from disk and store it in cache cache_logger.debug(format!("cache miss for {}", cache_key)); @@ -379,10 +383,10 @@ fn main() { // Log how long the cache operation took let cache_elapsed = cache_start.elapsed(); cache_logger.debug(format!( - "Cache operation completed in {:.6}µs", + "Cache operation completed in {:.6}µs", cache_elapsed.as_micros() )); - r // Return the cached content (or None if not found) + r // Return the cached content (or None if not found) } else { // If cache is disabled, read the file from disk serve_file(&safe_path) @@ -392,7 +396,7 @@ fn main() { let elapsed = start_time.elapsed(); http_logger.info(format!( "Request processed in {:.6}ms", - elapsed.as_secs_f64() * 1000.0 // Log the time taken in milliseconds + elapsed.as_secs_f64() * 1000.0 // Log the time taken in milliseconds )); // If content was found, return it with the appropriate headers, otherwise return a 404 @@ -404,11 +408,11 @@ fn main() { "X-Content-Type-Options" => "nosniff" ); HttpResponse::new(HttpStatus::OK, c, headers) - }, + } None => { // If no content is found, return a 404 Not Found response HttpResponse::new(HttpStatus::NotFound, "Not found", None) - }, + } } }); -} \ No newline at end of file +} diff --git a/src/shutdown.rs b/src/shutdown.rs new file mode 100644 index 0000000..49f0756 --- /dev/null +++ b/src/shutdown.rs @@ -0,0 +1,185 @@ +use std::net::TcpStream; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread; +use std::time::Duration; + +use crate::hteapot::Hteapot; +use crate::logger::Logger; + +pub fn setup_graceful_shutdown(server: &mut Hteapot, logger: Logger) -> Arc { + let existing_signal = server.get_shutdown_signal(); + if existing_signal.is_some() { + return existing_signal.unwrap(); + } + let running = Arc::new(AtomicBool::new(true)); + let shutdown_logger = logger.with_component("shutdown"); + + #[cfg(unix)] + { + let mut ush = unix_signhandler::UnixSignHandler::new(); + let running_clone = running.clone(); + let addr = server.get_addr(); + ush.set_handler(move || { + println!("It works!"); + running_clone.store(false, Ordering::SeqCst); + let _ = TcpStream::connect(format!("{}:{}", addr.0, addr.1)); + }); + } + + #[cfg(windows)] + { + let r_win = running.clone(); + let win_logger = shutdown_logger.clone(); + + if !win_console::set_handler(r_win, win_logger.clone()) { + win_logger.error("Failed to set Windows Ctrl+C handler".to_string()); + } else { + win_logger.info("Windows Ctrl+C handler registered".to_string()); + } + } + + // Add shutdown hook for cleanup + let shutdown_logger_clone = shutdown_logger.clone(); + server.add_shutdown_hook(move || { + shutdown_logger_clone.info("Waiting for ongoing requests to complete...".to_string()); + thread::sleep(Duration::from_secs(3)); + shutdown_logger_clone.info("Exiting".to_string()); + }); + + server.set_shutdown_signal(running.clone()); + // Return the running flag so the main thread can also check it + running +} + +#[cfg(windows)] +pub mod win_console { + use std::sync::Arc; + use std::sync::Mutex; + use std::sync::atomic::{AtomicBool, Ordering}; + + // Define the external Windows API function in an unsafe extern block + unsafe extern "system" { + pub fn SetConsoleCtrlHandler( + handler: Option i32>, + add: i32, + ) -> i32; + } + + pub const CTRL_C_EVENT: u32 = 0; + + struct StaticData { + running: Option>, + logger: Option, + } + + // We need to ensure thread safety, so use a Mutex for it + static HANDLER_DATA: Mutex = Mutex::new(StaticData { + running: None, + logger: None, + }); + + pub fn set_handler(running: Arc, logger: crate::logger::Logger) -> bool { + // Store references in the mutex-protected static + let mut data = HANDLER_DATA.lock().unwrap(); + data.running = Some(running); + data.logger = Some(logger); + + unsafe extern "system" fn handler_func(ctrl_type: u32) -> i32 { + if ctrl_type == CTRL_C_EVENT { + if let Ok(data) = HANDLER_DATA.lock() { + if let (Some(r), Some(l)) = (&data.running, &data.logger) { + l.info("initiating graceful shutdown...".to_string()); + r.store(false, Ordering::SeqCst); + std::process::exit(0); + } + } + } + 0 + } + + unsafe { SetConsoleCtrlHandler(Some(handler_func), 1) != 0 } + } +} + +//Brought to you by: Overengeneering DIY™️ +#[cfg(unix)] +mod unix_signhandler { + use libc::{POLLIN, SA_RESTART, poll, pollfd, sigaction, sighandler_t}; + use std::io; + use std::sync::{Arc, RwLock}; + use std::{mem::zeroed, os::fd::RawFd, thread}; + + static mut PIPE_FD_READ: RawFd = -1; + static mut PIPE_FD_WRITE: RawFd = -1; + extern "C" fn handler(_: i32) { + let buf = [1u8]; + unsafe { + if PIPE_FD_WRITE != -1 { + let _ = libc::write(PIPE_FD_WRITE, buf.as_ptr() as *const _, 1); + } + } + } + + fn wait_for_readable(fd: RawFd) -> io::Result<()> { + let mut fds = [pollfd { + fd, + events: POLLIN, + revents: 0, + }]; + let ret = unsafe { poll(fds.as_mut_ptr(), 1, -1) }; // -1 = undefined timeout + + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + if fds[0].revents & POLLIN != 0 { + Ok(()) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + "unexpected poll result", + )) + } + } + + pub struct UnixSignHandler { + actions: Arc>>>, + } + + impl UnixSignHandler { + pub fn new() -> Self { + let ush = UnixSignHandler { + actions: Arc::new(RwLock::new(Vec::new())), + }; + unsafe { + let mut fds = [0; 2]; + if libc::pipe(fds.as_mut_ptr()) == -1 { + panic!("failed to create pipe"); + } + PIPE_FD_READ = fds[0]; + PIPE_FD_WRITE = fds[1]; + } + unsafe { + let mut action: sigaction = zeroed(); + action.sa_flags = SA_RESTART; + action.sa_sigaction = handler as sighandler_t; + sigaction(libc::SIGINT, &action, std::ptr::null_mut()); + } + + let actions_clone = ush.actions.clone(); + thread::spawn(move || { + unsafe { + let _ = wait_for_readable(PIPE_FD_READ); + } + for action in actions_clone.read().unwrap().iter() { + action(); + } + }); + return ush; + } + pub fn set_handler(&mut self, action: impl Fn() + Send + Sync + 'static) { + self.actions.write().unwrap().push(Box::new(action)); + } + } +}