diff --git a/.gitignore b/.gitignore index bd488e4..2f19ac0 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,6 @@ Thumbs.db # Docker docker-compose.override.yml .pip-cache/ + +# Build artifacts +build/ diff --git a/_codeql_detected_source_root b/_codeql_detected_source_root new file mode 120000 index 0000000..945c9b4 --- /dev/null +++ b/_codeql_detected_source_root @@ -0,0 +1 @@ +. \ No newline at end of file diff --git a/src/config.h b/src/config.h index e431b4a..76724bd 100644 --- a/src/config.h +++ b/src/config.h @@ -2,6 +2,7 @@ #include #include // for getenv() +#include // for std::cerr // Global configuration const int PORT = 8080; @@ -59,10 +60,18 @@ inline std::string get_pg_db() { inline std::string get_pg_user() { const char* u = getenv("POSTGRES_USER"); - return u ? u : "jic"; + if (!u) { + std::cerr << "WARNING: Using default database user. Set POSTGRES_USER environment variable for production." << std::endl; + return "jic"; + } + return u; } inline std::string get_pg_password() { const char* p = getenv("POSTGRES_PASSWORD"); - return p ? p : "jic_password"; + if (!p) { + std::cerr << "WARNING: Using default database password. Set POSTGRES_PASSWORD environment variable for production." << std::endl; + return "jic_password"; + } + return p; } diff --git a/src/server.cpp b/src/server.cpp index 8d773e2..c33fb1f 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include "llama.h" @@ -32,13 +33,48 @@ struct Document { std::string text; }; -// HTTP response builder +// Helper function to get reason phrase for status code +std::string get_reason_phrase(int status_code) { + switch (status_code) { + case 200: return "OK"; + case 201: return "Created"; + case 202: return "Accepted"; + case 204: return "No Content"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 304: return "Not Modified"; + case 400: return "Bad Request"; + case 401: return "Unauthorized"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 405: return "Method Not Allowed"; + case 413: return "Request Entity Too Large"; + case 415: return "Unsupported Media Type"; + case 429: return "Too Many Requests"; + case 500: return "Internal Server Error"; + case 501: return "Not Implemented"; + case 502: return "Bad Gateway"; + case 503: return "Service Unavailable"; + default: return "Unknown"; + } +} + +// HTTP response builder with security headers std::string build_http_response(int status_code, const std::string& content_type, const std::string& body) { std::ostringstream response; - response << "HTTP/1.1 " << status_code << " OK\r\n"; + response << "HTTP/1.1 " << status_code << " " << get_reason_phrase(status_code) << "\r\n"; response << "Content-Type: " << content_type << "\r\n"; response << "Content-Length: " << body.length() << "\r\n"; - response << "Access-Control-Allow-Origin: *\r\n"; + + // Security headers + response << "X-Frame-Options: DENY\r\n"; + response << "X-Content-Type-Options: nosniff\r\n"; + response << "X-XSS-Protection: 1; mode=block\r\n"; + response << "Content-Security-Policy: default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'\r\n"; + response << "Referrer-Policy: strict-origin-when-cross-origin\r\n"; + + // CORS headers (restrictive configuration) + response << "Access-Control-Allow-Origin: https://example.com\r\n"; response << "Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n"; response << "Access-Control-Allow-Headers: Content-Type\r\n"; response << "\r\n"; @@ -76,11 +112,24 @@ HttpRequest parse_http_request(const std::string& request) { } } - // Parse body based on Content-Length - if (req.headers.find("Content-Length") != req.headers.end()) { - int content_length = std::stoi(req.headers["Content-Length"]); - req.body.resize(content_length); - stream.read(&req.body[0], content_length); + // Parse body based on Content-Length with validation + const size_t MAX_REQUEST_SIZE = 30 * 1024 * 1024; // 30MB + try { + if (req.headers.find("Content-Length") != req.headers.end()) { + long content_length = std::stol(req.headers["Content-Length"]); + + // Validate Content-Length + if (content_length < 0 || static_cast(content_length) > MAX_REQUEST_SIZE) { + std::cerr << "Invalid Content-Length: " << content_length << std::endl; + return req; // Return request with empty body + } + + req.body.resize(static_cast(content_length)); + stream.read(&req.body[0], content_length); + } + } catch (const std::exception& e) { + std::cerr << "Error parsing Content-Length: " << e.what() << std::endl; + // Return request with empty body } return req; @@ -100,6 +149,40 @@ struct ConversationHistory { std::map conversations; std::mutex conversations_mutex; +// Connection tracking for rate limiting +struct ConnectionTracker { + std::map> connections; + std::mutex tracker_mutex; + const int MAX_CONNECTIONS_PER_IP = 100; + const int RATE_LIMIT_WINDOW_SEC = 60; + const int MAX_REQUESTS_PER_WINDOW = 60; + + bool should_allow_connection(const std::string& ip) { + std::lock_guard lock(tracker_mutex); + auto now = std::chrono::system_clock::now(); + auto& timestamps = connections[ip]; + + // Remove old timestamps outside the window + timestamps.erase( + std::remove_if(timestamps.begin(), timestamps.end(), + [&](const auto& ts) { + return std::chrono::duration_cast(now - ts).count() > RATE_LIMIT_WINDOW_SEC; + }), + timestamps.end() + ); + + // Check rate limit + if (timestamps.size() >= static_cast(MAX_REQUESTS_PER_WINDOW)) { + return false; + } + + // Record this connection + timestamps.push_back(now); + return true; + } +}; +ConnectionTracker connection_tracker; + // Load index from disk void load_index() { if (vector_index) delete vector_index; @@ -122,46 +205,67 @@ void load_index() { } } -// Handle static file serving +// Handle static file serving with path traversal protection std::string serve_static_file(const std::string& path) { - std::string file_path = path; - if (file_path == "/") { - file_path = "/index.html"; - } - - // Remove leading slash - if (!file_path.empty() && file_path[0] == '/') { - file_path = file_path.substr(1); - } - - // Prepend public/ directory - file_path = "public/" + file_path; - - // Security check - prevent directory traversal - if (file_path.find("..") != std::string::npos) { - return build_http_response(403, "text/plain", "Forbidden"); - } - - // Check if file exists - if (!fs::exists(file_path)) { + try { + std::string file_path = path; + if (file_path == "/") { + file_path = "/index.html"; + } + + // Remove leading slash + if (!file_path.empty() && file_path[0] == '/') { + file_path = file_path.substr(1); + } + + // Construct full path and validate it's within public/ + fs::path requested_path = fs::path("public") / file_path; + fs::path base_path = fs::canonical("public"); + + // Check if requested file exists before canonicalizing + if (!fs::exists(requested_path)) { + return build_http_response(404, "text/plain", "Not Found"); + } + + // Resolve to canonical path and validate it's within public/ + fs::path canonical_path = fs::canonical(requested_path); + + // Check if canonical path starts with base_path (prevent path traversal) + auto [base_end, path_end] = std::mismatch( + base_path.begin(), base_path.end(), + canonical_path.begin(), canonical_path.end() + ); + + if (base_end != base_path.end()) { + return build_http_response(403, "text/plain", "Forbidden"); + } + + // Determine content type + std::string content_type = "text/plain"; + std::string path_str = canonical_path.string(); + if (string_ends_with(path_str, ".html")) content_type = "text/html"; + else if (string_ends_with(path_str, ".css")) content_type = "text/css"; + else if (string_ends_with(path_str, ".js")) content_type = "application/javascript"; + else if (string_ends_with(path_str, ".json")) content_type = "application/json"; + else if (string_ends_with(path_str, ".pdf")) content_type = "application/pdf"; + + // Read file + std::ifstream file(canonical_path, std::ios::binary); + if (!file.is_open()) { + return build_http_response(403, "text/plain", "Forbidden"); + } + std::stringstream buffer; + buffer << file.rdbuf(); + std::string content = buffer.str(); + + return build_http_response(200, content_type, content); + } catch (const fs::filesystem_error& e) { + std::cerr << "Filesystem error in serve_static_file: " << e.what() << std::endl; return build_http_response(404, "text/plain", "Not Found"); + } catch (const std::exception& e) { + std::cerr << "Error in serve_static_file: " << e.what() << std::endl; + return build_http_response(500, "text/plain", "Internal Server Error"); } - - // Determine content type - std::string content_type = "text/plain"; - if (string_ends_with(file_path, ".html")) content_type = "text/html"; - else if (string_ends_with(file_path, ".css")) content_type = "text/css"; - else if (string_ends_with(file_path, ".js")) content_type = "application/javascript"; - else if (string_ends_with(file_path, ".json")) content_type = "application/json"; - else if (string_ends_with(file_path, ".pdf")) content_type = "application/pdf"; - - // Read file - std::ifstream file(file_path, std::ios::binary); - std::stringstream buffer; - buffer << file.rdbuf(); - std::string content = buffer.str(); - - return build_http_response(200, content_type, content); } // Handle query endpoint @@ -385,8 +489,19 @@ std::string handle_status() { return build_http_response(200, "application/json", status.dump()); } -// Handle client connection +// Handle client connection with security protections void handle_client(int client_socket) { + const size_t MAX_REQUEST_SIZE = 30 * 1024 * 1024; // 30MB + const int READ_TIMEOUT_SEC = 30; + + // Set socket timeout to prevent slowloris attacks + struct timeval tv; + tv.tv_sec = READ_TIMEOUT_SEC; + tv.tv_usec = 0; + if (setsockopt(client_socket, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0) { + std::cerr << "Failed to set socket timeout" << std::endl; + } + std::vector buffer(65536); int total_read = 0; int valread; @@ -396,6 +511,15 @@ void handle_client(int client_socket) { total_read += valread; buffer[total_read] = '\0'; + // Check size limits to prevent memory exhaustion + if (static_cast(total_read) >= MAX_REQUEST_SIZE) { + std::cerr << "Request too large" << std::endl; + std::string error_response = build_http_response(413, "text/plain", "Request Entity Too Large"); + send(client_socket, error_response.c_str(), error_response.length(), 0); + close(client_socket); + return; + } + // Check if we have complete headers std::string current(buffer.data(), total_read); size_t header_end = current.find("\r\n\r\n"); @@ -405,25 +529,56 @@ void handle_client(int client_socket) { // If POST request with body, ensure we read it all if (temp_req.method == "POST" && temp_req.headers.find("Content-Length") != temp_req.headers.end()) { - int content_length = std::stoi(temp_req.headers["Content-Length"]); - int body_start = header_end + 4; - int body_read = total_read - body_start; - - // Read remaining body if needed - while (body_read < content_length) { - valread = read(client_socket, buffer.data() + total_read, buffer.size() - total_read - 1); - if (valread <= 0) break; - total_read += valread; - body_read += valread; - buffer[total_read] = '\0'; + try { + long content_length = std::stol(temp_req.headers["Content-Length"]); + + // Validate content length + if (content_length < 0 || static_cast(content_length) > MAX_REQUEST_SIZE) { + std::cerr << "Invalid Content-Length" << std::endl; + std::string error_response = build_http_response(400, "text/plain", "Bad Request"); + send(client_socket, error_response.c_str(), error_response.length(), 0); + close(client_socket); + return; + } + + int body_start = header_end + 4; + int body_read = total_read - body_start; + + // Read remaining body if needed + while (body_read < content_length) { + // Check total size limit + if (static_cast(total_read) >= MAX_REQUEST_SIZE) { + std::cerr << "Request too large" << std::endl; + std::string error_response = build_http_response(413, "text/plain", "Request Entity Too Large"); + send(client_socket, error_response.c_str(), error_response.length(), 0); + close(client_socket); + return; + } + + valread = read(client_socket, buffer.data() + total_read, buffer.size() - total_read - 1); + if (valread <= 0) break; + total_read += valread; + body_read += valread; + buffer[total_read] = '\0'; + } + } catch (const std::exception& e) { + std::cerr << "Error parsing Content-Length: " << e.what() << std::endl; + std::string error_response = build_http_response(400, "text/plain", "Bad Request"); + send(client_socket, error_response.c_str(), error_response.length(), 0); + close(client_socket); + return; } } break; } - // Resize buffer if needed - if (total_read >= buffer.size() - 1024) { - buffer.resize(buffer.size() * 2); + // Resize buffer if needed (with safety check) + if (total_read >= static_cast(buffer.size()) - 1024) { + size_t new_size = buffer.size() * 2; + if (new_size > MAX_REQUEST_SIZE) { + new_size = MAX_REQUEST_SIZE; + } + buffer.resize(new_size); } } @@ -506,8 +661,8 @@ int main() { return 1; } - // Listen for connections - if (listen(server_fd, 3) < 0) { + // Listen for connections with increased backlog + if (listen(server_fd, 128) < 0) { std::cerr << "Listen failed" << std::endl; return 1; } @@ -516,14 +671,29 @@ int main() { // Accept connections while (true) { - int addrlen = sizeof(address); - int client_socket = accept(server_fd, (struct sockaddr *)&address, (socklen_t*)&addrlen); + struct sockaddr_in client_addr; + int addrlen = sizeof(client_addr); + int client_socket = accept(server_fd, (struct sockaddr *)&client_addr, (socklen_t*)&addrlen); if (client_socket < 0) { std::cerr << "Accept failed" << std::endl; continue; } + // Get client IP for rate limiting + char client_ip[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &(client_addr.sin_addr), client_ip, INET_ADDRSTRLEN); + std::string ip_str(client_ip); + + // Check rate limit + if (!connection_tracker.should_allow_connection(ip_str)) { + std::cerr << "Rate limit exceeded for IP: " << ip_str << std::endl; + std::string error_response = build_http_response(429, "text/plain", "Too Many Requests"); + send(client_socket, error_response.c_str(), error_response.length(), 0); + close(client_socket); + continue; + } + // Handle client in a new thread try { std::thread client_thread(handle_client, client_socket);